首页 文章

在党的数据上训练Keras模型

提问于
浏览
1

我生成训练数据我的整个RAM(亚马逊AWS上的32 GB - https://aws.amazon.com/marketplace/pp/B077GCH38C?qid=1527197041958&sr=0-1&ref_=srh_res_product_title)正在被吃掉,进程正在被杀死 .

为了构建训练数据,我迭代文章列表(每个都有500-1000个字符) . 在每篇文章中,我将前100个字符作为输入,下一个字符作为输出,然后我移动一个字符并重复此操作直到文本结尾 . 这种方法产生了很多训练向量,即具有500个字符的文章将产生大约400个测试数据,这就是问题 .

随着15k文章和滑动窗口100将有数百万的训练数据和我的AWS机器(32 GB RAM t2.2xlarge - https://aws.amazon.com/marketplace/pp/B077GCH38C?qid=1527197041958&sr=0-1&ref_=srh_res_product_title)死亡约79%-3,500万训练数据 .

所以我的问题是 - 在Keras有没有办法开始学习模型,让我们说25%的数据,然后加载25%并执行此操作直到所有内容消耗完毕?

我的学习伪代码:

with open(articles_path, 'rt', encoding="UTF-8") as file:
    for line in file:
        article = line[0:-1]
        article_length = len(article)
        # here is the problematic code 
        for i in range(0, article_length - seq_length, 1):
            seq_in = article[i:i + seq_length]
            seq_out = article[i + seq_length]
            dataX.append([tokens[char] for char in seq_in])
            dataY.append(tokens[seq_out])

model = Sequential()
model.add(LSTM(256, input_shape=(seq_length, 1)))
model.add(Dropout(0.2))
model.add(Dense(len(tokens), activation=activation))
model.compile(loss=loss, optimizer=optimizer)

model.fit(X, y, epochs=epochs, batch_size=batch_size, callbacks=callbacks_list)

注意:当我编写自己的程序时,我正在使用本教程https://machinelearningmastery.com/text-generation-lstm-recurrent-neural-networks-python-keras/

2 回答

  • 1

    这看起来是切换到生成器的好时机,基本上你会一次吐出一批而不是加载整个数据集:

    def data_gen(batch_size=32):
      """Yield single batch at a time."""
      dataX, dataY = list(), list()
      while True: # the generator yields forever
        # here is the problematic code 
        for i in range(0, article_length - seq_length, 1):
          for _ in range(batch_size):
            seq_in = article[i:i + seq_length]
            seq_out = article[i + seq_length]
            dataX.append([tokens[char] for char in seq_in])
            dataY.append(tokens[seq_out])
          yield np.array(dataX), np.array(dataY)
          dataX, dataY = list(), list()
    

    您现在可以使用 fit_generatorref)进行训练,这将从您的发电机进行批量 生产环境 . 因此,您只处理 batch_size 个样本而不是整个集合 . 您可能希望使用NumPy数组而不是Python列表 .

    对于更有组织的版本,您可以实现Sequence class,它封装数据并充当生成器 .

  • 1

    您的数据生成方法很有趣,但您不必从文本中生成 every 100字节的样本 . 用以下代码替换有问题的代码:

    for i in range(0, article_length - seq_length, 1):
            if random.randint(1,10) not in [5, 6] : continue   # this will skip 80% of the samples
            seq_in = article[i:i + seq_length]
            seq_out = article[i + seq_length]
            dataX.append([tokens[char] for char in seq_in])
            dataY.append(tokens[seq_out])
    

    import random 放在文件开头的某处 . 一旦你把它放到你的代码中,只有5个序列中的1个将进入你的训练数据,有效地减小了大小 .

    有一种方法可以更有效地生成随机采样的字符串,但它需要重写代码,这种方法只需添加一行 .

相关问题