From 4786d546edd0feef56ab33124cfeb42259f1bcda Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 14 Oct 2024 16:55:36 +0100 Subject: [PATCH] [Performance] Faster clone ghstack-source-id: 6eecbac5e946a4d93d3d6e148e8c18aaa2501b00 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1040 --- tensordict/_td.py | 3 +++ tensordict/base.py | 17 +++++++++++++++++ tensordict/nn/cudagraphs.py | 17 ++++++----------- 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index f56e25052..2843c07b6 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -3009,6 +3009,9 @@ def is_contiguous(self) -> bool: return all([value.is_contiguous() for _, value in self.items()]) def _clone(self, recurse: bool = True) -> T: + if recurse and self.device is not None and self.device.type == "cuda": + return self._clone_recurse() + result = TensorDict._new_unsafe( source={key: _clone_value(value, recurse) for key, value in self.items()}, batch_size=self.batch_size, diff --git a/tensordict/base.py b/tensordict/base.py index f74fd5891..bb8eefe30 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -8156,6 +8156,23 @@ def cosh_(self) -> T: torch._foreach_cosh_(self._values_list(True, True)) return self + def _clone_recurse(self) -> TensorDictBase: # noqa: D417 + keys, vals = self._items_list(True, True) + vals = torch._foreach_add(vals, 0) + items = dict(zip(keys, vals)) + result = self._fast_apply( + lambda name, val: items.pop(name, None), + named=True, + nested_keys=True, + is_leaf=_NESTED_TENSORS_AS_LISTS, + propagate_lock=True, + filter_empty=True, + default=None, + ) + if items: + result.update(items) + return result + def add( self, other: TensorDictBase | torch.Tensor, diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py index e99236b48..f1980a57a 100644 --- a/tensordict/nn/cudagraphs.py +++ b/tensordict/nn/cudagraphs.py @@ -302,14 +302,11 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor): torch._foreach_copy_(dests, srcs) torch.cuda.synchronize() self.graph.replay() - if self._return_unchanged == "clone": - result = self._out.clone() - elif self._return_unchanged: + if self._return_unchanged: result = self._out else: - result = tree_map( - lambda x: x.detach().clone() if x is not None else x, - self._out, + result = tree_unflatten( + torch._foreach_add(self._out, 0.0), self._out_struct ) return result @@ -340,7 +337,7 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor): self.graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.graph): out = self.module(*self._args, **self._kwargs) - self._out = out + self._out, self._out_struct = tree_flatten(out) self.counter += 1 # Check that there is not intersection between the indentity of inputs and outputs, otherwise warn # user. @@ -356,10 +353,8 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor): f"and the identity between input and output will not match anymore. " f"Make sure you don't rely on input-output identity further in the code." ) - if isinstance(self._out, torch.Tensor) or self._out is None: - self._return_unchanged = ( - "clone" if self._out is not None else True - ) + if not self._out: + self._return_unchanged = True else: self._return_unchanged = False return this_out