首页 文章

pytorch数据加载器多次迭代

提问于
浏览
2

我使用iris-dataset训练一个带有pytorch的简单网络 .

trainset = iris.Iris(train=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=150,
                                          shuffle=True, num_workers=2)

dataiter = iter(trainloader)

数据集本身只有150个数据点,并且pytorch dataloader在整个数据集上迭代一次,因为批量大小为150 .

我现在的问题是,通常有什么方法可以告诉pytorch的dataloader重复数据集,如果它曾经迭代过一次吗?

thnaks

update

得到它runnning :)刚刚创建了一个dataloader的子类并实现了自己的 __next__()

2 回答

  • 1

    最简单的选择是使用嵌套循环:

    for i in range(10):
        for batch in trainloader:
            do_something(batch)
    

    另一个选择是使用itertools.cycle,也许与itertools.take结合使用 .

    当然,使用批量大小等于整个数据集的DataLoader有点不寻常 . 您也不需要在trainloader上调用iter() .

  • 1

    使用itertools.cycle有一个重要的缺点,因为它不会在每次迭代后重排数据:

    当iterable耗尽时,返回保存副本中的元素 .

    在某些情况下,这会对模型的性能产生负面影响 . 解决这个问题的方法是编写自己的循环生成器:

    def cycle(iterable):
        while True:
            for x in iterable:
                yield x
    

    您将使用哪个:

    dataiter = iter(cycle(trainloader))
    

相关问题