首页 文章

使用混淆矩阵理解多标签分类器

提问于
浏览
4

我有12个类的多标签分类问题 . 我正在使用 slimTensorflow 使用 ImageNet 上预训练的模型训练模型 . 以下是培训和验证中每个 class 存在的百分比

Training     Validation
  class0      44.4          25
  class1      55.6          50
  class2      50            25
  class3      55.6          50
  class4      44.4          50
  class5      50            75
  class6      50            75
  class7      55.6          50
  class8      88.9          50
  class9     88.9           50
  class10     50            25
  class11     72.2          25

问题是模型没有收敛,并且验证集上的曲线( Az )不足,例如:

Az 
  class0      0.99
  class1      0.44
  class2      0.96  
  class3      0.9
  class4      0.99
  class5      0.01
  class6      0.52
  class7      0.65
  class8      0.97
  class9     0.82
  class10     0.09
  class11     0.5
  Average     0.65

我不知道为什么它适用于某些类,而不适用于其他类 . 我决定深入研究细节,看看神经网络在学习什么 . 我知道混淆矩阵只适用于二元或多类分类 . 因此,为了能够绘制它,我不得不将问题转换为多类分类对 . 尽管模型是使用 sigmoid 训练来为每个类提供预测,但对于下面的混淆矩阵中的每个单个单元,我都显示了概率的平均值(通过对张量流的预测应用 sigmoid 函数得到) . 存在矩阵行中的类且不存在列中的类的图像 . 这适用于验证集图像 . 这样我认为我可以获得有关模型学习内容的更多细节 . 我只是圈出对角线元素用于显示目的 .

enter image description here

我的解释是:

  • 类别0和4在它们存在时被检测到而在它们不存在时不存在 . 这意味着可以很好地检测到这些类 .

  • 始终检测到类别2,6和7不存在 . 这不是我想要的 .

  • 始终检测到类别3,8和9存在 . 这不是我想要的 . 这可以应用于11级 .

  • 当存在不存在时检测到存在类别5并且当它存在时检测为不存在 . 它被反向检测到 .

  • 第3和第10类:我认为我们不能为这两个类提取太多信息 .

我的问题是解释 . 我不确定问题出在哪里,我不确定数据集中是否存在产生此类结果的偏差 . 我也想知道是否有一些指标可以帮助解决多标签分类问题?你可以和我分享你对这种混淆矩阵的解释吗?以及接下来要看什么/在哪里?对其他指标的一些建议会很棒 .

谢谢 .

EDIT:

我将问题转换为多类分类,因此对于每对类(例如0,1)来计算概率(类0,类1),表示为 p(0,1) :我采用工具0的图像工具1的预测存在并且工具1不存在并且我通过应用sigmoid函数将它们转换为概率,然后我显示那些概率的平均值 . 对于 p(1, 0) ,我也是这样做但是现在对于工具0使用工具1存在且工具0不存在的图像 . 对于 p(0, 0) ,我使用工具0所在的所有图像 . 考虑到上图中的 p(0,4) ,N / A表示没有工具0存在且工具4不存在的图像 .

以下是2个子集的图像数量:

  • 169320用于训练的图像

  • 37440图像用于验证

这是在训练集上计算的混淆矩阵(以与前面描述的验证集相同的方式计算),但这次颜色代码是用于计算每个概率的图像数量:
enter image description here

EDITED: 对于数据扩充,我对网络中的每个输入图像进行随机转换,旋转和缩放 . 此外,以下是有关这些工具的一些信息:

class 0 shape is completely different than the other objects.
class 1 resembles strongly to class 4.
class 2 shape resembles to class 1 & 4 but it's always accompanied by an object different than the others objects in the scene. As a whole, it is different than the other objects.
class 3 shape is completely different than the other objects.
class 4 resembles strongly to class 1
class 5 have common shape with classes 6 & 7 (we can say that they are all from the same category of objects)
class 6 resembles strongly to class 7
class 7 resembles strongly to class 6
class 8 shape is completely different than the other objects.
class 9 resembles strongly to class 10
class 10 resembles strongly to class 9
class 11 shape is completely different than the other objects.

EDITED: 以下是针对训练集提出的代码的输出:

