Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] RB compability with compile #2426

Open
wants to merge 13 commits into
base: gh/vmoens/25/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ class ReplayBuffer:
.. warning:: As of now, the generator has no effect on the transforms.
shared (bool, optional): whether the buffer will be shared using multiprocessing or not.
Defaults to ``False``.
compilable (bool, optional): whether the writer is compilable.
If ``True``, the writer cannot be shared between multiple processes.
Defaults to ``False``.
Examples:
>>> import torch
Expand Down Expand Up @@ -216,11 +219,20 @@ def __init__(
checkpointer: "StorageCheckpointerBase" | None = None, # noqa: F821
generator: torch.Generator | None = None,
shared: bool = False,
compilable: bool = None,
) -> None:
self._storage = storage if storage is not None else ListStorage(max_size=1_000)
self._storage.attach(self)
if compilable is not None:
self._storage._compilable = compilable
self._storage._len = self._storage._len

self._sampler = sampler if sampler is not None else RandomSampler()
self._writer = writer if writer is not None else RoundRobinWriter()
self._writer = (
writer
if writer is not None
else RoundRobinWriter(compilable=bool(compilable))
)
self._writer.register_storage(self._storage)

self._get_collate_fn(collate_fn)
Expand Down Expand Up @@ -601,7 +613,9 @@ def _add(self, data):
return index

