我正在使用基于scikit-learn的nolearn的DBN(深度信念网络) .
我已经构建了一个可以很好地对我的数据进行分类的网络,现在我有兴趣导出模型进行部署,但我不知道(我每次想要预测某些东西时都在训练DBN) . 在 matlab 中,我只是导出权重矩阵并将其导入另一台机器 .
matlab
有人知道如何导出要导入的模型/权重矩阵而无需再次训练整个模型吗?
您可以使用:
>>> from sklearn.externals import joblib >>> joblib.dump(clf, 'my_model.pkl', compress=9)
然后,在预测服务器上:
>>> from sklearn.externals import joblib >>> model_clone = joblib.load('my_model.pkl')
这基本上是一个Python pickle,具有针对大型numpy数组的优化处理 . 它与常规泡菜w.r.t具有相同的局限性 . 代码更改:如果pickle对象的类结构发生更改,则可能无法再使用nolearn或scikit-learn的新版本对该对象进行unpickle .
如果您想要长期稳健的存储模型参数的方法,您可能需要编写自己的IO层(例如使用二进制格式序列化工具,如协议缓冲区或avro或低效但可移植的text / json / xml表示,如PMML) .
pickling / unpickling的缺点是它只适用于匹配的python版本(主要版本,也可能是次要版本)和sklearn,joblib库版本 .
机器学习模型还有其他描述性输出格式,例如由Data Mining Group开发的,例如预测模型标记语言(PMML)和可移植分析格式(PFA) . 在这两者中,PMML是much better supported .
因此,您可以选择将模型从scikit-learn保存到PMML(例如使用sklearn2pmml),然后使用jpmml在java,spark或hive中部署和运行它(当然您有更多选择) .
scikit-learn文档中的3.4. Model persistence部分几乎涵盖了所有内容 .
除了 sklearn.externals.joblib ogrisel指出,它还显示了如何使用常规的泡菜包:
sklearn.externals.joblib
>>> from sklearn import svm >>> from sklearn import datasets >>> clf = svm.SVC() >>> iris = datasets.load_iris() >>> X, y = iris.data, iris.target >>> clf.fit(X, y) SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0, kernel='rbf', max_iter=-1, probability=False, random_state=None, shrinking=True, tol=0.001, verbose=False) >>> import pickle >>> s = pickle.dumps(clf) >>> clf2 = pickle.loads(s) >>> clf2.predict(X[0]) array([0]) >>> y[0] 0
并提供一些警告,例如在一个版本的scikit-learn中保存的模型可能无法在另一个版本中加载 .
3 回答
您可以使用:
然后,在预测服务器上:
这基本上是一个Python pickle,具有针对大型numpy数组的优化处理 . 它与常规泡菜w.r.t具有相同的局限性 . 代码更改:如果pickle对象的类结构发生更改,则可能无法再使用nolearn或scikit-learn的新版本对该对象进行unpickle .
如果您想要长期稳健的存储模型参数的方法,您可能需要编写自己的IO层(例如使用二进制格式序列化工具,如协议缓冲区或avro或低效但可移植的text / json / xml表示,如PMML) .
pickling / unpickling的缺点是它只适用于匹配的python版本(主要版本,也可能是次要版本)和sklearn,joblib库版本 .
机器学习模型还有其他描述性输出格式,例如由Data Mining Group开发的,例如预测模型标记语言(PMML)和可移植分析格式(PFA) . 在这两者中,PMML是much better supported .
因此,您可以选择将模型从scikit-learn保存到PMML(例如使用sklearn2pmml),然后使用jpmml在java,spark或hive中部署和运行它(当然您有更多选择) .
scikit-learn文档中的3.4. Model persistence部分几乎涵盖了所有内容 .
除了
sklearn.externals.joblib
ogrisel指出,它还显示了如何使用常规的泡菜包:并提供一些警告,例如在一个版本的scikit-learn中保存的模型可能无法在另一个版本中加载 .