首页 文章

如何在计算损失函数时使用keras层的权重?

提问于
浏览
-1

我正在尝试构建一个只有一层的自动编码器:

from keras import backend as K

def cost2(y_true, y_pred):
    print "shapes:", model.get_weights()[0].shape
    yy = K.dot( y_pred, model.get_weights()[0].T )
    return np.sum((y_true - yy)**2)

x = Input(shape=(original_dim,))
y = Dense(latent_dim)(x)
model = Model(inputs=x, outputs=y)
model.summary()
model.compile(optimizer='adagrad', loss=cost2)

这给了我错误:

Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 784)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 2)                 1570      
=================================================================
Total params: 1,570
Trainable params: 1,570
Non-trainable params: 0
_________________________________________________________________

shapes: (784, 2)

回溯(最近一次调用最后一次):文件“vae_kears_gidital_mnist3.py”,第45行,在model.compile中(optimizer ='adagrad',loss = cost2)文件“/Users/asgharrazavi/anaconda/lib/python2.7/site- packages / keras / engine / training.py“,第830行,编译sample_weight,mask)文件”/Users/asgharrazavi/anaconda/lib/python2.7/site-packages/keras/engine/training.py“,第429行,在加权score_array = fn(y_true,y_pred)文件“vae_kears_gidital_mnist3.py”,第18行,在cost2 yy = K.dot(y_pred,model.get_weights()[0] .T)文件“/ Users / asgharrazavi / anaconda /lib/python2.7/site-packages/keras/backend/tensorflow_backend.py“,第1048行,如果ndim(x)不是None且(ndim(x)> 2或ndim(y)> 2),则为dot:文件“/Users/asgharrazavi/anaconda/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py”,第606行,在ndim dims = x.get_shape()._ dims AttributeError:'numpy.ndarray'对象没有属性'get_shape'

我只是试图将模型的输出乘以模型的转置权重以返回输入维度 . 有任何想法吗?

1 回答

  • 1

    你的成本函数应该返回一个keras张量而不是 numpdy ndarray . 您应该仅使用 keras.backend 函数或您在客户丢失函数中的特定后端函数(例如 tf.something )(即 K.sum 而不是 np.sum

    这是您在问题中提到的错误的原因,但更重要的是,您没有以keras的方式创建自动编码器 . 在keras中,您的模型将使用两层(编码器和解码器)创建,其中图层通过转置和标准MSE损耗共享权重 . 我建议你阅读keras博客的this post,看看this issue .

相关问题