Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jun 14, 2024
1 parent 441d5e8 commit 5d7f757
Showing 1 changed file with 37 additions and 16 deletions.
53 changes: 37 additions & 16 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import collections
import concurrent.futures
import contextlib
import gc
import importlib
import json
import numbers
Expand Down Expand Up @@ -7287,14 +7288,22 @@ def is_locked(self, value: bool) -> None:

def _propagate_lock(self, lock_parents_weakrefs=None):
"""Registers the parent tensordict that handles the lock."""
if self._is_locked and lock_parents_weakrefs is not None:
lock_parents_weakrefs = [
ref
for ref in lock_parents_weakrefs
if not any(refref is ref for refref in self._lock_parents_weakrefs)
]
self._is_locked = True
is_root = lock_parents_weakrefs is None
if is_root:
lock_parents_weakrefs = []
self._lock_parents_weakrefs = (
self._lock_parents_weakrefs + lock_parents_weakrefs
)
lock_parents_weakrefs = copy(lock_parents_weakrefs) + [weakref.ref(self)]
else:
self._lock_parents_weakrefs = (
self._lock_parents_weakrefs + lock_parents_weakrefs
)
lock_parents_weakrefs = copy(lock_parents_weakrefs)
lock_parents_weakrefs.append(weakref.ref(self))
for value in self.values():
if _is_tensor_collection(type(value)):
value._propagate_lock(lock_parents_weakrefs)
Expand Down Expand Up @@ -7364,24 +7373,35 @@ def _propagate_unlock(self):
sub_tds.append(value)
return sub_tds

def _check_unlock(self):
def _check_unlock(self, first_attempt=True):
if not first_attempt:
gc.collect()
obj = None
for ref in self._lock_parents_weakrefs:
obj = ref()
# check if the locked parent exists and if it's locked
# we check _is_locked because it can be False or None in the case of Lazy stacks,
# but if we check obj.is_locked it will be True for this class.
if obj is not None and obj._is_locked:
raise RuntimeError(
"Cannot unlock a tensordict that is part of a locked graph. "
"Unlock the root tensordict first. If the tensordict is part of multiple graphs, "
"group the graphs under a common tensordict an unlock this root. "
f"self: {self}, obj: {obj}"
)
try:
self._lock_parents_weakrefs = []
except AttributeError:
# Some tds (eg, LazyStack) have an automated way of creating the _lock_parents_weakref
pass
break

else:
try:
self._lock_parents_weakrefs = []
except AttributeError:
# Some tds (eg, LazyStack) have an automated way of creating the _lock_parents_weakref
pass
return

if first_attempt:
del obj
return self._check_unlock(False)
raise RuntimeError(
"Cannot unlock a tensordict that is part of a locked graph. "
"Unlock the root tensordict first. If the tensordict is part of multiple graphs, "
"group the graphs under a common tensordict an unlock this root. "
f"self: {self}, obj: {obj}"
)

@as_decorator("is_locked")
def unlock_(self) -> T:
Expand All @@ -7395,6 +7415,7 @@ def unlock_(self) -> T:
sub_tds = self._propagate_unlock()
for sub_td in sub_tds:
sub_td._check_unlock()

self._check_unlock()
except RuntimeError as err:
self.lock_()
Expand Down

0 comments on commit 5d7f757

Please sign in to comment.