首页 文章

Tensorflow实现crf损失

提问于
浏览
1

我试图在Tensorflow图中使用条件随机场丢失 .

我正在执行序列标记任务:

我有一系列元素作为输入 [A, B, C, D] . 每个元素可以属于3个不同类中的一个 . 类以单热编码方式表示:属于类0的元素由向量[ 1, 0, 0] 表示 .

我的输入标签(y)的大小( batch_size x sequence_length x num_classes ) .

我的网络生成具有相同形状的logits .

假设我的所有序列都有4个长度 .

这是我的代码:

import tensorflow as tf

sequence_length = 4
num_classes = 3
input_y = tf.placeholder(tf.int32, shape=[None, sequence_length, num_classes])
logits = tf.placeholder(tf.float32, shape=[None, None, num_classes])
dense_y = tf.argmax(input_y, -1, output_type=tf.int32)

log_likelihood, _ = tf.contrib.crf.crf_log_likelihood(logits, dense_y, sequence_length)

我收到以下错误:

文件“”,第1行,在文件“/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/crf/python/ops/crf.py”,第182行,in crf_log_likelihood transition_params)文件“ /usr/local/lib/python2.7/dist-packages/tensorflow/contrib/crf/python/ops/crf.py“,第109行,in crf_sequence_score false_fn = _multi_seq_fn)文件”/ usr / local / lib / python2 . 7 / dist-packages / tensorflow / python / layers / utils.py“,第206行,在smart_cond pred中,true_fn = true_fn,false_fn = false_fn,name = name)文件”/usr/local/lib/python2.7/dist -packages / tensorflow / python / framework / smart_cond.py“,第59行,在smart_cond name = name中)文件”/usr/local/lib/python2.7/dist-packages/tensorflow/python/util/deprecation.py“ ,第432行,在new_func中返回func(* args,** kwargs)文件“/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py”,第2063行,在cond orig_res_t中,res_t = context_t.BuildCondBranch(true_fn)文件“/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py”,第1913行,BuildCondBra nch original_result = fn()文件“/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/crf/python/ops/crf.py”,第95行,在_single_seq_fn array_ops.concat中[(example_inds, tag_indices],axis = 1))文件“/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_array_ops.py”,第2975行,在gather_nd“GatherNd”中,params = params,indices = indices,name = name)文件“/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py”,第787行,在_apply_op_helper中op_def = op_def)文件“/ usr / local /lib/python2.7/dist-packages/tensorflow/python/framework/ops.py“,第3392行,在create_op中op_def = op_def)文件”/usr/local/lib/python2.7/dist-packages/tensorflow/ python / framework / ops.py“,第1734行,在init control_input_ops中)文件”/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py“,第1570行,在_create_c_op中ValueError(str(e))ValueError:indices.shape [-1]必须是<= params.rank,但看到索引形状:[?,5]和params形状:[?,3]为'cond / GatherNd '(op:'GatherNd')输入形状:[?,3],[?,5]

1 回答

  • 1

    错误是由于序列长度变量的错误维度造成的 . 它必须是一个向量,而不是一个标量 .

    import tensorflow as tf
    
    num_classes = 3
    input_x = tf.placeholder(tf.int32, shape=[None, None], name="input_x")
    input_y = tf.placeholder(tf.int32, shape=[None, sequence_length, num_classes])
    sequence_length = tf.reduce_sum(tf.sign(input_x), 1)
    
    # After some network operation you will come up with logits
    
    logits = tf.placeholder(tf.float32, shape=[None, None, num_classes])
    dense_y = tf.argmax(input_y, -1, output_type=tf.int32)
    log_likelihood, _ = tf.contrib.crf.crf_log_likelihood(logits, dense_y, sequence_length
    

相关问题