假设有一个张量 state
包含一个rnn状态列表,另一个张量 prob
包含每个状态的概率 .
state = tf.placeholder(tf.float32, [None, 49, 32])
print state.get_shape() # (?, 49, 32) (batch_size, candidate_size, state_size)
prob = tf.placeholder(tf.float32, [None, 49])
print prob.get_shape() # (?, 49) (batch_size, candidate_size)
# Now I want to fetch 7 states with top probabilities
_, indices = tf.nn.top_k(prob, 7)
print indices.get_shape() # (?, 7)
如何用 indices
切片 state
?
编辑:
使用 tf.gather(state, indices)
的问题是它只会沿 first dimension 切片 state
,这是批量维度 . 在这里,我们希望沿着第二维(长度为49)对其进行切片 .