首页 文章

矩阵乘法不起作用 - Tensorflow

提问于
浏览
0

我很想使用tensorflow并且正在使用学校项目 . 在这里,我试图 Build 一个房屋标识符,我在excel表上做了一些数据,把它变成了一个csv文件,我正在测试是否可以读取数据 . 读取数据但是当我进行矩阵乘法时会产生多个错误并且说....“ValueError:Shape必须是等级2,但对于'MatMul'(op:'MatMul'),输入形状为[0] ,[1,1] . “非常感谢!

import tensorflow as tf
import os
dir_path = os.path.dirname(os.path.realpath(__file__))
filename = dir_path+ "\House Price Data .csv"
w1=tf.Variable(tf.zeros([1,1]))
w2=tf.Variable(tf.zeros([1,1])) #Feature 1's weight
w3=tf.Variable(tf.zeros([1,1])) #Feature 1's weight
b=tf.Variable(tf.zeros([1])) #bias for various features
x1= tf.placeholder(tf.float32,[None, 1])
x2= tf.placeholder(tf.float32,[None, 1])
x3= tf.placeholder(tf.float32,[None, 1])
Y= tf.placeholder(tf.float32,[None, 1])
y_=tf.placeholder(tf.float32,[None,1])
with tf.Session() as sess:
    sess.run( tf.global_variables_initializer())
    with open(filename) as inf:
        # Skip header
        next(inf)
        for line in inf:
            # Read data, using python, into our features
            housenumber, x1, x2, x3, y_ = line.strip().split(",")
            x1 = float(x1)
            product = tf.matmul(x1, w1)
            y = product + b

1 回答

  • 0

    @Aaron是对的,当你从csv文件加载数据时,你正在覆盖变量 .

    您需要将加载的值保存到单独的变量中,例如 _x1 而不是 x1 ,然后使用feed_dict将值提供给占位符 . 并且因为 x1 的形状是 [None,1] ,您需要将字符串标量 _x1 转换为具有相同形状的浮点数,在这种情况下为 [1,1] .

    import tensorflow as tf
    import os
    dir_path = os.path.dirname(os.path.realpath(__file__))
    filename = dir_path+ "\House Price Data .csv"
    w1=tf.Variable(tf.zeros([1,1]))
    b=tf.Variable(tf.zeros([1])) #bias for various features
    x1= tf.placeholder(tf.float32,[None, 1])
    
    y_pred = tf.matmul(x1, w1) + b
    
    with tf.Session() as sess:
        sess.run( tf.global_variables_initializer())
        with open(filename) as inf:
            # Skip header
            next(inf)
            for line in inf:
                # Read data, using python, into our features
                housenumber, _x1, _x2, _x3, _y_ = line.strip().split(",")
                sess.run(y_pred, feed_dict={x1:[[float(_x1)]]})
    

相关问题