我试图在我的Android应用程序上运行Tensorflow模型,但是与在桌面上运行Python时相比,相同的训练模型会产生不同的结果(错误的推断) .
该模型是一个简单的连续CNN来识别字符,就像this number plate recognition network一样,减去窗口,因为我的模型已经将字符裁剪到位 .
我有:
-
模型保存在protobuf(.pb)文件中 - 在Keras上在Python / Linux GPU上建模和训练
-
推理在纯Tensorflow上的另一台计算机上进行了测试,以确保Keras不是罪魁祸首 . 在这里,结果如预期 .
-
Tensorflow 1.3.0正在Python和Android上使用 . 在PIP上安装Python和jcenter在Android上安装 .
-
Android上的结果与预期结果不符 .
-
输入为129 * 45 RGB图像,因此为129 * 45 * 3阵列,输出为4 * 36阵列(表示0-9和a-z中的4个字符) .
我使用this code将Keras模型保存为.pb文件 .
Python code ,这按预期工作:
test_image = [ndimage.imread("test_image.png", mode="RGB").astype(float)/255]
imTensor = np.asarray(test_image)
def load_graph(model_file):
graph = tf.Graph()
graph_def = tf.GraphDef()
with open(model_file, "rb") as f:
graph_def.ParseFromString(f.read())
with graph.as_default():
tf.import_graph_def(graph_def)
return graph
graph=load_graph("model.pb")
with tf.Session(graph=graph) as sess:
input_operation = graph.get_operation_by_name("import/conv2d_1_input")
output_operation = graph.get_operation_by_name("import/output_node0")
results = sess.run(output_operation.outputs[0],
{input_operation.outputs[0]: imTensor})
Android code ,基于this example;这给出了看似随机的结果:
Bitmap bitmap;
try {
InputStream stream = getAssets().open("test_image.png");
bitmap = BitmapFactory.decodeStream(stream);
} catch (IOException e) {
e.printStackTrace();
}
inferenceInterface = new TensorFlowInferenceInterface(context.getAssets(), "model.pb");
int[] intValues = new int[129*45];
float[] floatValues = new float[129*45*3];
String outputName = "output_node0";
String[] outputNodes = new String[]{outputName};
float[] outputs = new float[4*36];
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
for (int i = 0; i < intValues.length; ++i) {
final int val = intValues[i];
floatValues[i * 3 + 0] = ((val >> 16) & 0xFF) / 255;
floatValues[i * 3 + 1] = ((val >> 8) & 0xFF) / 255;
floatValues[i * 3 + 2] = (val & 0xFF) / 255;
}
inferenceInterface.feed("conv2d_1_input", floatValues, 1, 45, 129, 3);
inferenceInterface.run(outputNodes, false);
inferenceInterface.fetch(outputName, outputs);
任何帮助是极大的赞赏!
1 回答
一个问题在于:
其中RGB值除以整数,从而产生整数结果(即每次为0) .
此外,即使用
255.0
执行产生0到1.0之间的浮点也可能会产生一个问题,因为这些值不会像在Natura中一样分布在投影空间(0..1)中 . 为了解释这一点:传感器域中的值255(例如,R值)意味着测量信号的自然值落在"255"桶的某处,这是整个能量/强度范围/等 . 将此值映射到1.0很可能会减少其范围的一半,因为后续计算可能会在1.0的最大乘数处饱和,这实际上只是-1 / 256桶的中点 . 因此,转换可能更准确地映射到0..1范围的256桶分区的中点:但这只是我身边的猜测 .