首页 文章

TensorFlow,批量索引(第一维)和排序

提问于
浏览
4

我有一个形状为 (?,368,5) 的参数张量,以及一个形状为 (?,368) 的查询张量 . 查询张量存储用于对第一张量进行排序的索引 .

所需的输出具有形状: (?,368,5) . 由于我需要它用于神经网络中的损失函数,因此使用的操作应该保持可微 . 此外,在运行时,第一个轴 ? 的大小对应于batchsize .

到目前为止,我尝试了 tf.gathertf.gather_nd ,但 tf.gather(params,query) 导致形状为 (?,368,368,5) 的张量 .

通过执行以下操作来实现查询张量:

query = tf.nn.top_k(params[:, :, 0], k=params.shape[1], sorted=True).indices

总的来说,我试图通过第三轴上的第一个元素对params张量进行排序(对于倒角距离的种类) . 最后要提到的是,我使用 Keras 框架 .

1 回答

  • 2

    您需要将第一个维度的索引添加到 query ,以便将其与 tf.gather_nd 一起使用 . 这是一种方法:

    import tensorflow as tf
    import numpy as np
    
    np.random.seed(100)
    
    with tf.Graph().as_default(), tf.Session() as sess:
        params = tf.placeholder(tf.float32, [None, 368, 5])
        query = tf.nn.top_k(params[:, :, 0], k=params.shape[1], sorted=True).indices
        n = tf.shape(params)[0]
        # Make tensor of indices for the first dimension
        ii = tf.tile(tf.range(n)[:, tf.newaxis], (1, params.shape[1]))
        # Stack indices
        idx = tf.stack([ii, query], axis=-1)
        # Gather reordered tensor
        result = tf.gather_nd(params, idx)
        # Test
        out = sess.run(result, feed_dict={params: np.random.rand(10, 368, 5)})
        # Check the order is correct
        print(np.all(np.diff(out[:, :, 0], axis=1) <= 0))
        # True
    

相关问题