为了账号安全,请及时绑定邮箱和手机立即绑定

CIFAR10 数据加载器采样器拆分

CIFAR10 数据加载器采样器拆分

PHP
拉丁的传说 2023-11-09 21:18:47
我正在尝试分割 CIFAR10 的训练数据,因此训练集的最后 5000 个用于验证。我的代码size = len(CIFAR10_training)dataset_indices = list(range(size))val_index = int(np.floor(0.9 * size))train_idx, val_idx = dataset_indices[:val_index], dataset_indices[val_index:]train_sampler = SubsetRandomSampler(train_idx)val_sampler = SubsetRandomSampler(val_idx)train_dataloader = torch.utils.data.DataLoader(CIFAR10_training,                                          batch_size=config['batch_size'],                                          shuffle=False,  sampler = train_sampler)valid_dataloader = torch.utils.data.DataLoader(CIFAR10_training,                                           batch_size=config['batch_size'],                                           shuffle=False,  sampler = val_sampler)print(len(train_dataloader.dataset),len(valid_dataloader.dataset),但最后一个打印语句打印 50000 和 10000。当我打印 train_idx 和 val_idx 时,它不应该是 45000 和 5000 它打印正确的值([0:44999],[45000:49999] 我的代码有什么问题吗
查看完整描述

1 回答

?
阿波罗的战车

TA贡献1862条经验 获得超6个赞

我无法复制您的结果,当我执行您的代码时,打印语句输出相同数字的两倍:train_CIFAR10valid_dataloaderCIFAR10_test(50000, 50000)

train_dataloader.datasetvalid_dataloader.datasetCIFAR10_training

您不能要求len(train_dataloader)45000/batch_size

如果您需要知道分割的大小,那么您必须计算采样器的长度:

print(len(train_dataloader.sampler), len(valid_dataloader.sampler))

除此之外,您的代码很好,您正确地分割了数据。


查看完整回答
反对 回复 2023-11-09
  • 1 回答
  • 0 关注
  • 75 浏览

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信