首页 文章

tensorflow java模型推理将fetched tensor转换为string?

提问于
浏览
1

我现在正在使用tensorflow(python)来训练我的模型,并希望在线使用tensorflow(java)来推断结果 .

计算图有一个返回 shape[1,16] 结果的操作,张量中的每个元素都是一个字符串 . 现在我想将结果转换为整个字符串 .

我创建一个ByteBuffer,并调用Tensor.writeTo在缓冲区中写入数据 . 但是当我解码最后的缓冲区时,它在 Headers 中有一些意想不到的字符,我猜最后的字节可能包含一些张量元信息 .

Tensor predictedTensor = result.get(0);
ByteBuffer bb = ByteBuffer.allocate(predictedTensor.numBytes());
predictedTensor.writeTo(bb);
String predictedTokens = null;
byte[] bbArray = bb.array();
predictedTokens = new String(bbArray, "UTF-8");

结果是这样的:第一部分是一些不正确的代码,最后一部分是正确的 .

& *  ? *  C J M X & *  ? *  C J M X hello,world!

我想也许带有形状的Tensor(1,16)在字节中有元信息,但我不知道如何获取我需要的部分 .

任何人都可以知道如何在java tensorflow接口中将多维张量转换为java字符串?

2 回答

  • 0

    我为此找到了一个workaroud!训练模型时,我在张量上调用 tf.reduce_join 形状(1,16)得到一个标量 . 当在java语言中进行推理时,我只需获取该标量节点,并调用 tensor.byteValue() 来获取张量字节 . 它将返回没有 Headers 代码的纯结果 .

  • 0

    如果操作的结果具有形状 [1, 16] ,则表示它生成16个不同的字符串,而不是一个字符串 .

    最近才添加了对Java中多维字符串张量的支持(github commit),并且不包含在TensorFlow 1.3及更早版本的预构建二进制文件中 . 您将不得不从源代码构建或等待TensorFlow 1.4版本 .

    有了这个功能,你应该可以用这样的东西解码 (1, 16) 形状的张量:

    Tensor predictedTensor = result.get(0);
    byte[][][] predictedTokenBytes = predictedTensor.copyTo(new byte[1][16][]);
    String[] predictedTokens = new String[16];
    for (int i = 0; i < 16; ++i) {
      // This works under the assumption that the model is actually
      // producing UTF-8 strings    
      predictedTokens[i] = new String(predictedTokenBytes[0][i], "UTF-8");
    }
    

    如果你真的需要一个字符串,那么你可以使用 tf.reduce_join 让模型将16个字符串合并为一个,然后提取标量 .

相关问题