首页 文章

Tensorflow:feed_dict的形状错误{}

提问于
浏览
1

第一次遇到这样的问题 .

错误大约是 feed_dict={tfkids: kids, tfkids_fit: kids_fit} ,似乎需要重塑 kids_fit .

任何人都可以帮我解决这个问题吗?

import tensorflow as tf
from tensorflow.contrib.distributions import Normal
import numpy as np
import matplotlib.pyplot as plt

DNA_SIZE = 1
POP_SIZE = 10
LR = 0.1
N_GENERATION = 50

def F(x):
    return x**2

def get_fitness(value):
    return -value

mean = tf.Variable(tf.constant(13.), dtype=tf.float32)
sigma = tf.Variable(tf.constant(5.), dtype=tf.float32)
N_dist = Normal(loc=mean, scale=sigma)
make_kids = N_dist.sample([POP_SIZE])

tfkids = tf.placeholder(tf.float32, [POP_SIZE, DNA_SIZE])
tfkids_fit = tf.placeholder(tf.float32, [POP_SIZE])
loss = -tf.reduce_mean(N_dist.log_prob(tfkids) * tfkids_fit)
train_op = tf.train.GradientDescentOptimizer(LR).minimize(loss)

x = np.linspace(-20, 20, 100)
plt.plot(x, F(x))

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

plt.ion()
for g in range(N_GENERATION):
    kids = sess.run(make_kids)
    kids_fit = get_fitness(F(kids))
    sess.run(train_op, feed_dict={tfkids: kids, tfkids_fit: kids_fit})

    if "plot_points" in globals():
        plot_points.remove()

    plot_points = plt.scatter(kids, F(kids), s=30)
    plt.pause(0.05)

plt.ioff()
plt.show()

这在尝试测试代码时出错 .

ValueError:无法为Tensor'占位符:0'提供形状值(10,),其形状为'(10,1)'

3 回答

  • 1

    你的 Placeholder:0tfkids = tf.placeholder(tf.float32, [POP_SIZE, DNA_SIZE]) .

    如您所见, tfkids 形状为 [POP_SIZE, DNA_SIZE] = (10, 1) .

    相反,您的 kids 变量的形状为= (10) .

    虽然两个形状都包含10个值,但第一个有2个维度,第二个是1个 .

    因此,您必须扩展 kids 变量的维度,以便以这种方式与 tfkids 兼容:

    sess.run(train_op, feed_dict={tfkids: np.expand_dims(kids, axis=1), tfkids_fit: kids_fit})
    

    np.expand_dims 允许您为 kids 形状添加一维尺寸

  • 1

    你可以重塑孩子的张量 .

    kids = sess.run(make_kids)
    kids = tf.reshape(kids,(None,1))
    kids_fit = get_fitness(F(kids))
    sess.run(train_op, feed_dict={tfkids: kids, tfkids_fit: kids_fit})
    
  • 1

    Problem: 当您声明 tfkids 变量时,将其形状指定为 [POP_SIZE, DNA_SIZE] ,即(10,1) . 但是,当您在训练期间将实际数据输入占位符时,您将传递(10)形状的数据 .

    Solution: 因此,您必须将训练数据重新整形为(10,1),以便将其提供给变量 . 您可以通过多种方式重塑数据 . 您可以使用numpy库的重塑功能 . 在为您提供训练数据之前,请执

    kids = np.reshape(kids, [-1, 1])
    

    希望这可以帮助!

相关问题