首页 文章

Tensorflow上的多维RNN

提问于
浏览
3

我正在尝试在人类行为分类的背景下实现2D RNN(RNN的一个轴上的关节和另一个轴上的时间),并且一直在寻找可以完成工作的Tensorflow中的某些东西 .

我听说 GridLSTMCellinternallyexternally贡献)但无法使用 dynamic_rnn (接受3-D张量,但我必须提供4-D张量[batchsize,max_time,num_joints,n_features]) .

另外, ndlstm 也是TF库的(有些未知)part,它基本上使用普通的1-D LSTM并转换输出以将其馈送到第二个1-D LSTM . 这也被提倡here但是我不太确定它是否与我需要的想法相同 .

任何帮助,将不胜感激 .

1 回答

  • 0

    我已成功尝试在tensorflow中使用 GridLSTMndlstm .

    我不确定如何将4D Tensor转换为3D,因为它被 dynamic_rnn 接受但我想这可能会让你知道如何使用 GridLSTM

    def reshape_to_rnn_dims(tensor, num_time_steps):
        return tf.unstack(tensor, num_time_steps, 1)
    
    
    class GridLSTMCellTest(tf.test.TestCase):
        def setUp(self):
            self.num_features = 1
            self.time_steps = 1
            self.batch_size = 1
            tf.reset_default_graph()
            self.input_layer = tf.placeholder(tf.float32, [self.batch_size, self.time_steps, self.num_features])
            self.cell = grid_rnn.Grid1LSTMCell(num_units=8)
    
        def test_simple_grid_rnn(self):
            self.input_layer = reshape_to_rnn_dims(self.input_layer, self.time_steps)
            tf.nn.static_rnn(self.cell, self.input_layer, dtype=tf.float32)
    
        def test_dynamic_grid_rnn(self):
            tf.nn.dynamic_rnn(self.cell, self.input_layer, dtype=tf.float32)
    
    
    class BidirectionalGridRNNCellTest(tf.test.TestCase):
        def setUp(self):
            self.num_features = 1
            self.time_steps = 1
            self.batch_size = 1
            tf.reset_default_graph()
            self.input_layer = tf.placeholder(tf.float32, [self.batch_size, self.time_steps, self.num_features])
            self.cell_fw = grid_rnn.Grid1LSTMCell(num_units=8)
            self.cell_bw = grid_rnn.Grid1LSTMCell(num_units=8)
    
        def test_simple_bidirectional_grid_rnn(self):
            self.input_layer = reshape_to_rnn_dims(self.input_layer, self.time_steps)
            tf.nn.static_bidirectional_rnn(self.cell_fw, self.cell_fw, self.input_layer, dtype=tf.float32)
    
        def test_bidirectional_dynamic_grid_rnn(self):
            tf.nn.bidirectional_dynamic_rnn(self.cell_fw, self.cell_bw, self.input_layer, dtype=tf.float32)
    
    if __name__ == '__main__':
        tf.test.main()
    

    显然, ndlstm 接受形状为 (batch_size, height, width, depth) 的4D张量,我有这些测试(一个涉及使用tensorflow的 ctc_loss . 还发现它与conv2d一起使用的example):

    class MultidimensionalRNNTest(tf.test.TestCase):
        def setUp(self):
            self.num_classes = 26
            self.num_features = 32
            self.time_steps = 64
            self.batch_size = 1 # Can't be dynamic, apparently.
            self.num_channels = 1
            self.num_filters = 16
            self.input_layer = tf.placeholder(tf.float32, [self.batch_size, self.time_steps, self.num_features, self.num_channels])
            self.labels = tf.sparse_placeholder(tf.int32)
    
        def test_simple_mdrnn(self):
            net = lstm2d.separable_lstm(self.input_layer, self.num_filters)
    
        def test_image_to_sequence(self):
            net = lstm2d.separable_lstm(self.input_layer, self.num_filters)
            net = lstm2d.images_to_sequence(net)
    
        def test_convert_to_ctc_dims(self):
            net = lstm2d.separable_lstm(self.input_layer, self.num_filters)
            net = lstm2d.images_to_sequence(net)
    
            net = tf.reshape(inputs, [-1, self.num_filters])
    
             W = tf.Variable(tf.truncated_normal([self.num_filters,
                                         self.num_classes],
                                        stddev=0.1, dtype=tf.float32), name='W')
             b = tf.Variable(tf.constant(0., dtype=tf.float32, shape=[self.num_classes], name='b'))
    
             net = tf.matmul(net, W) + b
             net = tf.reshape(net, [self.batch_size, -1, self.num_classes])
    
             net = tf.transpose(net, (1, 0, 2))
    
             loss = tf.nn.ctc_loss(inputs=net, labels=self.labels, sequence_length=[2])
    
        print(net)
    
    
    if __name__ == '__main__':
        tf.test.main()
    

相关问题