Skip to content

Commit

Permalink
[Feature] Compile - nn compatibility
Browse files Browse the repository at this point in the history
ghstack-source-id: 6fb01c93298d57ad12f4799488ac52691f722d5f
Pull Request resolved: #881
  • Loading branch information
vmoens committed Jul 12, 2024
1 parent df9c196 commit 7768a53
Show file tree
Hide file tree
Showing 12 changed files with 447 additions and 101 deletions.
1 change: 1 addition & 0 deletions .github/unittest/linux/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ lib_dir="${env_dir}/lib"
# solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir
export MKL_THREADING_LAYER=GNU
export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1

coverage run -m pytest test/smoke_test.py -v --durations 20
coverage run -m pytest --instafail -v --durations 20 --timeout 120
Expand Down
1 change: 1 addition & 0 deletions .github/unittest/linux_torchrec/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ lib_dir="${env_dir}/lib"
# solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir
export MKL_THREADING_LAYER=GNU
export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1

coverage run -m pytest test/smoke_test.py -v --durations 20
coverage run -m pytest --instafail -v --durations 20
Expand Down
1 change: 1 addition & 0 deletions .github/unittest/rl_linux_optdeps/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ git config --global --add safe.directory '*'
root_dir="$(git rev-parse --show-toplevel)"
export MKL_THREADING_LAYER=GNU
export CKPT_BACKEND=torch
export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1

#MUJOCO_GL=glfw pytest --cov=torchrl --junitxml=test-results/junit.xml -v --durations 20
MUJOCO_GL=egl python -m pytest rl/test --instafail -v --durations 20 --ignore rl/test/test_distributed.py
23 changes: 12 additions & 11 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
TensorDictBase,
)
from tensordict.utils import (
_as_context_manager,
_broadcast_tensors,
_get_shape_from_args,
_getitem_batch_size,
Expand All @@ -69,7 +70,6 @@
_shape,
_td_fields,
_unravel_key_to_tuple,
as_decorator,
cache,
convert_ellipsis_to_idx,
DeviceType,
Expand Down Expand Up @@ -2651,16 +2651,17 @@ def _lock_parents_weakrefs(self):
]
return _lock_parents_weakrefs

def _propagate_lock(self, lock_parents_weakrefs=None):
def _propagate_lock(self, lock_parents_weakrefs=None, *, is_compiling):
"""Registers the parent tensordict that handles the lock."""
self._is_locked = True
is_root = lock_parents_weakrefs is None
if is_root:
lock_parents_weakrefs = []
if not is_compiling:
is_root = lock_parents_weakrefs is None
if is_root:
lock_parents_weakrefs = []

lock_parents_weakrefs = copy(lock_parents_weakrefs) + [weakref.ref(self)]
lock_parents_weakrefs = copy(lock_parents_weakrefs) + [weakref.ref(self)]
for dest in self.tensordicts:
dest._propagate_lock(lock_parents_weakrefs)
dest._propagate_lock(lock_parents_weakrefs, is_compiling=is_compiling)

@erase_cache
def _propagate_unlock(self):
Expand Down Expand Up @@ -3380,13 +3381,13 @@ def is_locked(self, value) -> bool:
else:
self.unlock_()

@as_decorator("is_locked")
@_as_context_manager("is_locked")
def lock_(self) -> T:
self._source.lock_()
return self

@erase_cache
@as_decorator("is_locked")
@_as_context_manager("is_locked")
def unlock_(self) -> T:
self._source.unlock_()
return self
Expand All @@ -3395,8 +3396,8 @@ def _remove_lock(self, lock_id):
return self._source._remove_lock(lock_id)

@erase_cache
def _propagate_lock(self, lock_ids):
return self._source._propagate_lock(lock_ids)
def _propagate_lock(self, lock_ids, *, is_compiling):
return self._source._propagate_lock(lock_ids, is_compiling=is_compiling)

@erase_cache
def _propagate_unlock(self):
Expand Down
71 changes: 38 additions & 33 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numbers
import os
import weakref
from collections import defaultdict
from concurrent.futures import Future, ThreadPoolExecutor, wait
from copy import copy
Expand All @@ -17,10 +18,8 @@
from warnings import warn

import numpy as np

import orjson as json
import torch

from tensordict.base import (
_ACCEPTED_CLASSES,
_default_is_leaf,
Expand All @@ -35,12 +34,13 @@
T,
TensorDictBase,
)

from tensordict.memmap import MemoryMappedTensor
from tensordict.utils import (
_add_batch_dim_pre_hook,
_as_context_manager,
_BatchedUninitializedBuffer,
_BatchedUninitializedParameter,
_check_inbuild,
_clone_value,
_expand_to_match_shape,
_get_item,
Expand All @@ -64,7 +64,6 @@
_StringOnlyDict,
_sub_index,
_unravel_key_to_tuple,
as_decorator,
Buffer,
cache,
convert_ellipsis_to_idx,
Expand All @@ -80,7 +79,6 @@
unravel_key_list,
)
from torch import Tensor

from torch._dynamo import graph_break
from torch.jit._shape_functions import infer_size_impl
from torch.utils._pytree import tree_map
Expand Down Expand Up @@ -410,6 +408,9 @@ def _to_module(
use_state_dict: bool = False,
non_blocking: bool = False,
):
is_dynamo = torch.compiler.is_dynamo_compiling()
if is_dynamo:
_check_inbuild()

