diff --git a/tensordict/base.py b/tensordict/base.py index d591bd8b6..124815a95 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -6174,94 +6174,336 @@ def map( initargs=(seed, queue, worker_threads), maxtasksperchild=max_tasks_per_child, ) as pool: - return self.map( - fn, + return self._map( + fn=fn, dim=dim, chunksize=chunksize, num_chunks=num_chunks, pool=pool, pbar=pbar, out=out, + index_with_generator=index_with_generator, + iterable=False, + shuffle=False, ) - num_workers = pool._processes - dim_orig = dim - if dim < 0: - dim = self.ndim + dim - if dim < 0 or dim >= self.ndim: - raise ValueError(f"Got incompatible dimension {dim_orig}") - - self_split = _split_tensordict( - self, - chunksize, - num_chunks, - num_workers, - dim, - use_generator=index_with_generator, - ) - if not index_with_generator: - length = len(self_split) else: - length = None - call_chunksize = 1 - - if out is not None and (out.is_shared() or out.is_memmap()): - - def wrap_fn_with_out(fn, out): - @wraps(fn) - def newfn(item_and_out): - item, out = item_and_out - result = fn(item) - out.update_(result) - return - - out_split = _split_tensordict( - out, - chunksize, - num_chunks, - num_workers, - dim, - use_generator=index_with_generator, + return self._map( + fn=fn, + dim=dim, + chunksize=chunksize, + num_chunks=num_chunks, + pool=pool, + pbar=pbar, + out=out, + index_with_generator=index_with_generator, + iterable=False, + shuffle=False, + ) + + def map_iter( + self, + fn: Callable[[TensorDictBase], TensorDictBase | None], + dim: int = 0, + num_workers: int | None = None, + *, + shuffle: bool = False, + chunksize: int | None = None, + num_chunks: int | None = None, + pool: mp.Pool | None = None, + generator: torch.Generator | None = None, + max_tasks_per_child: int | None = None, + worker_threads: int = 1, + index_with_generator: bool = True, + pbar: bool = False, + mp_start_method: str | None = None, + ): + """Maps a function to splits of the tensordict across one dimension iteratively. + + This is the iterable version of :meth:`~TensorDictBase.map`. + + This method will apply a function to a tensordict instance by chunking + it in tensordicts of equal size and dispatching the operations over the + desired number of workers. It will yield the results one at a time. + + The function signature should be ``Callabe[[TensorDict], Union[TensorDict, Tensor]]``. + The function must be serializable. + + Args: + fn (callable): function to apply to the tensordict. + Signatures similar to ``Callabe[[TensorDict], Union[TensorDict, Tensor]]`` + are supported. + dim (int, optional): the dim along which the tensordict will be chunked. + num_workers (int, optional): the number of workers. Exclusive with ``pool``. + If none is provided, the number of workers will be set to the + number of cpus available. + + Keyword Args: + shuffle (bool, optional): whether the indices should be globally shuffled. + If ``True``, each batch will contain non-contiguous samples. + If ``index_with_generator=False`` and `shuffle=True``, an error will be raised. + Defaults to ``False``. + chunksize (int, optional): The size of each chunk of data. + A ``chunksize`` of 0 will unbind the tensordict along the + desired dimension and restack it after the function is applied, + whereas ``chunksize>0`` will split the tensordict and call + :func:`torch.cat` on the resulting list of tensordicts. + If none is provided, the number of chunks will equate the number + of workers. For very large tensordicts, such large chunks + may not fit in memory for the operation to be done and + more chunks may be needed to make the operation practically + doable. This argument is exclusive with ``num_chunks``. + num_chunks (int, optional): the number of chunks to split the tensordict + into. If none is provided, the number of chunks will equate the number + of workers. For very large tensordicts, such large chunks + may not fit in memory for the operation to be done and + more chunks may be needed to make the operation practically + doable. This argument is exclusive with ``chunksize``. + pool (mp.Pool, optional): a multiprocess Pool instance to use + to execute the job. If none is provided, a pool will be created + within the ``map`` method. + generator (torch.Generator, optional): a generator to use for seeding. + A base seed will be generated from it, and each worker + of the pool will be seeded with the provided seed incremented + by a unique integer from ``0`` to ``num_workers``. If no generator + is provided, a random integer will be used as seed. + To work with unseeded workers, a pool should be created separately + and passed to :meth:`map` directly. + .. note:: + Caution should be taken when providing a low-valued seed as + this can cause autocorrelation between experiments, example: + if 8 workers are asked and the seed is 4, the workers seed will + range from 4 to 11. If the seed is 5, the workers seed will range + from 5 to 12. These two experiments will have an overlap of 7 + seeds, which can have unexpected effects on the results. + + .. note:: + The goal of seeding the workers is to have independent seed on + each worker, and NOT to have reproducible results across calls + of the `map` method. In other words, two experiments may and + probably will return different results as it is impossible to + know which worker will pick which job. However, we can make sure + that each worker has a different seed and that the pseudo-random + operations on each will be uncorrelated. + max_tasks_per_child (int, optional): the maximum number of jobs picked + by every child process. Defaults to ``None``, i.e., no restriction + on the number of jobs. + worker_threads (int, optional): the number of threads for the workers. + Defaults to ``1``. + index_with_generator (bool, optional): if ``True``, the splitting / chunking + of the tensordict will be done during the query, sparing init time. + Note that :meth:`~.chunk` and :meth:`~.split` are much more + efficient than indexing (which is used within the generator) + so a gain of processing time at init time may have a negative + impact on the total runtime. Defaults to ``True``. + + .. note:: The default value of ``index_with_generator`` differs for ``map_iter`` + and ``map`` and the former assumes that it is prohibitively expensive to + store a split version of the TensorDict in memory. + + pbar (bool, optional): if ``True``, a progress bar will be displayed. + Requires tqdm to be available. Defaults to ``False``. + mp_start_method (str, optional): the start method for multiprocessing. + If not provided, the default start method will be used. + Accepted strings are ``"fork"`` and ``"spawn"``. Keep in mind that + ``"cuda"`` tensors cannot be shared between processes with the + ``"fork"`` start method. This is without effect if the ``pool`` + is passed to the ``map`` method. + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> + >>> def process_data(data): + ... data.unlock_() + ... data.set("y", data.get("x") + 1) + ... return data + >>> if __name__ == "__main__": + ... data = TensorDict({"x": torch.zeros(1, 1_000_000)}, [1, 1_000_000]).memmap_() + ... for sample in data.map_iter(process_data, dim=1, chunksize=5): + ... print(sample["y"]) + ... break + ... + tensor([[1., 1., 1., 1., 1.]]) + + .. note:: This method is particularily useful when working with large + datasets stored on disk (e.g. memory-mapped tensordicts) where + chunks will be zero-copied slices of the original data which can + be passed to the processes with virtually zero-cost. This allows + to tread very large datasets (eg. over a Tb big) to be processed + at little cost. + + .. note:: This function be used to represent a dataset and load from it, + in a dataloader-like fashion. + + """ + from torch import multiprocessing as mp + + if pool is None: + if num_workers is None: + num_workers = mp.cpu_count() # Get the number of CPU cores + if generator is None: + generator = torch.Generator() + seed = ( + torch.empty((), dtype=torch.int64).random_(generator=generator).item() + ) + if mp_start_method is not None: + ctx = mp.get_context(mp_start_method) + else: + ctx = mp.get_context() + + queue = ctx.Queue(maxsize=num_workers) + for i in range(num_workers): + queue.put(i) + pool = ctx.Pool( + processes=num_workers, + initializer=_proc_init, + initargs=(seed, queue, worker_threads), + maxtasksperchild=max_tasks_per_child, + ) + try: + yield from self._map( + fn=fn, + dim=dim, + chunksize=chunksize, + num_chunks=num_chunks, + pool=pool, + pbar=pbar, + out=None, + index_with_generator=index_with_generator, + iterable=True, + shuffle=shuffle, ) - return _CloudpickleWrapper(newfn), zip(self_split, out_split) + finally: + try: + pool.terminate() + finally: + pool.join() + else: + yield from self._map( + fn=fn, + dim=dim, + chunksize=chunksize, + num_chunks=num_chunks, + pool=pool, + pbar=pbar, + out=None, + index_with_generator=index_with_generator, + iterable=True, + shuffle=shuffle, + ) - fn, self_split = wrap_fn_with_out(fn, out) - out = None + def _map( + self, + fn: Callable[[TensorDictBase], TensorDictBase | None], + dim: int = 0, + *, + shuffle: bool = False, + out: TensorDictBase | None = None, + chunksize: int | None = None, + num_chunks: int | None = None, + pool: mp.Pool | None = None, + index_with_generator: bool = False, + pbar: bool = False, + iterable: bool, + ): + def _get_imap( + pool=pool, + dim=dim, + out=out, + fn=fn, + chunksize=chunksize, + num_chunks=num_chunks, + shuffle=shuffle, + index_with_generator=index_with_generator, + pbar=pbar, + ): + num_workers = pool._processes + dim_orig = dim + if dim < 0: + dim = self.ndim + dim + if dim < 0 or dim >= self.ndim: + raise ValueError(f"Got incompatible dimension {dim_orig}") + + self_split = _split_tensordict( + self, + chunksize, + num_chunks, + num_workers, + dim, + shuffle=shuffle, + use_generator=index_with_generator, + ) + if not index_with_generator: + length = len(self_split) + else: + length = None + call_chunksize = 1 + + if out is not None and (out.is_shared() or out.is_memmap()): + + def wrap_fn_with_out(fn, out): + @wraps(fn) + def newfn(item_and_out): + item, out = item_and_out + result = fn(item) + out.update_(result) + return + + out_split = _split_tensordict( + out, + chunksize, + num_chunks, + num_workers, + dim, + shuffle=shuffle, + use_generator=index_with_generator, + ) + return _CloudpickleWrapper(newfn), zip(self_split, out_split) - imap = pool.imap(fn, self_split, call_chunksize) + fn, self_split = wrap_fn_with_out(fn, out) + out = None - if pbar and importlib.util.find_spec("tqdm", None) is not None: - import tqdm + imap_fn = pool.imap if not shuffle else pool.imap_unordered + imap = imap_fn(fn, self_split, call_chunksize) - imap = tqdm.tqdm(imap, total=length) + if pbar and importlib.util.find_spec("tqdm", None) is not None: + import tqdm - imaplist = [] - start = 0 - base_index = (slice(None),) * dim - for item in imap: - if item is not None: - if out is not None: - if chunksize == 0: - out[base_index + (start,)].update_(item) - start += 1 + imap = tqdm.tqdm(imap, total=length) + return imap, out + + if iterable: + return _get_imap()[0] + else: + imap, out = _get_imap() + imaplist = [] + start = 0 + base_index = (slice(None),) * dim + for item in imap: + if item is not None: + if out is not None: + if chunksize == 0: + out[base_index + (start,)].update_(item) + start += 1 + else: + end = start + item.shape[dim] + chunk = base_index + (slice(start, end),) + out[chunk].update_(item) + start = end else: - end = start + item.shape[dim] - chunk = base_index + (slice(start, end),) - out[chunk].update_(item) - start = end - else: - imaplist.append(item) - del imap + imaplist.append(item) + del imap - # support inplace modif - if imaplist: - if chunksize == 0: - from tensordict._lazy import LazyStackedTensorDict + # support inplace modif + if imaplist: + if chunksize == 0: + from tensordict._lazy import LazyStackedTensorDict - # We want to be able to return whichever data structure - out = LazyStackedTensorDict.maybe_dense_stack(imaplist, dim) - else: - out = torch.cat(imaplist, dim) - return out + # We want to be able to return whichever data structure + out = LazyStackedTensorDict.maybe_dense_stack(imaplist, dim) + else: + out = torch.cat(imaplist, dim) + return out # point-wise arithmetic ops def __add__(self, other: TensorDictBase | float) -> T: diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index a656b201e..8b6140a23 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -127,22 +127,18 @@ def __subclasscheck__(self, subclass): "__truediv__", "_add_batch_dim", "_apply_nest", - "_multithread_apply_flat", "_check_unlock", "_erase_names", # TODO: must be specialized "_exclude", # TODO: must be specialized "_fast_apply", + "_get_names_idx", # no wrap output "_has_names", + "_multithread_apply_flat", "_propagate_lock", "_propagate_unlock", "_remove_batch_dim", "_select", # TODO: must be specialized - # "_set_at_str", "_set_at_tuple", - "_set_at_tuple", - # _set_str needs a special treatment to catch keys that are already in - # non tensor data - # "_set_at_tuple", "_set_str", "_set_tuple", "abs", @@ -219,6 +215,8 @@ def __subclasscheck__(self, subclass): "log2", "log2_", "log_", + "map", + "map_iter", "masked_fill", "masked_fill_", "maximum", @@ -268,7 +266,6 @@ def __subclasscheck__(self, subclass): "unsqueeze", "values", # no wrap output "view", - "_get_names_idx", # no wrap output "where", "zero_", ] diff --git a/tensordict/utils.py b/tensordict/utils.py index 8b7a4f2b0..b5c64b211 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1253,6 +1253,29 @@ def new_func(_self, *args, **kwargs): return new_func +def _find_smallest_uint(N): + if not hasattr(torch, "uint32"): + # Fallback + return torch.int64 + if N < 0: + raise ValueError("N must be a non-negative integer") + + int8_max = 127 + int16_max = 32767 + int32_max = 2147483647 + int64_max = 9223372036854775807 + if N <= int8_max: + return torch.int8 + elif N <= int16_max: + return torch.int16 + elif N <= int32_max: + return torch.int32 + elif N <= int64_max: + return torch.int64 + else: + return "uint is too large to be represented by uint64" + + def _split_tensordict( td, chunksize, @@ -1261,7 +1284,12 @@ def _split_tensordict( dim, use_generator=False, to_tensordict=False, + shuffle=False, ): + if shuffle and not use_generator: + raise RuntimeError( + "Shuffling is not permitted unless use_generator is set to ``True`` for efficiency purposes." + ) if chunksize is None and num_chunks is None: num_chunks = num_workers if chunksize is not None and num_chunks is not None: @@ -1272,50 +1300,63 @@ def _split_tensordict( num_chunks = min(td.shape[dim], num_chunks) if use_generator: - def _chunk_generator(): - chunksize = -(td.shape[dim] // -num_chunks) + def next_index(td=td, dim=dim, num_chunks=num_chunks): idx_start = 0 - base = (slice(None),) * dim - for _ in range(num_chunks): - idx_end = idx_start + chunksize - out = td[base + (slice(idx_start, idx_end),)] - if to_tensordict: - out = out.to_tensordict() - yield out + n = td.shape[dim] + chunksize = -(n // -num_chunks) + idx_end = chunksize + while idx_start < n: + yield slice(idx_start, idx_end) idx_start = idx_end + idx_end += chunksize - return _chunk_generator() - return td.chunk(num_chunks, dim=dim) + else: + return td.chunk(num_chunks, dim=dim) else: if chunksize == 0: if use_generator: - def _unbind_generator(): - base = (slice(None),) * dim - for i in range(td.shape[dim]): - out = td[base + (i,)] - if to_tensordict: - out = out.to_tensordict() - yield out + def next_index(td=td, dim=dim): + yield from range(td.shape[dim]) - return _unbind_generator() - return td.unbind(dim=dim) - if use_generator: + else: + return td.unbind(dim=dim) + else: + if use_generator: - def _split_generator(): - idx_start = 0 - base = (slice(None),) * dim - for _ in range(num_chunks): - idx_end = idx_start + chunksize - out = td[base + (slice(idx_start, idx_end),)] - if to_tensordict: - out = out.to_tensordict() - yield out - idx_start = idx_end + def next_index(td=td, dim=dim, chunksize=chunksize): + idx_start = 0 + idx_end = chunksize + n = td.shape[dim] + while idx_start < n: + yield slice(idx_start, idx_end) + idx_start = idx_end + idx_end += chunksize - return _split_generator() - chunksize = min(td.shape[dim], chunksize) - return td.split(chunksize, dim=dim) + else: + chunksize = min(td.shape[dim], chunksize) + return td.split(chunksize, dim=dim) + # end up here only when use_generator = True + if shuffle: + + def next_index_shuffle(next_index=next_index): + n = td.shape[dim] + device = td.device + rp = torch.randperm(n, dtype=_find_smallest_uint(n), device=device) + for idx in next_index(): + yield rp[idx].long() + + next_index = next_index_shuffle + + def _split_generator(): + base = (slice(None),) * dim + for idx in next_index(): + out = td[base + (idx,)] + if to_tensordict: + out = out.to_tensordict() + yield out + + return _split_generator() def _parse_to(*args, **kwargs): diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 303583d6c..b12517abd 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -8798,7 +8798,7 @@ def test_modules(self, as_module): assert y._tensor.shape[0] == param_batch -@pytest.mark.skipif(_IS_OSX, reason="Pool execution in osx can hang forever.") +# @pytest.mark.skipif(_IS_OSX, reason="Pool execution in osx can hang forever.") class TestMap: """Tests for TensorDict.map that are independent from tensordict's type.""" @@ -9008,6 +9008,83 @@ def test_non_tensor(self): td = td.map(self.nontensor_check, chunksize=0) assert td["check"].all() + @staticmethod + def _return_identical(td): + return td.clone() + + @pytest.mark.parametrize("shuffle", [False, True]) + @pytest.mark.parametrize( + "chunksize,num_chunks", [[0, None], [11, None], [None, 11]] + ) + def test_map_iter(self, chunksize, num_chunks, shuffle): + torch.manual_seed(0) + td = TensorDict( + { + "a": torch.arange(100), + "b": { + "c": torch.arange(100, 200), + "d": NonTensorStack(*[NonTensorData(str(i)) for i in range(100)]), + }, + }, + batch_size=[100], + ) + strings = set() + a_elts = set() + c_elts = set() + data_prev = None + for data in td.map_iter( + self._return_identical, + shuffle=shuffle, + num_chunks=num_chunks, + chunksize=chunksize, + ): + if data_prev is not None and data.shape == data_prev.shape: + if shuffle: + assert (data != data_prev + data.numel()).any() + else: + assert ( + data.filter_non_tensor_data() + == data_prev.filter_non_tensor_data() + data.numel() + ).all() + d = data["b", "d"] + if not isinstance(d, str): + strings.update(d) + a_elts.update(data["a"].tolist()) + c_elts.update(data["b", "c"].tolist()) + else: + strings.add(d) + a_elts.add(data["a"].item()) + c_elts.add(data["b", "c"].item()) + data_prev = data + + assert a_elts == set(range(100)) + assert c_elts == set(range(100, 200)) + assert strings == {str(i) for i in range(100)} + + @pytest.mark.parametrize("shuffle", [False, True]) + @pytest.mark.parametrize( + "chunksize,num_chunks", [[0, None], [11, None], [None, 11]] + ) + def test_map_iter_interrupt_early(self, chunksize, num_chunks, shuffle): + torch.manual_seed(0) + td = TensorDict( + { + "a": torch.arange(100), + "b": { + "c": torch.arange(100, 200), + "d": NonTensorStack(*[NonTensorData(str(i)) for i in range(100)]), + }, + }, + batch_size=[100], + ) + for _ in td.map_iter( + self._return_identical, + shuffle=shuffle, + num_chunks=num_chunks, + chunksize=chunksize, + ): + return + # class TestNonTensorData: class TestNonTensorData: