使用Tensorflow Java API进行文本分类
我想用Java API来预测Python Tensorflow API中我训练过的模型,但我有问题要用Java来预测句子 .
我已经训练,保存,导出和加载我的模型在python adn一切正常 .
但是,当我尝试使用Java API将句子插入到我训练过的模型时,我收到了以下错误:
2018-09-26 16:13:27.094248:I tensorflow / cc / saved_model / loader.cc:161]恢复SavedModel包 . 2018-09-26 16:13:27.103857:I tensorflow / cc / saved_model / loader.cc:196]在SavedModel包上运行LegacyInitOp . 2018-09-26 16:13:27.107908:I tensorflow / cc / saved_model / loader.cc:291]标签的SavedModel加载;现状:成功 . 花了34162微秒 . java.lang.IllegalArgumentException:无法创建java.lang.String类型的Tensors
我在Java中的代码非常简单:
SavedModelBundle smb = SavedModelBundle.load("/textcnn/export/", "serve");
Session s = smb.session();
String text = "hello world";
Tensor input = Tensor.create(text);
List<Tensor<?>> result = s.runner()
.feed("input_x", input)
.feed("keep_prob", prob)
.fetch("score/pred_scores")
.run();
我读到了它,但我不清楚如何在Java API上插入一个字符串 .
Hint(from my .pb file):
signature_def {
key: "predict_images"
value {
inputs {
key: "images"
value {
name: "input_x:0"
dtype: DT_STRING
tensor_shape {
unknown_rank: true
}
}
}
inputs {
key: "prob"
value {
name: "keep_prob:0"
dtype: DT_FLOAT
tensor_shape {
unknown_rank: true
}
}
}
outputs {
key: "scores"
value {
name: "score/pred_scores:0"
dtype: DT_FLOAT
tensor_shape {
dim {
size: -1
}
dim {
size: 10
}
}
}
}
method_name: "tensorflow/serving/predict"
}
}
感谢您的任何帮助!