Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmark: Add archive support (generation and training) #744

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions benchmarks/make_archives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import argparse
import io
import pickle
import tarfile
from math import ceil
from pathlib import Path

import torch
import torchvision
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument("--input-dir", default="/datasets01_ontap/tinyimagenet/081318/train/")
parser.add_argument("--output-dir", default="./tinyimagenet/081318/train")
parser.add_argument("--archiver", default="pickle", help="pickle or tar or torch")
parser.add_argument(
"--archive-content", default="BytesIo", help="BytesIO or tensor. Only valid for pickle or torch archivers"
)
parser.add_argument("--archive-size", type=int, default=500, help="Number of samples per archive")
parser.add_argument("--shuffle", type=bool, default=True, help="Whether to shuffle the samples within each archive")

# The archive parameter determines whether we use `tar.add`, `pickle.dump` or
# `torch.save` to write an archive. `torch.save` relies on pickle in the backend
# but has a special handling for tensors (which is maybe faster???):
# - tar: each tar file contains files. Each file is the original encoded jpeg
# file. To avoid storing labels in separate files, we write the corresponding
# label in each file name in the archive. This is ugly, but OK at this stage.
# - pickle or torch: in this case, each archive contains a list of tuples. Each
# tuple represents a sample in the form (img_data, label). label is always an
# int, and img_data is the *encoded* jpeg bytes which can be represented
# either as a tensor or a BytesIO object, depending on the archive-content
# parameter.


class Archiver:
def __init__(
self,
input_dir,
output_dir,
archiver="pickle",
archive_content="BytesIO",
archive_size=500,
shuffle=True,
):
self.input_dir = input_dir
self.archiver = archiver.lower()
self.archive_content = archive_content.lower()
self.archive_size = archive_size
self.shuffle = shuffle

self.output_dir = Path(output_dir).resolve()
self.output_dir.mkdir(parents=True, exist_ok=True)

def archive_dataset(self):
def loader(path):
# identity loader to avoid decoding images with PIL or something else
# This means the dataset will always return (path_to_image_file, int_label)
return path

dataset = torchvision.datasets.ImageFolder(self.input_dir, loader=loader)
self.num_samples = len(dataset)

if self.shuffle:
self.indices = torch.randperm(self.num_samples)
else:
self.indices = torch.arange(self.num_samples)

archive_samples = []
for i, idx in enumerate(tqdm(self.indices)):
archive_samples.append(dataset[idx])
if ((i + 1) % self.archive_size == 0) or (i == len(self.indices) - 1):
archive_path = self._get_archive_path(archive_samples, last_idx=i)
{"pickle": self._save_pickle, "torch": self._save_torch, "tar": self._save_tar}[self.archiver](
archive_samples, archive_path
)

archive_samples = []

def _get_archive_path(self, samples, last_idx):
current_archive_number = last_idx // self.archive_size
total_num_archives_needed = ceil(self.num_samples / self.archive_size)
zero_pad_fmt = len(str(total_num_archives_needed))
num_samples_in_archive = len(samples)

archive_content_str = "" if self.archiver == "tar" else f"{self.archive_content}_"
path = (
self.output_dir
/ f"archive_{self.archive_size}_{archive_content_str}{current_archive_number:0{zero_pad_fmt}d}"
)
print(f"Archiving {num_samples_in_archive} samples in {path}")
return path

def _make_content(self, samples):
archive_content = []
for sample_file_name, label in samples:
if self.archive_content == "bytesio":
with open(sample_file_name, "rb") as f:
img_data = io.BytesIO(f.read())
elif self.archive_content == "tensor": # Note: this doesn't decode anything
img_data = torchvision.io.read_file(sample_file_name)
else:
raise ValueError(f"Unsupported {self.archive_content = }")
archive_content.append((img_data, label))
return archive_content

def _save_pickle(self, samples, archive_name):
archive_content = self._make_content(samples)
archive_name = archive_name.with_suffix(".pkl")
with open(archive_name, "wb") as f:
pickle.dump(archive_content, f)