if not use_state_dict and isinstance(module, TensorDictBase):
if return_swap:
Expand All @@ -420,13 +421,11 @@ def _to_module(
module.update(self)
return

# we use __dict__ directly to avoid the getattr/setattr overhead whenever we can
__dict__ = module.__dict__

hooks = memo["hooks"]
if return_swap:
_swap = {}
memo[id(module)] = _swap
if not is_dynamo:
memo[weakref.ref(module)] = _swap

if use_state_dict:
if inplace is not None:
Expand Down Expand Up @@ -466,9 +465,18 @@ def convert_type(x, y):
input = self
inplace = bool(inplace)

# we use __dict__ directly to avoid the getattr/setattr overhead whenever we can
if module.__class__.__setattr__ is __base__setattr__:
__dict__ = module.__dict__
else:
__dict__ = None

for key, value in input.items():
if isinstance(value, (Tensor, ftdim.Tensor)):
if module.__class__.__setattr__ is __base__setattr__:
# For Dynamo, we use regular set/delattr as we're not
# much afraid by overhead (and dynamo doesn't like those
# hacks we're doing).
if __dict__ is not None:
# if setattr is the native nn.Module.setattr, we can rely on _set_tensor_dict
local_out = _set_tensor_dict(
__dict__, hooks, module, key, value, inplace
Expand All @@ -490,19 +498,15 @@ def convert_type(x, y):
# if there is at least one key, we must populate the module.
# Otherwise, we just go to the next key
continue
child = __dict__["_modules"][key]
local_out = memo.get(id(child), NO_DEFAULT)

if local_out is NO_DEFAULT:
# if isinstance(child, TensorDictBase):
# # then child is a TensorDictParams
# from tensordict.nn import TensorDictParams
#
# local_out = child
# if not isinstance(value, TensorDictParams):
# value = TensorDictParams(value, no_convert=True)
# __dict__["_modules"][key] = value
# else:
if __dict__ is not None:
child = __dict__["_modules"][key]
else:
child = module._modules.get(key)

if not is_dynamo:
local_out = memo.get(weakref.ref(child), NO_DEFAULT)

if is_dynamo or local_out is NO_DEFAULT:
local_out = value._to_module(
child,
inplace=inplace,
Expand All @@ -515,6 +519,7 @@ def convert_type(x, y):

if return_swap:
_swap[key] = local_out

if return_swap:
if isinstance(swap_dest, dict):
return _swap
Expand Down Expand Up @@ -3770,7 +3775,7 @@ def is_locked(self, value) -> bool:
else:
self.unlock_()

@as_decorator("is_locked")
@_as_context_manager("is_locked")
def lock_(self) -> T:
# we can't lock sub-tensordicts because that would mean that the
# parent tensordict cannot be modified either.
Expand All @@ -3780,7 +3785,7 @@ def lock_(self) -> T:
)
return self

@as_decorator("is_locked")
@_as_context_manager("is_locked")
def unlock_(self) -> T:
if self.is_locked:
raise RuntimeError(
Expand All @@ -3793,7 +3798,7 @@ def _remove_lock(self, lock_id):
"Cannot unlock a _SubTensorDict. Unlock the parent tensordict instead."
)

def _propagate_lock(self, lock_ids=None):
def _propagate_lock(self, lock_ids=None, *, is_compiling):
raise RuntimeError(
"Cannot lock a _SubTensorDict. Lock the parent tensordict instead."
)
Expand Down Expand Up @@ -4075,7 +4080,7 @@ def __repr__(self):


def _set_tensor_dict( # noqa: F811
module_dict,
__dict__,
hooks,
module: torch.nn.Module,
name: str,
Expand All @@ -4084,12 +4089,12 @@ def _set_tensor_dict( # noqa: F811
) -> None:
"""Simplified version of torch.nn.utils._named_member_accessor."""
was_buffer = False
out = module_dict["_parameters"].pop(name, None) # type: ignore[assignment]
out = __dict__["_parameters"].pop(name, None) # type: ignore[assignment]
if out is None:
out = module_dict["_buffers"].pop(name, None)
out = __dict__["_buffers"].pop(name, None)
was_buffer = out is not None
if out is None:
out = module_dict.pop(name)
out = __dict__.pop(name)
if inplace:
# swap tensor and out after updating out
out_tmp = out.clone()
Expand All @@ -4102,7 +4107,7 @@ def _set_tensor_dict( # noqa: F811
output = hook(module, name, tensor)
if output is not None:
tensor = output
module_dict["_parameters"][name] = tensor
__dict__["_parameters"][name] = tensor

if isinstance(
tensor, (_BatchedUninitializedParameter, _BatchedUninitializedBuffer)
Expand All @@ -4112,9 +4117,9 @@ def _set_tensor_dict( # noqa: F811
)

elif was_buffer and isinstance(tensor, torch.Tensor):
module_dict["_buffers"][name] = tensor
__dict__["_buffers"][name] = tensor
else:
module_dict[name] = tensor
__dict__[name] = tensor
return out


Expand Down
Loading

0 comments on commit 7768a53

Please sign in to comment.