首页 文章

Python scikit-learn:导出训练有素的分类器

提问于
浏览
40

我正在使用基于scikit-learn的nolearn的DBN(深度信念网络) .

我已经构建了一个可以很好地对我的数据进行分类的网络,现在我有兴趣导出模型进行部署,但我不知道(我每次想要预测某些东西时都在训练DBN) . 在 matlab 中,我只是导出权重矩阵并将其导入另一台机器 .

有人知道如何导出要导入的模型/权重矩阵而无需再次训练整个模型吗?

3 回答

  • 58

    您可以使用:

    >>> from sklearn.externals import joblib
    >>> joblib.dump(clf, 'my_model.pkl', compress=9)
    

    然后,在预测服务器上:

    >>> from sklearn.externals import joblib
    >>> model_clone = joblib.load('my_model.pkl')
    

    这基本上是一个Python pickle,具有针对大型numpy数组的优化处理 . 它与常规泡菜w.r.t具有相同的局限性 . 代码更改:如果pickle对象的类结构发生更改,则可能无法再使用nolearn或scikit-learn的新版本对该对象进行unpickle .

    如果您想要长期稳健的存储模型参数的方法,您可能需要编写自己的IO层(例如使用二进制格式序列化工具,如协议缓冲区或avro或低效但可移植的text / json / xml表示,如PMML) .

  • 8

    pickling / unpickling的缺点是它只适用于匹配的python版本(主要版本,也可能是次要版本)和sklearn,joblib库版本 .

    机器学习模型还有其他描述性输出格式,例如由Data Mining Group开发的,例如预测模型标记语言(PMML)和可移植分析格式(PFA) . 在这两者中,PMML是much better supported .

    因此,您可以选择将模型从scikit-learn保存到PMML(例如使用sklearn2pmml),然后使用jpmml在java,spark或hive中部署和运行它(当然您有更多选择) .

  • 3

    scikit-learn文档中的3.4. Model persistence部分几乎涵盖了所有内容 .

    除了 sklearn.externals.joblib ogrisel指出,它还显示了如何使用常规的泡菜包:

    >>> from sklearn import svm
    >>> from sklearn import datasets
    >>> clf = svm.SVC()
    >>> iris = datasets.load_iris()
    >>> X, y = iris.data, iris.target
    >>> clf.fit(X, y)  
    SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,
      kernel='rbf', max_iter=-1, probability=False, random_state=None,
      shrinking=True, tol=0.001, verbose=False)
    
    >>> import pickle
    >>> s = pickle.dumps(clf)
    >>> clf2 = pickle.loads(s)
    >>> clf2.predict(X[0])
    array([0])
    >>> y[0]
    0
    

    并提供一些警告,例如在一个版本的scikit-learn中保存的模型可能无法在另一个版本中加载 .

相关问题