我通过卷积神经网络进行文本分类 . 我为我的项目使用了 Health 文档(ICD-9-CM代码),并使用了与dennybritz相同的模型,但我的数据有36个标签 . 我使用one_hot编码来编码我的标签 .
这是我的问题,当我运行每个文档都有一个标签的数据时,我的代码精确度从0.8到1是完美的 . 如果我运行的数据有多个标签,则精度会大大降低 .
例如:文档的单个标签为 "782.0"
: [0 0 1 0 ... 0]
,
文档的多个标签为 "782.0 V13.09 593.5"
: [1 0 1 0 ... 1]
.
谁能提出为什么会发生这种情况以及如何改进呢?
1 回答
标签编码似乎是正确的 . 如果您有多个正确的标签,
[1 0 1 0 ... 1]
看起来完全正常 . Denny的post中使用的损失函数是tf.nn.softmax_cross_entropy_with_logits
,这是多类问题的损失函数 .在多标签问题中,您应该使用
tf.nn.sigmoid_cross_entropy_with_logits
:损失函数的输入将是logits(
WX
)和目标(标签) .修复精度测量
为了正确测量多标签问题的准确性,需要更改以下代码 .
如果您有多个正确的标签,则上述
correct_predictions
的逻辑不正确 . 例如,说num_classes=4
,标签0和2是正确的 . 因此,input_y=[1, 0, 1, 0].
correct_predictions
需要打破索引0和索引2之间的联系 . 我不确定tf.argmax
如何打破平局,但如果通过选择较小的索引打破平局,则标签2的预测总是被认为是错误的,这肯定会伤害你的准确度 .实际上,在多标签问题中,precision and recall是比准确性更好的指标 . 您还可以考虑使用precision @ k(
tf.nn.in_top_k
)来报告分类器性能 .