Skip to content

Commit

Permalink
[Refactor] Make all leaves in tensorclass part of _tensordict, exce…
Browse files Browse the repository at this point in the history
…pt for NonTensorData (#841)
  • Loading branch information
vmoens authored Jul 3, 2024
1 parent 4838205 commit 5c34868
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 140 deletions.
54 changes: 32 additions & 22 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,15 +794,16 @@ def __setitem__(
) from err

keys = set(self.keys())
if any(key not in keys for key in value.keys()):
subtd = self._get_sub_tensordict(index)
for key, item in value.items():
if key in keys:
subtd = None
for value_key, item in value.items():
if value_key in keys:
self._set_at_str(
key, item, index, validated=False, non_blocking=False
value_key, item, index, validated=False, non_blocking=False
)
else:
subtd.set(key, item, inplace=True, non_blocking=False)
if subtd is None:
subtd = self._get_sub_tensordict(index)
subtd.set(value_key, item, inplace=True, non_blocking=False)
else:
for key in self.keys():
self.set_at_(key, value, index)
Expand Down Expand Up @@ -2083,12 +2084,35 @@ def _set_tuple(
)
return self

_SHARED_INPLACE_ERROR = (
"You're attempting to update a leaf in-place with a shared "
"tensordict, but the new value does not match the previous. "
"If you're using NonTensorData, see the class documentation "
"to see how to properly pre-allocate memory in shared contexts."
)

def _set_at_str(self, key, value, idx, *, validated, non_blocking: bool):
if not validated:
value = self._validate_value(value, check_shape=False)
validated = True
tensor_in = self._get_str(key, NO_DEFAULT)

if is_non_tensor(value) and not (self._is_shared or self._is_memmap):
dest = self._get_str(key, NO_DEFAULT)
is_diff = dest[idx].tolist() != value.tolist()
if is_diff:
dest_val = dest.maybe_to_stack()
dest_val[idx] = value
if dest_val is not dest:
self._set_str(
key,
dest_val,
validated=True,
inplace=False,
ignore_lock=True,
)
return

if isinstance(idx, tuple) and len(idx) and isinstance(idx[0], tuple):
warn(
"Multiple indexing can lead to unexpected behaviours when "
Expand All @@ -2103,12 +2127,7 @@ def _set_at_str(self, key, value, idx, *, validated, non_blocking: bool):
)
if tensor_in is not tensor_out:
if self._is_shared or self._is_memmap:
raise RuntimeError(
"You're attempting to update a leaf in-place with a shared "
"tensordict, but the new value does not match the previous. "
"If you're using NonTensorData, see the class documentation "
"to see how to properly pre-allocate memory in shared contexts."
)
raise RuntimeError(self._SHARED_INPLACE_ERROR)
# this happens only when a NonTensorData becomes a NonTensorStack
# so it is legitimate (there is no in-place modification of a tensor
# that was expected to happen but didn't).
Expand Down Expand Up @@ -2301,16 +2320,7 @@ def _memmap_(
raise RuntimeError(
"memmap and shared memory are mutually exclusive features."
)
dest = (
self
if inplace
else TensorDict(
{},
batch_size=self.batch_size,
names=self.names if self._has_names() else None,
device=torch.device("cpu"),
)
)
dest = self if inplace else self.empty(device=torch.device("cpu"))

# We must set these attributes before memmapping because we need the metadata
# to match the tensordict content.
Expand Down
Loading

2 comments on commit 5c34868

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 5c34868 Previous: 4838205 Ratio
benchmarks/tensorclass/test_tensorclass_speed.py::test_tc_first_layer_tensor 120326.58019806597 iter/sec (stddev: 4.4986387165725007e-7) 1373832.594290356 iter/sec (stddev: 4.8530099191791996e-8) 11.42
benchmarks/tensorclass/test_tensorclass_speed.py::test_tc_first_layer_nontensor 120210.26347896754 iter/sec (stddev: 4.5650937546949285e-7) 1361545.4841384036 iter/sec (stddev: 5.995081721941169e-8) 11.33
benchmarks/tensorclass/test_tensorclass_speed.py::test_tc_second_layer_nontensor 108153.82005496862 iter/sec (stddev: 5.150928083308207e-7) 581376.6788713622 iter/sec (stddev: 1.8062761970691373e-7) 5.38
benchmarks/tensorclass/test_torch_functions.py::test_unbind 68.29561176745676 iter/sec (stddev: 0.020064073685849282) 187.90133432174042 iter/sec (stddev: 0.00008771959415108646) 2.75
benchmarks/tensorclass/test_torch_functions.py::test_split 3618.8562636464767 iter/sec (stddev: 0.000020585293066090808) 8582.488348172195 iter/sec (stddev: 0.000006168920741504216) 2.37

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'GPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 5c34868 Previous: 4838205 Ratio
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_last 103333.64591645026 iter/sec (stddev: 6.220006522992922e-7) 328181.3367510252 iter/sec (stddev: 3.1744912231902153e-7) 3.18
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_leaf_last 103609.51088710976 iter/sec (stddev: 5.99423427460808e-7) 324154.23471058556 iter/sec (stddev: 3.2223949422480034e-7) 3.13
benchmarks/tensorclass/test_tensorclass_speed.py::test_tc_init 18003.98459803785 iter/sec (stddev: 0.00000301726293045441) 38710.2270438557 iter/sec (stddev: 0.000004482695150123839) 2.15
benchmarks/tensorclass/test_tensorclass_speed.py::test_tc_first_layer_tensor 271011.4141052725 iter/sec (stddev: 3.6895030942722313e-7) 2771555.8518061205 iter/sec (stddev: 6.661000960912629e-8) 10.23
benchmarks/tensorclass/test_tensorclass_speed.py::test_tc_first_layer_nontensor 268910.2192190714 iter/sec (stddev: 7.191341820439208e-7) 2572563.2223312114 iter/sec (stddev: 8.882903972538797e-8) 9.57
benchmarks/tensorclass/test_tensorclass_speed.py::test_tc_second_layer_nontensor 233347.68865283718 iter/sec (stddev: 3.575173074145315e-7) 1182043.1014238065 iter/sec (stddev: 2.7427378493623303e-7) 5.07
benchmarks/tensorclass/test_torch_functions.py::test_split 323.51706261163264 iter/sec (stddev: 0.00005106988903723365) 10524.689363859243 iter/sec (stddev: 0.000004529742492512983) 32.53

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.