首页 文章

Joblib错误:TypeError:无法pickle _thread.lock对象

提问于
浏览
0

我无法使用我的函数运行joblib,该函数采用numpy数组,训练有素的Keras模型列表和字符串列表作为参数 .

我尝试将参数创建为namedtuple甚至是具有不可变属性的类 . 有任何想法吗 ?

Params = collections.namedtuple('Params',['inputs','y_list','trained_models'])
p = Params(inputs, y_list, trained_models)

要么

class Params:
    def __init__(self, inputs, mq_list,trained_models):
        super(Params , self).__setattr__("inputs", inputs)
        super(Params , self).__setattr__("y_list", y_list)
        super(Params , self).__setattr__("trained_models", trained_models)

我喜欢并行运行的功能:

def predict(params):
    inputs = params.inputs
    y_list = params.y_list
    trained_models = params.trained_models

    # process and vectorize inputs
    X= new_X(inputs)
    X_vect= vect.transform(X)    
    predictions = dict()  

    for y in y_list:
        y_field = trained_models[y].predict(X_vect)
        # evaluate model
        if y_field[0] > 0.05:
            return None, None

        predictions[y] = y_field[0]

    return X, predictions

并行调用函数:

r= Parallel(n_jobs=4, verbose=5)(
    delayed(predict)(p)
    for c in range(100))

错误:

TypeErrorTraceback (most recent call last) <timed exec> in <module>()

~/.conda/envs/mlgpu/lib/python3.6/site-packages/joblib/parallel.py in
__call__(self, iterable)
    787                 # consumption.
    788                 self._iterating = False
--> 789             self.retrieve()
    790             # Make sure that we get a last message telling us we are done
    791             elapsed_time = time.time() - self._start_time

~/.conda/envs/mlgpu/lib/python3.6/site-packages/joblib/parallel.py in retrieve(self)
    697             try:
    698                 if getattr(self._backend, 'supports_timeout', False):
--> 699                     self._output.extend(job.get(timeout=self.timeout))
    700                 else:
    701                     self._output.extend(job.get())

~/.conda/envs/mlgpu/lib/python3.6/multiprocessing/pool.py in get(self, timeout)
    642             return self._value
    643         else:
--> 644             raise self._value
    645 
    646     def _set(self, i, obj):

~/.conda/envs/mlgpu/lib/python3.6/multiprocessing/pool.py in
_handle_tasks(taskqueue, put, outqueue, pool, cache)
    422                         break
    423                     try:
--> 424                         put(task)
    425                     except Exception as e:
    426                         job, idx = task[:2]

~/.conda/envs/mlgpu/lib/python3.6/site-packages/joblib/pool.py in send(obj)
    369             def send(obj):
    370                 buffer = BytesIO()
--> 371                 CustomizablePickler(buffer, self._reducers).dump(obj)
    372                 self._writer.send_bytes(buffer.getvalue())
    373             self._send = send

TypeError: can't pickle _thread.lock objects

1 回答

  • 0

    您应该创建自己的类,因为您不知道函数 collections.namedtuple 是否具有不可拾取的部分 .

    几个月前我遇到了类似的问题,我正在将一个lambda函数添加到类中以将其作为参数传递 . 但由于lambda函数不可选(由包 pickle ),因此会出错 .

相关问题