首页 文章

Tensorflow批量稀疏乘法

提问于
浏览
1

我想用一个密集张量乘以一个稀疏张量但是在一批中这样做 .

例如,我有一个稀疏张量,其对应的密集形状为(20,65536,65536),其中20是批量大小 . 我想将批处理中的每个(65536,65536)与具有密集表示的张量形状(20,65536)的相应(65536x1)相乘 . tf.sparse_tensor_dense_matmul 只接受2级稀疏张量 . 有没有办法在批次中执行此操作?

由于内存限制,我希望尽可能避免将稀疏矩阵转换为密集矩阵 .

1 回答

  • 0

    答案很简单 - 首先重构稀疏张量,然后将其乘以密集矩阵 . 这样的东西会起作用:

    sparse_tensor_rank2 = tf.sparse_reshape(sparse_tensor, [-1, 65536])
    

相关问题