Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jul 3, 2024
1 parent 78c236e commit 0464836
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 116 deletions.
253 changes: 144 additions & 109 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6150,22 +6150,55 @@ def map(
at little cost.
"""
return self._map(
fn,
dim=dim,
num_workers=num_workers,
chunksize=chunksize,
num_chunks=num_chunks,
pool=pool,
generator=generator,
max_tasks_per_child=max_tasks_per_child,
worker_threads=worker_threads,
index_with_generator=index_with_generator,
pbar=pbar,
mp_start_method=mp_start_method,
out=out,
iterable=False,
)
from torch import multiprocessing as mp

if pool is None:
if num_workers is None:
num_workers = mp.cpu_count() # Get the number of CPU cores
if generator is None:
generator = torch.Generator()
seed = (
torch.empty((), dtype=torch.int64).random_(generator=generator).item()
)
if mp_start_method is not None:
ctx = mp.get_context(mp_start_method)
else:
ctx = mp.get_context()

queue = ctx.Queue(maxsize=num_workers)
for i in range(num_workers):
queue.put(i)
with ctx.Pool(
processes=num_workers,
initializer=_proc_init,
initargs=(seed, queue, worker_threads),
maxtasksperchild=max_tasks_per_child,
) as pool:
return self._map(
fn=fn,
dim=dim,
chunksize=chunksize,
num_chunks=num_chunks,
pool=pool,
pbar=pbar,
out=out,
index_with_generator=index_with_generator,
iterable=False,
shuffle=False,
)
else:
return self._map(
fn=fn,
dim=dim,
chunksize=chunksize,
num_chunks=num_chunks,
pool=pool,
pbar=pbar,
out=out,
index_with_generator=index_with_generator,
iterable=False,
shuffle=False,
)

def map_iter(
self,
Expand Down Expand Up @@ -6303,42 +6336,6 @@ def map_iter(
in a dataloader-like fashion.
"""
yield from self._map(
fn,
dim=dim,
num_workers=num_workers,
shuffle=shuffle,
chunksize=chunksize,
num_chunks=num_chunks,
pool=pool,
generator=generator,
max_tasks_per_child=max_tasks_per_child,
worker_threads=worker_threads,
index_with_generator=index_with_generator,
pbar=pbar,
mp_start_method=mp_start_method,
iterable=True,
)

def _map(
self,
fn: Callable[[TensorDictBase], TensorDictBase | None],
dim: int = 0,
num_workers: int | None = None,
*,
shuffle: bool = False,
out: TensorDictBase | None = None,
chunksize: int | None = None,
num_chunks: int | None = None,
pool: mp.Pool | None = None,
generator: torch.Generator | None = None,
max_tasks_per_child: int | None = None,
worker_threads: int = 1,
index_with_generator: bool = False,
pbar: bool = False,
mp_start_method: str | None = None,
iterable: bool,
):
from torch import multiprocessing as mp

if pool is None:
Expand All @@ -6365,85 +6362,123 @@ def _map(
)
try:
yield from self._map(
fn,
fn=fn,
dim=dim,
chunksize=chunksize,
num_chunks=num_chunks,
pool=pool,
pbar=pbar,
out=out,
out=None,
index_with_generator=index_with_generator,
iterable=iterable,
iterable=True,
shuffle=shuffle,
)
return
finally:
try:
pool.terminate()
finally:
pool.join()
else:
yield from self._map(
fn=fn,
dim=dim,
chunksize=chunksize,
num_chunks=num_chunks,
pool=pool,
pbar=pbar,
out=None,
index_with_generator=index_with_generator,
iterable=True,
shuffle=shuffle,
)

num_workers = pool._processes
dim_orig = dim
if dim < 0:
dim = self.ndim + dim
if dim < 0 or dim >= self.ndim:
raise ValueError(f"Got incompatible dimension {dim_orig}")

self_split = _split_tensordict(
self,
chunksize,
num_chunks,
num_workers,
dim,
def _map(
self,
fn: Callable[[TensorDictBase], TensorDictBase | None],
dim: int = 0,
*,
shuffle: bool = False,
out: TensorDictBase | None = None,
chunksize: int | None = None,
num_chunks: int | None = None,
pool: mp.Pool | None = None,
index_with_generator: bool = False,
pbar: bool = False,
iterable: bool,
):
def _get_imap(
pool=pool,
dim=dim,
out=out,
fn=fn,
chunksize=chunksize,
num_chunks=num_chunks,
shuffle=shuffle,
use_generator=index_with_generator,
)
if not index_with_generator:
length = len(self_split)
else:
length = None
call_chunksize = 1

if out is not None and (out.is_shared() or out.is_memmap()):

def wrap_fn_with_out(fn, out):
@wraps(fn)
def newfn(item_and_out):
item, out = item_and_out
result = fn(item)
out.update_(result)
return

out_split = _split_tensordict(
out,
chunksize,
num_chunks,
num_workers,
dim,
shuffle=shuffle,
use_generator=index_with_generator,
)
return _CloudpickleWrapper(newfn), zip(self_split, out_split)
index_with_generator=index_with_generator,
pbar=pbar,
):
num_workers = pool._processes
dim_orig = dim
if dim < 0:
dim = self.ndim + dim
if dim < 0 or dim >= self.ndim:
raise ValueError(f"Got incompatible dimension {dim_orig}")

self_split = _split_tensordict(
self,
chunksize,
num_chunks,
num_workers,
dim,
shuffle=shuffle,
use_generator=index_with_generator,
)
if not index_with_generator:
length = len(self_split)
else:
length = None
call_chunksize = 1

if out is not None and (out.is_shared() or out.is_memmap()):

def wrap_fn_with_out(fn, out):
@wraps(fn)
def newfn(item_and_out):
item, out = item_and_out
result = fn(item)
out.update_(result)
return

out_split = _split_tensordict(
out,
chunksize,
num_chunks,
num_workers,
dim,
shuffle=shuffle,
use_generator=index_with_generator,
)
return _CloudpickleWrapper(newfn), zip(self_split, out_split)

fn, self_split = wrap_fn_with_out(fn, out)
out = None
fn, self_split = wrap_fn_with_out(fn, out)
out = None

imap_fn = pool.imap if not shuffle else pool.imap_unordered
imap = imap_fn(fn, self_split, call_chunksize)
imap_fn = pool.imap if not shuffle else pool.imap_unordered
imap = imap_fn(fn, self_split, call_chunksize)

if pbar and importlib.util.find_spec("tqdm", None) is not None:
import tqdm
if pbar and importlib.util.find_spec("tqdm", None) is not None:
import tqdm

imap = tqdm.tqdm(imap, total=length)
imap = tqdm.tqdm(imap, total=length)
return imap, out

imaplist = []
start = 0
base_index = (slice(None),) * dim
if iterable:
yield from imap
return
return _get_imap()[0]
else:
imap, out = _get_imap()
imaplist = []
start = 0
base_index = (slice(None),) * dim
for item in imap:
if item is not None:
if out is not None:
Expand Down
11 changes: 4 additions & 7 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,22 +127,18 @@ def __subclasscheck__(self, subclass):
"__truediv__",
"_add_batch_dim",
"_apply_nest",
"_multithread_apply_flat",
"_check_unlock",
"_erase_names", # TODO: must be specialized
"_exclude", # TODO: must be specialized
"_fast_apply",
"_get_names_idx", # no wrap output
"_has_names",
"_multithread_apply_flat",
"_propagate_lock",
"_propagate_unlock",
"_remove_batch_dim",
"_select", # TODO: must be specialized
# "_set_at_str",
"_set_at_tuple",
"_set_at_tuple",
# _set_str needs a special treatment to catch keys that are already in
# non tensor data
# "_set_at_tuple",
"_set_str",
"_set_tuple",
"abs",
Expand Down Expand Up @@ -219,6 +215,8 @@ def __subclasscheck__(self, subclass):
"log2",
"log2_",
"log_",
"map",
"map_iter",
"masked_fill",
"masked_fill_",
"maximum",
Expand Down Expand Up @@ -268,7 +266,6 @@ def __subclasscheck__(self, subclass):
"unsqueeze",
"values", # no wrap output
"view",
"_get_names_idx", # no wrap output
"where",
"zero_",
]
Expand Down
1 change: 1 addition & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9089,6 +9089,7 @@ def test_map_iter_interrupt_early(self, chunksize, num_chunks, shuffle):
):
return


# class TestNonTensorData:
class TestNonTensorData:
@pytest.fixture
Expand Down

0 comments on commit 0464836

Please sign in to comment.