Skip to content

Commit

Permalink
[BugFix] Fix map (#862)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 9, 2024
1 parent 612fbbc commit ceb9d05
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 63 deletions.
111 changes: 49 additions & 62 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6420,76 +6420,63 @@ def _map(
pbar: bool = False,
iterable: bool,
):
def _get_imap(
pool=pool,
dim=dim,
out=out,
fn=fn,
chunksize=chunksize,
num_chunks=num_chunks,
num_workers = pool._processes
dim_orig = dim
if dim < 0:
dim = self.ndim + dim
if dim < 0 or dim >= self.ndim:
raise ValueError(f"Got incompatible dimension {dim_orig}")

self_split = _split_tensordict(
self,
chunksize,
num_chunks,
num_workers,
dim,
shuffle=shuffle,
index_with_generator=index_with_generator,
pbar=pbar,
):
num_workers = pool._processes
dim_orig = dim
if dim < 0:
dim = self.ndim + dim
if dim < 0 or dim >= self.ndim:
raise ValueError(f"Got incompatible dimension {dim_orig}")

self_split = _split_tensordict(
self,
chunksize,
num_chunks,
num_workers,
dim,
shuffle=shuffle,
use_generator=index_with_generator,
)
if not index_with_generator:
length = len(self_split)
else:
length = None
call_chunksize = 1

if out is not None and (out.is_shared() or out.is_memmap()):

def wrap_fn_with_out(fn, out):
@wraps(fn)
def newfn(item_and_out):
item, out = item_and_out
result = fn(item)
out.update_(result)
return

out_split = _split_tensordict(
out,
chunksize,
num_chunks,
num_workers,
dim,
shuffle=shuffle,
use_generator=index_with_generator,
)
return _CloudpickleWrapper(newfn), zip(self_split, out_split)
use_generator=index_with_generator,
)
if not index_with_generator:
length = len(self_split)
else:
length = None
call_chunksize = 1

if out is not None and (out.is_shared() or out.is_memmap()):

def wrap_fn_with_out(fn, out):
@wraps(fn)
def newfn(item_and_out):
item, out = item_and_out
result = fn(item)
out.update_(result)
return

out_split = _split_tensordict(
out,
chunksize,
num_chunks,
num_workers,
dim,
shuffle=shuffle,
use_generator=index_with_generator,
)
return _CloudpickleWrapper(newfn), zip(self_split, out_split)

fn, self_split = wrap_fn_with_out(fn, out)
out = None
fn, self_split = wrap_fn_with_out(fn, out)
out = None

imap_fn = pool.imap if not shuffle else pool.imap_unordered
imap = imap_fn(fn, self_split, call_chunksize)
imap_fn = pool.imap if not shuffle else pool.imap_unordered
imap = imap_fn(fn, self_split, call_chunksize)

if pbar and importlib.util.find_spec("tqdm", None) is not None:
import tqdm
if pbar and importlib.util.find_spec("tqdm", None) is not None:
import tqdm

imap = tqdm.tqdm(imap, total=length)
return imap, out
imap = tqdm.tqdm(imap, total=length)

if iterable:
return _get_imap()[0]
return imap
else:
imap, out = _get_imap()
imaplist = []
start = 0
base_index = (slice(None),) * dim
Expand Down
2 changes: 1 addition & 1 deletion test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6791,7 +6791,7 @@ def test_mp(self, td_type):
elif td_type == "memmap":
tensordict = tensordict.memmap_()
elif td_type == "memmap_stack":
tensordict = stack_td(
tensordict = TensorDict.lazy_stack(
[
tensordict[0].clone().memmap_(),
tensordict[1].clone().memmap_(),
Expand Down

0 comments on commit ceb9d05

Please sign in to comment.