def _save_torch(self, samples, archive_name):
archive_content = self._make_content(samples)
archive_name = archive_name.with_suffix(".pt")
torch.save(archive_content, archive_name)

def _save_tar(self, samples, archive_path):
archive_path = archive_path.with_suffix(".tar")
with tarfile.open(archive_path, "w") as tar:
for sample_file_name, label in samples:
path = Path(sample_file_name)
tar.add(path, arcname=f"{label}/{path.name}")


args = parser.parse_args()
Archiver(
input_dir=args.input_dir,
output_dir=args.output_dir,
archiver=args.archiver,
archive_content=args.archive_content,
shuffle=args.shuffle,
archive_size=args.archive_size,
).archive_dataset()
126 changes: 107 additions & 19 deletions benchmarks/torchvision_classification/helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import os
import pickle
import random
from functools import partial
from pathlib import Path
Expand All @@ -8,7 +9,7 @@
import torch.distributed as dist
import torchvision
from PIL import Image
from torchdata.datapipes.iter import FileLister, IterDataPipe
from torchdata.datapipes.iter import FileLister, FileOpener, IterDataPipe, TarArchiveLoader


# TODO: maybe infinite buffer can / is already natively supported by torchdata?
Expand All @@ -20,13 +21,19 @@

class _LenSetter(IterDataPipe):
# TODO: Ideally, we woudn't need this extra class
def __init__(self, dp, root):
def __init__(self, dp, root, args):
self.dp = dp

if "train" in str(root):
self.size = IMAGENET_TRAIN_LEN
if args.tiny:
self.size = 100_000
else:
self.size = IMAGENET_TRAIN_LEN
elif "val" in str(root):
self.size = IMAGENET_TEST_LEN
if args.tiny:
self.size = 10_000
else:
self.size = IMAGENET_TEST_LEN
else:
raise ValueError("oops?")

Expand All @@ -35,36 +42,117 @@ def __iter__(self):

def __len__(self):
# TODO The // world_size part shouldn't be needed. See https://github.com/pytorch/data/issues/533
return self.size // dist.get_world_size()
if dist.is_initialized():
return self.size // dist.get_world_size()
else:
return self.size


def _decode(path, root, category_to_int):
category = Path(path).relative_to(root).parts[0]
def _apply_tranforms(img_and_label, transforms):
img, label = img_and_label
return transforms(img), label

image = Image.open(path).convert("RGB")
label = category_to_int(category)

return image, label
class ArchiveLoader(IterDataPipe):
def __init__(self, source_datapipe, loader):
self.loader = pickle.load if loader == "pickle" else torch.load
self.source_datapipe = source_datapipe

def __iter__(self):
for filename in self.source_datapipe:
with open(filename, "rb") as f:
yield self.loader(f)


def _apply_tranforms(img_and_label, transforms):
img, label = img_and_label
return transforms(img), label
class ConcaterIterable(IterDataPipe):
# TODO: This should probably be a built-in: https://github.com/pytorch/data/issues/648
def __init__(self, source_datapipe):
self.source_datapipe = source_datapipe

def __iter__(self):
for iterable in self.source_datapipe:
yield from iterable


def _decode_path(data, root, category_to_int):
path = data
category = Path(path).relative_to(root).parts[0]
image = Image.open(path).convert("RGB")
label = category_to_int[category]
return image, label

def make_dp(root, transforms):

def _make_dp_from_image_folder(root):
root = Path(root).expanduser().resolve()
categories = sorted(entry.name for entry in os.scandir(root) if entry.is_dir())
category_to_int = {category: i for (i, category) in enumerate(categories)}

dp = FileLister(str(root), recursive=True, masks=["*.JPEG"])