Avg. num labels per image =  6.892700212615167
On average, images with label  0  also have  6.365296803652968  other labels.
On average, images with label  1  also have  6.601033718926901  other labels.
On average, images with label  2  also have  6.758548914659531  other labels.
On average, images with label  3  also have  6.131520940484937  other labels.
On average, images with label  4  also have  6.219187208527648  other labels.
On average, images with label  5  also have  6.536933407946279  other labels.
On average, images with label  6  also have  6.533908387864367  other labels.
On average, images with label  7  also have  6.485973817793214  other labels.
On average, images with label  8  also have  6.1241642788920725  other labels.
On average, images with label  9  also have  5.94092288040875  other labels.
On average, images with label  10  also have  6.983303518187239  other labels.
On average, images with label  11  also have  6.1974066621953945  other labels.

对于验证集:

Avg. num labels per image =  6.001282051282051
On average, images with label  0  also have  6.0  other labels.
On average, images with label  1  also have  3.987080103359173  other labels.
On average, images with label  2  also have  6.0  other labels.
On average, images with label  3  also have  5.507731958762887  other labels.
On average, images with label  4  also have  5.506459948320414  other labels.
On average, images with label  5  also have  5.00169779286927  other labels.
On average, images with label  6  also have  5.6729452054794525  other labels.
On average, images with label  7  also have  6.0  other labels.
On average, images with label  8  also have  6.0  other labels.
On average, images with label  9  also have  5.506459948320414  other labels.
On average, images with label  10  also have  3.0  other labels.
On average, images with label  11  also have  4.666095890410959  other labels.

Comments: 我认为它不仅与分布之间的差异有关,因为如果模型能够很好地概括第10类(意味着对象在训练过程中被正确识别,如0级),则验证集的准确性将是够好了 . 我的意思是问题在于训练集本身以及它的构建方式比两个分布之间的差异更大 . 它可以是:类或对象的存在频率强烈相似(如类10中非常类似于类9的情况)或数据集内的偏差或薄对象(表示输入中可能有1%或2%的像素)图像类2) . 我更多了.2604656比两个分布之间的差异 .

