首页 文章

获取具有可变序列长度的激活时的Tensorflow GRU单元错误

提问于
浏览
4

我想在一些时间序列数据上运行GRU单元格,以根据最后一层中的激活对它们进行聚类 . 我对GRU单元实现做了一个小改动

def __call__(self, inputs, state, scope=None):
"""Gated recurrent unit (GRU) with nunits cells."""
with vs.variable_scope(scope or type(self).__name__):  # "GRUCell"
  with vs.variable_scope("Gates"):  # Reset gate and update gate.
    # We start with bias of 1.0 to not reset and not update.
    r, u = array_ops.split(1, 2, linear([inputs, state], 2 * self._num_units, True, 1.0))
    r, u = sigmoid(r), sigmoid(u)
  with vs.variable_scope("Candidate"):
    c = tanh(linear([inputs, r * state], self._num_units, True))
  new_h = u * state + (1 - u) * c

  # store the activations, everything else is the same
  self.activations = [r,u,c]
return new_h, new_h

在此之后,我将以下面的方式连接激活,然后在调用此GRU单元的脚本中返回它们

@property
def activations(self):
    return self._activations


@activations.setter
def activations(self, activations_array):
    print "PRINT THIS"         
    concactivations = tf.concat(concat_dim=0, values=activations_array, name='concat_activations')
    self._activations = tf.reshape(tensor=concactivations, shape=[-1], name='flatten_activations')

我以下面的方式调用GRU单元

outputs, state = rnn.rnn(cell=cell, inputs=x, initial_state=initial_state, sequence_length=s)

其中 s 是批处理长度数组,其中包含输入批处理的每个元素中的时间戳数 .

最后我拿到了

fetched = sess.run(fetches=cell.activations, feed_dict=feed_dict)

执行时我收到以下错误

回溯(最近一次调用最后一次):文件“xxx.py”,第162行,在fetched = sess.run(fetches = cell.activations,feed_dict = feed_dict)文件“/xxx/local/lib/python2.7/site- packages / tensorflow / python / client / session.py“,第315行,在运行中返回self._run(None,fetches,feed_dict)文件”/xxx/local/lib/python2.7/site-packages/tensorflow/python/ client / session.py“,第511行,在_run feed_dict_string中)文件”/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py“,第564行,在_do_run target_list中)文件“/xxx/local/lib/python2.7/site-packages/tensorflow/python/client/session.py”,第588行,在_do_call six.reraise(e_type,e_value,e_traceback)文件“/ xxx / local / lib /python2.7/site-packages/tensorflow/python/client/session.py“,第571行,在_do_call中返回fn(* args)文件”/xxx/local/lib/python2.7/site-packages/tensorflow/在_run_fn中的python / client / session.py“,第555行

return tf_session.TF_Run(session, feed_dict, fetch_list, target_list) tensorflow.python.pywrap_tensorflow.StatusNotOK: Invalid argument: The tensor returned for RNN/cond_396/ClusterableGRUCell/flatten_activations:0 was not valid.

有人可以通过传递可变长度序列来了解如何在最后一步从GRU单元获取激活吗?谢谢 .

1 回答

  • 0

    要从最后一步获取激活,您需要将激活作为状态的一部分,由tf.rnn返回 .

相关问题