diff --git a/tensordict/base.py b/tensordict/base.py index cd5299184..5494c9037 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -9,6 +9,7 @@ import collections import concurrent.futures import contextlib +import gc import importlib import json import numbers @@ -7287,14 +7288,22 @@ def is_locked(self, value: bool) -> None: def _propagate_lock(self, lock_parents_weakrefs=None): """Registers the parent tensordict that handles the lock.""" + if self._is_locked and lock_parents_weakrefs is not None: + lock_parents_weakrefs = [ + ref + for ref in lock_parents_weakrefs + if not any(refref is ref for refref in self._lock_parents_weakrefs) + ] self._is_locked = True is_root = lock_parents_weakrefs is None if is_root: lock_parents_weakrefs = [] - self._lock_parents_weakrefs = ( - self._lock_parents_weakrefs + lock_parents_weakrefs - ) - lock_parents_weakrefs = copy(lock_parents_weakrefs) + [weakref.ref(self)] + else: + self._lock_parents_weakrefs = ( + self._lock_parents_weakrefs + lock_parents_weakrefs + ) + lock_parents_weakrefs = copy(lock_parents_weakrefs) + lock_parents_weakrefs.append(weakref.ref(self)) for value in self.values(): if _is_tensor_collection(type(value)): value._propagate_lock(lock_parents_weakrefs) @@ -7364,24 +7373,35 @@ def _propagate_unlock(self): sub_tds.append(value) return sub_tds - def _check_unlock(self): + def _check_unlock(self, first_attempt=True): + if not first_attempt: + gc.collect() + obj = None for ref in self._lock_parents_weakrefs: obj = ref() # check if the locked parent exists and if it's locked # we check _is_locked because it can be False or None in the case of Lazy stacks, # but if we check obj.is_locked it will be True for this class. if obj is not None and obj._is_locked: - raise RuntimeError( - "Cannot unlock a tensordict that is part of a locked graph. " - "Unlock the root tensordict first. If the tensordict is part of multiple graphs, " - "group the graphs under a common tensordict an unlock this root. " - f"self: {self}, obj: {obj}" - ) - try: - self._lock_parents_weakrefs = [] - except AttributeError: - # Some tds (eg, LazyStack) have an automated way of creating the _lock_parents_weakref - pass + break + + else: + try: + self._lock_parents_weakrefs = [] + except AttributeError: + # Some tds (eg, LazyStack) have an automated way of creating the _lock_parents_weakref + pass + return + + if first_attempt: + del obj + return self._check_unlock(False) + raise RuntimeError( + "Cannot unlock a tensordict that is part of a locked graph. " + "Unlock the root tensordict first. If the tensordict is part of multiple graphs, " + "group the graphs under a common tensordict an unlock this root. " + f"self: {self}, obj: {obj}" + ) @as_decorator("is_locked") def unlock_(self) -> T: @@ -7395,6 +7415,7 @@ def unlock_(self) -> T: sub_tds = self._propagate_unlock() for sub_td in sub_tds: sub_td._check_unlock() + self._check_unlock() except RuntimeError as err: self.lock_()