1 回答

  • 2

    输出校准

    我认为最重要的一件事是,神经网络的输出可能很难校准 . 我的意思是,它给不同实例的输出可能导致良好的排名(标签L的图像倾向于比没有标签L的图像具有更高的分数分数),但这些分数不能总是可靠地解释为概率(它可以给没有标签的实例提供非常高的分数,例如 0.9 ,并且只给出具有标签的实例更高的分数,如 0.99 ) . 我想这是否可能发生,取决于你选择的损失函数 .

    有关详细信息,请参阅示例:https://arxiv.org/abs/1706.04599


    逐一浏览所有课程

    Class 0: AUC(曲线下面积)= 0.99 . 这是一个非常好的分数 . 你的混淆矩阵中的第0列看起来也很好,所以这里没有错 .

    Class 1: AUC = 0.44 . 这是非常可怕的,低于0.5,如果我故意做出与你的网络为这个标签预测的相反的做法更好 .

    查看混淆矩阵中的第1列,它在各处都有相同的分数 . 对我来说,这表明网络没有设法了解这个类,并且几乎只是根据训练集中包含该标签的图像的百分比“猜测”(55.6%) . 由于这个百分比在验证集中下降到50%,这个策略确实意味着它会比随机稍差 . 尽管如此,第1行仍然具有此列中所有行的最大数量,因此它似乎至少学到了一点点,但并不多 .

    Class 2: AUC = 0.96 . 这是非常好的 .

    您对此课程的解释是,根据整个专栏的浅色阴影,它总是被预测为不存在 . 我不认为这种解释是正确的 . 看看它如何在对角线上得分> 0,在列中的其他地方只得0 . 该行中的分数可能相对较低,但可以轻松地与同一列中的其他行分开 . 您可能只需设置阈值来选择标签是否存在相对较低 . 我怀疑这是由于上面提到的校准事项 .

    这也是AUC实际上非常好的原因;可以选择一个阈值,使得分数高于阈值的大多数实例正确地具有标签,并且其下面的大多数实例都不正确 . 该阈值可能不是0.5,这是您假设良好校准时可能期望的阈值 . 绘制此特定标签的ROC曲线可帮助您确定阈值的确切位置 .

    Class 3: AUC = 0.9,相当不错 .

    您将其解释为始终被检测为存在,并且混淆矩阵确实在列中具有大量高数,但是AUC是好的并且对角线上的单元确实具有足够高的值以使其可以容易地与其他 . 我怀疑这是类似于第2类的情况(只是翻转,到处都是高预测,因此正确决策需要很高的阈值) .

    如果您希望能够确定一个精心选择的阈值是否确实可以从大多数“否定”(没有第3类的实例)中正确地分割大多数“肯定”(具有第3类的实例),那么您将需要对所有实例进行排序根据标签3的预测得分,然后浏览整个列表,并在每对连续条目之间计算如果您决定将阈值放在那里,那么您将获得的验证集合的准确度,并选择最佳阈值 .

    Class 4: 与0级相同 .

    Class 5: AUC = 0.01,显然很糟糕 . 同意你对混淆矩阵的解释 . 它在这里表现得如此糟糕 . 也许这是一个很难识别的对象?可能还有一些过度拟合(训练数据中的0个假阳性从第二个矩阵中的列判断,尽管还有其他类发生这种情况) .

    标签5图像的比例从训练到验证数据的增加可能也没有帮助 . 这意味着在培训期间,网络在此标签上的表现不如验证期间那么重要 .

    Class 6: AUC = 0.52,仅比随机略好 .

    从第一个矩阵中的第6列来看,这实际上可以类似于第2类的情况 . 如果我们也考虑AUC,它看起来也不会学习很好地排列实例 . 与5级相似,也不差 . 此外,再次,培训和验证分布完全不同 .

    Class 7: AUC = 0.65,相当平均 . 例如,显然不如第2类好,但也没有你从矩阵中解释的那么糟糕 .

    Class 8: AUC = 0.97,非常好,类似于3级 .

    Class 9: AUC = 0.82,不太好,但仍然很好 . 矩阵中的列有很多暗单元格,数字非常接近,我认为AUC非常好 . 它几乎存在于训练数据中的每一张图像中,因此预测它经常出现并不奇怪 . 也许这些非常暗的细胞中的一些只基于绝对数量较少的图像?弄清楚这很有趣 .

    Class 10: AUC = 0.09,太可怕了 . 对角线上的0很关注(你的数据标记是否正确?) . 根据第一个矩阵的第10行,它似乎经常对第3和第9类感到困惑(棉花和primary_incision_knives看起来很像secondary_incision_knives吗?) . 也许还有一些过度拟合训练数据 .

    Class 11: AUC = 0.5,不比随机好 . 性能差(并且矩阵中得分过高)很可能是因为这个标签出现在大多数训练图像中,但只有少数验证图像 .


    还有什么可以绘制/测量?

    为了更深入地了解您的数据,我首先绘制每个类共同出现频率的热图(一个用于培训,一个用于验证数据) . 单元格(i,j)将根据包含标签i和j的图像的比率着色 . 这将是一个对称图,在对角线单元格上根据您问题中的第一个数字列表着色 . 比较两个热图,看看它们的不同之处,看看它是否有助于解释模型的性能 .

    另外,知道(对于两个数据集)每个图像平均具有多少不同标签可能是有用的,并且对于每个单独的标签,平均共享图像的其他标签有多少 . 例如,我怀疑带有标签10的图像在训练数据中具有相对较少的其他标签 . 如果标签10识别出其他东西,则这可以阻止网络预测标签10,并且如果标签10突然在验证数据中更频繁地与其他对象共享图像,则可能导致性能不佳 . 由于伪代码可能比单词更容易得到点,因此打印如下内容可能会很有趣:

    # Do all of the following once for training data, AND once for validation data    
    tot_num_labels = 0
    for image in images:
        tot_num_labels += len(image.get_all_labels())
    avg_labels_per_image = tot_num_labels / float(num_images)
    print("Avg. num labels per image = ", avg_labels_per_image)
    
    for label in range(num_labels):
        tot_shared_labels = 0
        for image in images_with_label(label):
            tot_shared_labels += (len(image.get_all_labels()) - 1)
        avg_shared_labels = tot_shared_labels / float(len(images_with_label(label)))
        print("On average, images with label ", label, " also have ", avg_shared_labels, " other labels.")
    

    对于单个数据集,这不会提供太多有用的信息,但如果您为训练集和验证集执行此操作,则可以判断如果数字非常不同,则它们的分布完全不同

    最后,我有点担心你的第一个矩阵中的某些列如何在许多不同的行上出现完全相同的平均预测 . 我不太确定是什么原因引起的,但这可能对调查有用 .


    如何改善?

    如果你没有't already, I'建议研究你的训练数据的数据增加 . 由于您正在处理图像,因此可以尝试将现有图像的旋转版本添加到数据中 .

    对于具体的多标签案例,其目标是检测不同类型的对象,尝试简单地将一堆不同的图像(例如,两个或四个图像)连接在一起也可能是有趣的 . 然后,您可以将它们缩小到原始图像大小,并且标签指定原始标签集的并集 . 你会在合并图像的边缘出现有趣的不连续性,我不知道这是否有害 . 也许它不适合您的多物体检测,我认为值得一试 .

相关问题