首页 文章

在矩阵的每一行中查找1的列索引

提问于
浏览
5

我在Matlab中有以下矩阵:

M = [0 0 1
     1 0 0
     0 1 0
     1 0 0
     0 0 1];

每行只有一个1.我如何(没有循环)确定列向量,以便第一个元素是2,如果第二列中有1,第二个元素是3,对于第三列中的一个等 . ?上面的例子应该变成:

M = [ 3
      1
      2
      1
      3];

3 回答

  • 5

    你可以用简单的矩阵乘法来解决这个问题 .

    result = M * (1:size(M, 2)).';
    
         3
         1
         2
         1
         3
    

    这通过将M x 3矩阵与3 x 1阵列相乘来实现,其中3x1的元素仅为 [1; 2; 3] . 简而言之,对于 M 的每一行,使用3 x 1阵列执行逐元素乘法 . 只有 M 行中的1将在结果中产生任何结果 . 然后将该元素乘法的结果相加 . 因为每行只有一个"1",所以结果将是该1所在的列索引 .

    例如,对于 M 的第一行 .

    element_wise_multiplication = [0 0 1] .* [1 2 3]
    
        [0, 0, 3]
    
    sum(element_wise_multiplication)
    
        3
    

    Update

    基于@reyryeng@Luis提供的解决方案,我决定运行一个比较,看看各种方法的性能如何比较 .

    为了设置测试矩阵( M ),我创建了一个原始问题中指定形式的矩阵,并改变了行数 . 使用 randi([1 nCols], size(M, 1)) 随机选择哪一列具有1 . 使用 timeit 分析执行时间 .

    当使用 M (MATLAB的默认值) M 运行时,您将获得以下执行时间 .

    enter image description here

    如果 Mlogical ,那么矩阵乘法会因为必须在矩阵乘法之前将其转换为数字类型而受到影响,而其他两个则有一点性能提升 .

    enter image description here

    这是我使用的测试代码 .

    sizes = round(linspace(100, 100000, 100));
    times = zeros(numel(sizes), 3);
    
    for k = 1:numel(sizes)
        M = generateM(sizes(k));
        times(k,1) = timeit(@()M * (1:size(M, 2)).');
        M = generateM(sizes(k));
        times(k,2) = timeit(@()max(M, [], 2), 2);
        M = generateM(sizes(k));
        times(k,3) = timeit(@()find(M.'), 2);
    end
    
    figure
    plot(range, times / 1000);
    legend({'Multiplication', 'Max', 'Find'})
    xlabel('Number of rows in M')
    ylabel('Execution Time (ms)')
    
    function M = generateM(nRows)
        M = zeros(nRows, 3);
        col = randi([1 size(M, 2)], 1, size(M, 1));
        M(sub2ind(size(M), 1:numel(col), col)) = 1;
    end
    
  • 2

    您也可以滥用find并观察 M 的转置的 row 位置 . 您必须首先转置矩阵,因为 find 按列主要顺序运行:

    M = [0 0 1
         1 0 0
         0 1 0
         1 0 0
         0 0 1];
    
    [out,~] = find(M.');
    

    不确定这是否比矩阵乘法更快 .

  • 2

    另一种方法:使用max的第二个输出:

    [~, result] = max(M.', [], 1);
    

    或者,如@rayryeng所示,沿第二维使用 max 而不是转置 M

    [~, result] = max(M, [], 2);
    

    对于

    M = [0 0 1
         1 0 0
         0 1 0
         1 0 0
         0 0 1];
    

    这给了

    result =
         3     1     2     1     3
    

    如果 M 在给定行中包含多个 1 ,则这将给出第一个 1 的索引 .

相关问题