对不起新手问题 . 我试过寻找例子,但大多数时候它是伪代码/片段我无法运行,很难说是什么,特别是对于所有不同的TF工作流程 .
我试图将预先训练的TensorFlow保存模型ResNet-50 v2(fp32)转换为量化的TensorFlow Lite文件,并有两个问题:
-
批量大小似乎固定为64.即输入张量为64x224x224x3 . 我希望它是1x224x224x3
-
加载/转换保存模型时出现错误消息,无论是使用tflite_convert / toco还是使用resnet_v2_50()
即使有错误消息,模型一目了然也是正确的,所以我最关心的是批量修改 .
保存的模型我尝试转换:
我尝试使用的Checkpoint数据和.py模型生成一个新的保存模型:
我如何在bash中将其转换为.tflite:
tflite_convert --output_file resnet_imagenet_v2_uint8_20181001.tflite --saved_model_dir . --post_training_quantize
这导致64x224x224x3的合理模型 . 尽管有错误,这可能适用于Android / iOS(尚未尝试过),但我正在尝试在自定义平台上使用它进行试验 .
我尝试使用目标输入形状生成保存模型的脚本:
import tensorflow as tf
import numpy as np
from tensorflow.contrib.slim.nets import resnet_v2
def main():
# Directory containing resnet_v2_50.ckpt
ckpt_dir = "/media/resnet_v2_50_2017_04_14/"
with tf.contrib.slim.arg_scope(resnet_v2.resnet_arg_scope()):
input_tensor = tf.placeholder(tf.float32, shape=[1,224,224,3], name="input_tensor")
output_tensor = tf.placeholder(tf.float32, shape=[1,1000])
# Create model
# Generates errors for all Conv2D nodes like:
# 2018-10-19 16:41:41.393976: E tensorflow/core/framework/node_def_util.cc:110] Error in the node: {{node resnet_v2_50/conv1/Conv2D}} = Conv2D[T=DT_FLOAT, data_format="NHWC", dilations=[1, 1, 1, 1], padding="VALID", strides=[1, 2, 2, 1], use_cudnn_on_gpu=true](resnet_v2_50/Pad, resnet_v2_50/conv1/weights/read)
net, end_points = resnet_v2.resnet_v2_50(input_tensor, 1000)
# Load checkpoint data
sv = tf.train.Supervisor(logdir=ckpt_dir)
with sv.managed_session() as sess:
# Allows saving, but unexpected results:
# sess.graph._unsafe_unfinalize()
# Below call fails with:
# RuntimeError: Graph is finalized and cannot be modified.
tf.saved_model.simple_save(
sess,
"./export",
inputs={"input_tensor": input_tensor},
outputs={"resnet_v2_50/predictions/Softmax": output_tensor}
)
main()
我想我已经在网上找到了所以我有点迷失了 . 使用 sess.graph._unsafe_unfinalize()
允许simple_save()运行,在export /下创建一个.pb和variables /目录,但是当我在Netron中查看它时,我发现一个模型查看器,节点比提供的保存模型(批量大小为64)多得多我从tensorflow模型repo下载 . 3395 vs 1930 .
尝试转换此模型无论如何都会导致此错误:
$ tflite_convert --output_file export.tflite --saved_model_dir . --post_training_quantize
.
.
.
Traceback (most recent call last):
File "/home/tfuser/venv/bin/tflite_convert", line 11, in <module>
sys.exit(main())
File "/home/tfuser/venv/lib/python3.6/site-packages/tensorflow/contrib/lite/python/tflite_convert.py", line 412, in main
app.run(main=run_main, argv=sys.argv[:1])
File "/home/tfuser/venv/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 125, in run
_sys.exit(main(argv))
File "/home/tfuser/venv/lib/python3.6/site-packages/tensorflow/contrib/lite/python/tflite_convert.py", line 408, in run_main
_convert_model(tflite_flags)
File "/home/tfuser/venv/lib/python3.6/site-packages/tensorflow/contrib/lite/python/tflite_convert.py", line 162, in _convert_model
output_data = converter.convert()
File "/home/tfuser/venv/lib/python3.6/site-packages/tensorflow/contrib/lite/python/lite.py", line 453, in convert
**converter_kwargs)
File "/home/tfuser/venv/lib/python3.6/site-packages/tensorflow/contrib/lite/python/convert.py", line 370, in toco_convert_impl
input_data.SerializeToString())
File "/home/tfuser/venv/lib/python3.6/site-packages/tensorflow/contrib/lite/python/convert.py", line 149, in toco_convert_protos
"TOCO failed see console for info.\n%s\n%s\n" % (stdout, stderr))
RuntimeError: TOCO failed see console for info.
b'2018-10-19 17:00:58.673690: F tensorflow/contrib/lite/toco/tooling_util.cc:886] Check failed: GetOpWithInput(model, input_array.name()) Specified input array "input_tensor" is not consumed by any op in this graph. Is it a typo? To silence this message, pass this flag: allow_nonexistent_arrays\n'
None
所以看起来我对input_tensor的使用是错误的并且正在创建许多额外的节点?
我的设置:
-
CentOS 7
-
Python 3.6.6
-
TensorFlow Nightly 1.13.0.dev20181018(CPU)
我寻求帮助的一些页面:
Any ideas on how to achieve this? A better strategy or a fix to my script would be very appreciated and welcome.
次要问题:
-
simple_save()的文档说输入和输出应该是张量到张量的映射 .
-
张量是否意味着tf.placeholder?还是一个numpy阵列?还是tf.Variable?
-
对于字符串,它是否期望输入和输出张量的名称?所以"input_tensor"和"resnet_v2_50/predictions/Softmax"是正确的吗?
-
我需要在某个时候拨打
sess.run(tf.global_variables_initializer())
吗?如果我尝试在上面的脚本中调用它,它表示图形已完成 .