Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Multithread memmap #592

Merged
merged 20 commits into from
Dec 12, 2023
51 changes: 51 additions & 0 deletions benchmarks/common/memmap_benchmarks_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import argparse
from pathlib import Path

import pytest
import torch

from tensordict import MemmapTensor, TensorDict
from torch import nn


def get_available_devices():
Expand Down Expand Up @@ -77,6 +79,55 @@ def test_memmaptd_index_op(benchmark, td_memmap):
)


def test_serialize_model(benchmark, tmpdir):
"""Tests efficiency of saving weights as memmap tensors, including TD construction."""
with torch.device("cuda" if torch.cuda.device_count() else "cpu"):
t = nn.Transformer()
benchmark(lambda: TensorDict.from_module(t).memmap(tmpdir, num_threads=32))


def test_serialize_model_filesystem(benchmark):
"""Tests efficiency of saving weights as memmap tensors in file system, including TD construction."""
with torch.device("cuda" if torch.cuda.device_count() else "cpu"):
t = nn.Transformer()
benchmark(lambda: TensorDict.from_module(t).memmap(num_threads=32))


def test_serialize_model_pickle(benchmark, tmpdir):
"""Tests efficiency of pickling a model state-dict, including state-dict construction."""
with torch.device("cuda" if torch.cuda.device_count() else "cpu"):
t = nn.Transformer()
path = Path(tmpdir) / "file.t"
benchmark(lambda: torch.save(t.state_dict(), path))


def test_serialize_weights(benchmark, tmpdir):
"""Tests efficiency of saving weights as memmap tensors."""
with torch.device("cuda" if torch.cuda.device_count() else "cpu"):
t = nn.Transformer()

weights = TensorDict.from_module(t)
benchmark(lambda: weights.memmap(tmpdir, num_threads=32))


def test_serialize_weights_filesystem(benchmark):
"""Tests efficiency of saving weights as memmap tensors."""
with torch.device("cuda" if torch.cuda.device_count() else "cpu"):
t = nn.Transformer()

weights = TensorDict.from_module(t)
benchmark(lambda: weights.memmap(num_threads=32))


