首页 文章

如何从Java中的Example对象创建Tensor?

提问于
浏览
2

我的用例:我正在尝试使用libtensorflow_jni在我们现有的JVM服务中为python训练的模型提供服务 .

现在我可以使用 SavedModelBundle.load() 加载模型 . 但我发现很难将请求提供给模型 . 由于我的用户请求不仅仅是标量矩阵,而是功能映射,例如:

{'gender':1, 'age': 20, 'country': 100, other features ...}

通过搜索张量流教程,我看到Example协议缓冲区可能适合这里,因为它基本上包含一系列功能 . 但我不知道如何将其转换为Java Tensor对象 .

如果我直接使用序列化的Example对象创建Tensor,TensorFlow运行时似乎对数据类型不满意 . 例如,我做了以下,

Tensor inputTensor = Tensor.create(example.toByteArray());
s.runner().feed(inputTensorName, inputTensor).fetch(outputTensorName).run().get(0);

我将得到一个IllegalArgumentException:

java.lang.IllegalArgumentException: Expected serialized to be a vector, got shape: []

如果你碰巧知道或有相同的使用案例,你们能否解释一下如何从这里前进?

谢谢!

1 回答

  • 1

    看看你的错误信息,看来问题是你的模型期望一个字符串张量向量(很可能对应于一批序列化的 Example 协议缓冲区消息,可能来自tf.parse_example),但是你正在给它一个标量字符串张量 .

    不幸的是,直到issue #8531被解决,Java API没有办法创建除标量之外的字符串 Tensor . 一旦问题得到解决,事情会变得更容易 .

    同时,你可以通过构造一个TensorFlow“模型”来将你的标量字符串转换为大小为1的向量来解决这个问题 . 这可以通过以下方式完成:

    // A TensorFlow "model" that reshapes a string scalar into a vector.
    // Should be much prettier once https://github.com/tensorflow/tensorflow/issues/7149
    // is resolved.
    private static class Reshaper implements AutoCloseable {
      Reshaper() {
        this.graph = new Graph();
        this.session = new Session(graph);
        this.in =
            this.graph.opBuilder("Placeholder", "in")
                .setAttr("dtype", DataType.STRING)
                .build()
                .output(0);
        try (Tensor shape = Tensor.create(new int[] {1})) {
          Output vectorShape =
              this.graph.opBuilder("Const", "vector_shape")
                  .setAttr("dtype", shape.dataType())
                  .setAttr("value", shape)
                  .build()
                  .output(0);
          this.out =
              this.graph.opBuilder("Reshape", "out").addInput(in).addInput(vectorShape).build().output(0);
        }
      }
    
      @Override
      public void close() {
        this.session.close();
        this.graph.close();
      }
    
      public Tensor vector(Tensor input) {
        return this.session.runner().feed(this.in, input).fetch(this.out).run().get(0);
      }
    
      private final Graph graph;
      private final Session session;
      private final Output in;
      private final Output out;
    }
    

    通过上面的内容,您可以将示例原型张量转换为矢量并将其输入到您感兴趣的模型中,如下所示:

    Tensor inputTensor = null;
    try (Tensor scalar = Tensor.create(example.toByteArray())) {
      inputTensor = reshaper.vector(scalar);
    }
    s.runner().feed(inputTensorName, inputTensor).fetch(outputTensorName).run().get(0);
    

    有关详细信息,请see this example on github

    希望有所帮助!

相关问题