def _extend(self, data: Sequence) -> torch.Tensor:
with self._replay_lock, self._write_lock:
is_compiling = torch.compiler.is_dynamo_compiling()
nc = contextlib.nullcontext()
with self._replay_lock if not is_compiling else nc, self._write_lock if not is_compiling else nc:
if self.dim_extend > 0:
data = self._transpose(data)
index = self._writer.extend(data)
Expand Down Expand Up @@ -654,7 +668,7 @@ def update_priority(

@pin_memory_output
def _sample(self, batch_size: int) -> Tuple[Any, dict]:
with self._replay_lock:
with self._replay_lock if not torch.compiler.is_dynamo_compiling() else contextlib.nullcontext():
index, info = self._sampler.sample(self._storage, batch_size)
info["index"] = index
data = self._storage.get(index)
Expand Down Expand Up @@ -1753,6 +1767,7 @@ def __init__(
num_buffer_sampled: int | None = None,
generator: torch.Generator | None = None,
shared: bool = False,
compilable: bool = False,
**kwargs,
):

Expand Down
47 changes: 32 additions & 15 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,15 @@ class Storage:
_rng: torch.Generator | None = None

def __init__(
self, max_size: int, checkpointer: StorageCheckpointerBase | None = None
self,
max_size: int,
checkpointer: StorageCheckpointerBase | None = None,
compilable: bool = False,
) -> None:
self.max_size = int(max_size)
self.checkpointer = checkpointer
self._compilable = compilable
self._attached_entities_set = set()

@property
def checkpointer(self):
Expand All @@ -80,11 +85,11 @@ def _is_full(self):
def _attached_entities(self):
# RBs that use a given instance of Storage should add
# themselves to this set.
_attached_entities = self.__dict__.get("_attached_entities_set", None)
if _attached_entities is None:
_attached_entities = set()
self.__dict__["_attached_entities_set"] = _attached_entities
return _attached_entities
return getattr(self, "_attached_entities_set", None)

@torch._dynamo.assume_constant_result
def _attached_entities_iter(self):
return list(self._attached_entities)

@abc.abstractmethod
def set(self, cursor: int, data: Any, *, set_cursor: bool = True):
Expand Down Expand Up @@ -140,6 +145,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
def _empty(self):
...

@torch._dynamo.disable()
def _rand_given_ndim(self, batch_size):
# a method to return random indices given the storage ndim
if self.ndim == 1:
Expand Down Expand Up @@ -330,6 +336,9 @@ class TensorStorage(Storage):
measuring the storage size. For instance, a storage of shape ``[3, 4]``
has capacity ``3`` if ``ndim=1`` and ``12`` if ``ndim=2``.
Defaults to ``1``.
compilable (bool, optional): whether the storage is compilable.
If ``True``, the writer cannot be shared between multiple processes.
Defaults to ``False``.
Examples:
>>> data = TensorDict({
Expand Down Expand Up @@ -389,6 +398,7 @@ def __init__(
*,
device: torch.device = "cpu",
ndim: int = 1,
compilable: bool = False,
):
if not ((storage is None) ^ (max_size is None)):
if storage is None:
Expand All @@ -404,7 +414,7 @@ def __init__(
else:
max_size = tree_flatten(storage)[0][0].shape[0]
self.ndim = ndim
super().__init__(max_size)
super().__init__(max_size, compilable=compilable)
self.initialized = storage is not None
if self.initialized:
self._len = max_size
Expand All @@ -423,16 +433,23 @@ def __init__(
@property
def _len(self):
_len_value = self.__dict__.get("_len_value", None)
if not self._compilable or not isinstance(self._len_value, int):
if _len_value is None:
_len_value = self._len_value = mp.Value("i", 0)
return _len_value.value
if _len_value is None:
_len_value = self._len_value = mp.Value("i", 0)
return _len_value.value
_len_value = self._len_value = 0
return _len_value

@_len.setter
def _len(self, value):
_len_value = self.__dict__.get("_len_value", None)
if _len_value is None:
_len_value = self._len_value = mp.Value("i", 0)
_len_value.value = value
if not self._compilable:
if _len_value is None:
_len_value = self._len_value = mp.Value("i", 0)
_len_value.value = value
else:
self._len_value = value

@property
def _total_shape(self):
Expand Down Expand Up @@ -1184,9 +1201,9 @@ def _rng(self, value):
for storage in self._storages:
storage._rng = value

@property
def _attached_entities(self):
return set()
# @property
# def _attached_entities(self):
# return set()

def extend(self, value):
raise RuntimeError
Expand Down
58 changes: 40 additions & 18 deletions torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ class Writer(ABC):
_storage: Storage
_rng: torch.Generator | None = None

def __init__(self) -> None:
def __init__(self, compilable: bool = False) -> None:
self._storage = None
self._compilable = compilable

def register_storage(self, storage: Storage) -> None:
self._storage = storage
Expand Down Expand Up @@ -138,10 +139,17 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:


class RoundRobinWriter(Writer):
"""A RoundRobin Writer class for composable replay buffers."""
"""A RoundRobin Writer class for composable replay buffers.
def __init__(self, **kw) -> None:
super().__init__(**kw)
Args:
compilable (bool, optional): whether the writer is compilable.
If ``True``, the writer cannot be shared between multiple processes.
Defaults to ``False``.
"""

def __init__(self, compilable: bool = False) -> None:
super().__init__(compilable=compilable)
self._cursor = 0

def dumps(self, path):
Expand Down Expand Up @@ -197,7 +205,7 @@ def extend(self, data: Sequence) -> torch.Tensor:
# Other than that, a "flat" (1d) index is ok to write the data
self._storage.set(index, data)
index = self._replicate_index(index)
for ent in self._storage._attached_entities:
for ent in self._storage._attached_entities_iter():
ent.mark_update(index)
return index

Expand All @@ -213,30 +221,44 @@ def _empty(self):
@property
def _cursor(self):
_cursor_value = self.__dict__.get("_cursor_value", None)
if not self._compilable or not isinstance(_cursor_value, int):
if _cursor_value is None:
_cursor_value = self._cursor_value = mp.Value("i", 0)
return _cursor_value.value
if _cursor_value is None:
_cursor_value = self._cursor_value = mp.Value("i", 0)
return _cursor_value.value
_cursor_value = self._cursor_value = 0
return _cursor_value

@_cursor.setter
def _cursor(self, value):
_cursor_value = self.__dict__.get("_cursor_value", None)
if _cursor_value is None:
_cursor_value = self._cursor_value = mp.Value("i", 0)
_cursor_value.value = value
if not self._compilable:
_cursor_value = self.__dict__.get("_cursor_value", None)
if _cursor_value is None:
_cursor_value = self._cursor_value = mp.Value("i", 0)
_cursor_value.value = value
else:
self._cursor_value = value

@property
def _write_count(self):
_write_count = self.__dict__.get("_write_count_value", None)
if not self._compilable or not isinstance(_write_count, int):
if _write_count is None:
_write_count = self._write_count_value = mp.Value("i", 0)
return _write_count.value
if _write_count is None:
_write_count = self._write_count_value = mp.Value("i", 0)
return _write_count.value
_write_count = self._write_count_value = 0
return _write_count

@_write_count.setter
def _write_count(self, value):
_write_count = self.__dict__.get("_write_count_value", None)
if _write_count is None:
_write_count = self._write_count_value = mp.Value("i", 0)
_write_count.value = value
if not self._compilable:
_write_count = self.__dict__.get("_write_count_value", None)
if _write_count is None:
_write_count = self._write_count_value = mp.Value("i", 0)
_write_count.value = value
else:
self._write_count_value = value

def __getstate__(self):
state = super().__getstate__()
Expand All @@ -248,7 +270,7 @@ def __getstate__(self):

def __setstate__(self, state):
cursor = state.pop("cursor__context", None)
if cursor is not None:
if not state["_compilable"] and cursor is not None:
_cursor_value = mp.Value("i", cursor)
state["_cursor_value"] = _cursor_value
self.__dict__.update(state)
Expand Down
Loading