From 67ff25e04374e0fa68240e33c16207c7bc8bbbda Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 9 Sep 2024 21:21:55 +0100 Subject: [PATCH 01/11] [Feature] str2td ghstack-source-id: c7737107beff69fa6f5438067c795d2e1a98d45f Pull Request resolved: https://github.com/pytorch/tensordict/pull/953 --- docs/source/reference/tensordict.rst | 1 + tensordict/utils.py | 88 ++++++++++++++++++++++++++++ test/test_utils.py | 28 +++++++++ 3 files changed, 117 insertions(+) diff --git a/docs/source/reference/tensordict.rst b/docs/source/reference/tensordict.rst index 668d31e9b..6369a3bdc 100644 --- a/docs/source/reference/tensordict.rst +++ b/docs/source/reference/tensordict.rst @@ -129,3 +129,4 @@ Utils dense_stack_tds set_lazy_legacy lazy_legacy + parse_tensor_dict_string diff --git a/tensordict/utils.py b/tensordict/utils.py index dd1669e05..ecf584d02 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -12,6 +12,7 @@ import math import os +import re import sys import time @@ -2560,3 +2561,90 @@ def _infer_size_impl(shape: List[int], numel: int) -> List[int]: if infer_dim is not None: out[infer_dim] = numel // newsize return out + + +def parse_tensor_dict_string(s: str): + """Parse a TensorDict repr to a TensorDict. + + .. note:: This functions is intended to be used for debugging, to reproduce a tensordict + given its printed version, and should not be used in real applications. + + """ + from tensordict import TensorDict + + # Regular expression patterns + field_pattern = r"(\w+): Tensor\(shape=torch.Size\((\[(.*?)\])\), device=(\w+), dtype=torch.(\w+), is_shared=(\w+)\)" + nested_field_pattern = r"(\w+): TensorDict\(" + batch_size_pattern = r"batch_size=torch.Size\((\[(.*?)\])\)" + device_pattern = r"device=(\w+)(?=,|$)" + + # Find all nested TensorDicts first + nested_dict_ranges = [] + for match in re.finditer(nested_field_pattern, s): + start_idx = match.start() + depth = 1 + for i in range(start_idx + len(match.group(0)), len(s)): + if s[i] == "(": + depth += 1 + elif s[i] == ")": + depth -= 1 + if depth == 0: + end_idx = i + break + nested_dict_ranges.append((start_idx, end_idx)) + + # Find all fields in the string that are not part of a nested TensorDict + fields = {} + for match in re.finditer(field_pattern, s): + name, _, shape, device, dtype, is_shared = match.groups() + field_start = match.start() + field_end = match.end() + if any( + field_start >= start and field_end <= end + for start, end in nested_dict_ranges + ): + continue # skip if this field is inside a nested TensorDict + shape = [int(x) for x in shape.split(", ")] if shape else [] + fields[name] = torch.zeros( + tuple(shape), device=torch.device(device), dtype=getattr(torch, dtype) + ) + + # Now find nested TensorDicts and add them to the fields + for match in re.finditer(nested_field_pattern, s): + name = match.group(1) + start_idx = match.end() + depth = 1 + for i in range(start_idx, len(s)): + if s[i] == "(": + depth += 1 + elif s[i] == ")": + depth -= 1 + if depth == 0: + end_idx = i + break + content = s[start_idx:end_idx] + nested_fields = parse_tensor_dict_string(f"TensorDict({content})") + fields[name] = nested_fields + + # Parse batch size + batch_size_matches = re.findall(batch_size_pattern, s) + if batch_size_matches: + batch_size_match = batch_size_matches[-1] # Take the last match + if batch_size_match[1]: + batch_size = [int(x) for x in batch_size_match[1].split(", ")] + else: + batch_size = [] + else: + raise ValueError("Batch size not found in the string") + # Parse device + device_matches = re.findall(device_pattern, s) + if device_matches: + device = device_matches[-1] # Take the last match + if device == "None": + device = None + else: + device = torch.device(device) + else: + raise ValueError("Device not found in the string") + tensor_dict = TensorDict(fields, batch_size=torch.Size(batch_size), device=device) + return tensor_dict diff --git a/test/test_utils.py b/test/test_utils.py index 88313eeae..1e0f5caca 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -16,6 +16,7 @@ _make_cache_key, convert_ellipsis_to_idx, isin, + parse_tensor_dict_string, remove_duplicates, ) @@ -327,6 +328,33 @@ def test_remove_duplicates_2dims(dim, device): assert (output_tensordict == expected_output).all() +def test_parse_tensor_dict_string(): + td = TensorDict(a=0) + assert str(td) == str(parse_tensor_dict_string(str(td))) + td = TensorDict(a=[0], batch_size=[1]) + assert str(td) == str(parse_tensor_dict_string(str(td))) + td = TensorDict(a=[[0] * 2], batch_size=[1, 2]) + assert str(td) == str(parse_tensor_dict_string(str(td))) + td = TensorDict( + a=[[0] * 2], b=TensorDict(c=[[1] * 2], batch_size=[1, 2]), batch_size=[1] + ) + assert str(td) == str(parse_tensor_dict_string(str(td))) + + td = TensorDict( + a=[[0] * 2], + b=TensorDict(c=[[1] * 2], batch_size=[1, 2]), + batch_size=[1], + device="cpu", + ) + assert str(td) == str(parse_tensor_dict_string(str(td))) + td = TensorDict( + a=[[0] * 2], + b=TensorDict(c=[[1] * 2], batch_size=[1, 2], device="cpu"), + batch_size=[1], + ) + assert str(td) == str(parse_tensor_dict_string(str(td))) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) From f154093943d947e5320ad131009b84fd76905a6e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 10 Sep 2024 16:27:26 +0100 Subject: [PATCH 02/11] [Feature] Better typing for tensorclass ghstack-source-id: 1000aeadc067a9c5f9f206f652abd74172012e29 Pull Request resolved: https://github.com/pytorch/tensordict/pull/983 --- tensordict/base.py | 4 +- tensordict/tensorclass.py | 95 +++++++++++++++++++++++++++++++-------- 2 files changed, 79 insertions(+), 20 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index b33dce266..51d5180a7 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -10023,7 +10023,9 @@ def _sync_all(self): if not is_dynamo_compiling() and torch.cuda.is_initialized(): torch.cuda.synchronize() elif _has_mps: - torch.mps.synchronize() + mps = getattr(torch, "mps", None) + if mps is not None: + mps.synchronize() def is_floating_point(self): for item in self.values(include_nested=True, leaves_only=True): diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 9a8c6cf3f..f28094e48 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -24,7 +24,16 @@ from dataclasses import dataclass from pathlib import Path from textwrap import indent -from typing import Any, Callable, get_type_hints, List, Sequence, Type, TypeVar +from typing import ( + Any, + Callable, + get_type_hints, + List, + overload, + Sequence, + Type, + TypeVar, +) import numpy as np import orjson as json @@ -63,6 +72,18 @@ except ImportError: # torch 2.0 from torch._dynamo import is_compiling as is_dynamo_compiling +try: + from typing import dataclass_transform +except ImportError: + + def dataclass_transform(*args, **kwargs): + """No-op. + + Placeholder for dataclass_transform (python<3.11). + """ + return lambda cls: cls + + T = TypeVar("T", bound=TensorDictBase) # We use an abstract AnyType instead of Any because Any isn't recognised as a type for python < 3.10 major, minor = sys.version_info[:2] @@ -327,14 +348,64 @@ def is_non_tensor(obj): return isinstance(obj, (NonTensorData, NonTensorStack)) -class tensorclass: +class _tensorclass_dec: + def __new__(cls, autocast: bool = False): + if not isinstance(autocast, bool): + clz = autocast + self = super().__new__(cls) + self.__init__(autocast=False) + return self.__call__(clz) + return super().__new__(cls) + + def __init__(self, autocast: bool): + self.autocast = autocast + + @dataclass_transform() + def __call__(self, cls): + clz = _tensorclass(cls) + clz.autocast = self.autocast + return clz + + +@overload +def tensorclass(autocast: bool = False) -> _tensorclass_dec: ... + + +@overload +def tensorclass(cls: T) -> T: ... + + +@dataclass_transform() +def tensorclass(*args, **kwargs): """A decorator to create :obj:`tensorclass` classes. - :obj:`tensorclass` classes are specialized :obj:`dataclass` instances that + ``tensorclass`` classes are specialized :func:`dataclasses.dataclass` instances that can execute some pre-defined tensor operations out of the box, such as indexing, item assignment, reshaping, casting to device or storage and many others. + Args: + autocast (bool, optional): if ``True``, the types indicated will be enforced when an argument is set. + Defaults to ``False``. + + tensorclass can be used with or without arguments: + Examples: + >>> @tensorclass + ... class X: + ... y: torch.Tensor + >>> X(1).y + 1 + >>> @tensorclass(autocast=False) + ... class X: + ... y: torch.Tensor + >>> X(1).y + 1 + >>> @tensorclass(autocast=True) + ... class X: + ... y: torch.Tensor + >>> X(1).y + torch.tensor(1) + Examples: >>> from tensordict import tensorclass >>> import torch @@ -383,24 +454,10 @@ class tensorclass: """ - - def __new__(cls, autocast: bool = False): - if not isinstance(autocast, bool): - clz = autocast - self = super().__new__(cls) - self.__init__(autocast=False) - return self.__call__(clz) - return super().__new__(cls) - - def __init__(self, autocast: bool): - self.autocast = autocast - - def __call__(self, cls): - clz = _tensorclass(cls) - clz.autocast = self.autocast - return clz + return _tensorclass_dec(*args, **kwargs) +@dataclass_transform() def _tensorclass(cls: T) -> T: def __torch_function__( cls, From 53747380f21acd2b1303100e3cfd60dd19c48d41 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 10 Sep 2024 16:27:27 +0100 Subject: [PATCH 03/11] [Feature] Frozen tensorclass ghstack-source-id: ae7e6170eebeb7026a230596f39f704971e0fc06 Pull Request resolved: https://github.com/pytorch/tensordict/pull/984 --- tensordict/tensorclass.py | 36 ++++++++++++++++++++++-------- test/test_tensorclass.py | 46 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 9 deletions(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index f28094e48..aaccb8869 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -349,26 +349,31 @@ def is_non_tensor(obj): class _tensorclass_dec: - def __new__(cls, autocast: bool = False): + def __new__(cls, autocast: bool = False, frozen: bool = False): if not isinstance(autocast, bool): clz = autocast self = super().__new__(cls) - self.__init__(autocast=False) + self.__init__(autocast=False, frozen=False) return self.__call__(clz) return super().__new__(cls) - def __init__(self, autocast: bool): + def __init__(self, autocast: bool = False, frozen: bool = False): self.autocast = autocast + self.frozen = frozen @dataclass_transform() def __call__(self, cls): - clz = _tensorclass(cls) + clz = _tensorclass(cls, frozen=self.frozen) clz.autocast = self.autocast return clz @overload -def tensorclass(autocast: bool = False) -> _tensorclass_dec: ... +def tensorclass(autocast: bool = False, frozen: bool = False) -> _tensorclass_dec: ... + + +@overload +def tensorclass(cls: T) -> T: ... @overload @@ -387,6 +392,9 @@ def tensorclass(*args, **kwargs): Args: autocast (bool, optional): if ``True``, the types indicated will be enforced when an argument is set. Defaults to ``False``. + frozen (bool, optional): if ``True``, the content of the tensorclass cannot be modified. This argument is + provided to dataclass-compatibility, a similar behavior can be obtained through the `lock` argument in + the class constructor. Defaults to ``False``. tensorclass can be used with or without arguments: Examples: @@ -458,7 +466,7 @@ def tensorclass(*args, **kwargs): @dataclass_transform() -def _tensorclass(cls: T) -> T: +def _tensorclass(cls: T, *, frozen) -> T: def __torch_function__( cls, func: Callable, @@ -494,7 +502,7 @@ def __torch_function__( _is_non_tensor = getattr(cls, "_is_non_tensor", False) - cls = dataclass(cls) + cls = dataclass(cls, frozen=frozen) expected_keys = cls.__expected_keys__ = set(cls.__dataclass_fields__) for attr in expected_keys: @@ -509,7 +517,7 @@ def __torch_function__( delattr(cls, field.name) _get_type_hints(cls) - cls.__init__ = _init_wrapper(cls.__init__) + cls.__init__ = _init_wrapper(cls.__init__, frozen) cls._from_tensordict = classmethod(_from_tensordict) cls.from_tensordict = cls._from_tensordict if not hasattr(cls, "__torch_function__"): @@ -672,7 +680,7 @@ def _from_tensordict_with_none(tc, tensordict): ) -def _init_wrapper(__init__: Callable) -> Callable: +def _init_wrapper(__init__: Callable, frozen) -> Callable: init_sig = inspect.signature(__init__) params = list(init_sig.parameters.values()) # drop first entry of params which corresponds to self and isn't passed by the user @@ -685,8 +693,11 @@ def wrapper( batch_size: Sequence[int] | torch.Size | int = None, device: DeviceType | None = None, names: List[str] | None = None, + lock: bool | None = None, **kwargs, ): + if lock is None: + lock = frozen if not is_dynamo_compiling(): # zip not supported by dynamo @@ -741,6 +752,13 @@ def wrapper( for key, value in kwargs.items() } __init__(self, **kwargs) + if frozen: + local_setattr = _setattr_wrapper(self.__setattr__, self.__expected_keys__) + for key, val in kwargs.items(): + local_setattr(self, key, val) + del self.__dict__[key] + if lock: + self._tensordict.lock_() new_params = [ inspect.Parameter("batch_size", inspect.Parameter.KEYWORD_ONLY), diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 7bbd69342..2e6dce038 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -559,6 +559,52 @@ class MyDataNested: assert (full_like_tc.y.X == 9).all() assert full_like_tc.z == data.z == z + def test_frozen(self): + + @tensorclass(frozen=True, autocast=True) + class X: + y: torch.Tensor + + x = X(y=1) + assert isinstance(x.y, torch.Tensor) + _ = {x: 0} + assert x.is_locked + with pytest.raises(RuntimeError, match="locked"): + x.y = 0 + + @tensorclass(frozen=False, autocast=True) + class X: + y: torch.Tensor + + x = X(y=1) + assert isinstance(x.y, torch.Tensor) + with pytest.raises(TypeError, match="unhashable"): + _ = {x: 0} + assert not x.is_locked + x.y = 0 + + @tensorclass(frozen=True, autocast=False) + class X: + y: torch.Tensor + + x = X(y="a string!") + assert isinstance(x.y, str) + _ = {x: 0} + assert x.is_locked + with pytest.raises(RuntimeError, match="locked"): + x.y = 0 + + @tensorclass(frozen=False, autocast=False) + class X: + y: torch.Tensor + + x = X(y="a string!") + assert isinstance(x.y, str) + with pytest.raises(TypeError, match="unhashable"): + _ = {x: 0} + assert not x.is_locked + x.y = 0 + @pytest.mark.parametrize("from_torch", [True, False]) def test_gather(self, from_torch): @tensorclass From da98f1ef6e3241e2dc433e39d7395443f23eeb09 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 11 Sep 2024 11:26:47 +0100 Subject: [PATCH 04/11] [Performance] Faster `__setitem__` (#985) --- tensordict/_td.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index d18e37e39..619c06b2b 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -860,7 +860,11 @@ def __setitem__( if isinstance(value, (TensorDictBase, dict)): indexed_bs = _getitem_batch_size(self.batch_size, index) if isinstance(value, dict): - value = self.from_dict_instance(value, batch_size=indexed_bs) + value = self.from_dict_instance( + value, batch_size=indexed_bs, device=self.device + ) + elif value.device != self.device: + value = value.to(self.device) # value = self.empty(recurse=True)[index].update(value) if value.batch_size != indexed_bs: if value.shape == indexed_bs[-len(value.shape) :]: @@ -883,7 +887,7 @@ def __setitem__( for value_key, item in value.items(): if value_key in keys: self._set_at_str( - value_key, item, index, validated=False, non_blocking=False + value_key, item, index, validated=True, non_blocking=False ) else: if subtd is None: @@ -3129,6 +3133,7 @@ def _exclude( # self._maybe_set_shared_attributes(result) return result + # @cache def keys( self, include_nested: bool = False, From f96e40ce8c951c5894b1887a6c939cae94ad9eba Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 12 Sep 2024 10:03:19 +0100 Subject: [PATCH 05/11] [CI] Add aarch64-linux wheels (#987) --- .../workflows/build-wheels-aarch64-linux.yml | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 .github/workflows/build-wheels-aarch64-linux.yml diff --git a/.github/workflows/build-wheels-aarch64-linux.yml b/.github/workflows/build-wheels-aarch64-linux.yml new file mode 100644 index 000000000..6bd4fd1ac --- /dev/null +++ b/.github/workflows/build-wheels-aarch64-linux.yml @@ -0,0 +1,51 @@ +name: Build Aarch64 Linux Wheels + +on: + pull_request: + push: + branches: + - nightly + - main + - release/* + tags: + # NOTE: Binary build pipelines should only get triggered on release candidate builds + # Release candidate tags look like: v1.11.0-rc1 + - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ + workflow_dispatch: + +permissions: + id-token: write + contents: read + +jobs: + generate-matrix: + uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main + with: + package-type: wheel + os: linux-aarch64 + test-infra-repository: pytorch/test-infra + test-infra-ref: main + with-cuda: disable + build: + needs: generate-matrix + strategy: + fail-fast: false + matrix: + include: + - repository: pytorch/tensordict + smoke-test-script: test/smoke_test.py + package-name: tensordict + name: pytorch/tensordict + uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main + with: + repository: ${{ matrix.repository }} + ref: "" + test-infra-repository: pytorch/test-infra + test-infra-ref: main + build-matrix: ${{ needs.generate-matrix.outputs.matrix }} + package-name: ${{ matrix.package-name }} + smoke-test-script: ${{ matrix.smoke-test-script }} + trigger-event: ${{ github.event_name }} + env-var-script: .github/scripts/version_script.sh + architecture: aarch64 + setup-miniconda: false From d4f8eeee516fdf4a0b42e4adf89e2b8565acc1e4 Mon Sep 17 00:00:00 2001 From: Albert Bou Date: Thu, 12 Sep 2024 02:46:01 -0700 Subject: [PATCH 06/11] [Feature] Unify composite dist method signatures with other dists (#981) Co-authored-by: Vincent Moens Co-authored-by: Vincent Moens --- tensordict/nn/distributions/composite.py | 143 +++++++++++++++++++++-- tensordict/nn/probabilistic.py | 14 ++- test/test_nn.py | 124 +++++++++++++++++++- 3 files changed, 269 insertions(+), 12 deletions(-) diff --git a/tensordict/nn/distributions/composite.py b/tensordict/nn/distributions/composite.py index 3b59ab32f..629164b16 100644 --- a/tensordict/nn/distributions/composite.py +++ b/tensordict/nn/distributions/composite.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import warnings + import torch from tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey, unravel_key, unravel_keys @@ -13,9 +15,9 @@ class CompositeDistribution(d.Distribution): """A composition of distributions. - Groups distributions together with the TensorDict interface. All methods - (``log_prob``, ``cdf``, ``icdf``, ``rsample``, ``sample`` etc.) will return a - tensordict, possibly modified in-place if the input was a tensordict. + Groups distributions together with the TensorDict interface. Methods + (``log_prob_composite``, ``entropy_composite``, ``cdf``, ``icdf``, ``rsample``, ``sample`` etc.) + will return a tensordict, possibly modified in-place if the input was a tensordict. Args: params (TensorDictBase): a nested key-tensor map where the root entries @@ -32,8 +34,17 @@ class CompositeDistribution(d.Distribution): will be used. extra_kwargs (Dict[NestedKey, Dict]): a possibly incomplete dictionary of extra keyword arguments for the distributions to be built. + aggregate_probabilities (bool): if ``True``, the :meth:`~.log_prob` and :meth:`~.entropy` methods will + sum the probabilities and entropies of the individual distributions and return a single tensor. + If ``False``, the single log-probabilities will be registered in the input tensordict (for :meth:`~.log_prob`) + or retuned as leaves of the output tensordict (for :meth:`~.entropy`). + This parameter can be overridden at runtime by passing the ``aggregate_probabilities`` argument to + ``log_prob`` and ``entropy``. + Defaults to ``False``. log_prob_key (NestedKey, optional): key where to write the log_prob. Defaults to `'sample_log_prob'`. + entropy_key (NestedKey, optional): key where to write the entropy. + Defaults to `'entropy'`. .. note:: In this distribution class, the batch-size of the input tensordict containing the params (``params``) is indicative of the batch_shape of the distribution. For instance, @@ -73,7 +84,9 @@ def __init__( *, name_map: dict | None = None, extra_kwargs=None, + aggregate_probabilities: bool | None = None, log_prob_key: NestedKey = "sample_log_prob", + entropy_key: NestedKey = "entropy", ): self._batch_shape = params.shape if extra_kwargs is None: @@ -105,6 +118,25 @@ def __init__( dists[write_name] = dist self.dists = dists self.log_prob_key = log_prob_key + self.entropy_key = entropy_key + + self.aggregate_probabilities = aggregate_probabilities + + @property + def aggregate_probabilities(self): + aggregate_probabilities = self._aggregate_probabilities + if aggregate_probabilities is None: + warnings.warn( + "The default value of `aggregate_probabilities` will change from `False` to `True` in v0.7. " + "Please pass this value explicitly to avoid this warning.", + FutureWarning, + ) + aggregate_probabilities = self._aggregate_probabilities = False + return aggregate_probabilities + + @aggregate_probabilities.setter + def aggregate_probabilities(self, value): + self._aggregate_probabilities = value def sample(self, shape=None) -> TensorDictBase: if shape is None: @@ -147,19 +179,114 @@ def rsample(self, shape=None) -> TensorDictBase: shape + self.batch_shape, ) - def log_prob(self, sample: TensorDictBase) -> TensorDictBase: - """Writes a ``_log_prob entry`` for each sample in the input tensordict, along with a ``"sample_log_prob"`` entry with the summed log-prob.""" + def log_prob( + self, sample: TensorDictBase, *, aggregate_probabilities: bool | None = None + ) -> torch.Tensor | TensorDictBase: # noqa: D417 + """Computes and returns the summed log-prob. + + Args: + sample (TensorDictBase): the sample to compute the log probability. + + Keyword Args: + aggregate_probabilities (bool, optional): if provided, overrides the default ``aggregate_probabilities`` + from the class. + + If ``self.aggregate_probabilities`` is ``True``, this method will return a single tensor with + the summed log-probabilities. If ``self.aggregate_probabilities`` is ``False``, this method will + call the `:meth:`~.log_prob_composite` method and return a tensordict with the log-probabilities + of each sample in the input tensordict along with a ``sample_log_prob`` entry with the summed + log-prob. In both cases, the output shape will be the shape of the input tensordict. + """ + if aggregate_probabilities is None: + aggregate_probabilities = self.aggregate_probabilities + if not aggregate_probabilities: + return self.log_prob_composite(sample, include_sum=True) + slp = 0.0 + for name, dist in self.dists.items(): + lp = dist.log_prob(sample.get(name)) + if lp.ndim > sample.ndim: + lp = lp.flatten(sample.ndim, -1).sum(-1) + slp = slp + lp + return slp + + def log_prob_composite( + self, sample: TensorDictBase, include_sum=True + ) -> TensorDictBase: + """Writes a ``_log_prob`` entry for each sample in the input tensordict, along with a ``"sample_log_prob"`` entry with the summed log-prob. + + This method is called by the :meth:`~.log_prob` method when ``self.aggregate_probabilities`` is ``False``. + """ slp = 0.0 d = {} for name, dist in self.dists.items(): d[_add_suffix(name, "_log_prob")] = lp = dist.log_prob(sample.get(name)) - while lp.ndim > sample.ndim: - lp = lp.sum(-1) + if lp.ndim > sample.ndim: + lp = lp.flatten(sample.ndim, -1).sum(-1) slp = slp + lp - d[self.log_prob_key] = slp + if include_sum: + d[self.log_prob_key] = slp sample.update(d) return sample + def entropy( + self, samples_mc: int = 1, *, aggregate_probabilities: bool | None = None + ) -> torch.Tensor | TensorDictBase: # noqa: D417 + """Computes and returns the summed entropies. + + Args: + samples_mc (int): the number samples to draw if the entropy does not have a closed form formula. + Defaults to ``1``. + + Keyword Args: + aggregate_probabilities (bool, optional): if provided, overrides the default ``aggregate_probabilities`` + from the class. + + If ``self.aggregate_probabilities`` is ``True``, this method will return a single tensor with + the summed entropies. If ``self.aggregate_probabilities`` is ``False``, this method will call + the `:meth:`~.entropy_composite` method and return a tensordict with the entropies of each sample + in the input tensordict along with an ``entropy`` entry with the summed entropy. In both cases, + the output shape will match the shape of the distribution ``batch_shape``. + """ + if aggregate_probabilities is None: + aggregate_probabilities = self.aggregate_probabilities + if not aggregate_probabilities: + return self.entropy_composite(samples_mc, include_sum=True) + se = 0.0 + for _, dist in self.dists.items(): + try: + e = dist.entropy() + except NotImplementedError: + x = dist.rsample((samples_mc,)) + e = -dist.log_prob(x).mean(0) + if e.ndim > len(self.batch_shape): + e = e.flatten(len(self.batch_shape), -1).sum(-1) + se = se + e + return se + + def entropy_composite(self, samples_mc=1, include_sum=True) -> TensorDictBase: + """Writes a ``_entropy`` entry for each sample in the input tensordict, along with a ``"entropy"`` entry with the summed entropies. + + This method is called by the :meth:`~.entropy` method when ``self.aggregate_probabilities`` is ``False``. + """ + se = 0.0 + d = {} + for name, dist in self.dists.items(): + try: + e = dist.entropy() + except NotImplementedError: + x = dist.rsample((samples_mc,)) + e = -dist.log_prob(x).mean(0) + d[_add_suffix(name, "_entropy")] = e + if e.ndim > len(self.batch_shape): + e = e.flatten(len(self.batch_shape), -1).sum(-1) + se = se + e + if include_sum: + d[self.entropy_key] = se + return TensorDict( + d, + self.batch_shape, + ) + def cdf(self, sample: TensorDictBase) -> TensorDictBase: cdfs = { _add_suffix(name, "_cdf"): dist.cdf(sample.get(name)) diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index ec61ac344..91fb24490 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -408,8 +408,8 @@ def log_prob(self, tensordict): """Writes the log-probability of the distribution sample.""" dist = self.get_dist(tensordict) if isinstance(dist, CompositeDistribution): - tensordict = dist.log_prob(tensordict) - return tensordict.get("sample_log_prob") + td = dist.log_prob(tensordict, aggregate_probabilities=False) + return td.get(dist.log_prob_key) else: return dist.log_prob(tensordict.get(self.out_keys[0])) @@ -439,7 +439,15 @@ def forward( if isinstance(out_tensors, TensorDictBase): tensordict_out.update(out_tensors) if self.return_log_prob: - tensordict_out = dist.log_prob(tensordict_out) + kwargs = {} + if isinstance(dist, CompositeDistribution): + kwargs = {"aggregate_probabilities": False} + log_prob = dist.log_prob(tensordict_out, **kwargs) + if log_prob is not tensordict_out: + # Composite dists return the tensordict_out directly when aggrgate_prob is False + tensordict_out.set(self.log_prob_key, log_prob) + else: + tensordict_out.rename_key_(dist.log_prob_key, self.log_prob_key) else: if isinstance(out_tensors, Tensor): out_tensors = (out_tensors,) diff --git a/test/test_nn.py b/test/test_nn.py index a3d4bd2a8..7523d0d35 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -3299,6 +3299,7 @@ def test_log_prob(self): }, [3], ) + # Capture the warning for upcoming changes in aggregate_probabilities dist = CompositeDistribution( params, distribution_map={ @@ -3307,10 +3308,131 @@ def test_log_prob(self): }, extra_kwargs={("nested", "disc"): {"temperature": torch.tensor(1.0)}}, ) + + sample = dist.rsample((4,)) + with pytest.warns(FutureWarning, match="aggregate_probabilities"): + lp = dist.log_prob(sample) + + dist = CompositeDistribution( + params, + distribution_map={ + "cont": distributions.Normal, + ("nested", "disc"): distributions.RelaxedOneHotCategorical, + }, + extra_kwargs={("nested", "disc"): {"temperature": torch.tensor(1.0)}}, + aggregate_probabilities=True, + ) + sample = dist.rsample((4,)) - sample = dist.log_prob(sample) + lp = dist.log_prob(sample) + assert isinstance(lp, torch.Tensor) + assert lp.requires_grad + + def test_log_prob_composite(self): + params = TensorDict( + { + "cont": { + "loc": torch.randn(3, 4, requires_grad=True), + "scale": torch.rand(3, 4, requires_grad=True), + }, + ("nested", "disc"): {"logits": torch.randn(3, 10, requires_grad=True)}, + }, + [3], + ) + # Capture the warning for upcoming changes in aggregate_probabilities + dist = CompositeDistribution( + params, + distribution_map={ + "cont": distributions.Normal, + ("nested", "disc"): distributions.RelaxedOneHotCategorical, + }, + extra_kwargs={("nested", "disc"): {"temperature": torch.tensor(1.0)}}, + ) + with pytest.warns(FutureWarning, match="aggregate_probabilities"): + dist.log_prob(dist.sample()) + dist = CompositeDistribution( + params, + distribution_map={ + "cont": distributions.Normal, + ("nested", "disc"): distributions.RelaxedOneHotCategorical, + }, + extra_kwargs={("nested", "disc"): {"temperature": torch.tensor(1.0)}}, + aggregate_probabilities=False, + ) + sample = dist.rsample((4,)) + sample = dist.log_prob_composite(sample, include_sum=True) assert sample.get("cont_log_prob").requires_grad assert sample.get(("nested", "disc_log_prob")).requires_grad + assert "sample_log_prob" in sample.keys() + + def test_entropy(self): + params = TensorDict( + { + "cont": { + "loc": torch.randn(3, 4, requires_grad=True), + "scale": torch.rand(3, 4, requires_grad=True), + }, + ("nested", "disc"): {"logits": torch.randn(3, 10, requires_grad=True)}, + }, + [3], + ) + # Capture the warning for upcoming changes in aggregate_probabilities + dist = CompositeDistribution( + params, + distribution_map={ + "cont": distributions.Normal, + ("nested", "disc"): distributions.Categorical, + }, + ) + with pytest.warns(FutureWarning, match="aggregate_probabilities"): + dist.log_prob(dist.sample()) + dist = CompositeDistribution( + params, + distribution_map={ + "cont": distributions.Normal, + ("nested", "disc"): distributions.Categorical, + }, + aggregate_probabilities=True, + ) + ent = dist.entropy() + assert ent.shape == params.shape == dist._batch_shape + assert isinstance(ent, torch.Tensor) + assert ent.requires_grad + + def test_entropy_composite(self): + params = TensorDict( + { + "cont": { + "loc": torch.randn(3, 4, requires_grad=True), + "scale": torch.rand(3, 4, requires_grad=True), + }, + ("nested", "disc"): {"logits": torch.randn(3, 10, requires_grad=True)}, + }, + [3], + ) + # Capture the warning for upcoming changes in aggregate_probabilities + dist = CompositeDistribution( + params, + distribution_map={ + "cont": distributions.Normal, + ("nested", "disc"): distributions.Categorical, + }, + ) + with pytest.warns(FutureWarning, match="aggregate_probabilities"): + dist.log_prob(dist.sample()) + dist = CompositeDistribution( + params, + distribution_map={ + "cont": distributions.Normal, + ("nested", "disc"): distributions.Categorical, + }, + aggregate_probabilities=False, + ) + sample = dist.entropy() + assert sample.shape == params.shape == dist._batch_shape + assert sample.get("cont_entropy").requires_grad + assert sample.get(("nested", "disc_entropy")).requires_grad + assert "entropy" in sample.keys() def test_cdf(self): params = TensorDict( From 187541b3fd391ebd54e4d4d8e7a6bb2f6d511cc9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 13 Sep 2024 09:20:22 +0100 Subject: [PATCH 07/11] [Refactor] Use IntEnum for interaction types (#989) --- tensordict/nn/probabilistic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 91fb24490..4c33990b0 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -7,7 +7,7 @@ import re import warnings -from enum import auto, Enum +from enum import auto, IntEnum from textwrap import indent from typing import Any, Callable, Dict, List, Optional from warnings import warn @@ -30,7 +30,7 @@ __all__ = ["ProbabilisticTensorDictModule", "ProbabilisticTensorDictSequential"] -class InteractionType(Enum): +class InteractionType(IntEnum): """A list of possible interaction types with a distribution. MODE, MEDIAN and MEAN point to the property / attribute with the same name. From bbcb9fc7e7ec254beec37bdbf0395022224c8c3c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 12 Sep 2024 11:17:06 +0100 Subject: [PATCH 08/11] [BugFix] Fix td device sync when error is raised ghstack-source-id: d0e810c71ca1c9945561ca5a9e71cb71445095e4 Pull Request resolved: https://github.com/pytorch/tensordict/pull/988 --- tensordict/_td.py | 47 ++++++++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 619c06b2b..0d9cff420 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -272,29 +272,34 @@ def __init__( call_sync = non_blocking is None if call_sync: _device_recorder.mark() - self._device = device + try: + self._device = device - if source is None: - source = {} - if not isinstance(source, (TensorDictBase, dict)): - raise ValueError( - "A TensorDict source is expected to be a TensorDictBase " - f"sub-type or a dictionary, found type(source)={type(source)}." - ) - self._batch_size = self._parse_batch_size(source, batch_size) - # TODO: this breaks when stacking tensorclasses with dynamo - if not is_dynamo_compiling(): - self.names = names - - for key, value in source.items(): - self.set(key, value, non_blocking=sub_non_blocking) - if call_sync: - if _device_recorder.has_transfer(): - self._sync_all() - _device_recorder.unmark() + if source is None: + source = {} + if not isinstance(source, (TensorDictBase, dict)): + raise ValueError( + "A TensorDict source is expected to be a TensorDictBase " + f"sub-type or a dictionary, found type(source)={type(source)}." + ) + self._batch_size = self._parse_batch_size(source, batch_size) + # TODO: this breaks when stacking tensorclasses with dynamo + if not is_dynamo_compiling(): + self.names = names - if lock: - self.lock_() + for key, value in source.items(): + self.set(key, value, non_blocking=sub_non_blocking) + if call_sync: + if _device_recorder.has_transfer(): + self._sync_all() + _device_recorder.unmark() + call_sync = False + + if lock: + self.lock_() + finally: + if call_sync: + _device_recorder.unmark() @classmethod def _new_unsafe( From fc323e58ac7e7572b06f6d62e3207a09b81d8e9d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 16 Sep 2024 23:47:43 +0100 Subject: [PATCH 09/11] [Feature] Cudagraphs (#986) --- benchmarks/compile/compile_td_test.py | 16 +- benchmarks/compile/tensordict_nn_test.py | 18 +- docs/source/reference/nn.rst | 1 + tensordict/nn/__init__.py | 2 + tensordict/nn/common.py | 36 ++- tensordict/nn/cudagraphs.py | 362 +++++++++++++++++++++++ tensordict/nn/functional_modules.py | 13 +- test/test_compile.py | 6 +- test/test_nn.py | 189 +++++++++++- 9 files changed, 614 insertions(+), 29 deletions(-) create mode 100644 tensordict/nn/cudagraphs.py diff --git a/benchmarks/compile/compile_td_test.py b/benchmarks/compile/compile_td_test.py index 297ffcf33..eb98fac8d 100644 --- a/benchmarks/compile/compile_td_test.py +++ b/benchmarks/compile/compile_td_test.py @@ -106,7 +106,7 @@ def get_flat_tc(): # Tests runtime of a simple arithmetic op over a highly nested tensordict -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4") +@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") @pytest.mark.parametrize("mode", ["compile", "eager"]) @pytest.mark.parametrize("dict_type", ["tensordict", "pytree"]) def test_compile_add_one_nested(mode, dict_type, benchmark): @@ -128,7 +128,7 @@ def test_compile_add_one_nested(mode, dict_type, benchmark): # Tests the speed of copying a nested tensordict -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4") +@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") @pytest.mark.parametrize("mode", ["compile", "eager"]) @pytest.mark.parametrize("dict_type", ["tensordict", "pytree"]) def test_compile_copy_nested(mode, dict_type, benchmark): @@ -150,7 +150,7 @@ def test_compile_copy_nested(mode, dict_type, benchmark): # Tests runtime of a simple arithmetic op over a flat tensordict -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4") +@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") @pytest.mark.parametrize("mode", ["compile", "eager"]) @pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"]) def test_compile_add_one_flat(mode, dict_type, benchmark): @@ -177,7 +177,7 @@ def test_compile_add_one_flat(mode, dict_type, benchmark): benchmark(func, td) -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4") +@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") @pytest.mark.parametrize("mode", ["eager", "compile"]) @pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"]) def test_compile_add_self_flat(mode, dict_type, benchmark): @@ -207,7 +207,7 @@ def test_compile_add_self_flat(mode, dict_type, benchmark): # Tests the speed of copying a flat tensordict -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4") +@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") @pytest.mark.parametrize("mode", ["compile", "eager"]) @pytest.mark.parametrize("dict_type", ["tensordict", "pytree"]) def test_compile_copy_flat(mode, dict_type, benchmark): @@ -235,7 +235,7 @@ def test_compile_copy_flat(mode, dict_type, benchmark): # Tests the speed of assigning entries to an empty tensordict -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4") +@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") @pytest.mark.parametrize("mode", ["compile", "eager"]) @pytest.mark.parametrize("dict_type", ["tensordict", "pytree"]) def test_compile_assign_and_add(mode, dict_type, benchmark): @@ -264,7 +264,7 @@ def test_compile_assign_and_add(mode, dict_type, benchmark): # Tests the speed of assigning entries to a lazy stacked tensordict -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4") +@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") @pytest.mark.skipif( torch.cuda.is_available(), reason="max recursion depth error with cuda" ) @@ -285,7 +285,7 @@ def test_compile_assign_and_add_stack(mode, benchmark): # Tests indexing speed -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4") +@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") @pytest.mark.parametrize("mode", ["compile", "eager"]) @pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"]) @pytest.mark.parametrize("index_type", ["tensor", "slice", "int"]) diff --git a/benchmarks/compile/tensordict_nn_test.py b/benchmarks/compile/tensordict_nn_test.py index 40571a31d..94110162a 100644 --- a/benchmarks/compile/tensordict_nn_test.py +++ b/benchmarks/compile/tensordict_nn_test.py @@ -49,7 +49,7 @@ def mlp(device, depth=2, num_cells=32, feature_dim=3): ) -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4") +@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") @pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"]) def test_mod_add(mode, benchmark): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -64,7 +64,7 @@ def test_mod_add(mode, benchmark): benchmark(module, td) -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4") +@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") @pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"]) def test_mod_wrap(mode, benchmark): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -80,7 +80,7 @@ def test_mod_wrap(mode, benchmark): benchmark(module, td) -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4") +@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") @pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"]) def test_mod_wrap_and_backward(mode, benchmark): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -104,7 +104,7 @@ def module_exec(td): benchmark(module_exec, td) -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4") +@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") @pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"]) def test_seq_add(mode, benchmark): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -129,7 +129,7 @@ def delhidden(td): benchmark(module, td) -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4") +@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") @pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"]) def test_seq_wrap(mode, benchmark): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -161,7 +161,7 @@ def delhidden(td): benchmark(module, td) -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4") +@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") @pytest.mark.slow @pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"]) def test_seq_wrap_and_backward(mode, benchmark): @@ -201,7 +201,7 @@ def module_exec(td): benchmark(module_exec, td) -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4") +@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") @pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"]) @pytest.mark.parametrize("functional", [False, True]) def test_func_call_runtime(mode, functional, benchmark): @@ -272,7 +272,7 @@ def call(x, td): benchmark(call, x) -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4") +@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") @pytest.mark.slow @pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"]) @pytest.mark.parametrize( @@ -354,7 +354,7 @@ def call(x, td): benchmark(call_vmap, x, td) -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4") +@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") @pytest.mark.slow @pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"]) @pytest.mark.parametrize("plain_decorator", [None, False, True]) diff --git a/docs/source/reference/nn.rst b/docs/source/reference/nn.rst index b2f9c0f77..943956117 100644 --- a/docs/source/reference/nn.rst +++ b/docs/source/reference/nn.rst @@ -197,6 +197,7 @@ to build distributions from network outputs and get summary statistics or sample ProbabilisticTensorDictModule TensorDictSequential TensorDictModuleWrapper + CudaGraphModule Functional ---------- diff --git a/tensordict/nn/__init__.py b/tensordict/nn/__init__.py index 788d6f63a..c7aa4f8f7 100644 --- a/tensordict/nn/__init__.py +++ b/tensordict/nn/__init__.py @@ -42,3 +42,5 @@ set_skip_existing, skip_existing, ) + +from .cudagraphs import CudaGraphModule diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 4e9112781..c55ba814b 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -240,15 +240,22 @@ def __init__( self.auto_batch_size = auto_batch_size def __call__(self, func: Callable) -> Callable: + + is_method = inspect.ismethod(func) or ( + inspect.isfunction(func) + and func.__code__.co_argcount > 0 + and func.__code__.co_varnames[0] == "self" + ) # sanity check for i, key in enumerate(inspect.signature(func).parameters): - if i == 0: + if (is_method or (key == "self")) and (i == 0): + is_method = True # skip self continue if key != "tensordict": raise RuntimeError( "the first argument of the wrapped function must be " - "named 'tensordict'." + f"named 'tensordict'. Got {key} instead." ) break # if the env variable was used, we can skip the wrapper altogether @@ -256,12 +263,21 @@ def __call__(self, func: Callable) -> Callable: return func @functools.wraps(func) - def wrapper(_self, *args: Any, **kwargs: Any) -> Any: + def wrapper(*args: Any, **kwargs: Any) -> Any: + if is_method: + _self = args[0] + args = args[1:] + else: + _self = None if not _dispatch_td_nn_modules(): return func(_self, *args, **kwargs) source = self.source if isinstance(source, str): + if _self is None: + raise RuntimeError( + "The in keys must be passed to dispatch when func is not a method but a function." + ) source = getattr(_self, source) tensordict = None if len(args): @@ -271,6 +287,10 @@ def wrapper(_self, *args: Any, **kwargs: Any) -> Any: tensordict_values = {} dest = self.dest if isinstance(dest, str): + if _self is None: + raise RuntimeError( + "The in keys must be passed to dispatch when func is not a method but a function." + ) dest = getattr(_self, dest) for key in source: expected_key = self.separator.join(_unravel_key_to_tuple(key)) @@ -293,10 +313,16 @@ def wrapper(_self, *args: Any, **kwargs: Any) -> Any: tensordict_values, batch_size=torch.Size([]) if not self.auto_batch_size else None, ) - out = func(_self, tensordict, *args, **kwargs) + if _self is not None: + out = func(_self, tensordict, *args, **kwargs) + else: + out = func(tensordict, *args, **kwargs) + out = tuple(out[key] for key in dest) return out[0] if len(out) == 1 else out - return func(_self, tensordict, *args, **kwargs) + if _self is not None: + return func(_self, tensordict, *args, **kwargs) + return func(tensordict, *args, **kwargs) return wrapper diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py new file mode 100644 index 000000000..55308ab3b --- /dev/null +++ b/tensordict/nn/cudagraphs.py @@ -0,0 +1,362 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import functools +import os +import warnings + +from textwrap import indent +from typing import Any, Callable, List + +import torch + +from tensordict._nestedkey import NestedKey +from tensordict.base import is_tensor_collection, TensorDictBase +from tensordict.nn.common import dispatch +from tensordict.nn.functional_modules import ( + _exclude_td_from_pytree, + PYTREE_REGISTERED_LAZY_TDS, + PYTREE_REGISTERED_TDS, +) +from tensordict.utils import strtobool +from torch import Tensor + +from torch.utils._pytree import SUPPORTED_NODES, tree_leaves, tree_map + + +class CudaGraphModule: + """A cudagraph wrapper for PyTorch callables. + + ``CudaGraphModule`` is a wrapper class that provides a user-friendly interface to CUDA graphs for PyTorch callables. + + .. warning:: ``CudaGraphModule`` is a prototype feature and its API restrictions are likely to change in the future. + + This class provides a user-friendly interface to cudagraphs, allowing for a fast, CPU-overhead free execution of + operations on GPU. + It runs essential checks for the inputs to the function, and gives an nn.Module-like API to run + + .. warning:: + This module requires the wrapped function to meet a few requirements. It is the user responsibility + to make sure that all of these are fullfilled. + + - The function cannot have dynamic control flow. For instance, the following code snippet will fail to be + wrapped in `CudaGraphModule`: + + >>> def func(x): + ... if x.norm() > 1: + ... return x + 1 + ... else: + ... return x - 1 + + Fortunately, PyTorch offers solutions in most cases: + + >>> def func(x): + ... return torch.where(x.norm() > 1, x + 1, x - 1) + + - The function must execute a code that can be exactly re-run using the same buffers. This means that + dynamic shapes (changing shape in the input or during the code execution) is not supported. In other words, + the input must have a constant shape. + - The output of the function must be detached. If a call to the optimizers is required, put it in the input + function. For instance, the following function is a valid operator: + + >>> def func(x, y): + ... optim.zero_grad() + ... loss_val = loss_fn(x, y) + ... loss_val.backward() + ... optim.step() + ... return loss_val.detach() + + - The input should not be differntiable. If you need to use `nn.Parameters` (or differentiable tensors in general), + just write a function that uses them as global values rather than passing them as input: + + >>> x = nn.Parameter(torch.randn(())) + >>> optim = Adam([x], lr=1) + >>> def func(): # right + ... optim.zero_grad() + ... (x+1).backward() + ... optim.step() + >>> def func(x): # wrong + ... optim.zero_grad() + ... (x+1).backward() + ... optim.step() + + - Args and kwargs that are tensors or tensordict may change (provided that device and shape match), but non-tensor + args and kwargs should not change. For instance, if the function receives a string input and the input is changed + at any point, the module will silently execute the code with the string used during the capture of the cudagraph. + The only supported keyword argument is `tensordict_out` in case the input is a tensordict. + + - If the module is a :class:`~tensordict.nn.TensorDictModuleBase` instance and the output id matches the input + id, then this identity will be preserved during a call to ``CudaGraphModule``. In all other cases, the output + will be cloned, irrespective of whether its elements match or do not match one of the inputs. + + .. warning:: ``CudaGraphModule`` is not an :class:`~torch.nn.Module` by design, to discourage gathering parameters + of the input module and passing them to an optimizer. + + Args: + module (Callable): a function that receives tensors (or tensordict) as input and outputs a + PyTreeable collection of tensors. If a tensordict is provided, the module can be run with keyword arguments + too (see example below). + warmup (int, optional): the number of warmup steps in case the module is compiled (compiled modules should be + run a couple of times before being captured by cudagraphs). Defaults to ``2`` for all modules. + in_keys (list of NestedKeys): the input keys, if the module takes a TensorDict as input. + Defaults to ``module.in_keys`` if this value exists, otherwise ``None``. + + .. note:: If ``in_keys`` is provided but empty, the module is assumed to receive a tensordict as input. + This is sufficient to make ``CudaGraphModule`` aware that the function should be treated as a + `TensorDictModule`, but keyword arguments will not be dispatched. See below for some examples. + + out_keys (list of NestedKeys): the output keys, if the module takes and outputs TensorDict as output. + Defaults to ``module.out_keys`` if this value exists, otherwise ``None``. + + Examples: + >>> # Wrap a simple function + >>> def func(x): + ... return x + 1 + >>> func = CudaGraphModule(func) + >>> x = torch.rand((), device='cuda') + >>> out = func(x) + >>> assert isinstance(out, torch.Tensor) + >>> assert out == x+1 + >>> # Wrap a tensordict module + >>> func = TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["y"]) + >>> func = CudaGraphModule(func) + >>> # This can be called either with a TensorDict or regular keyword arguments alike + >>> y = func(x=x) + >>> td = TensorDict(x=x) + >>> td = func(td) + + + """ + + _REQUIRES_GRAD_ERROR = ( + "CudaGraphModule cannot be part of a graph (or leaves variables). If you need tensors to " + "require gradients, please pass them as constant inputs to the module instead (e.g. " + "`def func(a, b, c=c): ...` where c is the tensor requiring gradients)." + ) + + def __init__( + self, + module: Callable[[List[Tensor] | TensorDictBase], None], + warmup: int = 2, + in_keys: List[NestedKey] = None, + out_keys: List[NestedKey] = None, + ): + self._has_cuda = torch.cuda.is_available() + if not self._has_cuda: + warnings.warn( + "This module is instantiated on a machine without CUDA device available. " + "We permit this usage, but calls to this instance will just pass through it and execute " + "the provided function without additional performance gain." + ) + self.module = module + self.counter = 0 + if not isinstance(warmup, int) or warmup < 1: + raise ValueError("warmup must be an integer greater than 0.") + self._warmup = warmup + + if hasattr(module, "in_keys"): + self.in_keys = module.in_keys + else: + self.in_keys = in_keys + if hasattr(module, "out_keys"): + self.out_keys = module.out_keys + else: + if out_keys is None and self.in_keys is not None: + out_keys = [] + self.out_keys = out_keys + self._is_tensordict_module = self.in_keys is not None + self._out_matches_in = None + for tdtype in PYTREE_REGISTERED_TDS + PYTREE_REGISTERED_LAZY_TDS: + if tdtype in SUPPORTED_NODES: + if not strtobool(os.environ.get("EXCLUDE_TD_FROM_PYTREE", "0")): + warnings.warn( + f"Tensordict is registered in PyTree. This is incompatible with {self.__class__.__name__}. " + f"Removing TDs from PyTree. To silence this warning, call tensordict.nn.functional_module._exclude_td_from_pytree().set() " + f"or set the environment variable `EXCLUDE_TD_FROM_PYTREE=1`. " + f"This operation is irreversible." + ) + _exclude_td_from_pytree().set() + + if self._is_tensordict_module: + + @dispatch(source=self.in_keys, dest=self.out_keys, auto_batch_size=False) + def _call( + tensordict: TensorDictBase, + *args, + tensordict_out: TensorDictBase | None = None, + **kwargs: Any, + ) -> Any: + if self.counter < self._warmup: + if tensordict_out is not None: + kwargs["tensordict_out"] = tensordict_out + out = self.module(tensordict, *args, **kwargs) + if self._out_matches_in is None: + self._out_matches_in = out is tensordict + self.counter += self._has_cuda + return out + elif self.counter == self._warmup: + if tensordict.device is None: + tensordict.apply(self._check_device_and_grad, filter_empty=True) + elif tensordict.device.type != "cuda": + raise ValueError( + "The input tensordict device must be of the 'cuda' type." + ) + + tree_map(self._check_non_tensor, (args, kwargs)) + + self.graph = torch.cuda.CUDAGraph() + self._tensordict = tensordict.copy() + with torch.cuda.graph(self.graph): + if tensordict_out is not None: + kwargs["tensordict_out"] = tensordict_out + out = self.module(self._tensordict, *args, **kwargs) + self.graph.replay() + + if not is_tensor_collection(out) and out is not None: + raise RuntimeError( + "The output of the function must be a tensordict, a tensorclass or None. Got " + f"type(out)={type(out)}." + ) + self._out = out + self.counter += 1 + if self._out_matches_in: + # We need to know what keys to update in out + if self.out_keys: + self._selected_keys = self.out_keys + else: + # Gather keys that have changed during execution + self._selected_keys = [] + + def check_tensor_id(name, t0, t1): + if t0 is not t1: + self._selected_keys.append(name) + + self._out.named_apply( + check_tensor_id, + tensordict, + default=None, + filter_empty=True, + ) + return tensordict.update( + self._out, keys_to_update=self._selected_keys + ) + if tensordict_out is not None: + return tensordict_out.update(out, clone=True) + return out.clone() if self._out is not None else None + else: + self._tensordict.update_(tensordict) + self.graph.replay() + if self._out_matches_in: + return tensordict.update( + self._out, keys_to_update=self._selected_keys + ) + if tensordict_out is not None: + return tensordict_out.update(self._out, clone=True) + return self._out.clone() if self._out is not None else None + + else: + + def _call(*args: torch.Tensor, **kwargs: torch.Tensor): + if self.counter < self._warmup: + out = self.module(*args, **kwargs) + self.counter += self._has_cuda + return out + elif self.counter == self._warmup: + self.graph = torch.cuda.CUDAGraph() + + def check_device_and_clone(x): + if isinstance(x, torch.Tensor) or is_tensor_collection(x): + if x.requires_grad: + raise RuntimeError(self._REQUIRES_GRAD_ERROR) + if x.device is None: + # Check device of leaves of tensordict + x.apply(self._check_device_and_grad, filter_empty=True) + + elif x.device.type != "cuda": + raise ValueError( + f"All tensors must be stored on CUDA. Got {x.device.type}." + ) + + return x.clone() + return x + + self._args, self._kwargs = tree_map( + check_device_and_clone, (args, kwargs) + ) + with torch.cuda.graph(self.graph): + out = self.module(*self._args, **self._kwargs) + self.graph.replay() + self._out = out + self.counter += 1 + # Check that there is not intersection between the indentity of inputs and outputs, otherwise warn + # user. + tree_leaves_input = tree_leaves((args, kwargs)) + tree_leaves_output = tree_leaves(out) + if not isinstance(tree_leaves_output, tuple): + tree_leaves_output = (tree_leaves_output,) + for inp in tree_leaves_input: + if isinstance(inp, torch.Tensor) and inp in tree_leaves_output: + warnings.warn( + "An input tensor is also present in the output of the wrapped function. " + f"{type(self).__name__} will not copy tensors in place: the output will be cloned " + f"and the identity between input and output will not match anymore. " + f"Make sure you don't rely on input-output identity further in the code." + ) + if isinstance(self._out, torch.Tensor) or self._out is None: + self._return_unchanged = ( + "clone" if self._out is not None else True + ) + return ( + self._out.clone() + if self._return_unchanged == "clone" + else self._out + ) + self._return_unchanged = False + return tree_map(lambda x: x.clone(), out) + else: + tree_map( + lambda x, y: x.copy_(y), + (self._args, self._kwargs), + (args, kwargs), + ) + self.graph.replay() + if self._return_unchanged == "clone": + return self._out.clone() + elif self._return_unchanged: + return self._out + return tree_map( + lambda x: x.clone() if x is not None else x, self._out + ) + + _call_func = functools.wraps(self.module)(_call) + self._call_func = _call_func + + @classmethod + def _check_device_and_grad(cls, x): + if isinstance(x, torch.Tensor): + if x.device.type != "cuda": + raise ValueError( + f"All tensors must be stored on CUDA. Got {x.device.type}." + ) + if x.requires_grad: + raise RuntimeError(cls._REQUIRES_GRAD_ERROR) + + @staticmethod + def _check_non_tensor(arg): + if isinstance(arg, torch.Tensor): + raise ValueError( + "All tensors must be passed in the tensordict, not as arg or kwarg." + ) + + def __call__(self, *args, **kwargs): + return self._call_func(*args, **kwargs) + + def __repr__(self): + module = indent(f"module={self.module}", 4 * " ") + warmup = warmup = {self._warmup} + in_keys = indent(f"in_keys={self.in_keys}", 4 * " ") + out_keys = indent(f"out_keys={self.out_keys}", 4 * " ") + return f"{self.__class__.__name__}(\n{module}, \n{warmup}, \n{in_keys}, \n{out_keys}\n)" diff --git a/tensordict/nn/functional_modules.py b/tensordict/nn/functional_modules.py index 5d84e7914..4719fc9ed 100644 --- a/tensordict/nn/functional_modules.py +++ b/tensordict/nn/functional_modules.py @@ -6,6 +6,7 @@ from __future__ import annotations import inspect +import os import re import types import warnings @@ -21,7 +22,7 @@ from tensordict._td import TensorDict from tensordict.base import _is_tensor_collection, is_tensor_collection -from tensordict.utils import implement_for +from tensordict.utils import implement_for, strtobool from torch import nn from torch.utils._pytree import SUPPORTED_NODES @@ -139,6 +140,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): for tdtype in PYTREE_REGISTERED_TDS + PYTREE_REGISTERED_LAZY_TDS: SUPPORTED_NODES[tdtype] = self.tdnodes[tdtype] + def set(self): + self.__enter__() + + def unset(self): + self.__exit__(None, None, None) + # Monkey-patch functorch, mainly for cases where a "isinstance(obj, Tensor) is invoked if _has_functorch: @@ -605,3 +612,7 @@ def repopulate_module(model: nn.Module, tensordict: TensorDict) -> nn.Module: """Repopulates a module with its parameters, presented as a nested TensorDict.""" _swap_state(model, tensordict, is_stateless=False) return model + + +if strtobool(os.environ.get("EXCLUDE_TD_FROM_PYTREE", "0")): + _exclude_td_from_pytree().set() diff --git a/test/test_compile.py b/test/test_compile.py index 0ac8964d8..66b1ab667 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -31,7 +31,7 @@ def func(x, y): funcv_c(x, y) -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4") +@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") @pytest.mark.parametrize("mode", [None, "reduce-overhead"]) class TestTD: def test_tensor_output(self, mode): @@ -318,7 +318,7 @@ class MyClass: c: Any = None -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4") +@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") @pytest.mark.parametrize("mode", [None, "reduce-overhead"]) class TestTC: def test_tc_tensor_output(self, mode): @@ -557,7 +557,7 @@ def locked_op(tc): assert (tc_op == tc_op_c).all() -@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4") +@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4") @pytest.mark.parametrize("mode", [None, "reduce-overhead"]) class TestNN: def test_func(self, mode): diff --git a/test/test_nn.py b/test/test_nn.py index 7523d0d35..63f0b7737 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -6,23 +6,27 @@ import argparse import contextlib import copy +import inspect import pickle import unittest import warnings import weakref +from typing import Callable import pytest import torch - from tensordict import ( LazyStackedTensorDict, NonTensorData, NonTensorStack, + PYTREE_REGISTERED_LAZY_TDS, + PYTREE_REGISTERED_TDS, tensorclass, TensorDict, ) from tensordict._C import unravel_key_list from tensordict.nn import ( + CudaGraphModule, dispatch, probabilistic as nn_probabilistic, ProbabilisticTensorDictModule, @@ -39,7 +43,11 @@ ) from tensordict.nn.distributions.composite import CompositeDistribution from tensordict.nn.ensemble import EnsembleModule -from tensordict.nn.functional_modules import is_functional, make_functional +from tensordict.nn.functional_modules import ( + _exclude_td_from_pytree, + is_functional, + make_functional, +) from tensordict.nn.probabilistic import InteractionType, set_interaction_type from tensordict.nn.utils import ( _set_auto_make_functional, @@ -50,7 +58,7 @@ from torch import distributions, nn from torch.distributions import Normal -from torch.utils._pytree import tree_map +from torch.utils._pytree import SUPPORTED_NODES, tree_map try: import functorch # noqa @@ -71,6 +79,7 @@ except ImportError: from tensordict.utils import Buffer +TORCH_VERSION = torch.__version__ # Capture all warnings pytestmark = [ @@ -3818,6 +3827,180 @@ def __setattr__(self, key, value): assert (TensorDict.from_module(linear) == params).all() +@pytest.mark.skipif(TORCH_VERSION <= "2.4.1", reason="requires torch>=2.5") +@pytest.mark.parametrize("compiled", [False, True]) +class TestCudaGraphs: + @pytest.fixture(scope="class", autouse=True) + def _set_cuda_device(self): + device = torch.get_default_device() + do_unset = False + for tdtype in PYTREE_REGISTERED_TDS + PYTREE_REGISTERED_LAZY_TDS: + if tdtype in SUPPORTED_NODES: + do_unset = True + excluder = _exclude_td_from_pytree() + excluder.set() + break + if torch.cuda.is_available(): + torch.set_default_device("cuda:0") + yield + if do_unset: + excluder.unset() + torch.set_default_device(device) + + def test_cudagraphs_random(self, compiled): + def func(x): + return x + torch.randn_like(x) + + if compiled: + func = torch.compile(func) + + with ( + pytest.warns(UserWarning) + if not torch.cuda.is_available() + else contextlib.nullcontext() + ): + func = CudaGraphModule(func) + + x = torch.randn(10) + for _ in range(10): + func(x) + assert isinstance(func(torch.zeros(10)), torch.Tensor) + assert (func(torch.zeros(10)) != 0).any() + y0 = func(x) + y1 = func(x + 1) + with pytest.raises(AssertionError): + torch.testing.assert_close(y0, y1 + 1) + + @staticmethod + def _make_cudagraph( + func: Callable, compiled: bool, *args, **kwargs + ) -> CudaGraphModule: + if compiled: + func = torch.compile(func) + with ( + pytest.warns(UserWarning) + if not torch.cuda.is_available() + else contextlib.nullcontext() + ): + func = CudaGraphModule(func, *args, **kwargs) + return func + + @staticmethod + def check_types(func, *args, **kwargs): + signature = inspect.signature(func) + bound_args = signature.bind(*args, **kwargs) + bound_args.apply_defaults() + for param_name, param in signature.parameters.items(): + arg_value = bound_args.arguments[param_name] + if param.annotation != param.empty: + if not isinstance(arg_value, param.annotation): + raise TypeError( + f"Argument '{param_name}' should be of type {param.annotation}, but is of type {type(arg_value)}" + ) + + def test_signature(self, compiled): + if compiled: + pytest.skip() + + def func(x: torch.Tensor): + return x + torch.randn_like(x) + + with pytest.raises(TypeError): + self.check_types(func, "a string") + self.check_types(func, torch.ones(())) + + def test_backprop(self, compiled): + x = torch.nn.Parameter(torch.ones(3)) + y = torch.nn.Parameter(torch.ones(3)) + optimizer = torch.optim.SGD([x, y], lr=1) + + def func(): + optimizer.zero_grad() + z = x + y + z = z.sum() + z.backward() + optimizer.step() + + func = self._make_cudagraph(func, compiled, warmup=4) + + for i in range(1, 11): + torch.compiler.cudagraph_mark_step_begin() + func() + + assert (x == 1 - i).all(), i + assert (y == 1 - i).all(), i + # assert (x.grad == 1).all() + # assert (y.grad == 1).all() + + def test_tdmodule(self, compiled): + tdmodule = TensorDictModule(lambda x: x + 1, in_keys=["x"], out_keys=["y"]) + tdmodule = self._make_cudagraph(tdmodule, compiled) + assert tdmodule._is_tensordict_module + for i in range(10): + td = TensorDict(x=torch.randn(())) + tdmodule(td) + assert td["y"] == td["x"] + 1, i + + tdmodule = TensorDictModule(lambda x: x + 1, in_keys=["x"], out_keys=["y"]) + tdmodule = self._make_cudagraph(tdmodule, compiled) + assert tdmodule._is_tensordict_module + for _ in range(10): + x = torch.randn(()) + y = tdmodule(x=x) + assert y == x + 1 + + tdmodule = TensorDictModule(lambda x: x + 1, in_keys=["x"], out_keys=["y"]) + tdmodule = self._make_cudagraph(tdmodule, compiled) + assert tdmodule._is_tensordict_module + for _ in range(10): + td = TensorDict(x=torch.randn(())) + tdout = TensorDict() + tdmodule(td, tensordict_out=tdout) + assert tdout is not td + assert "x" not in tdout + assert tdout["y"] == td["x"] + 1 + + tdmodule = lambda td: td.set("y", td.get("x") + 1) + tdmodule = self._make_cudagraph(tdmodule, compiled, in_keys=[], out_keys=[]) + assert tdmodule._is_tensordict_module + for i in range(10): + td = TensorDict(x=torch.randn(())) + tdmodule(td) + assert tdmodule._out_matches_in + if i >= tdmodule._warmup and torch.cuda.is_available(): + assert tdmodule._selected_keys == ["y"] + assert td["y"] == td["x"] + 1 + + tdmodule = lambda td: td.set("y", td.get("x") + 1) + tdmodule = self._make_cudagraph( + tdmodule, compiled, in_keys=["x"], out_keys=["y"] + ) + assert tdmodule._is_tensordict_module + for _ in range(10): + td = TensorDict(x=torch.randn(())) + tdmodule(td) + assert td["y"] == td["x"] + 1 + + tdmodule = lambda td: td.copy().set("y", td.get("x") + 1) + tdmodule = self._make_cudagraph(tdmodule, compiled, in_keys=[], out_keys=[]) + assert tdmodule._is_tensordict_module + for _ in range(10): + td = TensorDict(x=torch.randn(())) + tdout = tdmodule(td) + assert tdout is not td + assert "y" not in td + assert tdout["y"] == td["x"] + 1 + + def test_td_input_non_tdmodule(self, compiled): + func = lambda x: x + 1 + func = self._make_cudagraph(func, compiled) + for i in range(10): + td = TensorDict(a=1) + func(td) + if i == 5: + assert not func._is_tensordict_module + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) From 48d52d20d149a290e50e7af81b5c09a7a36db13e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 16 Sep 2024 15:45:27 -0700 Subject: [PATCH 10/11] [Feature] Propagate `existsok` in memmap* methods ghstack-source-id: 6dcab0ff5e2ae2bb9b8d3bbf18cfb524c51d144d Pull Request resolved: https://github.com/pytorch/tensordict/pull/990 --- tensordict/_lazy.py | 4 ++++ tensordict/_td.py | 20 ++++++++++++++++---- tensordict/base.py | 16 ++++++++++++++++ tensordict/memmap.py | 26 ++++++++++++++------------ tensordict/persistent.py | 4 +++- tensordict/tensorclass.py | 6 ++++++ 6 files changed, 59 insertions(+), 17 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index c5f79f337..bc12dc015 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -2458,6 +2458,7 @@ def _memmap_( inplace=True, like=False, share_non_tensor, + existsok, ) -> T: if prefix is not None: prefix = Path(prefix) @@ -2489,6 +2490,7 @@ def save_metadata(prefix=prefix, self=self): inplace=inplace, like=like, share_non_tensor=share_non_tensor, + existsok=existsok, ) ) if not inplace: @@ -3526,6 +3528,7 @@ def _memmap_( inplace, like, share_non_tensor, + existsok, ) -> T: def save_metadata(data: TensorDictBase, filepath, metadata=None): if metadata is None: @@ -3558,6 +3561,7 @@ def save_metadata(data: TensorDictBase, filepath, metadata=None): inplace=inplace, like=like, share_non_tensor=share_non_tensor, + existsok=existsok, ) if not inplace: dest = type(self)( diff --git a/tensordict/_td.py b/tensordict/_td.py index 0d9cff420..72e3efbe3 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -2535,6 +2535,7 @@ def _memmap_( inplace, like, share_non_tensor, + existsok, ) -> T: if prefix is not None: @@ -2569,6 +2570,7 @@ def _memmap_( inplace=inplace, like=like, share_non_tensor=share_non_tensor, + existsok=existsok, ) if prefix is not None: _update_metadata( @@ -2585,6 +2587,7 @@ def _memmap_( copy_existing=copy_existing, prefix=prefix, like=like, + existsok=existsok, ) else: futures.append( @@ -2596,6 +2599,7 @@ def _memmap_( copy_existing=copy_existing, prefix=prefix, like=like, + existsok=existsok, ) ) if prefix is not None: @@ -2847,7 +2851,12 @@ def make_memmap_from_storage( return memmap_tensor def make_memmap_from_tensor( - self, key: NestedKey, tensor: torch.Tensor, *, copy_data: bool = True + self, + key: NestedKey, + tensor: torch.Tensor, + *, + copy_data: bool = True, + existsok: bool = True, ) -> MemoryMappedTensor: if not self.is_memmap(): raise RuntimeError( @@ -2876,6 +2885,7 @@ def make_memmap_from_tensor( copy_existing=True, prefix=last_node._memmap_prefix, like=not copy_data, + existsok=existsok, ) _update_metadata( metadata=metadata, @@ -3906,6 +3916,7 @@ def _memmap_( inplace, like, share_non_tensor, + existsok, ) -> T: if prefix is not None: @@ -3936,6 +3947,7 @@ def save_metadata(prefix=prefix, self=self): inplace=inplace, like=like, share_non_tensor=share_non_tensor, + existsok=existsok, ) if not inplace: result = _SubTensorDict(_source, idx=self.idx) @@ -4404,7 +4416,7 @@ def _save_metadata(data: TensorDictBase, prefix: Path, metadata=None): # user did specify location and memmap is in wrong place, so we copy -def _populate_memmap(*, dest, value, key, copy_existing, prefix, like): +def _populate_memmap(*, dest, value, key, copy_existing, prefix, like, existsok): filename = None if prefix is None else str(prefix / f"{key}.memmap") if value.is_nested: shape = value._nested_tensor_size() @@ -4416,7 +4428,7 @@ def _populate_memmap(*, dest, value, key, copy_existing, prefix, like): shape, filename=shape_filename, copy_existing=copy_existing, - existsok=True, + existsok=existsok, copy_data=True, ) else: @@ -4425,9 +4437,9 @@ def _populate_memmap(*, dest, value, key, copy_existing, prefix, like): value.data if value.requires_grad else value, filename=filename, copy_existing=copy_existing, - existsok=True, copy_data=not like, shape=shape, + existsok=existsok, ) dest._tensordict[key] = memmap_tensor return memmap_tensor diff --git a/tensordict/base.py b/tensordict/base.py index 51d5180a7..3b326caaa 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3211,6 +3211,7 @@ def _memmap_( inplace, like, share_non_tensor, + existsok, ) -> T: ... def densify(self, layout: torch.layout = torch.strided): @@ -3743,6 +3744,7 @@ def memmap_( num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, + existsok: bool = True, ) -> T: """Writes all tensors onto a corresponding memory-mapped Tensor, in-place. @@ -3767,6 +3769,8 @@ def memmap_( on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults to ``False``. + existsok (bool, optional): if ``False``, an exception will be raised if a tensor already + exists in the same path. Defaults to ``True``. The TensorDict is then locked, meaning that any writing operations that isn't in-place will throw an exception (eg, rename, set or remove an @@ -3799,6 +3803,7 @@ def memmap_( inplace=True, like=False, share_non_tensor=share_non_tensor, + existsok=existsok, ) if not return_early: concurrent.futures.wait(futures) @@ -3813,6 +3818,7 @@ def memmap_( executor=None, like=False, share_non_tensor=share_non_tensor, + existsok=existsok, ).lock_() @abc.abstractmethod @@ -3935,6 +3941,7 @@ def memmap( num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, + existsok: bool = True, ) -> T: """Writes all tensors onto a corresponding memory-mapped Tensor in a new tensordict. @@ -3958,6 +3965,8 @@ def memmap( on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults to ``False``. + existsok (bool, optional): if ``False``, an exception will be raised if a tensor already + exists in the same path. Defaults to ``True``. The TensorDict is then locked, meaning that any writing operations that isn't in-place will throw an exception (eg, rename, set or remove an @@ -3992,6 +4001,7 @@ def memmap( inplace=False, like=False, share_non_tensor=share_non_tensor, + existsok=existsok, ) if not return_early: concurrent.futures.wait(futures) @@ -4007,6 +4017,7 @@ def memmap( like=False, futures=None, share_non_tensor=share_non_tensor, + existsok=existsok, ).lock_() def memmap_like( @@ -4014,6 +4025,7 @@ def memmap_like( prefix: str | None = None, copy_existing: bool = False, *, + existsok: bool = True, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, @@ -4040,6 +4052,8 @@ def memmap_like( on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults to ``False``. + existsok (bool, optional): if ``False``, an exception will be raised if a tensor already + exists in the same path. Defaults to ``True``. The TensorDict is then locked, meaning that any writing operations that isn't in-place will throw an exception (eg, rename, set or remove an @@ -4089,6 +4103,7 @@ def memmap_like( inplace=False, like=True, share_non_tensor=share_non_tensor, + existsok=existsok, ) if not return_early: concurrent.futures.wait(futures) @@ -4106,6 +4121,7 @@ def memmap_like( executor=None, futures=None, share_non_tensor=share_non_tensor, + existsok=existsok, ).lock_() @classmethod diff --git a/tensordict/memmap.py b/tensordict/memmap.py index f256aba81..9ffe14d46 100644 --- a/tensordict/memmap.py +++ b/tensordict/memmap.py @@ -134,12 +134,12 @@ def from_tensor( cls, input, *, - filename=None, - existsok=False, - copy_existing=False, - copy_data=True, - shape=None, - ): + filename: Path | str = None, + existsok: bool = False, + copy_existing: bool = False, + copy_data: bool = True, + shape: torch.Size | None = None, + ): # noqa: D417 """Creates a MemoryMappedTensor with the same content as another tensor. If the tensor is already a MemoryMappedTensor the original tensor is @@ -149,6 +149,8 @@ def from_tensor( Args: input (torch.Tensor): the tensor which content must be copied onto the MemoryMappedTensor. + + Keyword Args: filename (path to a file): the path to the file where the tensor should be stored. If none is provided, a file handler is used instead. @@ -280,12 +282,12 @@ def from_storage( cls, storage, *, - shape=None, - dtype=None, - device=None, - index=None, - filename=None, - handler=None, + shape: torch.Size | None = None, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + index: IndexType | None = None, + filename: Path | str = None, + handler: _handler = None, ): if getattr(storage, "filename", None) is not None: if filename is None: diff --git a/tensordict/persistent.py b/tensordict/persistent.py index 2e2c37710..4da8c118d 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -701,6 +701,7 @@ def _memmap_( inplace, like, share_non_tensor, + existsok, ) -> T: if inplace: raise RuntimeError("Cannot call memmap inplace in a persistent tensordict.") @@ -749,6 +750,7 @@ def save_metadata(data: TensorDictBase, filepath, metadata=None): futures=futures, inplace=inplace, share_non_tensor=share_non_tensor, + existsok=existsok, ), inplace=False, validated=True, @@ -776,7 +778,7 @@ def _populate( ), copy_data=not like, copy_existing=copy_existing, - existsok=True, + existsok=existsok, ) tensordict._set_str( key, diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index aaccb8869..0f8e72cb5 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -953,6 +953,7 @@ def _memmap_( like=False, memmaped: bool = False, share_non_tensor: bool = False, + existsok: bool = True, ): _non_tensordict = dict(self._non_tensordict) cls = type(self) @@ -997,6 +998,7 @@ def save_metadata(cls=cls, _non_tensordict=_non_tensordict, prefix=prefix): like=like, copy_existing=copy_existing, share_non_tensor=share_non_tensor, + existsok=existsok, ) if new_futures: futures += new_futures @@ -2816,6 +2818,7 @@ def _memmap_( like=False, memmaped: bool = False, share_non_tensor: bool = False, + existsok: bool = True, ): # For efficiency, we can avoid doing this saving # if the data is already there. @@ -2842,6 +2845,7 @@ def _memmap_( like=like, memmaped=memmaped, share_non_tensor=share_non_tensor, + existsok=existsok, ) _metadata["_share_non_tensor"] = share_non_tensor out._non_tensordict["_metadata"] = _metadata @@ -2967,6 +2971,7 @@ def _memmap_( like=False, memmaped: bool = False, share_non_tensor: bool = False, + existsok: bool = True, ) -> T: memmaped_leaves = memmaped @@ -3013,6 +3018,7 @@ def save_metadata(prefix=prefix, self=self): # no memmapping should be executed memmaped=memmaped_leaves, share_non_tensor=share_non_tensor, + existsok=existsok, ) ) if not inplace: From 257c8599a03186881cdf8380c8bfb1394700d035 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 16 Sep 2024 15:56:14 -0700 Subject: [PATCH 11/11] [Feature] sorted keys, values and items ghstack-source-id: af5b112bc5a3dbabf5c490026f883bac89a6a052 Pull Request resolved: https://github.com/pytorch/tensordict/pull/965 --- tensordict/_lazy.py | 40 +++++++-- tensordict/_td.py | 149 +++++++++++++++++--------------- tensordict/base.py | 177 +++++++++++++++++++++++++-------------- tensordict/nn/params.py | 8 +- tensordict/persistent.py | 3 + test/test_tensordict.py | 31 +++++++ 6 files changed, 271 insertions(+), 137 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index bc12dc015..a4969ce03 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -1550,22 +1550,35 @@ def keys( include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, ) -> _LazyStackedTensorDictKeysView: keys = _LazyStackedTensorDictKeysView( self, include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf, + sort=sort, ) return keys - def values(self, include_nested=False, leaves_only=False, is_leaf=None): + def values( + self, + include_nested=False, + leaves_only=False, + is_leaf=None, + *, + sort: bool = False, + ): if is_leaf not in ( _NESTED_TENSORS_AS_LISTS, _NESTED_TENSORS_AS_LISTS_NONTENSOR, ): yield from super().values( - include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + sort=sort, ) else: for td in self.tensordicts: @@ -1573,15 +1586,26 @@ def values(self, include_nested=False, leaves_only=False, is_leaf=None): include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf, + sort=sort, ) - def items(self, include_nested=False, leaves_only=False, is_leaf=None): + def items( + self, + include_nested=False, + leaves_only=False, + is_leaf=None, + *, + sort: bool = False, + ): if is_leaf not in ( _NESTED_TENSORS_AS_LISTS, _NESTED_TENSORS_AS_LISTS_NONTENSOR, ): yield from super().items( - include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + sort=sort, ) else: for i, td in enumerate(self.tensordicts): @@ -1589,6 +1613,7 @@ def items(self, include_nested=False, leaves_only=False, is_leaf=None): include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf, + sort=sort, ): if isinstance(key, str): key = (str(i), key) @@ -3383,9 +3408,14 @@ def keys( include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, ) -> _TensorDictKeysView: return self._source.keys( - include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + sort=sort, ) def _select( diff --git a/tensordict/_td.py b/tensordict/_td.py index 72e3efbe3..376c93b97 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -3154,12 +3154,23 @@ def keys( include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, ) -> _TensorDictKeysView: if not include_nested and not leaves_only and is_leaf is None: - return _StringKeys(self._tensordict.keys()) + if not sort: + return _StringKeys(self._tensordict.keys()) + else: + return sorted( + _StringKeys(self._tensordict.keys()), + key=lambda x: ".".join(x) if isinstance(x, tuple) else x, + ) else: return self._nested_keys( - include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + sort=sort, ) @cache # noqa: B019 @@ -3168,12 +3179,15 @@ def _nested_keys( include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, ) -> _TensorDictKeysView: return _TensorDictKeysView( self, include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf, + sort=sort, ) # some custom methods for efficiency @@ -3182,65 +3196,44 @@ def items( include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, ) -> Iterator[tuple[str, CompatibleType]]: if not include_nested and not leaves_only: - return self._tensordict.items() - elif include_nested and leaves_only: + if not sort: + return self._tensordict.items() + return sorted(self._tensordict.items(), key=lambda x: x[0]) + elif include_nested and leaves_only and not sort: is_leaf = _default_is_leaf if is_leaf is None else is_leaf result = [] - if is_dynamo_compiling(): - - def fast_iter(): - for key, val in self._tensordict.items(): - if not is_leaf(type(val)): - for _key, _val in val.items( - include_nested=include_nested, - leaves_only=leaves_only, - is_leaf=is_leaf, - ): - result.append( - ( - ( - key, - *( - (_key,) - if isinstance(_key, str) - else _key - ), - ), - _val, - ) - ) - else: - result.append((key, val)) - return result - else: - # dynamo doesn't like generators - def fast_iter(): - for key, val in self._tensordict.items(): - if not is_leaf(type(val)): - yield from ( - ( - ( - key, - *((_key,) if isinstance(_key, str) else _key), - ), - _val, - ) - for _key, _val in val.items( - include_nested=include_nested, - leaves_only=leaves_only, - is_leaf=is_leaf, - ) - ) - else: - yield (key, val) + def fast_iter(): + for key, val in self._tensordict.items(): + # We could easily make this faster, here we're iterating twice over the keys, + # but we could iterate just once. + # Ideally we should make a "dirty" list of items then call unravel_key on all of them. + if not is_leaf(type(val)): + for _key, _val in val.items( + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + ): + if isinstance(_key, str): + _key = (key, _key) + else: + _key = (key, *_key) + result.append((_key, _val)) + else: + result.append((key, val)) + return result return fast_iter() else: return super().items( - include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + sort=sort, ) def values( @@ -3248,15 +3241,23 @@ def values( include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, ) -> Iterator[tuple[str, CompatibleType]]: if not include_nested and not leaves_only: - return self._tensordict.values() + if not sort: + return self._tensordict.values() + else: + return list(zip(*sorted(self._tensordict.items(), key=lambda x: x[0])))[ + 1 + ] else: return TensorDictBase.values( self, include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf, + sort=sort, ) @@ -3545,9 +3546,14 @@ def keys( include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, ) -> _TensorDictKeysView: return self._source.keys( - include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + sort=sort, ) def entry_class(self, key: NestedKey) -> type: @@ -4184,6 +4190,7 @@ def __init__( include_nested: bool, leaves_only: bool, is_leaf: Callable[[Type], bool] = None, + sort: bool = False, ) -> None: self.tensordict = tensordict self.include_nested = include_nested @@ -4191,22 +4198,32 @@ def __init__( if is_leaf is None: is_leaf = _default_is_leaf self.is_leaf = is_leaf + self.sort = sort def __iter__(self) -> Iterable[str] | Iterable[tuple[str, ...]]: - if not self.include_nested: - if self.leaves_only: - for key in self._keys(): - target_class = self.tensordict.entry_class(key) - if _is_tensor_collection(target_class): - continue - yield key + def _iter(): + if not self.include_nested: + if self.leaves_only: + for key in self._keys(): + target_class = self.tensordict.entry_class(key) + if _is_tensor_collection(target_class): + continue + yield key + else: + yield from self._keys() else: - yield from self._keys() - else: - yield from ( - key if len(key) > 1 else key[0] - for key in self._iter_helper(self.tensordict) + yield from ( + key if len(key) > 1 else key[0] + for key in self._iter_helper(self.tensordict) + ) + + if self.sort: + yield from sorted( + _iter(), + key=lambda key: ".".join(key) if isinstance(key, tuple) else key, ) + else: + yield from _iter() def _iter_helper( self, tensordict: T, prefix: str | None = None diff --git a/tensordict/base.py b/tensordict/base.py index 3b326caaa..918878227 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -5151,8 +5151,13 @@ def setdefault( return self.get(key) def items( - self, include_nested: bool = False, leaves_only: bool = False, is_leaf=None - ) -> Iterator[tuple[str, CompatibleType]]: + self, + include_nested: bool = False, + leaves_only: bool = False, + is_leaf=None, + *, + sort: bool = False, + ) -> Iterator[tuple[str, CompatibleType]]: # noqa: D417 """Returns a generator of key-value pairs for the tensordict. Args: @@ -5163,46 +5168,64 @@ def items( is_leaf: an optional callable that indicates if a class is to be considered a leaf or not. + Keyword Args: + sort (bool, optional): whether the keys should be sorted. For nested keys, + the keys are sorted according to their joined name (ie, ``("a", "key")`` will + be counted as ``"a.key"`` for sorting). Be mindful that sorting may incur + significant overhead when dealing with large tensordicts. + Defaults to ``False``. + """ if is_leaf is None: is_leaf = _default_is_leaf - # check the conditions once only - if include_nested and leaves_only: - for k in self.keys(): - val = self._get_str(k, NO_DEFAULT) - if not is_leaf(type(val)): - yield from ( - (_unravel_key_to_tuple((k, _key)), _val) - for _key, _val in val.items( - include_nested=include_nested, - leaves_only=leaves_only, - is_leaf=is_leaf, + def _items(): + if include_nested and leaves_only: + # check the conditions once only + for k in self.keys(): + val = self._get_str(k, NO_DEFAULT) + if not is_leaf(type(val)): + yield from ( + (_unravel_key_to_tuple((k, _key)), _val) + for _key, _val in val.items( + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + ) ) - ) - else: + else: + yield k, val + elif include_nested: + for k in self.keys(): + val = self._get_str(k, NO_DEFAULT) yield k, val - elif include_nested: - for k in self.keys(): - val = self._get_str(k, NO_DEFAULT) - yield k, val - if not is_leaf(type(val)): - yield from ( - (_unravel_key_to_tuple((k, _key)), _val) - for _key, _val in val.items( - include_nested=include_nested, - leaves_only=leaves_only, - is_leaf=is_leaf, + if not is_leaf(type(val)): + yield from ( + (_unravel_key_to_tuple((k, _key)), _val) + for _key, _val in val.items( + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + ) ) - ) - elif leaves_only: - for k in self.keys(): - val = self._get_str(k, NO_DEFAULT) - if is_leaf(type(val)): - yield k, val + elif leaves_only: + for k in self.keys(): + val = self._get_str(k, NO_DEFAULT) + if is_leaf(type(val)): + yield k, val + else: + for k in self.keys(): + yield k, self._get_str(k, NO_DEFAULT) + + if sort: + yield from sorted( + _items(), + key=lambda item: ( + item[0] if isinstance(item[0], str) else ".".join(item[0]) + ), + ) else: - for k in self.keys(): - yield k, self._get_str(k, NO_DEFAULT) + yield from _items() def non_tensor_items(self, include_nested: bool = False): """Returns all non-tensor leaves, maybe recursively.""" @@ -5219,7 +5242,9 @@ def values( include_nested: bool = False, leaves_only: bool = False, is_leaf=None, - ) -> Iterator[CompatibleType]: + *, + sort: bool = False, + ) -> Iterator[CompatibleType]: # noqa: D417 """Returns a generator representing the values for the tensordict. Args: @@ -5230,39 +5255,54 @@ def values( is_leaf: an optional callable that indicates if a class is to be considered a leaf or not. + Keyword Args: + sort (bool, optional): whether the keys should be sorted. For nested keys, + the keys are sorted according to their joined name (ie, ``("a", "key")`` will + be counted as ``"a.key"`` for sorting). Be mindful that sorting may incur + significant overhead when dealing with large tensordicts. + Defaults to ``False``. + """ if is_leaf is None: is_leaf = _default_is_leaf - # check the conditions once only - if include_nested and leaves_only: - for k in self.keys(): - val = self._get_str(k, NO_DEFAULT) - if not is_leaf(type(val)): - yield from val.values( - include_nested=include_nested, - leaves_only=leaves_only, - is_leaf=is_leaf, - ) - else: - yield val - elif include_nested: - for k in self.keys(): - val = self._get_str(k, NO_DEFAULT) - yield val - if not is_leaf(type(val)): - yield from val.values( - include_nested=include_nested, - leaves_only=leaves_only, - is_leaf=is_leaf, - ) - elif leaves_only: - for k in self.keys(): - val = self._get_str(k, NO_DEFAULT) - if is_leaf(type(val)): + + def _values(): + # check the conditions once only + if include_nested and leaves_only: + for k in self.keys(): + val = self._get_str(k, NO_DEFAULT) + if not is_leaf(type(val)): + yield from val.values( + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + ) + else: + yield val + elif include_nested: + for k in self.keys(): + val = self._get_str(k, NO_DEFAULT) yield val + if not is_leaf(type(val)): + yield from val.values( + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + ) + elif leaves_only: + for k in self.keys(sort=sort): + val = self._get_str(k, NO_DEFAULT) + if is_leaf(type(val)): + yield val + else: + for k in self.keys(sort=sort): + yield self._get_str(k, NO_DEFAULT) + + if not sort or not include_nested: + yield from _values() else: - for k in self.keys(): - yield self._get_str(k, NO_DEFAULT) + for _, value in self.items(include_nested, leaves_only, is_leaf, sort=sort): + yield value @cache # noqa: B019 def _values_list( @@ -5360,7 +5400,9 @@ def keys( self, include_nested: bool = False, leaves_only: bool = False, - is_leaf: Callable[[Type], bool] = None, + is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, ): """Returns a generator of tensordict keys. @@ -5372,6 +5414,13 @@ def keys( is_leaf: an optional callable that indicates if a class is to be considered a leaf or not. + Keyword Args: + sort (bool, optional): whether the keys shoulbe sorted. For nested keys, + the keys are sorted according to their joined name (ie, ``("a", "key")`` will + be counted as ``"a.key"`` for sorting). Be mindful that sorting may incur + significant overhead when dealing with large tensordicts. + Defaults to ``False``. + Examples: >>> from tensordict import TensorDict >>> data = TensorDict({"0": 0, "1": {"2": 2}}, batch_size=[]) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 66e7153a9..2b778815f 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -1014,10 +1014,12 @@ def values( include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, ) -> Iterator[CompatibleType]: if is_leaf is None: is_leaf = _default_is_leaf - for v in self._param_td.values(include_nested, leaves_only): + for v in self._param_td.values(include_nested, leaves_only, sort=sort): if not is_leaf(type(v)): yield v continue @@ -1082,10 +1084,12 @@ def items( include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, ) -> Iterator[CompatibleType]: if is_leaf is None: is_leaf = _default_is_leaf - for k, v in self._param_td.items(include_nested, leaves_only): + for k, v in self._param_td.items(include_nested, leaves_only, sort=sort): if not is_leaf(type(v)): yield k, v continue diff --git a/tensordict/persistent.py b/tensordict/persistent.py index 4da8c118d..54b0375a5 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -471,6 +471,8 @@ def keys( include_nested: bool = False, leaves_only: bool = False, is_leaf: Callable[[Type], bool] | None = None, + *, + sort: bool = False, ) -> _PersistentTDKeysView: if is_leaf not in (None, _default_is_leaf, _is_leaf_nontensor): raise ValueError( @@ -481,6 +483,7 @@ def keys( include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf, + sort=sort, ) def _items_metadata(self, include_nested=False, leaves_only=False): diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 583c592cf..5ff085150 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -2239,6 +2239,37 @@ def assert_shared(td0): td0 = td0.squeeze(0) assert_shared(td0) + def test_sorted_keys(self): + td = TensorDict( + { + "a": {"b": 0, "c": 1}, + "d": 2, + "e": {"f": 3, "g": {"h": 4, "i": 5}, "j": 6}, + } + ) + tdflat = td.flatten_keys() + tdflat["d"] = tdflat.pop("d") + tdflat["a.b"] = tdflat.pop("a.b") + for key1, key2 in zip( + td.keys(True, True, sort=True), tdflat.keys(True, True, sort=True) + ): + if isinstance(key1, str): + assert key1 == key2 + else: + assert ".".join(key1) == key2 + for v1, v2 in zip( + td.values(True, True, sort=True), tdflat.values(True, True, sort=True) + ): + assert v1 == v2 + for (k1, v1), (k2, v2) in zip( + td.items(True, True, sort=True), tdflat.items(True, True, sort=True) + ): + if isinstance(k1, str): + assert k1 == k2 + else: + assert ".".join(k1) == k2 + assert v1 == v2 + def test_split_keys(self): td = TensorDict( {