Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jul 5, 2024
1 parent 8db175e commit 4383e5f
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7177,7 +7177,7 @@ def sub(self, other: TensorDictBase | float, alpha: float | None = None):
val = self._values_list(True, True)
other_val = other
if alpha is not None:
vals = torch._foreach_sub(vals, other_val, alpha=alpha)
vals = torch._foreach_sub(val, other_val, alpha=alpha)
else:
vals = torch._foreach_sub(vals, other_val)
items = dict(zip(keys, vals))
Expand Down Expand Up @@ -7213,13 +7213,12 @@ def mul_(self, other: TensorDictBase | float) -> T:
return self

def mul(self, other: TensorDictBase | float) -> T:
keys, val = self._items_list(True, True)
if _is_tensor_collection(type(other)):
keys, val = self._items_list(True, True)
other_val = other._values_list(True, True, sorting_keys=keys)
else:
val = self._values_list(True, True)
other_val = other
vals = torch._foreach_mul(vals, other_val)
vals = torch._foreach_mul(val, other_val)
items = dict(zip(keys, vals))
return self._fast_apply(
lambda name, val: items.get(name, val),
Expand Down

0 comments on commit 4383e5f

Please sign in to comment.