diff --git a/tensordict/_td.py b/tensordict/_td.py index f56e25052..2843c07b6 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -3009,6 +3009,9 @@ def is_contiguous(self) -> bool: return all([value.is_contiguous() for _, value in self.items()]) def _clone(self, recurse: bool = True) -> T: + if recurse and self.device is not None and self.device.type == "cuda": + return self._clone_recurse() + result = TensorDict._new_unsafe( source={key: _clone_value(value, recurse) for key, value in self.items()}, batch_size=self.batch_size, diff --git a/tensordict/base.py b/tensordict/base.py index f74fd5891..bb8eefe30 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -8156,6 +8156,23 @@ def cosh_(self) -> T: torch._foreach_cosh_(self._values_list(True, True)) return self + def _clone_recurse(self) -> TensorDictBase: # noqa: D417 + keys, vals = self._items_list(True, True) + vals = torch._foreach_add(vals, 0) + items = dict(zip(keys, vals)) + result = self._fast_apply( + lambda name, val: items.pop(name, None), + named=True, + nested_keys=True, + is_leaf=_NESTED_TENSORS_AS_LISTS, + propagate_lock=True, + filter_empty=True, + default=None, + ) + if items: + result.update(items) + return result + def add( self, other: TensorDictBase | torch.Tensor, diff --git a/tensordict/nn/cudagraphs.py b/tensordict/nn/cudagraphs.py index e99236b48..e915953b0 100644 --- a/tensordict/nn/cudagraphs.py +++ b/tensordict/nn/cudagraphs.py @@ -267,6 +267,8 @@ def _call( "The output of the function must be a tensordict, a tensorclass or None. Got " f"type(out)={type(out)}." ) + if is_tensor_collection(out): + out.lock_() self._out = out self.counter += 1 if self._out_matches_in: @@ -302,14 +304,11 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor): torch._foreach_copy_(dests, srcs) torch.cuda.synchronize() self.graph.replay() - if self._return_unchanged == "clone": - result = self._out.clone() - elif self._return_unchanged: + if self._return_unchanged: result = self._out else: - result = tree_map( - lambda x: x.detach().clone() if x is not None else x, - self._out, + result = tree_unflatten( + [out.clone() for out in self._out], self._out_struct ) return result @@ -340,7 +339,7 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor): self.graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.graph): out = self.module(*self._args, **self._kwargs) - self._out = out + self._out, self._out_struct = tree_flatten(out) self.counter += 1 # Check that there is not intersection between the indentity of inputs and outputs, otherwise warn # user. @@ -356,11 +355,10 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor): f"and the identity between input and output will not match anymore. " f"Make sure you don't rely on input-output identity further in the code." ) - if isinstance(self._out, torch.Tensor) or self._out is None: - self._return_unchanged = ( - "clone" if self._out is not None else True - ) + if not self._out: + self._return_unchanged = True else: + self._out = [out.lock_() if is_tensor_collection(out) else out for out in self._out] self._return_unchanged = False return this_out