首页 文章

如何在TensorFlow import_graph_def期间更改输入的维度

提问于
浏览
4

我的情景:

  • 定义RNN模型结构并使用具有固定批次大小和序列长度的输入对其进行训练 .

  • 冻结模型(即将所有可训练变量转换为常量),产生 GraphDef ,其中包含在测试时需要使用模型的所有内容(通过 tf.graph_util.convert_variables_to_constants ) .

  • 通过 tf.import_graph_def 导入 GraphDef 并使用 input_map 参数替换输入 . 新输入需要具有任意批量大小和序列长度 .

问题:上述所有工作都有效,直到我将输入传递给使用批量大小或序列长度不同于训练时使用的原始大小的测试时间图 . 那时我得到这样的错误:

InvalidArgumentError (see above for traceback): ConcatOp : Dimensions of inputs should match: shape[0] = [1,5] vs. shape[1] = [2,7]
     [[Node: import/rnn/while/basic_rnn_cell/basic_rnn_cell_1/concat = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](import/rnn/while/TensorArrayReadV3, import/rnn/while/Identity_2, import/rnn/while/basic_rnn_cell/basic_rnn_cell_1/concat/axis)]]

为了说明和重现该问题,请考虑以下最小示例 .

  • v1 :创建具有任意批次大小和序列长度的图表 . 这很好但不幸的是我必须在训练时使用固定的批量大小和序列长度,并且必须在测试时使用任意批量大小和序列长度,所以我不能使用这种简单的方法 .

  • v2a :我们模拟创建具有固定批量大小(2)和序列长度(3)的训练时间图并冻结图形 .

  • v2ba :我们证明加载冻结模型不变仍然会产生相同的结果 .

  • v2bb :我们证明使用仍然使用固定批量大小和序列长度的替换输入加载冻结模型仍会产生相同的结果 .

  • v2bc :我们证明,使用任意批量大小和序列长度的替换输入加载冻结模型仍会产生相同的结果,只要输入根据原始批量大小和序列长度成形 . 它适用于 data ,但是失败了 data2 - 唯一的区别是前者的批量大小为2,后者的批量大小为1 .

Is it possible to change an RNN graph via the input_map argument to tf.import_graph_def such that the input no longer has a fixed batch size and sequence length?

以下代码适用于TensorFlow 1.1 RC2,可以与TensorFlow 1.0一起使用 .

import numpy
import tensorflow as tf
from tensorflow import graph_util as tf_graph_util
from tensorflow.contrib import rnn as tfc_rnn


def v1(data):
    with tf.Graph().as_default():
        tf.set_random_seed(1)
        x = tf.placeholder(tf.float32, shape=(None, None, 5))
        _, s = tf.nn.dynamic_rnn(tfc_rnn.BasicRNNCell(7), x, dtype=tf.float32)

        with tf.Session() as session:
            session.run(tf.global_variables_initializer())
            print session.run(s, feed_dict={x: data})


def v2a():
    with tf.Graph().as_default():
        tf.set_random_seed(1)
        x = tf.placeholder(tf.float32, shape=(2, 3, 5), name="x")
        _, s = tf.nn.dynamic_rnn(tfc_rnn.BasicRNNCell(7), x, dtype=tf.float32)

        with tf.Session() as session:
            session.run(tf.global_variables_initializer())
            return tf_graph_util.convert_variables_to_constants(
                session, session.graph_def, [s.op.name]), s.name


def v2ba((graph_def, s_name), data):
    with tf.Graph().as_default():
        x, s = tf.import_graph_def(graph_def,
                                   return_elements=["x:0", s_name])

        with tf.Session() as session:
            print '2ba', session.run(s, feed_dict={x: data})


def v2bb((graph_def, s_name), data):
    with tf.Graph().as_default():
        x = tf.placeholder(tf.float32, shape=(2, 3, 5))
        [s] = tf.import_graph_def(graph_def, input_map={"x:0": x},
                                  return_elements=[s_name])

        with tf.Session() as session:
            print '2bb', session.run(s, feed_dict={x: data})


def v2bc((graph_def, s_name), data):
    with tf.Graph().as_default():
        x = tf.placeholder(tf.float32, shape=(None, None, 5))
        [s] = tf.import_graph_def(graph_def, input_map={"x:0": x},
                                  return_elements=[s_name])

        with tf.Session() as session:
            print '2bc', session.run(s, feed_dict={x: data})


def main():
    data1 = numpy.random.random_sample((2, 3, 5))
    data2 = numpy.random.random_sample((1, 3, 5))
    v1(data1)
    model = v2a()
    v2ba(model, data1)
    v2bb(model, data1)
    v2bc(model, data1)
    v2bc(model, data2)


if __name__ == "__main__":
    main()

1 回答

  • 0

    这是一个持续一段时间的张量流中的错误:您无法可靠地将具有已定义形状的占位符替换为具有(部分)未定义形状的另一个占位符 .

    你会发现一个相关的问题here,显然没有引起太多关注 .

相关问题