首页 文章

索引Keras Tensor

提问于
浏览
0

我的Keras功能模型的输出层是维度 (None, 1344, 2) 的张量 x . 我希望从 x 的第二维中提取 n < 1344 条目,并创建一个大小为 (None, n, 2) 的新张量 y .

通过简单地访问 x[:, :n,:] 来提取 n 连续条目似乎是直截了当的,但如果 n 索引是非连续的(看似很难) . 在Keras有这样一个干净的方式吗?

到目前为止,这是我的方法 .

Experiment 1 (切片张量,连续索引,有效):

print('My tensor shape is', K.int_shape(x)) #my tensor 
(None, 1344, 2) # as printed in my code
print('Slicing first 5 entries, shape is', K.int_shape(x[:, :5, :]))
(None, 5, 2) # as printed in my code, works!

Experiment 2 (在任意索引处索引张量,失败)

print('My tensor shape is', K.int_shape(x)) #my tensor 
(None, 1344, 2) # as printed in my code
foo = np.array([1,2,4,5,8])
print('arbitrary indexing, shape is', K.int_shape(x[:,foo,:]))

Keras返回以下错误:

ValueError: Shapes must be equal rank, but are 1 and 0
From merging shape 1 with other shapes. for 'strided_slice_17/stack_1' (op: 
'Pack') with input shapes: [], [5], [].

Experiment 3 (Tensor flow后端函数)我也尝试了 K.backend.gather ,但它的用法不清楚,因为1)Keras文档声明索引应该是整数的张量,如果我的目标是在 x 中提取条目,那么没有Keras相当于 numpy.where 某个条件和2) K.backend.gather 似乎从 axis = 0 中提取条目,而我想从 x 的第二个维度中提取 .

1 回答

  • 1

    您正在寻找tf.gather_nd,它将根据索引数组进行索引:

    # From documentation
    indices = [[0, 0], [1, 1]]
    params = [['a', 'b'], ['c', 'd']]
    output = ['a', 'd']
    

    要在Keras模型中使用它,请确保将其包装在像 Lambda 这样的层中 .

相关问题