dp = dp.shuffle(buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False).sharding_filter()
dp = dp.map(partial(_decode, root=root, category_to_int=category_to_int))
dp = dp.map(partial(_apply_tranforms, transforms=transforms))
dp = dp.map(partial(_decode_path, root=root, category_to_int=category_to_int))
return dp


def _decode_bytesio(data):
image, label = data
image = Image.open(image).convert("RGB")
return image, label


def _decode_tensor(data):
image, label = data
image = torchvision.io.decode_jpeg(image, mode=torchvision.io.ImageReadMode.RGB)
return image, label

dp = _LenSetter(dp, root=root)

def _make_dp_from_archive(root, args):
ext = "pt" if args.archive == "torch" else "pkl"
dp = FileLister(str(root), masks=[f"archive_{args.archive_size}*{args.archive_content}*.{ext}"])
dp = dp.shuffle(buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False) # inter-archive shuffling
dp = ArchiveLoader(dp, loader=args.archive)
dp = ConcaterIterable(dp)
dp = dp.shuffle(buffer_size=args.archive_size).set_shuffle(False) # intra-archive shuffling

# TODO: we're sharding here but the big BytesIO or Tensors have already been
# loaded by all workers, possibly in vain. Hopefully the new experimental MP
# reading service will improve this?
dp = dp.sharding_filter()
decode = {"bytesio": _decode_bytesio, "tensor": _decode_tensor}[args.archive_content]
return dp.map(decode)


def _decode_tar_entry(data):
# Note on how we retrieve the label: each file name in the archive (the
# "arcnames" as from the tarfile docs) looks like "label/some_name.jpg".
# It's somewhat hacky and will obviously change, but it's OK for now.
filename, io_stream = data
label = int(Path(filename).parent.name)
image = Image.open(io_stream).convert("RGB")
return image, label


def _make_dp_from_tars(root, args):

dp = FileLister(str(root), masks=[f"archive_{args.archive_size}*.tar"])
dp = dp.shuffle(buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False) # inter-archive shuffling
dp = FileOpener(dp, mode="b")
dp = TarArchiveLoader(dp)
dp = dp.shuffle(buffer_size=args.archive_size).set_shuffle(False) # intra-archive shuffling
dp = dp.sharding_filter()
return dp.map(_decode_tar_entry)


def make_dp(root, transforms, args):
if args.archive in ("pickle", "torch"):
dp = _make_dp_from_archive(root, args)
elif args.archive == "tar":
dp = _make_dp_from_tars(root, args)
else:
dp = _make_dp_from_image_folder(root)

dp = dp.map(partial(_apply_tranforms, transforms=transforms))
dp = _LenSetter(dp, root=root, args=args)
return dp


Expand Down Expand Up @@ -97,10 +185,10 @@ def __iter__(self):
yield self.samples[idx % len(self.samples)]


def make_pre_loaded_dp(root, transforms):
def make_pre_loaded_dp(root, transforms, args):
dp = _PreLoadedDP(root=root, transforms=transforms)
dp = dp.shuffle(buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False).sharding_filter()
dp = _LenSetter(dp, root=root)
dp = _LenSetter(dp, root=root, args=args)
return dp


Expand Down
6 changes: 4 additions & 2 deletions benchmarks/torchvision_classification/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class ClassificationPresetTrain:
def __init__(
self,
*,
on_pil_images,
crop_size,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
Expand All @@ -17,7 +18,7 @@ def __init__(

trans.extend(
[
transforms.PILToTensor(),
transforms.PILToTensor() if on_pil_images else torch.nn.Identity(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
]
Expand All @@ -32,6 +33,7 @@ def __call__(self, img):
class ClassificationPresetEval:
def __init__(
self,
on_pil_images,
*,
crop_size,
resize_size=256,
Expand All @@ -43,7 +45,7 @@ def __init__(
[
transforms.Resize(resize_size),
transforms.CenterCrop(crop_size),
transforms.PILToTensor(),
transforms.PILToTensor() if on_pil_images else torch.nn.Identity(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
]
Expand Down
Loading