首页 文章

使用预先训练的inception_resnet_v2和Tensorflow

提问于
浏览
14

我一直在尝试使用谷歌发布的预训练的inception_resnet_v2模型 . 我正在使用他们的模型定义(https://github.com/tensorflow/models/blob/master/slim/nets/inception_resnet_v2.py)并给出检查点(http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz)来加载tensorflow中的模型,如下所示[下载提取检查点文件并下载示例图像dog.jpg和panda.jpg来测试此代码] -

import tensorflow as tf
slim = tf.contrib.slim
from PIL import Image
from inception_resnet_v2 import *
import numpy as np

checkpoint_file = 'inception_resnet_v2_2016_08_30.ckpt'
sample_images = ['dog.jpg', 'panda.jpg']
#Load the model
sess = tf.Session()
arg_scope = inception_resnet_v2_arg_scope()
with slim.arg_scope(arg_scope):
  logits, end_points = inception_resnet_v2(input_tensor, is_training=False)
saver = tf.train.Saver()
saver.restore(sess, checkpoint_file)
for image in sample_images:
  im = Image.open(image).resize((299,299))
  im = np.array(im)
  im = im.reshape(-1,299,299,3)
  predict_values, logit_values = sess.run([end_points['Predictions'], logits], feed_dict={input_tensor: im})
  print (np.max(predict_values), np.max(logit_values))
  print (np.argmax(predict_values), np.argmax(logit_values))

但是,此模型代码的结果未给出预期结果(无论输入图像如何,都会预测918类) . 有人能帮我理解我哪里错了吗?

1 回答

  • 14

    Inception网络期望输入图像具有从[-1,1]缩放的颜色通道 . 如图所示here .

    您可以使用现有的预处理,或者在您的示例中自己缩放图像: im = 2*(im/255.0)-1.0 ,然后将它们提供给网络 .

    如果不进行缩放,则输入[0-255]远大于网络预期,并且所有偏差都可以非常强烈地预测类别918(漫画书) .

相关问题