Skip to content

Commit

Permalink
[Performance] Faster clone
Browse files Browse the repository at this point in the history
ghstack-source-id: 5df15306395f6986e77caed2cfa87b3516a1b134
Pull Request resolved: #1040
  • Loading branch information
vmoens committed Oct 16, 2024
1 parent cfa618a commit db6a861
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 11 deletions.
3 changes: 3 additions & 0 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -3009,6 +3009,9 @@ def is_contiguous(self) -> bool:
return all([value.is_contiguous() for _, value in self.items()])

def _clone(self, recurse: bool = True) -> T:
if recurse and self.device is not None and self.device.type == "cuda":
return self._clone_recurse()

result = TensorDict._new_unsafe(
source={key: _clone_value(value, recurse) for key, value in self.items()},
batch_size=self.batch_size,
Expand Down
17 changes: 17 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8164,6 +8164,23 @@ def cosh_(self) -> T:
torch._foreach_cosh_(self._values_list(True, True))
return self

def _clone_recurse(self) -> TensorDictBase: # noqa: D417
keys, vals = self._items_list(True, True)
vals = torch._foreach_add(vals, 0)
items = dict(zip(keys, vals))
result = self._fast_apply(
lambda name, val: items.pop(name, None),
named=True,
nested_keys=True,
is_leaf=_NESTED_TENSORS_AS_LISTS,
propagate_lock=True,
filter_empty=True,
default=None,
)
if items:
result.update(items)
return result

def add(
self,
other: TensorDictBase | torch.Tensor,
Expand Down
20 changes: 9 additions & 11 deletions tensordict/nn/cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ def _call(
"The output of the function must be a tensordict, a tensorclass or None. Got "
f"type(out)={type(out)}."
)
if is_tensor_collection(out):
out.lock_()
self._out = out
self.counter += 1
if self._out_matches_in:
Expand Down Expand Up @@ -302,14 +304,11 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
torch._foreach_copy_(dests, srcs)
torch.cuda.synchronize()
self.graph.replay()
if self._return_unchanged == "clone":
result = self._out.clone()
elif self._return_unchanged:
if self._return_unchanged:
result = self._out
else:
result = tree_map(
lambda x: x.detach().clone() if x is not None else x,
self._out,
result = tree_unflatten(
[out.clone() for out in self._out], self._out_struct
)
return result

Expand Down Expand Up @@ -340,7 +339,7 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
out = self.module(*self._args, **self._kwargs)
self._out = out
self._out, self._out_struct = tree_flatten(out)
self.counter += 1
# Check that there is not intersection between the indentity of inputs and outputs, otherwise warn
# user.
Expand All @@ -356,11 +355,10 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
f"and the identity between input and output will not match anymore. "
f"Make sure you don't rely on input-output identity further in the code."
)
if isinstance(self._out, torch.Tensor) or self._out is None:
self._return_unchanged = (
"clone" if self._out is not None else True
)
if not self._out:
self._return_unchanged = True
else:
self._out = [out.lock_() if is_tensor_collection(out) else out for out in self._out]
self._return_unchanged = False
return this_out

Expand Down

0 comments on commit db6a861

Please sign in to comment.