我发现运行CIFAR10分布式培训的准确度非常低 . 即使在2个P2.8X(每台机器上有8个Tesla K80 GPU)上运行1M步,我也看到2.39的损失 . 我在每台机器上有一个ps,每个GPU有一个worker(总共16个worker),并且batch_size为8 .

成功培训后, the accuracy on the validataion dataset using the cifar10_eval is 0.010

我使用tensorflow tutorial中的模型运行cifar10训练的分布式版本 . 我使用distributed tensorflow的代码示例来运行它的分布式模式 . 代码的分布式版本如下 . what is wrong in this code and how can i fix it to get better accuracy?

command used to run ps and workers 类似

CUDA_VISIBLE_DEVICES='0' python cifar10_multi_machine_train.py --batch_size 8
--data_dir=./cifar10_data --train_dir=./train_logs --ps_hosts=host1:2222,host2:2222
--worker_hosts=host1:2230,host1:2231,host1:2232,host1:2233,host1:2234,host1:2235,
host1:2236,host1:2237,host2:2230,host2:2231,host2:2232,host2:2233,
host2:2234,host2:2235,host2:2236,host2:2237
--job_name=worker --task_index=0

distributed training code

if FLAGS.job_name == "ps":
    server.join()
elif FLAGS.job_name == "worker":

    # Assigns ops to the local worker by default.
    with tf.device(tf.train.replica_device_setter(
        worker_device="/job:worker/task:%d" % FLAGS.task_index,
        cluster=cluster)):

        global_step = tf.contrib.framework.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        images, labels = cifar10.distorted_inputs()

        # Build inference Graph.
        logits = cifar10.inference(images)

        # Build the portion of the Graph calculating the losses. Note that we will
        # assemble the total_loss using a custom function below.
        loss = cifar10.loss(logits, labels)

        train_op = cifar10.train(loss,global_step)

    # The StopAtStepHook handles stopping after running given steps.
    hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.num_steps), _LoggerHook()]

    # The MonitoredTrainingSession takes care of session initialization,
    # restoring from a checkpoint, saving to a checkpoint, and closing when done
    # or an error occurs.
    with tf.train.MonitoredTrainingSession(master=server.target,
                                            is_chief=(FLAGS.task_index == 0),
                                            checkpoint_dir=FLAGS.train_dir,
                                            save_checkpoint_secs=60,
                                            hooks=hooks) as mon_sess:
        while not mon_sess.should_stop():
            # Run a training step asynchronously.
            # See `tf.train.SyncReplicasOptimizer` for additional details on how to
            # perform *synchronous* training.
            # mon_sess.run handles AbortedError in case of preempted PS.
            mon_sess.run(train_op)