From 4383e5f6d7865c4a5a5d25959ea66cc020132d36 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 5 Jul 2024 12:14:30 +0100 Subject: [PATCH] amend --- tensordict/base.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index 8813133ac..5a225ff3a 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -7177,7 +7177,7 @@ def sub(self, other: TensorDictBase | float, alpha: float | None = None): val = self._values_list(True, True) other_val = other if alpha is not None: - vals = torch._foreach_sub(vals, other_val, alpha=alpha) + vals = torch._foreach_sub(val, other_val, alpha=alpha) else: vals = torch._foreach_sub(vals, other_val) items = dict(zip(keys, vals)) @@ -7213,13 +7213,12 @@ def mul_(self, other: TensorDictBase | float) -> T: return self def mul(self, other: TensorDictBase | float) -> T: + keys, val = self._items_list(True, True) if _is_tensor_collection(type(other)): - keys, val = self._items_list(True, True) other_val = other._values_list(True, True, sorting_keys=keys) else: - val = self._values_list(True, True) other_val = other - vals = torch._foreach_mul(vals, other_val) + vals = torch._foreach_mul(val, other_val) items = dict(zip(keys, vals)) return self._fast_apply( lambda name, val: items.get(name, val),