Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 10, 2025
1 parent 5c03f9f commit 86ab9b7
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tensordict.utils import NestedKey
from torch import distributions as d

from torchrl._utils import _replace_last
from torchrl.objectives.common import LossModule

from torchrl.objectives.utils import (
Expand Down Expand Up @@ -1267,3 +1268,30 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:

def reset(self) -> None:
self.beta = self._beta_init


def _make_lp_get_error(tensor_keys, log_prob, err):
result = (
f"The sample log probability key (tensor_keys.sample_log_prob={tensor_keys.sample_log_prob}) does "
f"not appear in the log-prob tensordict with keys {list(log_prob.keys(True, True))}. "
)
# now check if we can substitute the actions with action_log_prob and retrieve the log-probs
action_keys = tensor_keys.action
if isinstance(action_keys, list):
has_all_log_probs = True
log_prob_keys = []
for action_key in action_keys:
log_prob_key = _replace_last(action_key, "action_log_prob")
log_prob_keys.append(log_prob_key)
if log_prob_key not in log_prob:
has_all_log_probs = False
break
if has_all_log_probs:
result += (
f"The action keys are {action_keys} and all log_prob keys {log_prob_keys} are present in the "
f"log-prob tensordict. Calling `loss.set_keys(sample_log_prob={log_prob_keys})` should resolve "
f"this error."
)
return KeyError(result)
result += "This is usually due to a missing call to loss.set_keys(sample_log_prob=<list_of_log_prob_keys>)."
return KeyError(result)

0 comments on commit 86ab9b7

Please sign in to comment.