def test_serialize_weights_pickle(benchmark, tmpdir):
"""Tests efficiency of pickling a model state-dict."""
with torch.device("cuda" if torch.cuda.device_count() else "cpu"):
t = nn.Transformer()
path = Path(tmpdir) / "file.t"
weights = t.state_dict()
benchmark(lambda: torch.save(weights, path))


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
192 changes: 121 additions & 71 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tensordict.base import (
_ACCEPTED_CLASSES,
_is_tensor_collection,
_register_tensor_class,
BEST_ATTEMPT_INPLACE,
CompatibleType,
is_tensor_collection,
Expand Down Expand Up @@ -1522,55 +1523,60 @@ def detach_(self) -> T:
td.detach_()
return self

def memmap_(self, prefix: str | None = None, copy_existing: bool = False) -> T:
if prefix is not None:
prefix = Path(prefix)
if not prefix.exists():
os.makedirs(prefix, exist_ok=True)
torch.save({"stack_dim": self.stack_dim}, prefix / "meta.pt")
for i, td in enumerate(self.tensordicts):
td.memmap_(
prefix=(prefix / str(i)) if prefix is not None else None,
copy_existing=copy_existing,
)
self._is_memmap = True
self._is_shared = False
self._device = torch.device("cpu")
self.lock_()
return self

def memmap_like(
def _memmap_(
self,
prefix: str | None = None,
copy_existing: bool = False,
executor=None,
futures=None,
inplace=True,
like=False,
) -> T:
tds = []
if prefix is not None:
prefix = Path(prefix)
if not prefix.exists():
os.makedirs(prefix, exist_ok=True)
torch.save({"stack_dim": self.stack_dim}, prefix / "meta.pt")

def save_metadata(prefix=prefix, self=self):
prefix = Path(prefix)
if not prefix.exists():
os.makedirs(prefix, exist_ok=True)
with open(prefix / "meta.json", "w") as f:
json.dump(
{"_type": str(self.__class__), "stack_dim": self.stack_dim}, f
)

if executor is None:
save_metadata()
else:
futures.append(executor.submit(save_metadata))

results = []
for i, td in enumerate(self.tensordicts):
td_like = td.memmap_like(
prefix=(prefix / str(i)) if prefix is not None else None,
results.append(
td._memmap_(
prefix=(prefix / str(i)) if prefix is not None else None,
copy_existing=copy_existing,
executor=executor,
futures=futures,
inplace=inplace,
like=like,
)
)
tds.append(td_like)
td_out = torch.stack(tds, self.stack_dim)
self._is_memmap = True
self._is_shared = False
self._device = torch.device("cpu")
td_out.lock_()
return td_out
if not inplace:
results = torch.stack(results, dim=self.stack_dim)
else:
results = self
results._is_memmap = True
results._is_shared = False
results._device = torch.device("cpu")
return results

@classmethod
def load_memmap(cls, prefix: str) -> LazyStackedTensorDict:
prefix = Path(prefix)
def _load_memmap(cls, prefix: str, metadata: dict) -> LazyStackedTensorDict:
tensordicts = []
i = 0
while (prefix / str(i)).exists():
tensordicts.append(TensorDict.load_memmap(prefix / str(i)))
i += 1

metadata = torch.load(prefix / "meta.pt")
return cls(*tensordicts, stack_dim=metadata["stack_dim"])

def expand(self, *args: int, inplace: bool = False) -> T:
Expand Down Expand Up @@ -2009,8 +2015,14 @@ def __init__(
if batch_size is not None and batch_size != self.batch_size:
raise RuntimeError("batch_size does not match self.batch_size.")

# def __hash__(self):
# return hash((self._source, self.custom_op, self.inv_op, self.custom_op_kwargs, self.inv_op_kwargs))
def is_empty(self) -> bool:
return self._source.is_empty()

def is_memmap(self) -> bool:
return self._source.is_memmap()

def is_shared(self) -> bool:
return self._source.is_shared()

def _update_custom_op_kwargs(self, source_tensor: Tensor) -> dict[str, Any]:
"""Allows for a transformation to be customized for a certain shape, device or dtype.
Expand Down Expand Up @@ -2296,51 +2308,81 @@ def masked_fill(self, mask: Tensor, value: float | bool) -> T:
td_copy = self.clone()
return td_copy.masked_fill_(mask, value)

def memmap_(
self, prefix: str | None = None, copy_existing: bool = False
) -> _CustomOpTensorDict:
def load_metadata(filepath):
with open(filepath) as json_metadata:
metadata = json.load(json_metadata)
return metadata

def save_metadata(filepath, metadata=None):
def _memmap_(
self,
*,
prefix: str | None,
copy_existing: bool,
executor,
futures,
inplace,
like,
) -> T:
def save_metadata(data: TensorDictBase, filepath, metadata=None):
if metadata is None:
metadata = {}
metadata.update(
{
"shape": list(data.shape),
"device": str(data.device),
"_type": str(data.__class__),
"custom_op": data.custom_op,
"inv_op": data.inv_op,
"custom_op_kwargs": data.custom_op_kwargs,
"inv_op_kwargs": data.inv_op_kwargs,
}
)
with open(filepath, "w") as json_metadata:
json.dump(metadata, json_metadata)

self._source.memmap_(prefix=prefix, copy_existing=copy_existing)
if prefix is not None:

prefix = Path(prefix)
metadata = load_metadata(prefix / "meta.json")
metadata["custom_op"] = self.custom_op
metadata["inv_op"] = self.inv_op
metadata["custom_op_kwargs"] = self.custom_op_kwargs
metadata["inv_op_kwargs"] = self.inv_op_kwargs
save_metadata(prefix / "meta.json", metadata)

self._is_memmap = True
self._is_shared = False
self._device = torch.device("cpu")
self.lock_()
return self

@classmethod
def load_memmap(cls, prefix: str) -> _CustomOpTensorDict:
def load_metadata(filepath):
with open(filepath) as json_metadata:
metadata = json.load(json_metadata)
return metadata
if not prefix.exists():
os.makedirs(prefix, exist_ok=True)
metadata = {}

dest_source = self._source._memmap_(
prefix=None if prefix is None else prefix / "_source",
copy_existing=copy_existing,
executor=executor,
futures=futures,
inplace=inplace,
like=like,
)
if not inplace:
dest = type(self)(
dest_source,
custom_op=self.custom_op,
inv_op=self.inv_op,
custom_op_kwargs=self.custom_op_kwargs,
inv_op_kwargs=self.inv_op_kwargs,
batch_size=self.batch_size,
)
else:
dest = self

prefix = Path(prefix)
if prefix is not None:
if executor is None:
save_metadata(
dest,
prefix / "meta.json",
metadata=metadata,
)
else:
futures.append(
executor.submit(save_metadata, dest, prefix / "meta.json", metadata)
)
return dest

metadata = load_metadata(prefix / "meta.json")
@classmethod
def _load_memmap(cls, prefix: str, metadata: dict) -> _CustomOpTensorDict:
custom_op = metadata.pop("custom_op")
inv_op = metadata.pop("inv_op")
custom_op_kwargs = metadata.pop("custom_op_kwargs")
inv_op_kwargs = metadata.pop("inv_op_kwargs")

source = TensorDict.load_memmap(prefix)
source = TensorDict.load_memmap(prefix / "_source")

return cls(
source,
custom_op=custom_op,
Expand Down Expand Up @@ -2419,7 +2461,6 @@ def sorted_keys(self):
all = TensorDict.all
any = TensorDict.any
expand = TensorDict.expand
memmap_like = TensorDict.memmap_like
unbind = TensorDict.unbind
_permute = TensorDict._permute
_transpose = TensorDict._transpose
Expand Down Expand Up @@ -2775,3 +2816,12 @@ def _iter_items_lazystack(
else:
raise err
yield key, value


_register_tensor_class(LazyStackedTensorDict)
_register_tensor_class(_CustomOpTensorDict)
_register_tensor_class(_PermutedTensorDict)
_register_tensor_class(_SqueezedTensorDict)
_register_tensor_class(_UnsqueezedTensorDict)
_register_tensor_class(_TransposedTensorDict)
_register_tensor_class(_ViewedTensorDict)
Loading
Loading