我现在正在使用tensorflow(python)来训练我的模型,并希望在线使用tensorflow(java)来推断结果 .
计算图有一个返回 shape[1,16] 结果的操作,张量中的每个元素都是一个字符串 . 现在我想将结果转换为整个字符串 .
我创建一个ByteBuffer,并调用Tensor.writeTo在缓冲区中写入数据 . 但是当我解码最后的缓冲区时,它在 Headers 中有一些意想不到的字符,我猜最后的字节可能包含一些张量元信息 .
Tensor predictedTensor = result.get(0);
ByteBuffer bb = ByteBuffer.allocate(predictedTensor.numBytes());
predictedTensor.writeTo(bb);
String predictedTokens = null;
byte[] bbArray = bb.array();
predictedTokens = new String(bbArray, "UTF-8");
结果是这样的:第一部分是一些不正确的代码,最后一部分是正确的 .
& * ? * C J M X & * ? * C J M X hello,world!
我想也许带有形状的Tensor(1,16)在字节中有元信息,但我不知道如何获取我需要的部分 .
任何人都可以知道如何在java tensorflow接口中将多维张量转换为java字符串?
2 回答
我为此找到了一个workaroud!训练模型时,我在张量上调用 tf.reduce_join 形状(1,16)得到一个标量 . 当在java语言中进行推理时,我只需获取该标量节点,并调用 tensor.byteValue() 来获取张量字节 . 它将返回没有 Headers 代码的纯结果 .
如果操作的结果具有形状
[1, 16]
,则表示它生成16个不同的字符串,而不是一个字符串 .最近才添加了对Java中多维字符串张量的支持(github commit),并且不包含在TensorFlow 1.3及更早版本的预构建二进制文件中 . 您将不得不从源代码构建或等待TensorFlow 1.4版本 .
有了这个功能,你应该可以用这样的东西解码
(1, 16)
形状的张量:如果你真的需要一个字符串,那么你可以使用
tf.reduce_join
让模型将16个字符串合并为一个,然后提取标量 .