首页 文章

在张量内矢量化矩阵乘法

提问于
浏览
1

我在渲染部分代码时遇到了一些麻烦 . 我有一个(n,n,m)张量,我想将每个切片乘以m乘以秒(n乘n)矩阵(非元素) .

这是for循环的样子:

Tensor=zeros(2,2,3);
Matrix = [1,2; 3,4];

for j=1:n
    Matrices_Multiplied = Tensor(:,:,j)*Matrix;
    Recursive_Matrix=Recursive_Matrix + Tensor(:,:,j)/trace(Matrices_Multiplied);
end

如何以矢量化方式对张量内的各个矩阵执行矩阵乘法?是否有像tensor-dot这样的内置函数可以处理这个还是更聪明?

1 回答

  • 1

    Bsxfunning并使用efficient matrix-multiplication,我们可以 -

    % Calculate trace values using matrix-multiplication
    T = reshape(Matrix.',1,[])*reshape(Tensor,[],size(Tensor,3));
    
    % Use broadcasting to perform elementwise division across all slices
    out = sum(bsxfun(@rdivide,Tensor,reshape(T,1,1,[])),3);
    

    同样,可以用一个矩阵乘法替换最后一步,以进一步提高性能 . 因此,所有矩阵乘法专用解决方案将是 -

    [m,n,r] = size(Tensor);
    out = reshape(reshape(Tensor,[],size(Tensor,3))*(1./T.'),m,n)
    

    Runtime test

    基准代码 -

    % Input arrays
    n = 100; m = 100;
    Tensor=rand(n,n,m);
    Matrix=rand(n,n);
    num_iter = 100; % Number of iterations to be run for
    
    tic
    disp('------------ Loopy woopy doops : ')
    for iter = 1:num_iter
        Recursive_Matrix = zeros(n,n);
        for j=1:n
            Matrices_Multiplied = Tensor(:,:,j)*Matrix;
            Recursive_Matrix=Recursive_Matrix+Tensor(:,:,j)/trace(Matrices_Multiplied);
        end
    end
    toc, clear iter  Recursive_Matrix  Matrices_Multiplied
    
    tic
    disp('------------- Bsxfun matrix-mul not so dull : ')
    for iter = 1:num_iter
        T = reshape(Matrix.',1,[])*reshape(Tensor,[],size(Tensor,3));
        out = sum(bsxfun(@rdivide,Tensor,reshape(T,1,1,[])),3);
    end
    toc, clear T out
    
    tic
    disp('-------------- All matrix-mul having a ball : ')
    for iter = 1:num_iter
        T = reshape(Matrix.',1,[])*reshape(Tensor,[],size(Tensor,3));
        [m,n,r] = size(Tensor);
        out = reshape(reshape(Tensor,[],size(Tensor,3))*(1./T.'),m,n);
    end
    toc
    

    计时 -

    ------------ Loopy woopy doops : 
    Elapsed time is 3.339464 seconds.
    ------------- Bsxfun matrix-mul not so dull : 
    Elapsed time is 1.354137 seconds.
    -------------- All matrix-mul having a ball : 
    Elapsed time is 0.373712 seconds.
    

相关问题