首页 文章

在PyTorch的最后一层中屏蔽特定元素

提问于
浏览
0

我现在正在复制以下模型,该模型输出 action 并使用 filter 来过滤不合适的候选者 . https://arxiv.org/abs/1702.03274

在此模型中,输出在最后一个softmax层之后被过滤 . 我们假设 action_size==3 . 所以密集和asoftmax层之后的输出如下所示 .

output: [0.1, 0.7, 0.2]
filter: [0, 1, 1]
output*filter: [0, 0.7, 0.2]

但在pytorch中, logsoftmax 首选 NLLLoss . 所以我的输出如下 . 这没有意义 .

output: [-5.4, -0.2, -4.9]
filter: [0, 1, 1]
output*filter: [0, -0.2, -4.9]

所以pytoroch不推荐 vanilla Softmax . 我应该如何应用蒙版以消除特定的动作?或者,与香草Softmax有任何分类交叉熵损失函数?

此模块不直接与NLLLoss一起工作,NLLLoss期望在Softmax和它自身之间计算Log . 请改用Logsoftmax(它更快,具有更好的数值属性) . http://pytorch.org/docs/master/nn.html#torch.nn.Softmax

1 回答

  • 0

    LogSoftmax的输出只是Softmax输出的对数 . 这意味着您只需调用 torch.exp(output_from_logsoftmax) 即可获得与Softmax相同的值 .

    因此,如果我正确地阅读您的问题,您将计算LogSoftmax,然后将其提供给NLLLoss,并指数用于过滤 .

相关问题