diff --git a/tensordict/_reductions.py b/tensordict/_reductions.py index 7234a42bd..aca8a7ef5 100644 --- a/tensordict/_reductions.py +++ b/tensordict/_reductions.py @@ -5,6 +5,7 @@ from __future__ import annotations import copyreg +import queue from multiprocessing.reduction import ForkingPickler import torch @@ -21,7 +22,17 @@ def _rebuild_tensordict_files(flat_key_values, metadata_dict, is_shared: bool = False): + _nt_values_and_keys = queue.Queue() + _nt_lengths = queue.Queue() + _nt_offsets = queue.Queue() + def from_metadata(metadata=metadata_dict, prefix=None): + metadata = dict(metadata) + + _ = metadata.pop("njt_values_start", None) + _ = metadata.pop("njt_lengths_start", None) + _ = metadata.pop("njt_offsets_start", None) + non_tensor = metadata.pop("non_tensors") leaves = metadata.pop("leaves") cls = metadata.pop("cls") @@ -36,22 +47,21 @@ def from_metadata(metadata=metadata_dict, prefix=None): total_key = (key,) if prefix is None else prefix + (key,) if total_key[-1].startswith(""): nested_values = flat_key_values[total_key] - nested_lengths = None + total_key = total_key[:-1] + total_key[-1].replace("", "") + _nt_values_and_keys.put((nested_values, total_key)) continue if total_key[-1].startswith(""): nested_lengths = flat_key_values[total_key] + _nt_lengths.put(nested_lengths) continue elif total_key[-1].startswith("", "") - value = torch.nested.nested_tensor_from_jagged( - nested_values, offsets=offsets, lengths=nested_lengths - ) - del nested_values - del nested_lengths + _nt_offsets.put(offsets) + continue else: value = flat_key_values[total_key] d[key] = value + for k, v in metadata.items(): # Each remaining key is a tuple pointing to a sub-tensordict d[k] = from_metadata( @@ -64,7 +74,18 @@ def from_metadata(metadata=metadata_dict, prefix=None): # result._is_shared = is_shared return result - return from_metadata() + result = from_metadata() + # Then assign the nested tensors + while not _nt_values_and_keys.empty(): + vals, key = _nt_values_and_keys.get() + lengths = _nt_lengths.get() + offsets = _nt_offsets.get() + value = torch.nested.nested_tensor_from_jagged( + vals, offsets=offsets, lengths=lengths + ) + result._set_tuple(key, value, inplace=False, validated=True) + + return result def _rebuild_tensordict_files_shared(flat_key_values, metadata_dict): @@ -75,9 +96,18 @@ def _rebuild_tensordict_files_consolidated( metadata, storage, ): + _nt_values_and_keys = queue.Queue() + _nt_lengths = queue.Queue() + _nt_offsets = queue.Queue() + def from_metadata(metadata=metadata, prefix=None): consolidated = {"storage": storage, "metadata": metadata} metadata = dict(metadata) + + _ = metadata.pop("njt_values_start", None) + _ = metadata.pop("njt_lengths_start", None) + _ = metadata.pop("njt_offsets_start", None) + non_tensor = metadata.pop("non_tensors") leaves = metadata.pop("leaves") cls = metadata.pop("cls") @@ -101,25 +131,26 @@ def from_metadata(metadata=metadata, prefix=None): if key.startswith(""): raise RuntimeError elif key.startswith(""): - nested_values = value - nested_lengths = None + key = key.replace("", "") + if prefix: + total_key = prefix + (key,) + else: + total_key = (key,) + _nt_values_and_keys.put((value, total_key)) continue elif key.startswith(""): - nested_lengths = value + _nt_lengths.put(value) continue elif key.startswith(""): - from torch.nested._internal.nested_tensor import NestedTensor - - offsets = value - value = NestedTensor( - nested_values, offsets=offsets, lengths=nested_lengths - ) - key = key.replace("", "") + _nt_offsets.put(value) + if _nt_offsets.qsize() > _nt_lengths.qsize(): + _nt_lengths.put(None) + continue d[key] = value - for k, v in metadata.items(): + for key, val in metadata.items(): # Each remaining key is a tuple pointing to a sub-tensordict - d[k] = from_metadata( - v, prefix=prefix + (k,) if prefix is not None else (k,) + d[key] = from_metadata( + val, prefix=prefix + (key,) if prefix is not None else (key,) ) result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata) if is_locked: @@ -127,7 +158,18 @@ def from_metadata(metadata=metadata, prefix=None): result._consolidated = consolidated return result - return from_metadata() + result = from_metadata() + # Then assign the nested tensors + while not _nt_values_and_keys.empty(): + vals, key = _nt_values_and_keys.get() + lengths = _nt_lengths.get() + offsets = _nt_offsets.get() + value = torch.nested.nested_tensor_from_jagged( + vals, offsets=offsets, lengths=lengths + ) + result._set_tuple(key, value, inplace=False, validated=True) + + return result def _make_td(cls, state): diff --git a/tensordict/base.py b/tensordict/base.py index 5f54abb0b..aeeb56f23 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -21,7 +21,7 @@ from concurrent.futures import Future, ThreadPoolExecutor, wait from copy import copy -from functools import partial, wraps +from functools import wraps from pathlib import Path from textwrap import indent from threading import Thread @@ -3522,8 +3522,17 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata): flat_size = [] start = 0 - def add_single_value(value, key, metadata_dict, dtype, shape, flat_size): + def add_single_value( + value, + key: str, + metadata_dict: dict, + dtype: torch.dtype, + shape: torch.Size, + total_key_str: NestedKey, + ): + nonlocal flat_size nonlocal start + nonlocal flat_key_values n = value.element_size() * value.numel() if need_padding: pad = n % 8 @@ -3543,14 +3552,21 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size): pad, ) start = stop + flat_key_values[total_key_str] = value + + main_queue = queue.Queue() + njt_vals_queue = queue.Queue() + njt_offsets_queue = queue.Queue() + njt_lengths_queue = queue.Queue() def assign( key, value, track_key=(), metadata_dict=metadata_dict, - flat_size=flat_size, ): + nonlocal flat_size + total_key = key if isinstance(key, tuple) else (key,) total_key = track_key + total_key cls = type(value) @@ -3572,12 +3588,20 @@ def assign( "leaves": {}, "cls_metadata": value._reduce_get_metadata(), } - local_assign = partial( - assign, - track_key=total_key, - metadata_dict=metadata_dict_key, - flat_size=flat_size, - ) + + def local_assign( + key, + value, + total_key_=total_key, + metadata_dict_key_=metadata_dict_key, + ): + return assign( + key, + value, + track_key=total_key_, + metadata_dict=metadata_dict_key_, + ) + value._fast_apply( local_assign, named=True, @@ -3603,43 +3627,48 @@ def assign( # We will rely on the fact that the writing order is preserved in python dict # (since python 3.7). Later, we will read the NJT then the NJT offset in that order # to do the allocation. - flat_key_values[_prefix_last_key(total_key, "")] = value - flat_size.append(0) - flat_key_values[_prefix_last_key(total_key, "")] = ( - values + njt_vals_queue.put( + ( + value, + "" + key, + metadata_dict, + torch.uint8, + (), + _prefix_last_key(total_key, ""), + ) ) - add_single_value( - values, - _prefix_last_key(key, ""), - metadata_dict, - values.dtype, - shape, - flat_size, + njt_vals_queue.put( + ( + values, + "" + key, + metadata_dict, + values.dtype, + shape, + _prefix_last_key(total_key, ""), + ) ) # Lengths if lengths is not None: - flat_key_values[ - _prefix_last_key(total_key, "") - ] = lengths - add_single_value( - lengths, - _prefix_last_key(key, ""), - metadata_dict, - lengths.dtype, - lengths.shape, - flat_size, + njt_lengths_queue.put( + ( + lengths, + "" + key, + metadata_dict, + lengths.dtype, + lengths.shape, + _prefix_last_key(total_key, ""), + ) ) # Offsets - flat_key_values[_prefix_last_key(total_key, "")] = ( - offsets - ) - add_single_value( - offsets, - _prefix_last_key(key, ""), - metadata_dict, - offsets.dtype, - offsets.shape, - flat_size, + njt_offsets_queue.put( + ( + offsets, + "" + key, + metadata_dict, + offsets.dtype, + offsets.shape, + _prefix_last_key(total_key, ""), + ) ) else: @@ -3647,15 +3676,15 @@ def assign( "NST is not supported, please use layout=torch.jagged when building the nested tensor." ) return - flat_key_values[total_key] = value - add_single_value( - value, - key, - metadata_dict, - value.dtype, - value.shape, - # value.device, - flat_size, + main_queue.put( + ( + value, + key, + metadata_dict, + value.dtype, + value.shape, + total_key, + ) ) self._fast_apply( @@ -3666,6 +3695,19 @@ def assign( is_leaf=_NESTED_TENSORS_AS_LISTS_NONTENSOR, filter_empty=True, ) + + while not main_queue.empty(): + add_single_value(*main_queue.get()) + if not njt_vals_queue.empty(): + metadata_dict["njt_values_start"] = start + while not njt_vals_queue.empty(): + add_single_value(*njt_vals_queue.get()) + metadata_dict["njt_lengths_start"] = start + while not njt_lengths_queue.empty(): + add_single_value(*njt_lengths_queue.get()) + metadata_dict["njt_offsets_start"] = start + while not njt_offsets_queue.empty(): + add_single_value(*njt_offsets_queue.get()) return metadata_dict, flat_key_values, flat_size, need_padding def consolidate( @@ -3920,6 +3962,11 @@ def _view_and_pad(tensor): # sync if needed self._sync_all() torch.cat(items, out=storage) + _nt_values_and_names = queue.Queue() + _nt_values = queue.Queue() + _nt_lengths = queue.Queue() + _nt_offsets = queue.Queue() + for v, (k, oldv) in _zip_strict( storage.split(flat_size), list(flat_dict.items()) ): @@ -3927,33 +3974,30 @@ def _view_and_pad(tensor): flat_dict[k] = view_old_as_new(v, oldv) elif k[-1].startswith(""): # NJT/NT always comes before offsets/shapes - nt = oldv - assert not v.numel() - nt_lengths = None + newk = k[:-1] + (k[-1].replace("", ""),) + _nt_values_and_names.put((oldv, newk)) del flat_dict[k] elif k[-1].startswith(""): - nt_vaues = view_old_as_new(v, oldv) + _nt_values.put(view_old_as_new(v, oldv)) del flat_dict[k] elif k[-1].startswith(""): - nt_lengths = view_old_as_new(v, oldv) + _nt_lengths.put(view_old_as_new(v, oldv)) del flat_dict[k] elif k[-1].startswith(""): - newk = k[:-1] + (k[-1].replace("", ""),) - nt_offsets = view_old_as_new(v, oldv) + _nt_offsets.put(view_old_as_new(v, oldv)) + if _nt_lengths.qsize() < _nt_offsets.qsize(): + _nt_lengths.put(None) del flat_dict[k] - - val = _rebuild_njt_from_njt( - nt, values=nt_vaues, offsets=nt_offsets, lengths=nt_lengths - ) - - flat_dict[newk] = val - - # delete the nested value to make sure that if there was an - # ordering mismatch we wouldn't be looking at the value key of - # another nested tensor. - del nt, nt_vaues, nt_offsets, nt_lengths else: flat_dict[k] = view_old_as_new(v, oldv) + while not _nt_values_and_names.empty(): + nt, newk = _nt_values_and_names.get() + values = _nt_values.get() + lengths = _nt_lengths.get() + offsets = _nt_offsets.get() + flat_dict[newk] = _rebuild_njt_from_njt( + nt, values=nt_vaues, offsets=nt_offsets, lengths=nt_lengths + ) def assign_val(key, val): if isinstance(key, str): @@ -3982,9 +4026,6 @@ def assign_val(key, val): if use_buffer: with open(filename, "w+b") as f: f.write(total_storage._handler.buffer) - # with open(Path(filename).with_suffix(".json"), "wb") as f: - # metadata_dict["size"] = filesize - # f.write(json.dumps(metadata_dict)) return result @classmethod diff --git a/tensordict/utils.py b/tensordict/utils.py index 7d5c0a624..45afea490 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1540,12 +1540,12 @@ def assert_close( elif not isinstance(input1, torch.Tensor): continue if input1.is_nested: - input1v = input1.values() - input2v = input2.values() - mse = (input1v.to(torch.float) - input2v.to(torch.float)).pow(2).sum() input1o = input1.offsets() input2o = input2.offsets() - mse = mse + (input1o.to(torch.float) - input2o.to(torch.float)).pow(2).sum() + mse = (input1o.to(torch.float) - input2o.to(torch.float)).pow(2).sum() + input1 = input1.values() + input2 = input2.values() + mse = mse + (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum() else: mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum() mse = mse.div(input1.numel()).sqrt().item() diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 86a12f305..b9d80dc11 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -7924,11 +7924,11 @@ def test_consolidate_njt(self, device, use_file, tmpdir, num_threads): "b": {"c": torch.arange(3, dtype=torch.double).expand(4, 3).clone()}, "d": "a string!", "njt": torch.nested.nested_tensor_from_jagged( - torch.arange(10, device=device), - offsets=torch.tensor([0, 2, 5, 8, 10], device=device), + torch.arange(9, -1, -1, device=device), offsets=torch.tensor([0, 2, 5, 8, 10], device=device) ), + "intermediate": torch.randn(4), "njt_lengths": torch.nested.nested_tensor_from_jagged( - torch.arange(10, device=device), + torch.arange(10, 20, device=device), offsets=torch.tensor([0, 2, 5, 8, 10], device=device), lengths=torch.tensor([2, 3, 3, 2], device=device), ), @@ -7958,12 +7958,13 @@ def test_consolidate_njt(self, device, use_file, tmpdir, num_threads): tdload_make, tdload_data = _reduce_td(td) tdload = tdload_make(*tdload_data) - assert (td == tdload).all() + assert assert_allclose_td(td, tdload) td_c = td.consolidate(num_threads=num_threads) + assert assert_allclose_td(td, td_c.contiguous()) tdload_make, tdload_data = _reduce_td(td_c) tdload = tdload_make(*tdload_data) - assert assert_allclose_td(td, tdload) + assert assert_allclose_td(td_c, tdload) def check_id(a, b): if isinstance(a, (torch.Size, str)): @@ -7973,7 +7974,7 @@ def check_id(a, b): torch.utils._pytree.tree_map(check_id, td_c._consolidated, tdload._consolidated) assert tdload.is_consolidated() - assert tdload["njt_lengths"]._lengths is not None + # assert tdload["njt_lengths"]._lengths is not None @pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device detected") def test_consolidate_to_device(self):