首页 文章

张量流中的触发器:矩阵乘法

提问于
浏览
0

受此启发question我试图测量矩阵矩阵乘法的tensorflow所需的FLOPS .

对于分别具有大小(m×p)和(p×n)的两个矩阵A和B,得到的矩阵C =具有大小(m×n)的AB具有mn个条目 . 对于每个条目,需要p乘法和(p-1)求和 . 因此,操作总数为 mn(2p-1) .

使用链接问题/答案的代码,tensorflow输出 m*n*2p ,请参阅下面的代码 .

为什么返回这个近似值而不是理论值?在最坏的情况下,p = 1,该近似值比正确值大2倍 .

import numpy as np
import tensorflow as tf
g = tf.Graph()
run_meta = tf.RunMetadata()
with g.as_default():
    A=tf.convert_to_tensor(np.random.rand(13,9))
    B=tf.convert_to_tensor(np.random.rand(9,7))
    C = tf.matmul(A,B) # shape=[13,7]

    opts = tf.profiler.ProfileOptionBuilder.float_operation()    
    flops = tf.profiler.profile(g, run_meta=run_meta, cmd='op', options
=opts)
    if flops is not None:
        print('Flops should be ', 13*7*(2*9-1))
        print('Approximation 2*13*7*9=',2*13*7*9) 
        print('TF stats gives',flops.total_float_ops)

#Output: 
#Flops should be  1547
#Approximation 2*13*7*9= 1638
#TF stats gives 1638

2 回答

  • 2

    我认为这是因为在实践中,求和通常编码如下(下面的伪代码):

    total = 0
    for i in 0...p
      total += x[i] * y[i]
    

    也就是说,第一个元素 x[0] * y[0] 总和为 total (然后是0),它产生 p 个求和而不是 p-1 .

    你可以尝试聪明,避免这种额外的总和:

    total = x[0] * y[0]
    for i in 1...p
      total += x[i] * y[i]
    

    ......但是如果 p==0 会发生什么?哎哟我们需要添加一个额外的比较:

    if p > 0
      total = x[0] * y[0]
      for i in 1...p
        total += x[i] * y[i]
    else
      total = 0
    

    问题是,这种比较不是翻牌,也不会出现在你的翻牌数中 - 但实际上,与简单的添加相比,它的代价更高,甚至更高 .

    底线:

    • 如果实现不是初始总和,则翻转计算可能是正确的

    • 这"optimization"实际上可能无法加速您的代码

    • 采取一粒盐的拖把措施,不要太担心消失组件 .

  • 0

    我不确定为什么,但我认为这是"coded"理论值:

    ...
    
    @ops.RegisterStatistics("MatMul", "flops")
    def _calc_mat_mul_flops(graph, node):
      """Calculates the compute resources needed for MatMul."""
      transpose_a = node.attr["transpose_a"].b
      a_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
      a_shape.assert_is_fully_defined()
      if transpose_a:
        k = int(a_shape[0])
      else:
        k = int(a_shape[1])
      output_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
      output_shape.assert_is_fully_defined()
      output_count = np.prod(output_shape.as_list())
      return ops.OpStats("flops", (k * output_count * 2))
    
    ...
    

相关问题