我已经开始学习tensorflow并且难以理解占位符/变量问题 .
我正在尝试编写矩阵乘法函数 . 它在使用tf.constant时有效,但我很难理解如何使用变量
这是我的代码
import tensorflow as tf
import numpy as np
mat_1 = np.array([[0,1,1,0], [1,0,1,0], [1,0,0,1], [0,1,1,0]]).astype('int32')
mat_2 = np.array([[0,1,1,0], [1,0,1,0], [1,0,0,1], [0,1,1,0]]).astype('int32')
def my_matmult1(mat_1, mat_2):
#define session
x_sess = tf.Session()
with x_sess:
xmat_1 = tf.constant(mat_1)
xmat_2 = tf.constant(mat_2)
r1 = tf.matmul(xmat_1, xmat_2)
qq1 = x_sess.run(r1)
return qq1
def my_matmult2(mat_1, mat_2):
#define session
x_sess1 = tf.Session()
with x_sess1:
#initialize placeholders
xmat_1_plh = tf.placeholder(dtype=mat_1.dtype, shape=mat_1.shape)
xmat_2_plh = tf.placeholder(dtype=mat_2.dtype, shape=mat_2.shape)
#create variables
x_mat_1 = tf.Variable(xmat_1_plh, trainable = False)
x_mat_2 = tf.Variable(xmat_2_plh, trainable = False)
x_sess1.run(tf.initialize_all_variables())
#
r1 = tf.matmul(xmat_1, xmat_2)
qq1 = x_sess1.run(r1, feed_dic={mat_1, mat_2})
return qq1
这按预期工作:
my_matmult1(mat_1, mat_1)
但是,以下失败:
my_matmult2(mat_1, mat_1)
出现以下错误
InvalidArgumentError您必须使用dtype int32和shape [4,4]为占位符张量“占位符”提供值
即使改变了最后一行
qq1 = x_sess1.run(r1, feed_dic={tf.convert_to_tensor(mat_1), tf.convert_to_tensor(mat_2)})
我究竟做错了什么?
3 回答
为了有意义地回答这个问题,我必须回到tensorflow如何设计工作
Graphs
Tensorflow中的图形只是计算将采用的 Map /路径 . 它不包含任何值,也不执行任何操作 .
Session
另一方面,会话需要图形,数据和运行时间来执行 . 图形和会话的这个概念让TensorFolow将流定义或模型与实际计算运行时分开 .
Separating the run-time from flow graph
这很可能是为了将图形定义与运行时配置和实际执行数据分开 . 例如,运行时可以在群集上 . 因此,集群中的每个执行运行时都需要具有相同的图形定义 . 但是每个运行时可能在执行过程中本地具有不同的数据集 . 因此,在群集中的分布式执行期间提供输入和输出数据非常重要 .
Why Placeholders and Not Variables
占位符充当图形的输入/输出管道 . 如果将图形可视化为多个节点 - 占位符是输入或输出节点 .
真正的问题是为什么TensorFlow不为I / O节点使用正常变量?为什么还有其他类型?
在训练过程中(当程序在会话中执行时),需要确保使用实际值来训练模型 . 基本上,训练过程中的
feed_dict
只接受实际值,例如一个Numpy ndarry . 这些实际值不能由TensorFlow变量提供,因为除非使用eval()或session.run(),否则Tensorflow变量没有数据 . 但是,训练语句本身是session.run()函数的一部分 - 因此它不能在其中使用另一个session.run()来将张量变量解析为数据 . 到目前为止,session.run()已经必须绑定到特定的运行时配置和数据 .你没有正确地喂食字典 . 您需要将字典设置为占位符的名称 . 我还添加了一个名称,您可以使用“xmat_1_plh”作为名称,但我更喜欢添加自己的名称 . 我还认为你在my_matmult2()函数中有一些额外的行 . x_mat_1 / 2我不认为增加太多,但可能不会受到伤害(通过在图表中添加另一个OP可能会有一点性能 .
我不确定这个函数的最终目标是什么,但是你在图中创建节点 . 因此,您可能希望将“.run()”语句从此函数中移出(到您希望主动乘以2矩阵的位置),因为如果您只是在寻找,您不应该在循环中调用它一种乘以2矩阵的方法 .
如果这是对my_matmult2()的单个测试/调用,那么您应该使用对字典的更正 .
如果在创建占位符后删除
tf.Variable()
行(并相应地修改了fed变量的名称),则代码应该有效 .占位符用于您希望为模型提供的变量 . 变量用于模型的参数(如权重) .
因此,您正确创建了两个占位符,但随后您无缘无故地创建了其他变量,这可能会在Tensorflow图中出现问题 .
该功能看起来像: