首页 文章

在Keras和Tensorflow中复制模型以进行多线程设置

提问于
浏览
3

我正在尝试在Keras和TensorFlow中实现actor-critic的异步版本 . 我正在使用Keras作为构建我的网络层的前端(我正在使用tensorflow直接更新参数) . 我有一个 global_model 和一个主要的tensorflow会话 . 但是在每个线程中我创建了一个 local_model ,它从 global_model 复制参数 . 我的代码看起来像这样

def main(args):
    config=tf.ConfigProto(log_device_placement=False,allow_soft_placement=True)
    sess = tf.Session(config=config)
    K.set_session(sess) # K is keras backend
    global_model = ConvNetA3C(84,84,4,num_actions=3)

    threads = [threading.Thread(target=a3c_thread, args=(i, sess, global_model)) for i in range(NUM_THREADS)]

    for t in threads:
        t.start()

def a3c_thread(i, sess, global_model):
    K.set_session(sess) # registering a session for each thread (don't know if it matters)
    local_model = ConvNetA3C(84,84,4,num_actions=3)
    sync = local_model.get_from(global_model) # I get the error here

    #in the get_from function I do tf.assign(dest.params[i], src.params[i])

我收到了来自Keras的用户警告

UserWarning:默认的TensorFlow图形不是与当前向Keras注册的TensorFlow会话关联的图形,因此Keras无法自动初始化变量 . 您应该考虑通过K.set_session(sess)向Keras注册正确的会话

然后是 tf.assign 操作的张量流错误,说操作必须在同一个图上 .

ValueError:Tensor(“conv1_W:0”,shape =(8,8,4,16),dtype = float32_ref,device = / device:CPU:0)必须与Tensor(“conv1_W:0”)在同一图表中,shape =(8,8,4,16),dtype = float32_ref)

我不确定出了什么问题 .

谢谢

1 回答

  • 5

    该错误来自Keras,因为 tf.get_default_graph() is sess.graph 正在返回 False . 从TF文档中,我看到 tf.get_default_graph() 正在返回当前线程的默认图形 . 在我开始一个新线程并创建一个图形的那一刻,它被构建为一个特定于该线程的单独图形 . 我可以通过以下方式解决这个问题,

    with sess.graph.as_default():
       local_model = ConvNetA3C(84,84,4,3)
    

相关问题