我有一台带有一个cuda启用的gpu和一个8核心cpu的计算机 . 我正在尝试实施用于强化学习的A3C算法,该算法将图形和训练环境并行化并将梯度更新同步到全局图形 .
我正在尝试使用分布式tensorflow api来执行此操作:
# Hyperparameter definition
############################################################
epochs = 100000 # Global steps
t_max = 1000 # Thread steps
max_grad_norm = 0.5
alpha = 0.99
gamma = 0.99
#############################################################
ckpt_path = 'ckpt/checkpoints'
log_path = 'logs'
def test(ac, load_path, sess):
ac.load(load_path)
env = gym.make('Breakout-v0')
while True:
obs = env.reset()
state_c, state_h = ac.init_lstm_c, ac.init_lstm_h
done = False
while not done:
action, _, state_c, state_h = ac.step(sess, obs, state_c, state_h)
obs, _, done, _ = env.step(action)
def train(n_epochs, t_max, gamma, ac, sess):
env = gym.make("Breakout-v0")
save_dir = os.path.join(log_path)
writer = tf.summary.FileWriter(save_dir)
for ep in range(n_epochs):
ep_obs, ep_disc_rew, m_rew, ep_act, ep_vals, state_c, state_h = process_episode(sess, ac, env, t_max, gamma)
log = ac.learn(sess, ep_obs, state_c, state_h, ep_disc_rew, m_rew, ep_act, ep_vals)
step = tf.train.get_global_step().eval(session=sess)
writer.add_summary(log, global_step=step)
def process_episode(sess, ac, env, t_max, gamma):
ep_observations, ep_rewards, ep_actions, ep_values = [], [], [], []
done = False
t = 0
observation = env.reset()
state_c, state_h = ac.c_init, ac.h_init
while t < t_max and not done:
action, value, state_c, state_h = ac.step(sess, observation, state_c, state_h)
ep_observations.append(observation)
ep_values.append(value)
ep_actions.append(action)
observation, reward, done, _ = env.step(action)
ep_rewards.append(reward)
ep_disc_rewards = discount_rewards(ep_rewards, gamma)
t_rew = np.sum(ep_rewards)
return ep_observations, ep_disc_rewards, t_rew, ep_actions, ep_values, state_c, state_h
if __name__ == '__main__':
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
server = tf.train.Server.create_local_server()
gs = tf.train.create_global_step(tf.get_default_graph())
env = gym.make("Breakout-v0")
ac = ActorCritic(env.observation_space, env.action_space)
with tf.train.MonitoredTrainingSession(
master=server.target, checkpoint_dir=ckpt_path, save_summaries_steps=None, config=config) as sess:
train(epochs, t_max, gamma, ac, sess)
sess.stop()
我想知道的是: - MonitoredTrainingSession块中的代码是否与一些工作者并行化? - 我的图表是否被x Worker 复制?
如果没有,我该怎么办?