首页 文章

为稀疏密集乘法提供张量流稀疏矩阵 . 获得以下错误:TypeError:输入必须是SparseTensor

提问于
浏览
3

我正在尝试使用 tensorflow tf.sparse_tensor_dense_matmul(X, W1) . X定义为tf.placeholder:

X = tf.placeholder("float", [None, size]) .

W是tf.Variable

在进入dict时,我传递的是张量流稀疏矩阵 . 但是我收到了错误:

TypeError:输入必须是SparseTensor .

如何让sparse_tensor_dense_matmul模块知道我将以稀疏张量传递?

1 回答

  • 2

    要通过占位符传递SparseTensor,可以使用sparse_placeholder

    sparse_place = tf.sparse_placeholder(tf.float64)
    mul_result = tf.sparse_tensor_dense_matmul(sparse_place, some_dense_tensor)
    

    您可以按如下方式使用它:

    dense_version = tf.sparse_tensor_to_dense(sparse_place)
    sess.run(dense_version, feed_dict={sparse_place: tf.SparseTensorValue([[0,1], [2,2]], [1.5, 2.9], [3, 3])})
    Out: array([[ 0. ,  1.5,  0. ],
                [ 0. ,  0. ,  0. ],
                [ 0. ,  0. ,  2.9]])
    

    或者,您可以为值,形状和索引创建三个单独的占位符,例如:

    indices = tf.placeholder(tf.int64)
    shape = tf.placeholder(tf.int64)
    values = tf.placeholder(tf.float64)
    sparse_tensor = tf.SparseTensor(indices, shape, values)
    sess.run(sparse_tensor, feed_dict={shape: [3, 3], indices: [[0,1], [2,2]], values: [1.5, 2.9]})
    

相关问题