首页 文章

在MATLAB中用bsxfun替换repmat

提问于
浏览
3

在以下功能中,我想进行一些更改以使其快速 . 它本身很快但我必须在for循环中多次使用它所以需要很长时间 . 我想如果我用bsxfun替换repmat将使它更快但我不确定 . 我该怎么做这些替换

function out = lagcal(y1,y1k,source)
kn1 = y1(:);
kt1 = y1k(:);

kt1x = repmat(kt1,1,length(kt1));  

eq11 = 1./(prod(kt1x-kt1x'+eye(length(kt1))));
eq1 = eq11'*eq11;

dist = repmat(kn1,1,length(kt1))-repmat(kt1',length(kn1),1);
[fixi,fixj] = find(dist==0); dist(fixi,fixj)=eps;
mult = 1./(dist);

eq2 = prod(dist,2);
eq22 = repmat(eq2,1,length(kt1));
eq222 = eq22 .* mult; 

out = eq1 .* (eq222'*source*eq222);
end

它真的加快了我的功能吗?

1 回答

  • 6

    简介和代码更改

    功能代码中使用的所有 repmat 用法都是将输入扩展为大小,以便稍后可以执行涉及这些输入的数学运算 . 这是bsxfun的量身定制的情况 . 可悲的是,虽然功能代码的真正瓶颈似乎是别的东西 . 在我们讨论代码的所有性能相关方面时继续保持 .

    接下来会显示 repmat 替换为_1874198的代码,并将替换的代码保留为评论用于比较 -

    function out = lagcal(y1,y1k,source)
    
    kn1 = y1(:);
    kt1 = y1k(:);
    
    %//kt1x = repmat(kt1,1,length(kt1));
    %//eq11 = 1./(prod(kt1x-kt1x'+eye(length(kt1)))) %//'
    eq11 = 1./prod(bsxfun(@minus,kt1,kt1.') + eye(numel(kt1))) %//'
    
    eq1 = eq11'*eq11; %//'
    
    %//dist = repmat(kn1,1,length(kt1))-repmat(kt1',length(kn1),1) %//'
    dist = bsxfun(@minus,kn1,kt1.') %//'
    
    [fixi,fixj] = find(dist==0); 
    
    dist(fixi,fixj)=eps;
    mult = 1./(dist);
    
    eq2 = prod(dist,2);
    
    %//eq22 = repmat(eq2,1,length(kt1));
    %//eq222 = eq22 .* mult
    eq222 = bsxfun(@times,eq2,mult)
    
    out = eq1 .* (eq222'*source*eq222); %//'
    
    return; %// Better this way to end a function
    

    这里可以添加一个修改 . 在最后一行,我们可以做如下所示的事情,但时间结果并没有显示出它的巨大好处 -

    out = bsxfun(@times,eq11.',bsxfun(@times,eq11,eq222'*source*eq222))
    

    这样可以避免原始代码中先前完成的 eq1 的计算,因此您可以节省更多的时间 .

    基准测试

    接下来讨论对代码的bsxfun修改部分与原始基于repmat的代码的基准测试 .

    Benchmarking Code

    N_arr = [50 100 200 500 1000 2000 3000]; %// array elements for N (datasize)
    blocks = 3;
    timeall = zeros(2,numel(N_arr),blocks);
    
    for k1 = 1:numel(N_arr)
        N = N_arr(k1);
        y1 = rand(N,1);
        y1k = rand(N,1);
        source = rand(N);
    
        kn1 = y1(:);
        kt1 = y1k(:);
    
        %% Block 1 ----------------
        block = 1;
        f = @() block1_org(kt1);
        timeall(1,k1,block) = timeit(f);
        clear f
    
        f = @() block1_mod(kt1);
        timeall(2,k1,block) = timeit(f);
        eq11 = feval(f);
        clear f
        %% Block 1 ----------------
    
        eq1 = eq11'*eq11; %//'
    
        %% Block 2 ----------------
        block = 2;
        f = @() block2_org(kn1,kt1);
        timeall(1,k1,block) = timeit(f);
        clear f
    
        f = @() block2_mod(kn1,kt1);
        timeall(2,k1,block) = timeit(f);
        dist = feval(f);
        clear f
        %% Block 2 ----------------
    
        [fixi,fixj] = find(dist==0);
    
        dist(fixi,fixj)=eps;
        mult = 1./(dist);
    
        eq2 = prod(dist,2);
    
        %% Block 3 ----------------
        block = 3;
        f = @() block3_org(eq2,mult,length(kt1));
        timeall(1,k1,block) = timeit(f);
        clear f
    
        f = @() block3_mod(eq2,mult);
        timeall(2,k1,block) = timeit(f);
        clear f
        %% Block 3 ----------------
    
    end
    
    %// Display benchmark results
    figure,
    for k2 = 1:blocks
        subplot(blocks,1,k2),
        title(strcat('Block',num2str(k2),' results :'),'fontweight','bold'),hold on
        plot(N_arr,timeall(1,:,k2),'-ro')
        plot(N_arr,timeall(2,:,k2),'-kx')
        legend('REPMAT Method','BSXFUN Method')
        xlabel('Datasize (N) ->'),ylabel('Time(sec) ->')
    end
    

    Associated functions

    function out = block1_org(kt1)
    kt1x = repmat(kt1,1,length(kt1));
    out = 1./(prod(kt1x-kt1x'+eye(length(kt1))));
    return;
    
    function out = block1_mod(kt1)
    out = 1./prod(bsxfun(@minus,kt1,kt1.') + eye(numel(kt1)));
    return;
    
    function out = block2_org(kn1,kt1)
    out = repmat(kn1,1,length(kt1))-repmat(kt1',length(kn1),1);
    return;
    
    function out = block2_mod(kn1,kt1)
    out = bsxfun(@minus,kn1,kt1.');
    return;
    
    function out = block3_org(eq2,mult,length_kt1)
    eq22 = repmat(eq2,1,length_kt1);
    out = eq22 .* mult;
    return;
    
    function out = block3_mod(eq2,mult)
    out = bsxfun(@times,eq2,mult);
    return;
    

    Results

    enter image description here

    结论

    基于 bsxfun 的代码显示了基于repmat的加速,这是令人鼓舞的 . 但是,在不同的数据量上对原始代码进行分析表明,最后一行中的多个矩阵乘法似乎占据了函数代码的大部分运行时间,这在MATLAB中非常有效 . 除非你有办法通过使用其他数学技术来避免这些乘法,否则它们看起来像瓶颈 .

相关问题