我想实现一个图像检索模型 . 该模型将使用三重损失函数(与facenet或类似架构相同)进行训练 . 我的想法是使用Keras的预训练分类模型(例如resnet50),并使其成为三重架构 . 这是我在Keras的模特:
resnet_input = Input(shape=(224,224,3))
resnet_model = ResNet50(weights='imagenet', include_top = False, input_tensor=resnet_input)
net = resnet_model.output
net = Flatten(name='flatten')(net)
net = Dense(512, activation='relu', name='embded')(net)
net = Lambda(l2Norm, output_shape=[512])(net)
base_model = Model(resnet_model.input, net, name='resnet_model')
input_shape=(224,224,3)
input_anchor = Input(shape=input_shape, name='input_anchor')
input_positive = Input(shape=input_shape, name='input_pos')
input_negative = Input(shape=input_shape, name='input_neg')
net_anchor = base_model(input_anchor)
net_positive = base_model(input_positive)
net_negative = base_model(input_negative)
positive_dist = Lambda(euclidean_distance, name='pos_dist')([net_anchor, net_positive])
negative_dist = Lambda(euclidean_distance, name='neg_dist')([net_anchor, net_negative])
stacked_dists = Lambda(
lambda vects: K.stack(vects, axis=1),
name='stacked_dists'
)([positive_dist, negative_dist])
model = Model([input_anchor, input_positive, input_negative], stacked_dists, name='triple_siamese')
def triplet_loss(_, y_pred):
margin = K.constant(1)
return K.mean(K.maximum(K.constant(0), K.square(y_pred[0]) - K.square(y_pred[1]) + margin))
def accuracy(_, y_pred):
return K.mean(y_pred[0] < y_pred[1])
def l2Norm(x):
return K.l2_normalize(x, axis=-1)
def euclidean_distance(vects):
x, y = vects
return K.sqrt(K.maximum(K.sum(K.square(x - y), axis=1, keepdims=True), K.epsilon()))
模型应预测每个图像的特征向量 . 如果图像来自同一类,则这些向量之间的距离(在这种情况下为欧几里得)应接近于零,如果不是,则接近于1 .
我已经尝试了不同的学习步骤,批量大小,损失函数中的不同边距,从原始resnet模型中选择不同的输出层,将不同的层添加到resnet的末尾,仅训练新添加的层与训练整个模型 . 我也尝试使用这个没有预训练重量的resnet模型,无论我做了什么 . (输入图像的预处理方式与此模型的预期方式相同 keras.applications.resnet50.preprocess_input
)
我没有进行任何可能导致收敛缓慢的负面挖掘,但在这种情况下(检查函数)的0.5准确度仍然是随机预测 .
所以我开始思考也许我想念一些非常重要的东西(这是一个相当困难的架构) . 所以,如果您在我的实施中发现错误或可疑,我会非常高兴 .
2 回答
如果有人有兴趣,重写
y_pred[0]
和y_pred[1]
至
y_pred[:,0,0]
和y_pred[:,1,0]
固定它 .
现在模型似乎正在训练(损失正在减少,准确性正在增加) .
我没有足够的声誉点来评论所以我这样写 .
我想通过使用您的代码来做类似于您正在做的事情 .
我是CNN的新手,我不确定我的训练数据应该是什么样子 . 你愿意分享剩下的代码吗?我会很感激!
编辑:
要回答我自己的问题,这可能对某人有用,这就是我在假期照片(http://lear.inrialpes.fr/%7Ejegou/data.php)上的表现,它有效: