首页 文章

在sklearn估算器上调用set_params()时,super()出错

提问于
浏览
0

我正在尝试根据配置文件加载和配置scikit-learn估算器 . 该文件具有估算器类路径和名称以及参数的字典 . 我的计划是使用pydoc.locate()加载带有默认参数的估计器,然后使用参数的dict在估算器上调用set_params() . 但是我收到以下错误:

import pydoc
sgd = pydoc.locate('sklearn.linear_model.SGDClassifier')
print('{} {}'.format(type(sgd), sgd))
p_sgd = {'alpha':.1234}
sgd.set_params(p_sgd)
<class 'abc.ABCMeta'> <class 'sklearn.linear_model.stochastic_gradient.SGDClassifier'>
Traceback (most recent call last):
  File "<input>", line 5, in <module>
  File "/Users/doug/.pyenv/versions/learning-3.4.3/lib/python3.4/site-packages/sklearn/linear_model/stochastic_gradient.py", line 83, in set_params
    super(BaseSGD, self).set_params(*args, **kwargs)
TypeError: super(type, obj): obj must be an instance or subtype of type

我尝试两次使用相同的“加载和设置”方法 . 第一次,我按名称加载文本矢量化器并设置其参数 . 文本向量化程序是我基于HashingVectorizer创建的子类 . 它不会产生此错误,但似乎也没有通过调用set_params()来更改(即参数值保持默认值) . 第二次是具有我描述的行为的分类器 .

我之前在提供给GridSearchCV的Pipeline中运行它们时,我使用了pydoc.locate()来加载估算器 . 这工作得很好 . 在这种情况下,我使用默认的估计器构造函数构造管道,然后GridSearchCV让Pipeline在遍历参数网格时对每个估算器调用set_params() . 通过Pipeline和GridSearchCV源看,他们使用set_params()被称为set_params(** param_dict) . 如果我尝试,我会得到一个不同的错误 .

import pydoc
sgd = pydoc.locate('sklearn.linear_model.SGDClassifier')
p_sgd = {'alpha':.1234}
sgd.set_params(**p_sgd)
Traceback (most recent call last):
  File "<input>", line 4, in <module>
TypeError: set_params() missing 1 required positional argument: 'self'

最后一点,我已经读过原始错误(TypeError:super(type,obj)...)已被追踪到多次加载模块的问题 . 事实上我在这些尝试调用之前使用了pydoc.locate()(为了追踪他们的父母,并找出谁是矢量化器与分类器) . 我可能能够解决这个问题,但是之前仍然会尝试加载这些模块,因为我在循环中运行以根据配置文件训练多个模型 .

我正在使用Python 3.4

1 回答

  • 0

    正如user2357112指出的那样,我错误地只加载了类,而不是构造它 . 我更改了代码以在没有参数的情况下调用返回类的构造函数,然后使用我期望的参数语法调用set_params( p_sgd) .

    import pydoc
    sgd = pydoc.locate('sklearn.linear_model.SGDClassifier')()
    p_sgd = {'alpha':.1234}
    sgd.set_params(**p_sgd)
    sgd
    SGDClassifier(alpha=0.1234, average=False, class_weight=None, epsilon=0.1, eta0=0.0, fit_intercept=True, l1_ratio=0.15, learning_rate='optimal', loss='hinge', n_iter=5, n_jobs=1, penalty='l2', power_t=0.5, random_state=None, shuffle=True, verbose=0, warm_start=False)
    

相关问题