首页 文章

使用Lambda图层的Keras与K.ctc_decode错误

提问于
浏览
2

看来,在谈到CTC功能时,Keras已经为你做了很多繁重的工作 . 但是我发现构建一个我不想在神经网络中运行的解码函数很棘手 . 我有一个在epoch端执行的自定义函数,然后我遍历所有我的测试数据并评估指标,我目前正在手动执行此操作但是想要使用k.ctc_decode函数(贪婪和梁)但是我发现很难访问并合并到我的自定义功能中 .

我有一个模特:

# Define CTC loss
    def ctc_lambda_func(args):
        y_pred, labels, input_length, label_length = args
        return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

def ctc_decode(args):
     y_pred, input_length =args
     seq_len = tf.squeeze(input_length,axis=1)

     return K.ctc_decode(y_pred=y_pred, input_length=seq_len, greedy=True, beam_width=100, top_paths=1)

input_data = Input(name='the_input', shape=(None,mfcc_features))  
x = TimeDistributed(Dense(fc_size, name='fc1', activation='relu'))(input_data) 
y_pred = TimeDistributed(Dense(num_classes, name="y_pred", activation="softmax"))(x)

labels = Input(name='the_labels', shape=[None,], dtype='int32')
input_length = Input(name='input_length', shape=[1], dtype='int32')
label_length = Input(name='label_length', shape=[1], dtype='int32')

loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred,labels,input_length,label_length])

dec = Lambda(ctc_decode, output_shape=[None,], name='decoder')([y_pred,input_length])

model = Model(inputs=[input_data, labels, input_length, label_length], outputs=[loss_out])



iterate = K.function([input_data, K.learning_phase()], [y_pred])
decode = K.function([y_pred, input_length], [dec])

目前的错误是:

dec = Lambda(ctc_decode,name ='decoder')([y_pred,input_length])文件“/home/rob/py27/local/lib/python2.7/site-packages/keras/engine/topology.py”,第604行,在调用output_shape = self.compute_output_shape(input_shape)文件“/home/rob/py27/local/lib/python2.7/site-packages/keras/layers/core.py”,第631行,在compute_output_shape中返回K .int_shape(x)文件“/home/rob/py27/local/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py”,第451行,in int_shape shape = x.get_shape()AttributeError:'元组'对象没有属性'get_shape'

我有什么想法可以做到这一点?

1 回答

  • 1

    一个棘手的部分是 K.ctc_decode 返回单个张量列表的元组,而不是单个张量,因此您无法直接创建图层 . 而是尝试使用 K.function 创建解码器:

    top_k_decoded, _ = K.ctc_decode(y_pred, input_lengths)
    decoder = K.function([input_data, input_lengths], [top_k_decoded[0]])
    

    稍后您可以拨打您的解码器:

    decoded_sequences = decoder([test_input_data, test_input_lengths])
    

    您可能需要进行一些整形,因为 K.ctc_decoder 要求长度具有类似(样本)的形状,而长度张量是形状的(样本,1) .

相关问题