首页 文章

从Tensorflow的Tensor Variable访问值

提问于
浏览
0

我修改了autoencoder的示例代码以适合我的数据 . 我想知道一种方法,我可以访问编码器的layer_1值 . 我想使用编码器编码值作为输入数据的功能,并对此进行进一步分析 . 输入为200维,编码器有100个隐藏节点 .

from __future__ import division, print_function, absolute_import

import tensorflow as tf
import numpy as np
import word2vec

# Parameters
learning_rate = 0.01
training_epochs = 200

#layer_1 = None
#layer_2 = None

# Network Parameters
n_hidden_1 = 100 # 1st layer num features
n_input = 200 # concatenation of 2 word vectors

X = tf.placeholder("float", [None, n_input])

weights = {
    'encoder_h1': tf.Variable(tf.random_normal([n_input, n_hidden_1])),
    'decoder_h1': tf.Variable(tf.random_normal([n_hidden_1, n_input])),
}
biases = {
    'encoder_b1': tf.Variable(tf.random_normal([n_hidden_1])),
    'decoder_b1': tf.Variable(tf.random_normal([n_input])),
}

# Building the encoder
def encoder(x):
    #global layer_1
    layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['encoder_h1']),
                               biases['encoder_b1']))
    return layer_1

# Building the decoder
def decoder(x):
    #global layer_2
    layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['decoder_h1']),
                               biases['decoder_b1']))
    return layer_2


 # Construct model
encoder_op = encoder(X)
decoder_op = decoder(encoder_op)

# Prediction
y_pred = decoder_op
# Targets (Labels) are the input data.
y_true = X


# Define loss and optimizer, minimize the squared error
cost = tf.reduce_mean(tf.pow(y_true - y_pred, 2))
optimizer = tf.train.RMSPropOptimizer(learning_rate).minimize(cost)

# Initializing the variables
init = tf.initialize_all_variables()

with tf.Session() as sess:
    sess.run(init)    
    model = word2vec.load('./vectors.bin')    
    vector1 = (list(model['word1']))
    vector2 = (list(model['word2']))    
    input = []
    input.append(vector1+vector2)
    input_np = np.array(input)  
    for epoch in range(training_epochs):
        #import pdb; pdb.set_trace()
        _, c = sess.run([optimizer, cost], feed_dict={X: input_np}) 

    print(sess.run(w))

1 回答

相关问题