首页 文章

使用张量流中的另一个索引列表访问张量的元素

提问于
浏览
1

我需要使用我拥有的另一个索引列表来访问张量的元素,但目前使用简单的语法似乎是不可能的 . 我不确定它是否是一个bug,所以我在这里发布它以希望修复我的语法 . 我的代码是:

import tensorflow as tf
import numpy as np

sess = tf.Session()
input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
idx_list = np.array([0,2])
output = input[:, idx_list]

print(sess.run(output))

但我得到错误:

ValueError:形状必须等于等级,但是0和1从形状0与其他形状合并 . 对于带有输入形状的'strided_slice / stack_1'(op:'Pack'):[],[2] .

我安装的tensorflow版本是tensorflow-1.1.0-cp35(pip安装) .

Update:

我通过tf.fn_map执行此操作,但我真的怀疑这是进行索引的正确方法:

output = tf.transpose(tf.map_fn(lambda x: input[:,x], idx_list),perm=[1,0])

Update:

有一个特定的issue registered,在最新的评论中有一个很好的片段,可能会有所帮助 . 同时这个操作并不像numpy那么容易......

1 回答

  • 0

    您可以使用 tf.gathertf.transpose 执行此操作,如下所示:

    import tensorflow as tf
    import numpy as np
    
    sess = tf.Session()
    input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
    idx_list = np.array([0,2])
    output = tf.transpose(tf.gather(tf.transpose(input),idx_list))
    output.eval(session=sess)
    

    这打印

    array([[1, 3],
           [4, 6],
           [7, 9]])
    

相关问题