我正在使用MxNet开发一个条件计算框架 . 假设我们的小批量有N个样本 . 我需要在我的计算图中使用伪代码执行这种操作:
x = graph.Variable("x")
y = graph.DoSomeTranformations(x)
# The following operation generates a Nxk sized matrix, k responses for each sample.
z = graph.DoDecision(y)
for i in range(k):
argmax_sample_indices_for_i = graph.ArgMaxIndices(z, i)
y_selected_samples = graph.TakeSelectedSample(y, argmax_sample_indices_for_i )
result = graph.DoSomeTransformations(y_selected_samples)
我想要实现的目标如下:获得y后,我应用决策函数(这可以是D到k完全连接层,其中D是数据维度)并获取我的N尺寸小批量中每个样本的k个激活 . 然后,我想基于每个样本的最大激活的列索引,动态地将我的小批量分成k个不同的部分(k可以是2,3,一个小数字) . 假设z为Nxk大小的矩阵,我的假设“graph.ArgMaxIndices”函数执行该操作,而i,该函数查找样本索引,这些索引沿列i给出最大激活并返回其索引 . (注意,我寻找任何系列或函数组合,它们给“graph.ArgMaxIndices”等效结果,而不是单个函数,具体而言) . 最后,对于每个i,我选择具有最大激活的样本并对它们应用特定的变换 . 目前,据我所知,MxNet在其符号网络中不支持这种条件计算 . 因此,我在每个决策之后构建单独的符号图,并且必须编写单独的簿记 - 每个小批量分割的条件图结构,这产生1)维护和开发的非常复杂和繁琐的代码2)在训练和评估期间降低的运行性能 .
我的问题是,我可以使用Tensorflow的符号运算符来完成上述操作吗?是否允许根据标准选择小批量的子集?是否有一个函数或一系列函数,它们相当于上面伪代码中的“graph.ArgMaxIndices”? (给定Nxk矩阵和列索引i,返回行的索引,其在列k处具有最大激活) .
1 回答
你可以在Tensorflow中做到这一点 .
我看到的最好的方法是使用一个掩码和tf.boolean_mask
k
次,i
-th掩码由tf.equal(i, tf.argmax(z, axis=-1))给出