首页 文章

当在Keras中通过ResNet50进行转移学习时,损失总是变为纳米

提问于
浏览
0

我正在使用转移学习通过 Keras 中的 ResNet50 模型训练图像分类器并加载预训练的权重,但 loss 最初立即转到 nan 并且 acc 保持随机级别 .

实际上,我不知道出了什么问题,因为我已经使用这个模型成功地训练了一个分类器,虽然它并没有很高 acc 但它运作良好 . 这次失败了 .

我调整了 lr 但没有发生任何事情 . 有人说数据可能有问题,所以我改变了数据并且只发现不同的图像,相同的模型将显示不同的结果(也就是说,一些数据/图像效果很好,另一个数据/图像会立即产生结果) . 怎么会这样?我真的很困惑,无法弄清楚我的图像有什么问题 .

数据集:8个类,每个类包含大约300个图像 .

这是所有的代码:

import keras
import h5py
import numpy as np
import matplotlib.pyplot as plt

from keras.applications import ResNet50
from keras.models import Sequential
from keras.layers import Dense, Flatten, GlobalAveragePooling2D
from keras.applications.resnet50 import preprocess_input
from keras.preprocessing.image import ImageDataGenerator


data_generator = ImageDataGenerator(preprocessing_function= preprocess_input, 
                        rescale = 1./255)

train_generator = data_generator.flow_from_directory("image/train", 
                        target_size = (100, 100), 
                        batch_size = 32, 
                        class_mode = "categorical")
dev_generator = data_generator.flow_from_directory("image/dev", 
                        target_size = (100, 100), 
                        batch_size = 32, 
                        class_mode = "categorical")

num_classes = 8
model = Sequential()
model.add(ResNet50(include_top = False, pooling = "avg", weights= "resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5"))
model.add(Dense(num_classes, activation = "softmax"))
model.layers[0].trainable = False

model.compile(optimizer= "adam", loss= "categorical_crossentropy", metrics=["accuracy"])

model.fit_generator(train_generator, steps_per_epoch= 1, epochs = 1)

并且运行输出是:

Epoch 1/1
1/1 [==============================] - 6s 6s/step - loss: nan - acc: 0.0938

1 回答

  • 0

    第一次纠正 “image/dev”"image/dev"

    我认为你的错误在于这一行:

    data_generator = ImageDataGenerator(preprocessing_function= preprocess_input, rescale = 1./255)
    

    当您同时使用 preprocess_input 函数和 rescale = 1./255 时,可以对数据进行双倍缩放 . 尝试删除重新缩放...

    data_generator = ImageDataGenerator(preprocessing_function= preprocess_input)
    

相关问题