diff --git a/benchmarks/common/memmap_benchmarks_test.py b/benchmarks/common/memmap_benchmarks_test.py index 9dc388039..a3a819f13 100644 --- a/benchmarks/common/memmap_benchmarks_test.py +++ b/benchmarks/common/memmap_benchmarks_test.py @@ -1,9 +1,11 @@ import argparse +from pathlib import Path import pytest import torch from tensordict import MemmapTensor, TensorDict +from torch import nn def get_available_devices(): @@ -77,6 +79,55 @@ def test_memmaptd_index_op(benchmark, td_memmap): ) +def test_serialize_model(benchmark, tmpdir): + """Tests efficiency of saving weights as memmap tensors, including TD construction.""" + with torch.device("cuda" if torch.cuda.device_count() else "cpu"): + t = nn.Transformer() + benchmark(lambda: TensorDict.from_module(t).memmap(tmpdir, num_threads=32)) + + +def test_serialize_model_filesystem(benchmark): + """Tests efficiency of saving weights as memmap tensors in file system, including TD construction.""" + with torch.device("cuda" if torch.cuda.device_count() else "cpu"): + t = nn.Transformer() + benchmark(lambda: TensorDict.from_module(t).memmap(num_threads=32)) + + +def test_serialize_model_pickle(benchmark, tmpdir): + """Tests efficiency of pickling a model state-dict, including state-dict construction.""" + with torch.device("cuda" if torch.cuda.device_count() else "cpu"): + t = nn.Transformer() + path = Path(tmpdir) / "file.t" + benchmark(lambda: torch.save(t.state_dict(), path)) + + +def test_serialize_weights(benchmark, tmpdir): + """Tests efficiency of saving weights as memmap tensors.""" + with torch.device("cuda" if torch.cuda.device_count() else "cpu"): + t = nn.Transformer() + + weights = TensorDict.from_module(t) + benchmark(lambda: weights.memmap(tmpdir, num_threads=32)) + + +def test_serialize_weights_filesystem(benchmark): + """Tests efficiency of saving weights as memmap tensors.""" + with torch.device("cuda" if torch.cuda.device_count() else "cpu"): + t = nn.Transformer() + + weights = TensorDict.from_module(t) + benchmark(lambda: weights.memmap(num_threads=32)) + + +def test_serialize_weights_pickle(benchmark, tmpdir): + """Tests efficiency of pickling a model state-dict.""" + with torch.device("cuda" if torch.cuda.device_count() else "cpu"): + t = nn.Transformer() + path = Path(tmpdir) / "file.t" + weights = t.state_dict() + benchmark(lambda: torch.save(weights, path)) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 131790a9d..065400e6f 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -25,6 +25,7 @@ from tensordict.base import ( _ACCEPTED_CLASSES, _is_tensor_collection, + _register_tensor_class, BEST_ATTEMPT_INPLACE, CompatibleType, is_tensor_collection, @@ -1522,55 +1523,60 @@ def detach_(self) -> T: td.detach_() return self - def memmap_(self, prefix: str | None = None, copy_existing: bool = False) -> T: - if prefix is not None: - prefix = Path(prefix) - if not prefix.exists(): - os.makedirs(prefix, exist_ok=True) - torch.save({"stack_dim": self.stack_dim}, prefix / "meta.pt") - for i, td in enumerate(self.tensordicts): - td.memmap_( - prefix=(prefix / str(i)) if prefix is not None else None, - copy_existing=copy_existing, - ) - self._is_memmap = True - self._is_shared = False - self._device = torch.device("cpu") - self.lock_() - return self - - def memmap_like( + def _memmap_( self, prefix: str | None = None, + copy_existing: bool = False, + executor=None, + futures=None, + inplace=True, + like=False, ) -> T: - tds = [] if prefix is not None: - prefix = Path(prefix) - if not prefix.exists(): - os.makedirs(prefix, exist_ok=True) - torch.save({"stack_dim": self.stack_dim}, prefix / "meta.pt") + + def save_metadata(prefix=prefix, self=self): + prefix = Path(prefix) + if not prefix.exists(): + os.makedirs(prefix, exist_ok=True) + with open(prefix / "meta.json", "w") as f: + json.dump( + {"_type": str(self.__class__), "stack_dim": self.stack_dim}, f + ) + + if executor is None: + save_metadata() + else: + futures.append(executor.submit(save_metadata)) + + results = [] for i, td in enumerate(self.tensordicts): - td_like = td.memmap_like( - prefix=(prefix / str(i)) if prefix is not None else None, + results.append( + td._memmap_( + prefix=(prefix / str(i)) if prefix is not None else None, + copy_existing=copy_existing, + executor=executor, + futures=futures, + inplace=inplace, + like=like, + ) ) - tds.append(td_like) - td_out = torch.stack(tds, self.stack_dim) - self._is_memmap = True - self._is_shared = False - self._device = torch.device("cpu") - td_out.lock_() - return td_out + if not inplace: + results = torch.stack(results, dim=self.stack_dim) + else: + results = self + results._is_memmap = True + results._is_shared = False + results._device = torch.device("cpu") + return results @classmethod - def load_memmap(cls, prefix: str) -> LazyStackedTensorDict: - prefix = Path(prefix) + def _load_memmap(cls, prefix: str, metadata: dict) -> LazyStackedTensorDict: tensordicts = [] i = 0 while (prefix / str(i)).exists(): tensordicts.append(TensorDict.load_memmap(prefix / str(i))) i += 1 - metadata = torch.load(prefix / "meta.pt") return cls(*tensordicts, stack_dim=metadata["stack_dim"]) def expand(self, *args: int, inplace: bool = False) -> T: @@ -2009,8 +2015,14 @@ def __init__( if batch_size is not None and batch_size != self.batch_size: raise RuntimeError("batch_size does not match self.batch_size.") - # def __hash__(self): - # return hash((self._source, self.custom_op, self.inv_op, self.custom_op_kwargs, self.inv_op_kwargs)) + def is_empty(self) -> bool: + return self._source.is_empty() + + def is_memmap(self) -> bool: + return self._source.is_memmap() + + def is_shared(self) -> bool: + return self._source.is_shared() def _update_custom_op_kwargs(self, source_tensor: Tensor) -> dict[str, Any]: """Allows for a transformation to be customized for a certain shape, device or dtype. @@ -2296,51 +2308,81 @@ def masked_fill(self, mask: Tensor, value: float | bool) -> T: td_copy = self.clone() return td_copy.masked_fill_(mask, value) - def memmap_( - self, prefix: str | None = None, copy_existing: bool = False - ) -> _CustomOpTensorDict: - def load_metadata(filepath): - with open(filepath) as json_metadata: - metadata = json.load(json_metadata) - return metadata - - def save_metadata(filepath, metadata=None): + def _memmap_( + self, + *, + prefix: str | None, + copy_existing: bool, + executor, + futures, + inplace, + like, + ) -> T: + def save_metadata(data: TensorDictBase, filepath, metadata=None): + if metadata is None: + metadata = {} + metadata.update( + { + "shape": list(data.shape), + "device": str(data.device), + "_type": str(data.__class__), + "custom_op": data.custom_op, + "inv_op": data.inv_op, + "custom_op_kwargs": data.custom_op_kwargs, + "inv_op_kwargs": data.inv_op_kwargs, + } + ) with open(filepath, "w") as json_metadata: json.dump(metadata, json_metadata) - self._source.memmap_(prefix=prefix, copy_existing=copy_existing) if prefix is not None: - prefix = Path(prefix) - metadata = load_metadata(prefix / "meta.json") - metadata["custom_op"] = self.custom_op - metadata["inv_op"] = self.inv_op - metadata["custom_op_kwargs"] = self.custom_op_kwargs - metadata["inv_op_kwargs"] = self.inv_op_kwargs - save_metadata(prefix / "meta.json", metadata) - - self._is_memmap = True - self._is_shared = False - self._device = torch.device("cpu") - self.lock_() - return self - - @classmethod - def load_memmap(cls, prefix: str) -> _CustomOpTensorDict: - def load_metadata(filepath): - with open(filepath) as json_metadata: - metadata = json.load(json_metadata) - return metadata + if not prefix.exists(): + os.makedirs(prefix, exist_ok=True) + metadata = {} + + dest_source = self._source._memmap_( + prefix=None if prefix is None else prefix / "_source", + copy_existing=copy_existing, + executor=executor, + futures=futures, + inplace=inplace, + like=like, + ) + if not inplace: + dest = type(self)( + dest_source, + custom_op=self.custom_op, + inv_op=self.inv_op, + custom_op_kwargs=self.custom_op_kwargs, + inv_op_kwargs=self.inv_op_kwargs, + batch_size=self.batch_size, + ) + else: + dest = self - prefix = Path(prefix) + if prefix is not None: + if executor is None: + save_metadata( + dest, + prefix / "meta.json", + metadata=metadata, + ) + else: + futures.append( + executor.submit(save_metadata, dest, prefix / "meta.json", metadata) + ) + return dest - metadata = load_metadata(prefix / "meta.json") + @classmethod + def _load_memmap(cls, prefix: str, metadata: dict) -> _CustomOpTensorDict: custom_op = metadata.pop("custom_op") inv_op = metadata.pop("inv_op") custom_op_kwargs = metadata.pop("custom_op_kwargs") inv_op_kwargs = metadata.pop("inv_op_kwargs") - source = TensorDict.load_memmap(prefix) + source = TensorDict.load_memmap(prefix / "_source") + return cls( source, custom_op=custom_op, @@ -2419,7 +2461,6 @@ def sorted_keys(self): all = TensorDict.all any = TensorDict.any expand = TensorDict.expand - memmap_like = TensorDict.memmap_like unbind = TensorDict.unbind _permute = TensorDict._permute _transpose = TensorDict._transpose @@ -2775,3 +2816,12 @@ def _iter_items_lazystack( else: raise err yield key, value + + +_register_tensor_class(LazyStackedTensorDict) +_register_tensor_class(_CustomOpTensorDict) +_register_tensor_class(_PermutedTensorDict) +_register_tensor_class(_SqueezedTensorDict) +_register_tensor_class(_UnsqueezedTensorDict) +_register_tensor_class(_TransposedTensorDict) +_register_tensor_class(_ViewedTensorDict) diff --git a/tensordict/_td.py b/tensordict/_td.py index c2f493274..7171807ab 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -810,80 +810,6 @@ def split(self, split_size: int | list[int], dim: int = 0) -> list[TensorDictBas for ss, bs in zip(split_sizes, batch_sizes) ) - def memmap_like(self, prefix: str | None = None) -> T: - def save_metadata(data: TensorDictBase, filepath, metadata=None): - if metadata is None: - metadata = {} - metadata.update( - { - "shape": list(data.shape), - "device": str(data.device), - "_type": str(data.__class__), - } - ) - with open(filepath, "w") as json_metadata: - json.dump(metadata, json_metadata) - - if prefix is not None: - prefix = Path(prefix) - if not prefix.exists(): - os.makedirs(prefix, exist_ok=True) - metadata = {} - if not self.keys(): - raise Exception( - "memmap_like() must be called when the TensorDict is (partially) " - "populated. Set a tensor first." - ) - tensordict = TensorDict( - {}, - self.batch_size, - device=self.device, - names=self.names if self._has_names() else None, - ) - for key, value in self.items(): - if _is_tensor_collection(value.__class__): - if prefix is not None: - # ensure subdirectory exists - os.makedirs(prefix / key, exist_ok=True) - tensordict._set_str( - key, - value.memmap_like( - prefix=prefix / key, - ), - inplace=False, - validated=True, - ) - else: - tensordict._set_str( - key, value.memmap_like(), inplace=False, validated=True - ) - continue - else: - if prefix is not None: - metadata[key] = { - "dtype": str(value.dtype), - "shape": value.shape, - "device": str(value.device), - } - tensordict._set_str( - key, - MemoryMappedTensor.empty_like( - value, - filename=str(prefix / f"{key}.memmap") - if prefix is not None - else None, - ), - inplace=False, - validated=True, - ) - tensordict._is_memmap = True - tensordict._is_shared = False - tensordict._device = torch.device("cpu") - tensordict.lock_() - if prefix is not None: - save_metadata(self, filepath=prefix / "meta.json", metadata=metadata) - return tensordict - def masked_select(self, mask: Tensor) -> T: d = {} mask_expand = mask @@ -1473,10 +1399,14 @@ def detach_(self) -> T: value.detach_() return self - def memmap_( + def _memmap_( self, - prefix: str | None = None, - copy_existing: bool = False, + prefix: str | None, + copy_existing: bool, + executor, + futures, + inplace, + like, ) -> T: def save_metadata(data: TensorDictBase, filepath, metadata=None): if metadata is None: @@ -1496,35 +1426,51 @@ def save_metadata(data: TensorDictBase, filepath, metadata=None): if not prefix.exists(): os.makedirs(prefix, exist_ok=True) metadata = {} - if self.is_shared() and self.device.type == "cpu": + if inplace and self.is_shared() and self.device.type == "cpu": raise RuntimeError( "memmap and shared memory are mutually exclusive features." ) + dest = ( + self + if inplace + else TensorDict( + {}, + batch_size=self.batch_size, + _is_memmap=True, + _is_shared=False, + names=self.names if self._has_names() else None, + device=torch.device("cpu"), + ) + ) for key, value in self.items(): - if value.requires_grad: - raise Exception( - "memmap is not compatible with gradients, one of Tensors has requires_grad equals True" - ) if _is_tensor_collection(value.__class__): - if prefix is not None: - # ensure subdirectory exists - os.makedirs(prefix / key, exist_ok=True) - self._tensordict[key] = value.memmap_( - prefix=prefix / key, copy_existing=copy_existing - ) - else: - self._tensordict[key] = value.memmap_() + dest._tensordict[key] = value._memmap_( + prefix=prefix / key if prefix is not None else None, + copy_existing=copy_existing, + executor=executor, + futures=futures, + inplace=inplace, + like=like, + ) continue else: # user did specify location and memmap is in wrong place, so we copy - self._tensordict[key] = MemoryMappedTensor.from_tensor( - value, - filename=str(prefix / f"{key}.memmap") - if prefix is not None - else None, - copy_existing=copy_existing, - # copy_existing=copy_existing, - ) + def _populate( + dest=dest, value=value, key=key, copy_existing=copy_existing + ): + filename = None if prefix is None else str(prefix / f"{key}.memmap") + dest._tensordict[key] = MemoryMappedTensor.from_tensor( + value.data if value.requires_grad else value, + filename=filename, + copy_existing=copy_existing, + existsok=True, + copy_data=not like, + ) + + if executor is None: + _populate() + else: + futures.append(executor.submit(_populate)) if prefix is not None: metadata[key] = { "device": str(value.device), @@ -1533,48 +1479,31 @@ def save_metadata(data: TensorDictBase, filepath, metadata=None): } if prefix is not None: - save_metadata( - self, - prefix / "meta.json", - metadata=metadata, - ) - self._is_memmap = True - self._is_shared = False # since they are mutually exclusive - self._device = torch.device("cpu") - self.lock_() - return self - - @classmethod - def load_memmap(cls, prefix: str) -> T: - prefix = Path(prefix) - - def load_metadata(filepath): - with open(filepath) as json_metadata: - metadata = json.load(json_metadata) - if metadata["_type"] != str(cls): - # return early to load from another cls - return metadata - if metadata["device"] == "None": - metadata["device"] = None - else: - metadata["device"] = torch.device(metadata["device"]) - metadata["shape"] = torch.Size(metadata["shape"]) - # if 'dtype' in metadata: - # metadata.dtype = to - return metadata - - metadata = load_metadata(prefix / "meta.json") - type_name = metadata["_type"] - if type_name != str(cls): - import tensordict - - for other_cls in tensordict.base._ACCEPTED_CLASSES: - if str(other_cls) == type_name: - return other_cls.load_memmap(prefix) + if executor is None: + save_metadata( + dest, + prefix / "meta.json", + metadata=metadata, + ) else: - raise RuntimeError( - f"Could not find name {type_name} in {tensordict.base._ACCEPTED_CLASSES}." + futures.append( + executor.submit(save_metadata, dest, prefix / "meta.json", metadata) ) + if inplace: + self._is_memmap = True + self._is_shared = False # since they are mutually exclusive + self._device = torch.device("cpu") + dest._is_locked = True + return dest + + @classmethod + def _load_memmap(cls, prefix: str, metadata: dict) -> T: + if metadata["device"] == "None": + metadata["device"] = None + else: + metadata["device"] = torch.device(metadata["device"]) + metadata["shape"] = torch.Size(metadata["shape"]) + out = cls({}, batch_size=metadata.pop("shape"), device=metadata.pop("device")) for key, entry_metadata in metadata.items(): @@ -1590,14 +1519,15 @@ def load_metadata(filepath): ): # invalid dict means continue - out.set( + out._set_str( key, MemoryMappedTensor.from_filename( dtype=_STRDTYPE2DTYPE[dtype], shape=torch.Size(entry_metadata["shape"]), filename=str(prefix / f"{key}.memmap"), - # device=str(entry_metadata["device"]), ), + validated=True, + inplace=False, ) # iterate over folders and load them for path in prefix.iterdir(): @@ -2373,11 +2303,66 @@ def masked_fill(self, mask: Tensor, value: float | bool) -> T: td_copy = self.clone() return td_copy.masked_fill_(mask, value) - def memmap_(self, prefix: str | None = None, copy_existing: bool = False) -> T: + def memmap_( + self, + prefix: str | None = None, + copy_existing: bool = False, + num_threads: int = 0, + ) -> T: raise RuntimeError( "Converting a sub-tensordict values to memmap cannot be done." ) + def _memmap_( + self, + *, + prefix: str | None, + copy_existing: bool, + executor, + futures, + inplace, + like, + ) -> T: + if prefix is not None: + + def save_metadata(prefix=prefix, self=self): + prefix = Path(prefix) + if not prefix.exists(): + os.makedirs(prefix, exist_ok=True) + with open(prefix / "meta.json", "w") as f: + json.dump( + { + "_type": str(self.__class__), + "index": _index_to_str(self.idx), + }, + f, + ) + + if executor is None: + save_metadata() + else: + futures.append(executor.submit(save_metadata)) + + _source = self._source._memmap_( + prefix=prefix / "_source" if prefix is not None else None, + copy_existing=copy_existing, + executor=executor, + futures=futures, + inplace=inplace, + like=like, + ) + if not inplace: + result = _SubTensorDict(_source, idx=self.idx) + else: + result = self + return result + + def _load_memmap(cls, prefix: Path, metadata: dict): + index = metadata["index"] + return _SubTensorDict( + TensorDict.load_memmap(prefix / "_source"), _str_to_index(index) + ) + def share_memory_(self) -> T: raise RuntimeError( "Casting a sub-tensordict values to shared memory cannot be done." @@ -2655,6 +2640,35 @@ def __call__(self, *args, **kwargs): return instance +def _index_to_str(index): + if isinstance(index, tuple): + return tuple(_index_to_str(elt) for elt in index) + if isinstance(index, slice): + return ("slice", {"start": index.start, "stop": index.stop, "step": index.step}) + if isinstance(index, range): + return ("range", {"start": index.start, "stop": index.stop, "step": index.step}) + if isinstance(index, Tensor): + return ("tensor", index.tolist(), str(index.device)) + return index + + +def _str_to_index(index): + if isinstance(index, tuple): + if not len(index): + return index + if index[0] == "slice": + index = index[1] + return slice(index["start"], index["stop"], index["step"]) + if index[0] == "range": + index = index[1] + return range(index["start"], index["stop"], index["step"]) + if index[0] == "tensor": + index, device = index[1:] + return torch.tensor(index, device=device) + return tuple(_index_to_str(elt) for elt in index) + return index + + class SubTensorDict(_SubTensorDict, metaclass=_subtd_meta_deprec): """Deprecated public version of _SubTensorDict class. @@ -2662,3 +2676,7 @@ class SubTensorDict(_SubTensorDict, metaclass=_subtd_meta_deprec): """ ... + + +_register_tensor_class(TensorDict) +_register_tensor_class(_SubTensorDict) diff --git a/tensordict/base.py b/tensordict/base.py index 99406c957..33a64c56b 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -7,11 +7,13 @@ import abc import collections +import json import numbers import warnings import weakref from collections.abc import MutableMapping from copy import copy +from pathlib import Path from textwrap import indent from typing import ( Any, @@ -1525,8 +1527,25 @@ def share_memory_(self) -> T: ... @abc.abstractmethod - def memmap_(self, prefix: str | None = None, copy_existing: bool = False) -> T: - """Writes all tensors onto a corresponding memory-mapped Tensor. + def _memmap_( + self, + *, + prefix: str | None, + copy_existing: bool, + executor, + futures, + inplace, + like, + ) -> T: + ... + + def memmap_( + self, + prefix: str | None = None, + copy_existing: bool = False, + num_threads: int = 0, + ) -> T: + """Writes all tensors onto a corresponding memory-mapped Tensor, in-place. Args: prefix (str): directory prefix where the memory-mapped tensors will @@ -1550,10 +1569,92 @@ def memmap_(self, prefix: str | None = None, copy_existing: bool = False) -> T: Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop. """ - ... + if num_threads > 1: + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [] + result = self._memmap_( + prefix=prefix, + copy_existing=copy_existing, + executor=executor, + futures=futures, + inplace=True, + like=False, + ) + for future in futures: + future.result() + return result + return self._memmap_( + prefix=prefix, + copy_existing=copy_existing, + inplace=True, + futures=None, + executor=None, + like=False, + ).lock_() + + def memmap( + self, + prefix: str | None = None, + copy_existing: bool = False, + num_threads: int = 0, + ) -> T: + """Writes all tensors onto a corresponding memory-mapped Tensor in a new tensordict. - @abc.abstractmethod - def memmap_like(self, prefix: str | None = None) -> T: + Args: + prefix (str): directory prefix where the memory-mapped tensors will + be stored. The directory tree structure will mimic the tensordict's. + copy_existing (bool): If False (default), an exception will be raised if an + entry in the tensordict is already a tensor stored on disk + with an associated file, but is not saved in the correct + location according to prefix. + If ``True``, any existing Tensor will be copied to the new location. + + The TensorDict is then locked, meaning that any writing operations that + isn't in-place will throw an exception (eg, rename, set or remove an + entry). + Once the tensordict is unlocked, the memory-mapped attribute is turned to ``False``, + because cross-process identity is not guaranteed anymore. + + Returns: + A new tensordict with the tensors stored on disk. + + Note: + Serialising in this fashion might be slow with deeply nested tensordicts, so + it is not recommended to call this method inside a training loop. + """ + if num_threads > 1: + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [] + result = self._memmap_( + prefix=prefix, + copy_existing=copy_existing, + executor=executor, + futures=futures, + inplace=False, + like=False, + ) + for future in futures: + future.result() + return result + return self._memmap_( + prefix=prefix, + copy_existing=copy_existing, + inplace=False, + executor=None, + like=False, + futures=None, + ).lock_() + + def memmap_like( + self, + prefix: str | None = None, + copy_existing: bool = False, + num_threads: int = 0, + ) -> T: """Creates a contentless Memory-mapped tensordict with the same shapes as the original one. Args: @@ -1581,6 +1682,57 @@ def memmap_like(self, prefix: str | None = None) -> T: >>> buffer = td.memmap_like("/path/to/dataset") """ + if num_threads > 1: + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [] + result = self._memmap_( + prefix=prefix, + copy_existing=copy_existing, + executor=executor, + futures=futures, + inplace=False, + like=True, + ) + for future in futures: + future.result() + return result + return self._memmap_( + prefix=prefix, + copy_existing=copy_existing, + inplace=False, + like=True, + executor=None, + futures=None, + ).lock_() + + @classmethod + def load_memmap(cls, prefix: str | Path) -> T: + prefix = Path(prefix) + + def load_metadata(filepath): + with open(filepath) as json_metadata: + metadata = json.load(json_metadata) + return metadata + + metadata = load_metadata(prefix / "meta.json") + type_name = metadata["_type"] + if type_name != str(cls): + import tensordict + + for other_cls in tensordict.base._ACCEPTED_CLASSES: + if str(other_cls) == type_name: + return other_cls._load_memmap(prefix, metadata) + else: + raise RuntimeError( + f"Could not find name {type_name} in {tensordict.base._ACCEPTED_CLASSES}. Did you call _register_tensor_class(cls) on {type_name}?" + ) + return cls._load_memmap(prefix, metadata) + + @classmethod + @abc.abstractmethod + def _load_memmap(cls, prefix: Path, metadata: dict): ... # Key functionality: set, get, set_, set_at_, update, update_ diff --git a/tensordict/memmap.py b/tensordict/memmap.py index 44464ad02..5d87e75c6 100644 --- a/tensordict/memmap.py +++ b/tensordict/memmap.py @@ -133,7 +133,7 @@ def from_tensor( """ if isinstance(input, MemoryMappedTensor): - if filename is None or ( + if (filename is None and input._filename is None) or ( input._filename is not None and filename is not None and Path(filename).absolute() == Path(input.filename).absolute() @@ -174,7 +174,7 @@ def from_tensor( size = torch.iinfo(input.dtype).bits // 8 * shape.numel() handler = _FileHandler(size) out = torch.frombuffer(memoryview(handler.buffer), dtype=input.dtype) - out = torch.reshape(out, shape) + out = out.view(shape) out = cls(out) else: handler = None diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 925b5ebc0..96bdf7ede 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -15,12 +15,14 @@ import torch from functorch import dim as ftdim + from tensordict._lazy import _CustomOpTensorDict, LazyStackedTensorDict from tensordict._td import _SubTensorDict, TensorDict from tensordict._torch_func import TD_HANDLED_FUNCTIONS from tensordict.base import ( _is_tensor_collection, + _register_tensor_class, CompatibleType, NO_DEFAULT, T, @@ -695,9 +697,18 @@ def masked_fill_(self, *args, **kwargs): ... def memmap_( - self, prefix: str | None = None, copy_existing: bool = False + self, + prefix: str | None = None, + copy_existing: bool = False, + num_threads: int = 0, ) -> TensorDictBase: - raise RuntimeError("Cannot build a memmap TensorDict in-place.") + raise RuntimeError( + "Cannot build a memmap TensorDict in-place. Use memmap or memmap_like instead." + ) + + _memmap_ = TensorDict._memmap_ + + _load_memmap = TensorDict._load_memmap @_fallback_property def names(self): @@ -849,7 +860,12 @@ def masked_select(self, mask: Tensor) -> T: ... @_fallback - def memmap_like(self, prefix: str | None = None) -> T: + def memmap_like( + self, + prefix: str | None = None, + copy_existing: bool = False, + num_threads: int = 0, + ) -> T: ... @_fallback @@ -1067,3 +1083,6 @@ def _empty_like(td: TensorDictBase, *args, **kwargs) -> TensorDictBase: "Consider calling tensordict.to_tensordict() first." ) from err return tdclone.data.apply_(lambda x: torch.empty_like(x, *args, **kwargs)) + + +_register_tensor_class(TensorDictParams) diff --git a/tensordict/persistent.py b/tensordict/persistent.py index e8aaf25fe..a098c2d0c 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -518,24 +518,26 @@ def masked_fill_(self, mask, value): return self def memmap_( - self, prefix: str | None = None, copy_existing: bool = False + self, + prefix: str | None = None, + copy_existing: bool = False, + num_threads: int = 0, ) -> PersistentTensorDict: raise RuntimeError( "Cannot build a memmap TensorDict in-place from a PersistentTensorDict. Use `td.memmap()` instead." ) - def memmap( + def _memmap_( self, - prefix: str | None = None, - ) -> TensorDict: - """Converts the PersistentTensorDict to a memmap equivalent.""" - mm_like = self.memmap_like(prefix) - for key in self.keys(include_nested=True, leaves_only=True): - mm_val = mm_like[key] - mm_val._memmap_array[:] = self._get_array(key) - return mm_like + *, + prefix: str | None, + copy_existing: bool, + executor, + futures, + inplace, + like, + ) -> T: - def memmap_like(self, prefix: str | None = None) -> T: # re-implements this to make it faster using the meta-data def save_metadata(data: TensorDictBase, filepath, metadata=None): if metadata is None: @@ -544,6 +546,7 @@ def save_metadata(data: TensorDictBase, filepath, metadata=None): { "shape": list(data.shape), "device": str(data.device), + "_type": str(self.__class__), } ) with open(filepath, "w") as json_metadata: @@ -559,58 +562,78 @@ def save_metadata(data: TensorDictBase, filepath, metadata=None): "memmap_like() must be called when the TensorDict is (partially) " "populated. Set a tensor first." ) - tensordict = TensorDict( - {}, - self.batch_size, - device=self.device, - names=self.names if self._has_names() else None, + dest = ( + self + if inplace + else TensorDict( + {}, + batch_size=self.batch_size, + _is_memmap=True, + _is_shared=False, + names=self.names if self._has_names() else None, + device=torch.device("cpu"), + ) ) for key, value in self._items_metadata(): if not value["array"]: - value = self.get(key) - if prefix is not None: - # ensure subdirectory exists - os.makedirs(prefix / key, exist_ok=True) - tensordict._set_str( - key, - value.memmap_like( - prefix=prefix / key, - ), - inplace=False, - validated=True, - ) - else: - tensordict._set_str( - key, value.memmap_like(), inplace=False, validated=True - ) + value = self._get_str(key) + dest._set_str( + key, + value._memmap_( + prefix=prefix / key if prefix is not None else None, + executor=executor, + like=like, + copy_existing=copy_existing, + futures=futures, + inplace=inplace, + ), + inplace=False, + validated=True, + ) continue else: + value = self._get_str(key) if prefix is not None: metadata[key] = { "dtype": str(value.dtype), "shape": value.shape, "device": str(value.device), } - tensordict._set_str( - key, - MemoryMappedTensor.empty( - value["shape"], - # device="cpu", - dtype=value["dtype"], + + def _populate( + tensordict=dest, key=key, value=value, prefix=prefix, like=like + ): + val = MemoryMappedTensor.from_tensor( + value, filename=str(prefix / f"{key}.memmap") if prefix is not None else None, - ), - inplace=False, - validated=True, - ) - tensordict._is_memmap = True - tensordict._is_shared = False - tensordict._device = torch.device("cpu") - tensordict.lock_() + copy_data=not like, + copy_existing=copy_existing, + existsok=True, + ) + tensordict._set_str( + key, + val, + inplace=False, + validated=True, + ) + + if executor is None: + _populate() + else: + futures.append(executor.submit(_populate)) + if prefix is not None: - save_metadata(self, prefix=prefix, metadata=metadata) - return tensordict + if executor is None: + save_metadata(dest, prefix / "meta.json", metadata) + else: + futures.append( + executor.submit(save_metadata, dest, prefix / "meta.json", metadata) + ) + return dest + + _load_memmap = TensorDict._load_memmap def pin_memory(self): """Returns a new PersistentTensorDict where any given Tensor key returns a tensor with pin_memory=True. diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 88a3012c5..7c059308f 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -10,6 +10,7 @@ import inspect import json import numbers +import os import re import sys import warnings @@ -185,9 +186,14 @@ def __torch_function__( cls.unbind = _unbind cls.state_dict = _state_dict cls.load_state_dict = _load_state_dict - cls.memmap_ = _memmap - cls.memmap_like = _memmap_like - cls.load_memmap = classmethod(_load_memmap) + cls._memmap_ = _memmap_ + + # Memmap + cls.memmap_like = TensorDictBase.memmap_like + cls.memmap_ = TensorDictBase.memmap_ + cls.memmap = TensorDictBase.memmap + cls._load_memmap = classmethod(_load_memmap) + cls.load_memmap = TensorDictBase.load_memmap for attr in TensorDict.__dict__.keys(): func = getattr(TensorDict, attr) @@ -340,56 +346,60 @@ def wrapper(cls, tensordict, non_tensordict=None): # noqa: D417 return wrapper -def _memmap(self, prefix: str | None = None, copy_existing: bool = False): +def _memmap_( + self, + prefix: str | None = None, + copy_existing: bool = False, + executor=None, + futures=None, + inplace=True, + like=False, +): + _non_tensordict = self._non_tensordict + cls = self.__class__ + if prefix is not None: prefix = Path(prefix) - with open(prefix / "meta.json", "w") as f: - metadata = {"_type": str(self.__class__)} - for key, value in self._non_tensordict.items(): - if ( - isinstance(value, (int, bool, str, float, dict, tuple, list)) - or value is None - ): - metadata[key] = value - json.dump(metadata, f) + if not prefix.exists(): + os.makedirs(prefix, exist_ok=True) + + def save_metadata(cls=cls, _non_tensordict=_non_tensordict, prefix=prefix): + with open(prefix / "meta.json", "w") as f: + metadata = {"_type": str(cls)} + for key, value in _non_tensordict.items(): + if ( + isinstance(value, (int, bool, str, float, dict, tuple, list)) + or value is None + ): + metadata[key] = value + json.dump(metadata, f) + + if executor is None: + save_metadata() + else: + futures.append(executor.submit(save_metadata)) + prefix = prefix / "_tensordict" - self._tensordict.memmap_(prefix=prefix, copy_existing=copy_existing) - return self + td = self._tensordict._memmap_( + prefix=prefix, + executor=executor, + futures=futures, + inplace=inplace, + like=like, + copy_existing=copy_existing, + ) -def _memmap_like(self, prefix: str | None = None): - metadata = {"_type": str(self.__class__)} - for key, value in self._non_tensordict.items(): - if ( - isinstance(value, (int, bool, str, float, dict, tuple, list)) - or value is None - ): - metadata[key] = value - if prefix is not None: - prefix = Path(prefix) - with open(prefix / "meta.json", "w") as f: - json.dump(metadata, f) - prefix = prefix / "_tensordict" - metadata.pop("_type") - td = self._tensordict.memmap_like(prefix=prefix) - return self.__class__._from_tensordict(td, metadata) - - -def _load_memmap(cls, prefix: str): - prefix = Path(prefix) - with open(prefix / "meta.json", "r") as f: - metadata = json.load(f) - type_name = metadata.pop("_type") - if type_name != str(cls): - # try to get the type from register - for other_cls in tensordict_lib.base._ACCEPTED_CLASSES: - if str(other_cls) == type_name: - return other_cls.load_memmap(prefix) - else: - raise RuntimeError( - f"Could not find name '{type_name}' in {tensordict_lib.base._ACCEPTED_CLASSES}." - ) - non_tensordict = copy(metadata) + if not inplace: + result = cls._from_tensordict(td, _non_tensordict) + else: + result = self + return result + + +def _load_memmap(cls, prefix: Path, metadata: dict): + del metadata["_type"] + non_tensordict = copy(metadata) td = TensorDict.load_memmap(prefix / "_tensordict") return cls._from_tensordict(td, non_tensordict) diff --git a/test/test_memmap.py b/test/test_memmap.py index 1f1f797fd..a820dd56e 100644 --- a/test/test_memmap.py +++ b/test/test_memmap.py @@ -496,14 +496,16 @@ def test_filename(tmp_path): assert str(mt._filename) == str(tmp_path / "test.memmap") # memmap -> memmap keeps the filename - mt2 = MemoryMappedTensor.from_tensor(mt) + mt2 = MemoryMappedTensor.from_tensor(mt, filename=mt.filename) assert str(mt2._filename) == str(tmp_path / "test.memmap") assert mt2 is mt - # memmap -> memmap keeps the filename - mt3 = MemoryMappedTensor.from_tensor(mt, filename=tmp_path / "test.memmap") - assert str(mt3._filename) == str(tmp_path / "test.memmap") - assert mt3 is mt + # memmap -> memmap keeps id if no filename + mt0 = MemoryMappedTensor.from_tensor( + torch.zeros((), dtype=torch.float32).expand(10) + ) + mt3 = MemoryMappedTensor.from_tensor(mt0) + assert mt3 is mt0 # memmap -> memmap with a new filename with pytest.raises(RuntimeError, match="copy_existing"): diff --git a/test/test_tensordict.py b/test/test_tensordict.py index bb6ef79a8..7b37b7fe8 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -2618,27 +2618,51 @@ def test_repr(self, td_name, device): td = getattr(self, td_name)(device) _ = str(td) - def test_memmap_(self, td_name, device): + @pytest.mark.parametrize("use_dir", [True, False]) + @pytest.mark.parametrize("num_threads", [0, 2]) + def test_memmap_(self, td_name, device, use_dir, tmpdir, num_threads): td = getattr(self, td_name)(device) if td_name in ("sub_td", "sub_td2"): with pytest.raises( RuntimeError, match="Converting a sub-tensordict values to memmap cannot be done", ): - td.memmap_() + td.memmap_( + prefix=tmpdir if use_dir else None, + num_threads=num_threads, + copy_existing=True, + ) + return elif td_name in ("td_h5", "td_params"): with pytest.raises( RuntimeError, match="Cannot build a memmap TensorDict in-place", ): - td.memmap_() + td.memmap_( + prefix=tmpdir if use_dir else None, + num_threads=num_threads, + copy_existing=True, + ) + return else: - td.memmap_() + td.memmap_( + prefix=tmpdir if use_dir else None, + num_threads=num_threads, + copy_existing=True, + ) assert td.is_memmap() + if use_dir: + assert_allclose_td(TensorDict.load_memmap(tmpdir), td) - def test_memmap_like(self, td_name, device): + @pytest.mark.parametrize("use_dir", [True, False]) + @pytest.mark.parametrize("num_threads", [0, 2]) + def test_memmap_like(self, td_name, device, use_dir, tmpdir, num_threads): td = getattr(self, td_name)(device) - tdmemmap = td.memmap_like() + tdmemmap = td.memmap_like( + prefix=tmpdir if use_dir else None, + num_threads=num_threads, + copy_existing=True, + ) assert tdmemmap is not td for key in td.keys(True): assert td[key] is not tdmemmap[key] @@ -2676,8 +2700,6 @@ def test_memmap_prefix(self, td_name, device, tmp_path): metadata = json.load(file) if td_name in ("stacked_td", "nested_stacked_td"): assert metadata["shape"] == list(td.tensordicts[0].batch_size) - elif td_name in ("unsqueezed_td", "squeezed_td", "permute_td"): - assert metadata["shape"] == list(td._source.batch_size) else: assert metadata["shape"] == list(td.batch_size)