是否有任何一般性建议可以在自定义数据集中有效地处理数据,以便它与minibatch eval / train循环很好地配合?为了更具体地说明我的意思,假设我定义了将x映射到x 1的合成玩具数据集:
import torch.utils.data as data
class Dataset(data.Dataset):
def __init__(self):
super(Dataset, self).__init__()
# list of [x, y]
self.dataset = [
[1, 2],
[2, 3],
[3, 4],
[4, 5]
]
def __getitem__(self, index):
item = self.dataset[index]
return item[0], item[1]
def __len__(self):
return len(self.dataset)
在实践中,这将包装在DataLoader中并在eval / train循环内访问,如下所示:
dataset = Dataset()
data_loader = data.DataLoader(dataset=dataset, batch_size=2, shuffle=True)
epochs = 100
for i_epoch in range(epochs):
for i_minibatch, minibatch in enumerate(data_loader):
x, y = minibatch
# predict and train
数据集对象可能返回原始Python对象,如数字或列表,就像在我的示例实现中一样,但在最后一个代码片段的“预测和训练”部分,我们需要一些特定的数据类型来计算内容,比如torch.FloatTensor(似乎数据加载器可以隐式地执行此操作),甚至可能包装为torch.autograd.Variable,并且可能还需要一些.cuda()调用 . 我的问题是关于何时进行这些数据转换和函数调用的一般建议 .
例如,一个选项是将所有内容保存为数据集内的torch.FloatTensor,并在data_loader循环中添加Variable包装器并调用.cuda() . 我们也可以通过在数据集构造函数或getitem方法中调用.cuda()来获取GPU上的全部或部分数据 . 我认为所有这些方法可能都有利有弊 . 如果我正在训练几个时代的模型,我不想在每个时期或小批量迭代中引入不必要的开销,这可以通过预先计算数据集中的内容来避免 . 可能有更多关于pytorch内部知识的人(可能与一些缓存或jit编译相关的内容)可能会指出更具体的理由来选择一种方法而不是另一种方法 .
2 回答
您是否阅读过一些官方示例,例如imagenet train?在这些示例中,它们首先获取数据 . 正如您所说,数据已被隐式转换为火炬张量 . 然后,如果你有GPU,将cpu张量转换为GPU张量 . 最后,将GPU上的普通张量转换为火炬
Variable
以使autograd工作 .我认为这是做这些事情的规范和标准方式 . 至少,我到目前为止看到的所有pytorch代码都是这样做的 . 如果你想提高速度,你可以考虑
使用dataloader中的多个worker来获取数据
使用多个GPU进行培训
如果您有多个安装了多个GPU的服务器,则
甚至是分布式培训
通常,数据集以对磁盘上存储更友好的格式存储在文件中 . 加载数据集时,您希望数据类型对PyTorch更友好 . 这是由torchvision库的transformations interface完成的 . 例如,对于MNIST,下面是标准转换:
这里
ToTensor
将张量中的所有值除以255,这样如果数据是RGB图像,那么张量中的值将是0.0到1.0 .关键在于,您的磁盘数据理想情况下应该与您可能想要做的事情无关(训练,可视化,计算统计数据等)以及所使用的框架的不可知性 . 在加载与正在执行的操作相关的数据后,应用转换 .
我要提到的另一件事是处理像ImageNet这样的非常大的数据集 . 有一些重要的事情:
您应该避免使用单独的图像文件作为数据集,因为这在群集中不能很好地工作 . 相反,您可以打包所有文件,如LMDB或未压缩的zip(使用Python ZipFile模块),然后只能按顺序访问这些文件 . 大文件中的随机访问会极大地降低您的速度 .
您应该避免在DataLoader类中为大型数据集使用shuffle选项 . 如果你这样做,那么你再次访问大文件与随机访问,性能将坦克 . 相反,你可以做的是顺序读取
K = C * total_epochs * batch_size
记录,其中C
是你选择的常量> = 1.然后将K记录在内存中随机分组然后分批分割 . 不幸的是,您现在必须手动执行此操作 .