首页 文章

多标签分类:如何学习阈值?

提问于
浏览
0

我有一个很深的CNN,适用于多类分类 . 我想“升级”挑战,并根据多标签分类问题进行培训 .

为此,我用sigmoid替换了我的softmax并尝试训练我的网络以最小化:

tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=y_pred)

但我最终得到了奇怪的预测:

Prediction for Im1 : [ 0.59275776  0.08751075  0.37567005  0.1636796   0.42361438  0.08701646

0.38991812 0.54468459 0.34593087 0.82790571]

Prediction for Im1 : [ 0.52609032  0.07885984  0.45780018  0.04995904  0.32828355  0.07349177

0.35400775 0.36479294 0.30002621 0.84438241]

Prediction for Im1 : [ 0.58714485  0.03258472  0.3349618   0.03199361  0.54665488  0.02271551

0.43719986 0.54638696 0.20344526 0.88144571]

所以我想我尝试让我的网络学习每个类的阈值,以确定样本是否属于te类 .

所以我把它添加到我的代码中:

initial = tf.truncated_normal([numberOfClasses], stddev=0.1)
W_thresh = tf.Variable(initial)

y_predict_thresh = int(y_predict > W_thresh)

但是我有一个错误:

TypeError: int() argument must be a string or a number, not 'Tensor'.

任何人都有任何想法帮助我前进(如何避免这个错误?,我的数据集是否真的不 balancer 的事实导致这种“常量”预测?对多标签分类的其他建议?,...)?

谢谢

编辑:我刚刚意识到做阈值处理对于反向传播可能不是很酷:/

1 回答

  • 2

    不知道你是否仍然需要它,但你可以使用tensorflow转换函数 tf.to_int32, tf.to_int64 . 在评估之前,表达式是python的一个对象,因此它不能简单地将它转换为 int() .

    这可以满足您的需求:

    with tf.Session() as sess:
        check = sess.run([tf.to_int64(W1 > W2)])
    

相关问题