首页 文章

对不同规模的小型车进行培训

提问于
浏览
0

我正在尝试在PyTorch中训练一个深度学习模型的图像,这些图像已经被特定维度所取代 . 我想用迷你批次训练我的模型,但是迷你批量大小并没有巧妙地划分每个桶中的例子数量 .

我在a previous post中看到的一个解决方案是用额外的空格填充图像(在训练开始时即时或全部一次),但我不想这样做 . 相反,我希望在培训期间允许批量大小灵活 .

具体来说,如果 N 是存储桶中的图像数量且 B 是批量大小,那么对于该存储桶,如果 BN ,则需要 N // B 批次,否则 N // B + 1 批次 . 最后一批可以包含少于 B 个示例 .

举个例子,假设我有索引[0,1,...,19],包括在内,我想使用批量大小为3 .

索引[0,9]对应于桶0中的图像(形状(C,W1,H1))
索引[10,19]对应于桶1中的图像(形状(C,W2,H2))

(所有图像的通道深度相同) . 然后是可接受的索引分区

batches = [
    [0, 1, 2], 
    [3, 4, 5], 
    [6, 7, 8], 
    [9], 
    [10, 11, 12], 
    [13, 14, 15], 
    [16, 17, 18], 
    [19]
]

我更愿意分别处理分别为9和19的图像,因为它们具有不同的尺寸 .

通过PyTorch的文档,我找到了生成小批量索引列表的BatchSampler类 . 我创建了一个自定义 Sampler 类,它模拟上述索引的分区 . 如果它有帮助,这是我的实现:

class CustomSampler(Sampler):

    def __init__(self, dataset, batch_size):
        self.batch_size = batch_size
        self.buckets = self._get_buckets(dataset)
        self.num_examples = len(dataset)

    def __iter__(self):
        batch = []
        # Process buckets in random order
        dims = random.sample(list(self.buckets), len(self.buckets))
        for dim in dims:
            # Process images in buckets in random order
            bucket = self.buckets[dim]
            bucket = random.sample(bucket, len(bucket))
            for idx in bucket:
                batch.append(idx)
                if len(batch) == self.batch_size:
                    yield batch
                    batch = []
            # Yield half-full batch before moving to next bucket
            if len(batch) > 0:
                yield batch
                batch = []

    def __len__(self):
        return self.num_examples

    def _get_buckets(self, dataset):
        buckets = defaultdict(list)
        for i in range(len(dataset)):
            img, _ = dataset[i]
            dims = img.shape
            buckets[dims].append(i)
        return buckets

但是,当我使用自定义 Sampler 类时,我生成以下错误:

Traceback (most recent call last):
    File "sampler.py", line 143, in <module>
        for i, batch in enumerate(dataloader):
    File "/home/roflcakzorz/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 263, in __next__
        indices = next(self.sample_iter)  # may raise StopIteration
    File "/home/roflcakzorz/anaconda3/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 139, in __iter__
        batch.append(int(idx))
TypeError: int() argument must be a string, a bytes-like object or a number, not 'list'

DataLoader 类似乎期望传递索引,而不是索引列表 .

我不应该使用自定义 Sampler 类来执行此任务吗?我还考虑将自定义 collate_fn 传递给 DataLoader ,但是使用这种方法我不相信我可以控制允许哪些索引在同一个小批量中 . 任何指导将不胜感激 .

1 回答

  • 0

    每个样本有2个网络(必须修复cnn内核大小) . 如果是,只需将上述 custom_sampler 传递给DataLoader类的batch_sampler args即可 . 这将解决问题 .

相关问题