首页 文章

如何在Keras损失函数中使用预先训练的TensorFlow网络

提问于
浏览
0

我有一个预先训练好的网,我想用它来评估我的Keras网中的损失 . 使用TensorFlow训练预训练的网络,我只想将其用作损失计算的一部分 .

我的自定义丢失功能的代码目前是:

def custom_loss_func(y_true, y_pred):
   # Get saliency of both true and pred
   sal_true = deep_gaze.get_saliency_map(y_true)
   sal_pred = deep_gaze.get_saliency_map(y_pred)

   return K.mean(K.square(sal_true-sal_pred))

deep_gaze是一个对象,用于管理我正在使用的外部预训练网络的访问权限 .

它是这样定义的:

class DeepGaze(object):
  CHECK_POINT = os.path.join(os.path.dirname(__file__), 'DeepGazeII.ckpt')  # DeepGaze II

def __init__(self):
    print('Loading Deep Gaze II...')

    with tf.Graph().as_default() as deep_gaze_graph:
        saver = tf.train.import_meta_graph('{}.meta'.format(self.CHECK_POINT))

        self.input_tensor = tf.get_collection('input_tensor')[0]
        self.log_density_wo_centerbias = tf.get_collection('log_density_wo_centerbias')[0]

    self.tf_session = tf.Session(graph=deep_gaze_graph)
    saver.restore(self.tf_session, self.CHECK_POINT)

    print('Deep Gaze II Loaded')

'''
Returns the saliency map of the input data. 
input format is a 4d array [batch_num, height, width, channel]
'''
def get_saliency_map(self, input_data):
    log_density_prediction = self.tf_session.run(self.log_density_wo_centerbias,
                                                 {self.input_tensor: input_data})

    return log_density_prediction

当我运行这个时,我得到错误:

TypeError:Feed的值不能是tf.Tensor对象 . 可接受的Feed值包括Python标量,字符串,列表,numpy ndarrays或TensorHandles .

我究竟做错了什么?有没有办法评估TensorFlow对象的网络来自不同的网络(由Keras使用TensorFlow后端制作) .

提前致谢 .

1 回答

  • 0

    有两个主要问题:

    • 当您使用 input_data=y_true 调用 get_saliency_map 时,您正在向另一个张量 self.input_tensor 提供张量 input_data ,这是无效的 . 而且,这些张量在图形创建时不具有值,而是定义最终将产生值的计算 .

    • 即使您可以从 get_saliency_map 获得输出,您的代码仍然无法正常工作,因为此函数会断开您的TensorFlow图形(它不会返回张量),并且所有逻辑都必须位于图形中 . 每个张量必须基于图中的其他可用张量来计算 .

    此问题的解决方案是在定义损失函数的图形中定义生成 self.log_density_wo_centerbias 的模型,直接使用张量 y_truey_pred 作为输入而不断开图形 .

相关问题