假设有一个张量 state 包含一个rnn状态列表,另一个张量 prob 包含每个状态的概率 .

state = tf.placeholder(tf.float32, [None, 49, 32])
print state.get_shape()  # (?, 49, 32) (batch_size, candidate_size, state_size)

prob = tf.placeholder(tf.float32, [None, 49])
print prob.get_shape() # (?, 49) (batch_size, candidate_size)

# Now I want to fetch 7 states with top probabilities
_, indices = tf.nn.top_k(prob, 7)
print indices.get_shape() # (?, 7)

如何用 indices 切片 state


编辑:

使用 tf.gather(state, indices) 的问题是它只会沿 first dimension 切片 state ,这是批量维度 . 在这里,我们希望沿着第二维(长度为49)对其进行切片 .