diff --git a/benchmarks/make_archives.py b/benchmarks/make_archives.py new file mode 100644 index 000000000..e5c87c19b --- /dev/null +++ b/benchmarks/make_archives.py @@ -0,0 +1,133 @@ +import argparse +import io +import pickle +import tarfile +from math import ceil +from pathlib import Path + +import torch +import torchvision +from tqdm import tqdm + +parser = argparse.ArgumentParser() +parser.add_argument("--input-dir", default="/datasets01_ontap/tinyimagenet/081318/train/") +parser.add_argument("--output-dir", default="./tinyimagenet/081318/train") +parser.add_argument("--archiver", default="pickle", help="pickle or tar or torch") +parser.add_argument( + "--archive-content", default="BytesIo", help="BytesIO or tensor. Only valid for pickle or torch archivers" +) +parser.add_argument("--archive-size", type=int, default=500, help="Number of samples per archive") +parser.add_argument("--shuffle", type=bool, default=True, help="Whether to shuffle the samples within each archive") + +# The archive parameter determines whether we use `tar.add`, `pickle.dump` or +# `torch.save` to write an archive. `torch.save` relies on pickle in the backend +# but has a special handling for tensors (which is maybe faster???): +# - tar: each tar file contains files. Each file is the original encoded jpeg +# file. To avoid storing labels in separate files, we write the corresponding +# label in each file name in the archive. This is ugly, but OK at this stage. +# - pickle or torch: in this case, each archive contains a list of tuples. Each +# tuple represents a sample in the form (img_data, label). label is always an +# int, and img_data is the *encoded* jpeg bytes which can be represented +# either as a tensor or a BytesIO object, depending on the archive-content +# parameter. + + +class Archiver: + def __init__( + self, + input_dir, + output_dir, + archiver="pickle", + archive_content="BytesIO", + archive_size=500, + shuffle=True, + ): + self.input_dir = input_dir + self.archiver = archiver.lower() + self.archive_content = archive_content.lower() + self.archive_size = archive_size + self.shuffle = shuffle + + self.output_dir = Path(output_dir).resolve() + self.output_dir.mkdir(parents=True, exist_ok=True) + + def archive_dataset(self): + def loader(path): + # identity loader to avoid decoding images with PIL or something else + # This means the dataset will always return (path_to_image_file, int_label) + return path + + dataset = torchvision.datasets.ImageFolder(self.input_dir, loader=loader) + self.num_samples = len(dataset) + + if self.shuffle: + self.indices = torch.randperm(self.num_samples) + else: + self.indices = torch.arange(self.num_samples) + + archive_samples = [] + for i, idx in enumerate(tqdm(self.indices)): + archive_samples.append(dataset[idx]) + if ((i + 1) % self.archive_size == 0) or (i == len(self.indices) - 1): + archive_path = self._get_archive_path(archive_samples, last_idx=i) + {"pickle": self._save_pickle, "torch": self._save_torch, "tar": self._save_tar}[self.archiver]( + archive_samples, archive_path + ) + + archive_samples = [] + + def _get_archive_path(self, samples, last_idx): + current_archive_number = last_idx // self.archive_size + total_num_archives_needed = ceil(self.num_samples / self.archive_size) + zero_pad_fmt = len(str(total_num_archives_needed)) + num_samples_in_archive = len(samples) + + archive_content_str = "" if self.archiver == "tar" else f"{self.archive_content}_" + path = ( + self.output_dir + / f"archive_{self.archive_size}_{archive_content_str}{current_archive_number:0{zero_pad_fmt}d}" + ) + print(f"Archiving {num_samples_in_archive} samples in {path}") + return path + + def _make_content(self, samples): + archive_content = [] + for sample_file_name, label in samples: + if self.archive_content == "bytesio": + with open(sample_file_name, "rb") as f: + img_data = io.BytesIO(f.read()) + elif self.archive_content == "tensor": # Note: this doesn't decode anything + img_data = torchvision.io.read_file(sample_file_name) + else: + raise ValueError(f"Unsupported {self.archive_content = }") + archive_content.append((img_data, label)) + return archive_content + + def _save_pickle(self, samples, archive_name): + archive_content = self._make_content(samples) + archive_name = archive_name.with_suffix(".pkl") + with open(archive_name, "wb") as f: + pickle.dump(archive_content, f) + + def _save_torch(self, samples, archive_name): + archive_content = self._make_content(samples) + archive_name = archive_name.with_suffix(".pt") + torch.save(archive_content, archive_name) + + def _save_tar(self, samples, archive_path): + archive_path = archive_path.with_suffix(".tar") + with tarfile.open(archive_path, "w") as tar: + for sample_file_name, label in samples: + path = Path(sample_file_name) + tar.add(path, arcname=f"{label}/{path.name}") + + +args = parser.parse_args() +Archiver( + input_dir=args.input_dir, + output_dir=args.output_dir, + archiver=args.archiver, + archive_content=args.archive_content, + shuffle=args.shuffle, + archive_size=args.archive_size, +).archive_dataset() diff --git a/benchmarks/torchvision_classification/helpers.py b/benchmarks/torchvision_classification/helpers.py index 383bc554a..230b89c97 100644 --- a/benchmarks/torchvision_classification/helpers.py +++ b/benchmarks/torchvision_classification/helpers.py @@ -1,5 +1,6 @@ import itertools import os +import pickle import random from functools import partial from pathlib import Path @@ -8,7 +9,7 @@ import torch.distributed as dist import torchvision from PIL import Image -from torchdata.datapipes.iter import FileLister, IterDataPipe +from torchdata.datapipes.iter import FileLister, FileOpener, IterDataPipe, TarArchiveLoader # TODO: maybe infinite buffer can / is already natively supported by torchdata? @@ -20,13 +21,19 @@ class _LenSetter(IterDataPipe): # TODO: Ideally, we woudn't need this extra class - def __init__(self, dp, root): + def __init__(self, dp, root, args): self.dp = dp if "train" in str(root): - self.size = IMAGENET_TRAIN_LEN + if args.tiny: + self.size = 100_000 + else: + self.size = IMAGENET_TRAIN_LEN elif "val" in str(root): - self.size = IMAGENET_TEST_LEN + if args.tiny: + self.size = 10_000 + else: + self.size = IMAGENET_TEST_LEN else: raise ValueError("oops?") @@ -35,25 +42,47 @@ def __iter__(self): def __len__(self): # TODO The // world_size part shouldn't be needed. See https://github.com/pytorch/data/issues/533 - return self.size // dist.get_world_size() + if dist.is_initialized(): + return self.size // dist.get_world_size() + else: + return self.size -def _decode(path, root, category_to_int): - category = Path(path).relative_to(root).parts[0] +def _apply_tranforms(img_and_label, transforms): + img, label = img_and_label + return transforms(img), label - image = Image.open(path).convert("RGB") - label = category_to_int(category) - return image, label +class ArchiveLoader(IterDataPipe): + def __init__(self, source_datapipe, loader): + self.loader = pickle.load if loader == "pickle" else torch.load + self.source_datapipe = source_datapipe + + def __iter__(self): + for filename in self.source_datapipe: + with open(filename, "rb") as f: + yield self.loader(f) -def _apply_tranforms(img_and_label, transforms): - img, label = img_and_label - return transforms(img), label +class ConcaterIterable(IterDataPipe): + # TODO: This should probably be a built-in: https://github.com/pytorch/data/issues/648 + def __init__(self, source_datapipe): + self.source_datapipe = source_datapipe + def __iter__(self): + for iterable in self.source_datapipe: + yield from iterable + + +def _decode_path(data, root, category_to_int): + path = data + category = Path(path).relative_to(root).parts[0] + image = Image.open(path).convert("RGB") + label = category_to_int[category] + return image, label -def make_dp(root, transforms): +def _make_dp_from_image_folder(root): root = Path(root).expanduser().resolve() categories = sorted(entry.name for entry in os.scandir(root) if entry.is_dir()) category_to_int = {category: i for (i, category) in enumerate(categories)} @@ -61,10 +90,69 @@ def make_dp(root, transforms): dp = FileLister(str(root), recursive=True, masks=["*.JPEG"]) dp = dp.shuffle(buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False).sharding_filter() - dp = dp.map(partial(_decode, root=root, category_to_int=category_to_int)) - dp = dp.map(partial(_apply_tranforms, transforms=transforms)) + dp = dp.map(partial(_decode_path, root=root, category_to_int=category_to_int)) + return dp + + +def _decode_bytesio(data): + image, label = data + image = Image.open(image).convert("RGB") + return image, label + + +def _decode_tensor(data): + image, label = data + image = torchvision.io.decode_jpeg(image, mode=torchvision.io.ImageReadMode.RGB) + return image, label - dp = _LenSetter(dp, root=root) + +def _make_dp_from_archive(root, args): + ext = "pt" if args.archive == "torch" else "pkl" + dp = FileLister(str(root), masks=[f"archive_{args.archive_size}*{args.archive_content}*.{ext}"]) + dp = dp.shuffle(buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False) # inter-archive shuffling + dp = ArchiveLoader(dp, loader=args.archive) + dp = ConcaterIterable(dp) + dp = dp.shuffle(buffer_size=args.archive_size).set_shuffle(False) # intra-archive shuffling + + # TODO: we're sharding here but the big BytesIO or Tensors have already been + # loaded by all workers, possibly in vain. Hopefully the new experimental MP + # reading service will improve this? + dp = dp.sharding_filter() + decode = {"bytesio": _decode_bytesio, "tensor": _decode_tensor}[args.archive_content] + return dp.map(decode) + + +def _decode_tar_entry(data): + # Note on how we retrieve the label: each file name in the archive (the + # "arcnames" as from the tarfile docs) looks like "label/some_name.jpg". + # It's somewhat hacky and will obviously change, but it's OK for now. + filename, io_stream = data + label = int(Path(filename).parent.name) + image = Image.open(io_stream).convert("RGB") + return image, label + + +def _make_dp_from_tars(root, args): + + dp = FileLister(str(root), masks=[f"archive_{args.archive_size}*.tar"]) + dp = dp.shuffle(buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False) # inter-archive shuffling + dp = FileOpener(dp, mode="b") + dp = TarArchiveLoader(dp) + dp = dp.shuffle(buffer_size=args.archive_size).set_shuffle(False) # intra-archive shuffling + dp = dp.sharding_filter() + return dp.map(_decode_tar_entry) + + +def make_dp(root, transforms, args): + if args.archive in ("pickle", "torch"): + dp = _make_dp_from_archive(root, args) + elif args.archive == "tar": + dp = _make_dp_from_tars(root, args) + else: + dp = _make_dp_from_image_folder(root) + + dp = dp.map(partial(_apply_tranforms, transforms=transforms)) + dp = _LenSetter(dp, root=root, args=args) return dp @@ -97,10 +185,10 @@ def __iter__(self): yield self.samples[idx % len(self.samples)] -def make_pre_loaded_dp(root, transforms): +def make_pre_loaded_dp(root, transforms, args): dp = _PreLoadedDP(root=root, transforms=transforms) dp = dp.shuffle(buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False).sharding_filter() - dp = _LenSetter(dp, root=root) + dp = _LenSetter(dp, root=root, args=args) return dp diff --git a/benchmarks/torchvision_classification/presets.py b/benchmarks/torchvision_classification/presets.py index 00c7bfa8b..eaab88454 100644 --- a/benchmarks/torchvision_classification/presets.py +++ b/benchmarks/torchvision_classification/presets.py @@ -6,6 +6,7 @@ class ClassificationPresetTrain: def __init__( self, *, + on_pil_images, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), @@ -17,7 +18,7 @@ def __init__( trans.extend( [ - transforms.PILToTensor(), + transforms.PILToTensor() if on_pil_images else torch.nn.Identity(), transforms.ConvertImageDtype(torch.float), transforms.Normalize(mean=mean, std=std), ] @@ -32,6 +33,7 @@ def __call__(self, img): class ClassificationPresetEval: def __init__( self, + on_pil_images, *, crop_size, resize_size=256, @@ -43,7 +45,7 @@ def __init__( [ transforms.Resize(resize_size), transforms.CenterCrop(crop_size), - transforms.PILToTensor(), + transforms.PILToTensor() if on_pil_images else torch.nn.Identity(), transforms.ConvertImageDtype(torch.float), transforms.Normalize(mean=mean, std=std), ] diff --git a/benchmarks/torchvision_classification/train.py b/benchmarks/torchvision_classification/train.py index ff12e902f..35ed72b31 100644 --- a/benchmarks/torchvision_classification/train.py +++ b/benchmarks/torchvision_classification/train.py @@ -6,6 +6,7 @@ import helpers import presets import torch +import torch.distributed as dist import torch.utils.data import torchvision import utils @@ -73,6 +74,7 @@ def evaluate(model, criterion, data_loader, device, args, print_freq=100, log_su and hasattr(data_loader.dataset, "__len__") and len(data_loader.dataset) != num_processed_samples and torch.distributed.get_rank() == 0 + and not args.data_loading_only ): warnings.warn( f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} " @@ -93,15 +95,22 @@ def create_data_loaders(args): if args.fs == "fsx": dataset_dir = "/datasets01" elif args.fs == "fsx_isolated": - dataset_dir = "/fsx_isolated" + dataset_dir = "/fsx_isolated/nicolashug" elif args.fs == "ontap": dataset_dir = "/datasets01_ontap" elif args.fs == "ontap_isolated": - dataset_dir = "/ontap_isolated" + dataset_dir = "/ontap_isolated/nicolashug" else: raise ValueError(f"bad args.fs, got {args.fs}") - dataset_dir += "/imagenet_full_size/061417/" + if args.tiny: + dataset_dir += "/tinyimagenet/081318/" + else: + dataset_dir += "/imagenet_full_size/061417/" + + if args.archive is not None: + dataset_dir += "archives/" + train_dir = os.path.join(dataset_dir, "train") val_dir = os.path.join(dataset_dir, "val") @@ -110,13 +119,16 @@ def create_data_loaders(args): if args.no_transforms: train_preset = val_preset = helpers.no_transforms else: - train_preset = presets.ClassificationPresetTrain(crop_size=train_crop_size) - val_preset = presets.ClassificationPresetEval(crop_size=val_crop_size, resize_size=val_resize_size) + on_pil_images = args.archive_content != "tensor" + train_preset = presets.ClassificationPresetTrain(crop_size=train_crop_size, on_pil_images=on_pil_images) + val_preset = presets.ClassificationPresetEval( + crop_size=val_crop_size, resize_size=val_resize_size, on_pil_images=on_pil_images + ) if args.ds_type == "dp": builder = helpers.make_pre_loaded_dp if args.preload_ds else helpers.make_dp - train_dataset = builder(train_dir, transforms=train_preset) - val_dataset = builder(val_dir, transforms=val_preset) + train_dataset = builder(train_dir, transforms=train_preset, args=args) + val_dataset = builder(val_dir, transforms=val_preset, args=args) train_sampler = val_sampler = None train_shuffle = True @@ -136,8 +148,12 @@ def create_data_loaders(args): train_dataset = builder(train_dir, transform=train_preset) val_dataset = builder(val_dir, transform=val_preset) - train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True) - val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False) + if dist.is_initialized(): + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True) + val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False) + else: + train_sampler = torch.utils.data.RandomSampler(train_dataset) + val_sampler = torch.utils.data.SequentialSampler(val_dataset) train_shuffle = None # but actually True else: @@ -203,7 +219,7 @@ def main(args): train_data_loader, val_data_loader, train_sampler = create_data_loaders(args) - num_classes = 1000 # I'm lazy. TODO change this + num_classes = 200 if args.tiny else 1000 # I'm lazy. TODO change this print("Creating model") model = torchvision.models.__dict__[args.model](weights=args.weights, num_classes=num_classes) @@ -231,7 +247,8 @@ def main(args): if args.distributed and train_sampler is not None: train_sampler.set_epoch(epoch) train_one_epoch(model, criterion, optimizer, train_data_loader, device, epoch, args) - lr_scheduler.step() + if not args.data_loading_only: + lr_scheduler.step() evaluate(model, criterion, val_data_loader, device=device, args=args) if args.output_dir: @@ -338,6 +355,13 @@ def get_args_parser(add_help=True): help="'V1' or 'V2'. V2 only works for datapipes", ) + parser.add_argument("--tiny", action="store_true") + parser.add_argument("--archive", default=None, help="tar or pickle or torch") + parser.add_argument( + "--archive-content", default=None, help="tensor or bytesio. Only for pickle and torch archives." + ) + parser.add_argument("--archive-size", type=int, default=None, help="Number of samples in each archive.") + return parser diff --git a/benchmarks/torchvision_classification/utils.py b/benchmarks/torchvision_classification/utils.py index b41bb8971..99fc1e0d8 100644 --- a/benchmarks/torchvision_classification/utils.py +++ b/benchmarks/torchvision_classification/utils.py @@ -112,57 +112,42 @@ def log_every(self, iterable, print_freq, header=None): header = "" start_time = time.time() end = time.time() - iter_time = SmoothedValue(fmt="{avg:.4f}") - data_time = SmoothedValue(fmt="{avg:.4f}") - model_time = SmoothedValue(fmt="{avg:.4f}") + iter_time = SmoothedValue(fmt="{avg:.2f}") + data_time = SmoothedValue(fmt="{avg:.3f}") + model_time = SmoothedValue(fmt="{avg:.2f}") space_fmt = ":" + str(len(str(len(iterable)))) + "d" - if torch.cuda.is_available(): - log_msg = self.delimiter.join( - [ - header, - "[{0" + space_fmt + "}/{1}]", - "eta: {eta}", - "{meters}", - "time: {time}", - "data: {data}", - "model: {model}", - "max mem: {memory:.0f}", - ] - ) - else: - log_msg = self.delimiter.join( - [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"] - ) - MB = 1024.0 * 1024.0 - for obj in iterable: + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "time: {time}", + "data: {data}", + "model: {model}", + "qs: {qs}", + ] + ) + dl_iterator = iter(iterable) + q = getattr(dl_iterator, "_data_queue", None) + for obj in dl_iterator: dtime = time.time() - end data_time.update(dtime) + yield obj + ttime = time.time() - end iter_time.update(ttime) model_time.update(ttime - dtime) if i % print_freq == 0: - eta_seconds = iter_time.global_avg * (len(iterable) - i) - eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) - if torch.cuda.is_available(): - print( - log_msg.format( - i, - len(iterable), - eta=eta_string, - meters=str(self), - time=str(iter_time), - data=str(data_time), - model=str(model_time), - memory=torch.cuda.max_memory_allocated() / MB, - ) - ) - else: - print( - log_msg.format( - i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time) - ) + print( + log_msg.format( + i, + len(iterable), + time=str(iter_time), + data=str(data_time), + model=str(model_time), + qs=(q.qsize() if q else 0), ) + ) i += 1 end = time.time() total_time = time.time() - start_time @@ -243,15 +228,13 @@ def save_on_master(*args, **kwargs): def init_distributed_mode(args): - if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + if all(hasattr(args, attr) for attr in ("rank", "gpu", "world_size")): + called_from = "run_with_submitit.py" + elif "RANK" in os.environ and "WORLD_SIZE" in os.environ: + called_from = "torchrun" args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ["WORLD_SIZE"]) args.gpu = int(os.environ["LOCAL_RANK"]) - elif "SLURM_PROCID" in os.environ: - args.rank = int(os.environ["SLURM_PROCID"]) - args.gpu = args.rank % torch.cuda.device_count() - elif hasattr(args, "rank"): - pass else: print("Not using distributed mode") args.distributed = False @@ -266,8 +249,8 @@ def init_distributed_mode(args): backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank ) torch.distributed.barrier() - if args.data_loader.lower() != "ffcv": - setup_for_distributed(args.rank == 0) + setup_for_distributed(args.rank == 0) + print(f"Called from {called_from}") def reduce_across_processes(val):