首页 文章

关于scikit-learn Class Instance的Monkey-Patching Magic方法

提问于
浏览
3

我正在尝试构建一个名为 SafeModel 的工厂类,其 generate 方法接受scikit-learn类的实例,更改其某些属性,并返回相同的实例 . 具体来说,对于此示例,我想访问返回模型的 coef_ 属性,在案例1中,如果基础scikit-learn类包含 coef_ ,则返回该类' coef_ ,并在案例2中)如果底层scikit -learn类包含 feature_importances_ ,返回该类' feature_importances .

我对Python类实例的猴子修补魔术方法的成功较少 . 我的案例的警告是:属性 coef_feature_importances 永远不会在scikit-learn类实例化时定义;相反,它们仅在对各自的类调用 fit 方法后定义 . 出于这个原因,我无法覆盖属性定义本身 .

from types import MethodType


class SafeModel:

    FALLBACK_ATTRIBUTES = {
        'coef_': ['feature_importances_'],
    }

    @classmethod
    def generate(cls, model):
        safe_model = cls._secure_attributes(model)
        return safe_model

    @classmethod
    def _secure_attributes(cls, model):
        def __secure_getattr__(self, name):
            for fallback_attribute in cls.FALLBACK_ATTRIBUTES[name]:
                try:
                    return getattr(self, fallback_attribute)
                except:
                    continue
        model.__getattr__ = MethodType(__secure_getattr__, model)
        return model


    from sklearn.ensemble import RandomForestClassifier

    model = SafeModel.generate(RandomForestClassifier())
    model.coef_  # AttributeError: 'RandomForestClassifier' object has no attribute 'coef_'

1 回答

  • 0

    我无法确定您的代码有什么问题 . 我找到了一个可能适用于您的用例的workaroung .
    我'm using a different strategy as I'只是使用 SafeModel.__getattr__ 作为模型的 getattr 方法的包装而不是猴子修补 .

    from sklearn.utils.validation import NotFittedError
    from sklearn.ensemble import RandomForestClassifier
    
    class SafeModel(object):
    
        def __init__(self, model):
            self.FALLBACK_ATTRIBUTES = {
            'coef_': ['feature_importances_'],
        }
            self.model = model
    
        def __getattr__(self, name):
            try:
                return getattr(self.model, name)
            except AttributeError:
                pass
            for fallback_attribute in self.FALLBACK_ATTRIBUTES[name]:
                try:
                    return getattr(self.model, fallback_attribute)
                except NotFittedError as e:
                    # NotFittedError inherits AttributeError.
                    raise e
                except AttributeError:
                    continue
            else:
                raise AttributeError(
                    "{} object has no attribute {}.".format(
                        self.__class__.__name__, name) + 
                    " Could not retrieve any fallback attribute.")                    
    
    
    model = SafeModel(RandomForestClassifier())
    model.coef_
    

    输出:

    NotFittedError: Estimator not fitted, call `fit` before `feature_importances_`.
    

    请注意,这是正常行为,正如您所提到的,在适合随机林之前,您无法访问 feature_importances_ .

    不可否认,异常捕获在这里相当脆弱(你需要添加一堆可能会被引发的异常),但是如果你不想在尝试访问应该没问题的属性时提出正确的异常 .

    如果这对您有用,请告诉我 . 如果你发现你发布的代码发生了什么,我也会对解释感兴趣!

相关问题