我有一个形状为 (?,368,5)
的参数张量,以及一个形状为 (?,368)
的查询张量 . 查询张量存储用于对第一张量进行排序的索引 .
所需的输出具有形状: (?,368,5)
. 由于我需要它用于神经网络中的损失函数,因此使用的操作应该保持可微 . 此外,在运行时,第一个轴 ?
的大小对应于batchsize .
到目前为止,我尝试了 tf.gather
和 tf.gather_nd
,但 tf.gather(params,query)
导致形状为 (?,368,368,5)
的张量 .
通过执行以下操作来实现查询张量:
query = tf.nn.top_k(params[:, :, 0], k=params.shape[1], sorted=True).indices
总的来说,我试图通过第三轴上的第一个元素对params张量进行排序(对于倒角距离的种类) . 最后要提到的是,我使用 Keras
框架 .
1 回答
您需要将第一个维度的索引添加到
query
,以便将其与tf.gather_nd
一起使用 . 这是一种方法: