我正在训练GAN模型 . 为了加载数据集,我使用的是TensorFlow的数据集API .
# train_dataset has image and label. z_train dataset has noise (z).
train_dataset = tf.data.TFRecordDataset(train_file)
z_train = tf.data.Dataset.from_tensor_slices(tf.random_uniform([total_training_samples, seq_length, z_dim],
minval=0, maxval=1, dtype=tf.float32))
train_dataset = tf.data.Dataset.zip((train_dataset, z_train))
创建迭代器:
iter = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
使用迭代器:
(img, label), z = iter.get_next()
train_init_op = iter.make_initializer(train_dataset)
在 Session 期间培训GAN:
首先训练判别者:
_, disc_loss = sess.run([disc_optim, disc_loss])
然后训练发电机:
_, gen_loss = sess.run([gen_optim, gen_loss])
这是捕获 . 因为,我在鉴别器和生成器图中使用 label 作为条件(CGAN),使用两个sess.run在同一批次运行期间生成两组不同的 label 批 .
for epoch in range(num_of_epochs):
sess.run([tf.global_variables_initializer(), train_init_op.initializer])
for batch in range(num_of_batches):
_, disc_loss = sess.run([disc_optim, disc_loss])
_, gen_loss = sess.run([gen_optim, gen_loss])
因为,我必须在生成器_596157会话运行中提供相同批次的 label ,如何防止数据集API在批处理的同一循环中生成两个不同的批处理?
注意:我使用的是TensorFlow v1.9
提前致谢 .
1 回答
您可以为同一数据集创建2个迭代器 . 如果需要对数据集进行混洗,您甚至可以通过将种子指定为张量来实现 . 见下面的例子 .