From ceb9d05f5e335d44e25d2c8c1f390479c5ff4d24 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jul 2024 14:47:32 +0100 Subject: [PATCH] [BugFix] Fix map (#862) --- tensordict/base.py | 111 ++++++++++++++++++---------------------- test/test_tensordict.py | 2 +- 2 files changed, 50 insertions(+), 63 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index ccea68481..62e72b26c 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -6420,76 +6420,63 @@ def _map( pbar: bool = False, iterable: bool, ): - def _get_imap( - pool=pool, - dim=dim, - out=out, - fn=fn, - chunksize=chunksize, - num_chunks=num_chunks, + num_workers = pool._processes + dim_orig = dim + if dim < 0: + dim = self.ndim + dim + if dim < 0 or dim >= self.ndim: + raise ValueError(f"Got incompatible dimension {dim_orig}") + + self_split = _split_tensordict( + self, + chunksize, + num_chunks, + num_workers, + dim, shuffle=shuffle, - index_with_generator=index_with_generator, - pbar=pbar, - ): - num_workers = pool._processes - dim_orig = dim - if dim < 0: - dim = self.ndim + dim - if dim < 0 or dim >= self.ndim: - raise ValueError(f"Got incompatible dimension {dim_orig}") - - self_split = _split_tensordict( - self, - chunksize, - num_chunks, - num_workers, - dim, - shuffle=shuffle, - use_generator=index_with_generator, - ) - if not index_with_generator: - length = len(self_split) - else: - length = None - call_chunksize = 1 - - if out is not None and (out.is_shared() or out.is_memmap()): - - def wrap_fn_with_out(fn, out): - @wraps(fn) - def newfn(item_and_out): - item, out = item_and_out - result = fn(item) - out.update_(result) - return - - out_split = _split_tensordict( - out, - chunksize, - num_chunks, - num_workers, - dim, - shuffle=shuffle, - use_generator=index_with_generator, - ) - return _CloudpickleWrapper(newfn), zip(self_split, out_split) + use_generator=index_with_generator, + ) + if not index_with_generator: + length = len(self_split) + else: + length = None + call_chunksize = 1 + + if out is not None and (out.is_shared() or out.is_memmap()): + + def wrap_fn_with_out(fn, out): + @wraps(fn) + def newfn(item_and_out): + item, out = item_and_out + result = fn(item) + out.update_(result) + return + + out_split = _split_tensordict( + out, + chunksize, + num_chunks, + num_workers, + dim, + shuffle=shuffle, + use_generator=index_with_generator, + ) + return _CloudpickleWrapper(newfn), zip(self_split, out_split) - fn, self_split = wrap_fn_with_out(fn, out) - out = None + fn, self_split = wrap_fn_with_out(fn, out) + out = None - imap_fn = pool.imap if not shuffle else pool.imap_unordered - imap = imap_fn(fn, self_split, call_chunksize) + imap_fn = pool.imap if not shuffle else pool.imap_unordered + imap = imap_fn(fn, self_split, call_chunksize) - if pbar and importlib.util.find_spec("tqdm", None) is not None: - import tqdm + if pbar and importlib.util.find_spec("tqdm", None) is not None: + import tqdm - imap = tqdm.tqdm(imap, total=length) - return imap, out + imap = tqdm.tqdm(imap, total=length) if iterable: - return _get_imap()[0] + return imap else: - imap, out = _get_imap() imaplist = [] start = 0 base_index = (slice(None),) * dim diff --git a/test/test_tensordict.py b/test/test_tensordict.py index f4ff0ce77..93b01e331 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -6791,7 +6791,7 @@ def test_mp(self, td_type): elif td_type == "memmap": tensordict = tensordict.memmap_() elif td_type == "memmap_stack": - tensordict = stack_td( + tensordict = TensorDict.lazy_stack( [ tensordict[0].clone().memmap_(), tensordict[1].clone().memmap_(),