首页 文章

修正张量流检验点中张量的形状

提问于
浏览
3

我有一个张量流检查点,我可以在使用常规例程 tf.train.Saver()saver.restore(session, 'my_checkpoint.ckpt') 重新定义与之对应的图形后加载 .

但是,现在,我想修改网络的第一层以接受形状输入 [200, 200, 1] 而不是 [200, 200, 10] .

为此,我想通过在第三维上求和来修改对应于第一层的张量的形状,从 [3, 3, 10, 32] (3x3内核,10个输入通道,32个输出通道)到 [3, 3, 1, 32] .

我怎么能这样做?

1 回答

  • 1

    我找到了一种方法,但不是那么简单 . 给定一个检查点,我们可以将它转换为序列化的numpy数组(或者我们可能认为适合保存numpy数组字典的任何其他格式),如下所示:

    checkpoint = {}
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, 'my_checkpoint.ckpt')
        for x in tf.global_variables():
            checkpoint[x.name] = x.eval()
        np.save('checkpoint.npy', checkpoint)
    

    可能有一些例外要处理,但让我们保持代码简单 .

    然后,我们可以在numpy数组上执行我们喜欢的任何操作:

    checkpoint = np.load('checkpoint.npy')
    checkpoint = ...
    np.save('checkpoint.npy', checkpoint)
    

    最后,我们可以在构建图形后手动加载权重:

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        checkpoint = np.load('checkpoint.npy').item()
        for key, data in checkpoint.iteritems():
            var_scope = ... # to be extracted from key
            var_name = ...  # 
            with tf.variable_scope(var_scope, reuse=True):
                var = tf.get_variable(var_name)
                sess.run(var.assign(data))
    

    如果有一个更简单的方法,我全都耳朵!

相关问题