-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataset.py
21 lines (14 loc) · 1.37 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch.utils.data
from torchvision import datasets, transforms
def get_dataloaders(fabric, config):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
with fabric.rank_zero_first(): # set `local=True` if your filesystem is not shared between machines
train_dataset = datasets.MNIST(config.dataset_path, download=fabric.is_global_zero, train=True, transform=transform)
val_dataset = datasets.MNIST(config.dataset_path, download=fabric.is_global_zero, train=False, transform=transform)
if hasattr(config, 'fast_run') and config.fast_run:
train_dataset = torch.utils.data.Subset(train_dataset, list(range(config.eff_batch_size*4)))
val_dataset = torch.utils.data.Subset(val_dataset, list(range(config.eff_batch_size*2)))
config.batch_size = round(config.eff_batch_size/fabric.world_size) # using the effective batch_size to calculate the batch_size per gpu
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, drop_last = True, num_workers = 2, multiprocessing_context='spawn', pin_memory = True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=config.batch_size, drop_last = True, num_workers = 2, multiprocessing_context='spawn', pin_memory = True)
return train_dataloader, val_dataloader