我发现运行CIFAR10分布式培训的准确度非常低 . 即使在2个P2.8X(每台机器上有8个Tesla K80 GPU)上运行1M步,我也看到2.39的损失 . 我在每台机器上有一个ps,每个GPU有一个worker(总共16个worker),并且batch_size为8 .
成功培训后, the accuracy on the validataion dataset using the cifar10_eval is 0.010
我使用tensorflow tutorial中的模型运行cifar10训练的分布式版本 . 我使用distributed tensorflow的代码示例来运行它的分布式模式 . 代码的分布式版本如下 . what is wrong in this code and how can i fix it to get better accuracy?
command used to run ps and workers 类似
CUDA_VISIBLE_DEVICES='0' python cifar10_multi_machine_train.py --batch_size 8
--data_dir=./cifar10_data --train_dir=./train_logs --ps_hosts=host1:2222,host2:2222
--worker_hosts=host1:2230,host1:2231,host1:2232,host1:2233,host1:2234,host1:2235,
host1:2236,host1:2237,host2:2230,host2:2231,host2:2232,host2:2233,
host2:2234,host2:2235,host2:2236,host2:2237
--job_name=worker --task_index=0
distributed training code
if FLAGS.job_name == "ps":
server.join()
elif FLAGS.job_name == "worker":
# Assigns ops to the local worker by default.
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % FLAGS.task_index,
cluster=cluster)):
global_step = tf.contrib.framework.get_or_create_global_step()
# Get images and labels for CIFAR-10.
images, labels = cifar10.distorted_inputs()
# Build inference Graph.
logits = cifar10.inference(images)
# Build the portion of the Graph calculating the losses. Note that we will
# assemble the total_loss using a custom function below.
loss = cifar10.loss(logits, labels)
train_op = cifar10.train(loss,global_step)
# The StopAtStepHook handles stopping after running given steps.
hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.num_steps), _LoggerHook()]
# The MonitoredTrainingSession takes care of session initialization,
# restoring from a checkpoint, saving to a checkpoint, and closing when done
# or an error occurs.
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=(FLAGS.task_index == 0),
checkpoint_dir=FLAGS.train_dir,
save_checkpoint_secs=60,
hooks=hooks) as mon_sess:
while not mon_sess.should_stop():
# Run a training step asynchronously.
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
# perform *synchronous* training.
# mon_sess.run handles AbortedError in case of preempted PS.
mon_sess.run(train_op)