首页 文章

TensorFlow REST前端但不是TensorFlow服务[关闭]

提问于
浏览
22

我想部署一个简单的TensorFlow模型,并在像Flask这样的REST服务中运行它 . 到目前为止还没有找到github或这里的好例子 .

我还没有准备好使用其他帖子中建议的TF服务,它是谷歌的完美解决方案,但它对我的任务有点过分用gRPC,bazel,C编码,protobuf ......

3 回答

  • 3

    有不同的方法来做到这一点 . 纯粹地说,使用张量流不是很灵活,但相对简单 . 此方法的缺点是您必须重建图形并初始化恢复模型的代码中的变量 . tensorflow skflow/contrib learn中显示的方式更优雅,但目前似乎没有功能,文档已过时 .

    我在github here上放了一个简短的例子,它显示了如何将GET或POST参数命名为一个REST部署的tensorflow模型 .

    然后主代码在一个函数中,该函数根据POST / GET数据获取字典:

    @app.route('/model', methods=['GET', 'POST'])
    @parse_postget
    def apply_model(d):
        tf.reset_default_graph()
        with tf.Session() as session:
            n = 1
            x = tf.placeholder(tf.float32, [n], name='x')
            y = tf.placeholder(tf.float32, [n], name='y')
            m = tf.Variable([1.0], name='m')
            b = tf.Variable([1.0], name='b')
            y = tf.add(tf.mul(m, x), b) # fit y_i = m * x_i + b
            y_act = tf.placeholder(tf.float32, [n], name='y_')
            error = tf.sqrt((y - y_act) * (y - y_act))
            train_step = tf.train.AdamOptimizer(0.05).minimize(error)
    
            feed_dict = {x: np.array([float(d['x_in'])]), y_act: np.array([float(d['y_star'])])}
            saver = tf.train.Saver()
            saver.restore(session, 'linear.chk')
            y_i, _, _ = session.run([y, m, b], feed_dict)
        return jsonify(output=float(y_i))
    
  • 7

    这个github project显示了恢复模型检查点和使用Flask的工作示例 .

    @app.route('/api/mnist', methods=['POST'])
    def mnist():
        input = ((255 - np.array(request.json, dtype=np.uint8)) / 255.0).reshape(1, 784)
        output1 = simple(input)
        output2 = convolutional(input)
        return jsonify(results=[output1, output2])
    

    在线demo似乎很快 .

  • 3

    我不喜欢在flask restful文件中添加大量数据/模型处理代码 . 我通常有tf模型类等等 . 即它可能是这样的:

    # model init, loading data
    cifar10_recognizer = Cifar10_Recognizer()
    cifar10_recognizer.load('data/c10_model.ckpt')
    
    @app.route('/tf/api/v1/SomePath', methods=['GET', 'POST'])
    def upload():
        X = []
        if request.method == 'POST':
            if 'photo' in request.files:
                # place for uploading process workaround, obtaining input for tf
                X = generate_X_c10(f)
    
            if len(X) != 0:
                # designing desired result here
                answer = np.squeeze(cifar10_recognizer.predict(X))
                top3 = (-answer).argsort()[:3]
                res = ([cifar10_labels[i] for i in top3], [answer[i] for i in top3])
    
                # you can simply print this to console
                # return 'Prediction answer: {}'.format(res)
    
                # or generate some html with result
                return fk.render_template('demos/c10_show_result.html',
                                          name=file,
                                          result=res)
    
        if request.method == 'GET':
            # in html I have simple form to upload img file
            return fk.render_template('demos/c10_classifier.html')
    

    cifar10_recognizer.predict(X)是简单的func,它在tf会话中运行预测操作:

    def predict(self, image):
            logits = self.sess.run(self.model, feed_dict={self.input: image})
            return logits
    

    附:从文件中保存/恢复模型是一个非常漫长的过程,在提供post / get请求时尽量避免这种情况

相关问题