首页 文章

Tensorflow-Lite预训练模型在Android演示中不起作用

提问于
浏览
8

Tensorflow-Lite Android演示版可与其提供的原始模型配合使用:mobilenet_quant_v1_224.tflite . 见:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite

他们还在这里提供其他预训练的精简模型:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/models.md

但是,我从上面的链接中下载了一些较小的模型,例如mobilenet_v1_0.25_224.tflite,只需更改 ImageClassifier.java 中的 MODEL_PATH = "mobilenet_v1_0.25_224.tflite"; ,就可以在演示应用中用此模型替换原始模型 . 该应用程序崩溃:

12-11 12:52:34.222 17713-17729 /? E / AndroidRuntime:FATAL EXCEPTION:CameraBackground进程:android.example.com.tflitecamerademo,PID:17713 java.lang.IllegalArgumentException:无法获取输入维度 . 第0个输入应该有602112个字节,但是找到150528个字节 . 在org.tensorflow.lite.NativeInterpreterWrapper.getInputDims(本机方法)在org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:82)在org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:112)在组织.tensorflow.lite.Interpreter.run(Interpreter.java:93)在com.example.android.tflitecamerademo.ImageClassifier.classifyFrame(ImageClassifier.java:108)在com.example.android.tflitecamerademo.Camera2BasicFragment.classifyFrame(Camera2BasicFragment.java :663)在com.example.android.tflitecamerademo.Camera2BasicFragment.access $ 900(Camera2BasicFragment.java:69)在com.example.android.tflitecamerademo.Camera2BasicFragment $ 5.run(Camera2BasicFragment.java:558)在android.os.Handler . handleCallback(Handler.java:751)位于android.os.HandlerThread.run(HandlerThread.java)的android.os.Handler.dispatchMessage(Handler.java:95)android.os.Looper.loop(Looper.java:154) :61)

原因似乎是模型所需的输入尺寸是图像尺寸的四倍 . 所以我将 DIM_BATCH_SIZE = 1 修改为 DIM_BATCH_SIZE = 4 . 现在的错误是:

致命异常:CameraBackground工艺:android.example.com.tflitecamerademo,PID:18241 java.lang.IllegalArgumentException异常:不能与类型FLOAT32一个TensorFlowLite张量转换成类型的Java对象[[B(其是与TensorFlowLite类型UINT8兼容)org.tensorflow.lite.Tensor.copyTo(Tensor.java:36)org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:122)org.tensorflow.lite.Interpreter.run(Interpreter.java: 93)在com.example.android.tflitecamerademo.ImageClassifier.classifyFrame(ImageClassifier.java:108)在com.example.android.tflitecamerademo.Camera2BasicFragment.classifyFrame(Camera2BasicFragment.java:663)在com.example.android.tflitecamerademo.Camera2BasicFragment . 访问$ 900(Camera2BasicFragment.java:69)在com.example.android.tflitecamerademo.Camera2BasicFragment $ 5.run在android.os(Camera2BasicFragment.java:558)在android.os.Handler.handleCallback(Handler.java:751) . android.os.Lo中的Handler.dispatchMessage(Handler.java:95) oper.loop(Looper.java:154)在android.os.HandlerThread.run(HandlerThread.java:61)

My question is how to get a reduced-MobileNet tflite model to work with the TF-lite Android Demo.

(我实际上尝试了其他的东西,比如使用提供的工具将TF冻结图转换为TF-lite模型,甚至使用与https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md完全相同的示例代码,但转换后的tflite模型仍无法在Android Demo中使用 . )

2 回答

  • 1

    Tensorflow-Lite Android演示程序中包含的ImageClassifier.java需要一个 quantized 模型 . 截至目前,只有一种Mobilenets模型以量化形式提供:Mobilenet 1.0 224 Quant .

    要使用其他浮点模型,请从Tensorflow for Poets TF-Lite演示源交换ImageClassifier.java . 这是为 float 型号编写的 . https://github.com/googlecodelabs/tensorflow-for-poets-2/blob/master/android/tflite/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java

    做一个差异,你会发现在实现中有几个重要的区别 .

    另一个需要考虑的选择是使用TOCO将浮点模型转换为量化:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md

  • 4

    我也得到了与幼苗相同的错误 . 我为Mobilenet Float模型创建了一个新的Image分类器包装器 . 现在工作正常 . 您可以直接在图像分类器演示中添加此类,并使用它在Camera2BasicFragment中创建分类器

    classifier = new ImageClassifierFloatMobileNet(getActivity());
    

    下面是Mobilenet Float模型的Image分类器类包装器

    /**
     * This classifier works with the Float MobileNet model.
     */
    public class ImageClassifierFloatMobileNet extends ImageClassifier {
    
      /**
       * An array to hold inference results, to be feed into Tensorflow Lite as outputs.
       * This isn't part of the super class, because we need a primitive array here.
       */
      private float[][] labelProbArray = null;
    
      private static final int IMAGE_MEAN = 128;
      private static final float IMAGE_STD = 128.0f;
    
      /**
       * Initializes an {@code ImageClassifier}.
       *
       * @param activity
       */
      public ImageClassifierFloatMobileNet(Activity activity) throws IOException {
        super(activity);
        labelProbArray = new float[1][getNumLabels()];
      }
    
      @Override
      protected String getModelPath() {
        // you can download this file from
        // https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip
    //    return "mobilenet_quant_v1_224.tflite";
        return "retrained.tflite";
      }
    
      @Override
      protected String getLabelPath() {
    //    return "labels_mobilenet_quant_v1_224.txt";
        return "retrained_labels.txt";
      }
    
      @Override
      public int getImageSizeX() {
        return 224;
      }
    
      @Override
      public int getImageSizeY() {
        return 224;
      }
    
      @Override
      protected int getNumBytesPerChannel() {
        // the Float model uses a 4 bytes
        return 4;
      }
    
      @Override
      protected void addPixelValue(int val) {
        imgData.putFloat((((val >> 16) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
        imgData.putFloat((((val >> 8) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
        imgData.putFloat((((val) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
      }
    
      @Override
      protected float getProbability(int labelIndex) {
        return labelProbArray[0][labelIndex];
      }
    
      @Override
      protected void setProbability(int labelIndex, Number value) {
        labelProbArray[0][labelIndex] = value.byteValue();
      }
    
      @Override
      protected float getNormalizedProbability(int labelIndex) {
        return labelProbArray[0][labelIndex];
      }
    
      @Override
      protected void runInference() {
        tflite.run(imgData, labelProbArray);
      }
    }
    

相关问题