使用Tensorflow Java API进行文本分类

我想用Java API来预测Python Tensorflow API中我训练过的模型,但我有问题要用Java来预测句子 .

我已经训练,保存,导出和加载我的模型在python adn一切正常 .

但是,当我尝试使用Java API将句子插入到我训练过的模型时,我收到了以下错误:

2018-09-26 16:13:27.094248:I tensorflow / cc / saved_model / loader.cc:161]恢复SavedModel包 . 2018-09-26 16:13:27.103857:I tensorflow / cc / saved_model / loader.cc:196]在SavedModel包上运行LegacyInitOp . 2018-09-26 16:13:27.107908:I tensorflow / cc / saved_model / loader.cc:291]标签的SavedModel加载;现状:成功 . 花了34162微秒 . java.lang.IllegalArgumentException:无法创建java.lang.String类型的Tensors

我在Java中的代码非常简单:

SavedModelBundle smb = SavedModelBundle.load("/textcnn/export/", "serve");
Session s = smb.session();
String text = "hello world";
Tensor input =  Tensor.create(text);
List<Tensor<?>> result = s.runner()
.feed("input_x", input)
.feed("keep_prob", prob)
.fetch("score/pred_scores")
.run();

我读到了它,但我不清楚如何在Java API上插入一个字符串 .

Hint(from my .pb file):

signature_def {
    key: "predict_images"
    value {
      inputs {
        key: "images"
        value {
          name: "input_x:0"
          dtype: DT_STRING
          tensor_shape {
            unknown_rank: true
          }
        }
      }
      inputs {
        key: "prob"
        value {
          name: "keep_prob:0"
          dtype: DT_FLOAT
          tensor_shape {
            unknown_rank: true
          }
        }
      }
      outputs {
        key: "scores"
        value {
          name: "score/pred_scores:0"
          dtype: DT_FLOAT
          tensor_shape {
            dim {
              size: -1
            }
            dim {
              size: 10
            }
          }
        }
      }
      method_name: "tensorflow/serving/predict"
    }
  }

感谢您的任何帮助!