首页 文章

如何在tensorflow r0.9(skflow)中训练DNNC分类器时打印进度?

提问于
浏览
7

我无法得到DNNClassifier在训练时打印进度,即损失和验证分数 . 据我所知,可以使用继承自BaseEstimator的config参数打印丢失,但是当我传递RunConfig对象时,分类器没有打印任何内容 .

from tensorflow.contrib.learn.python.learn.estimators import run_config

config = run_config.RunConfig(verbose=1)
classifier = learn.DNNClassifier(hidden_units=[10, 20, 10],
                             n_classes=3,
                             config=config)
classifier.fit(X_train, y_train, steps=1000)

我错过了什么吗?我检查了RunConfig如何处理详细参数,它似乎是that it only cares if its greater than 1,它与文档不匹配:

verbose:控制详细程度,可能的值:0:算法和调试信息被静音 . 1:培训师打印进度 . 2:打印日志设备放置 .

至于验证分数我认为使用monitors.ValidationMonitor会很好,但是当尝试它时,分类器不会找到任何监视器 .

2 回答

  • 1

    在fit函数之前添加这些以显示进度:

    import logging
    logging.getLogger().setLevel(logging.INFO)
    

    样品:

    INFO:tensorflow:global_step/sec: 0
    INFO:tensorflow:Training steps [0,1000000)
    INFO:tensorflow:Step 1: loss = 10.5043
    INFO:tensorflow:training step 100, loss = 10.45380 (0.223 sec/batch).
    INFO:tensorflow:Step 101: loss = 10.5623
    INFO:tensorflow:training step 200, loss = 10.46701 (0.220 sec/batch).
    INFO:tensorflow:Step 201: loss = 10.3885
    INFO:tensorflow:training step 300, loss = 10.36501 (0.232 sec/batch).
    INFO:tensorflow:Step 301: loss = 10.3441
    INFO:tensorflow:training step 400, loss = 10.44571 (0.220 sec/batch).
    INFO:tensorflow:Step 401: loss = 10.396
    INFO:tensorflow:global_step/sec: 3.95
    
  • 12

    在训练前添加此行:

    import logging
    tf.logging.set_verbosity(tf.logging.INFO)
    

相关问题