From 86ab9b7d7bbe8f24759642798acc6b45c25cd335 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 10 Jan 2025 10:02:49 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- torchrl/objectives/ppo.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 079a1efa92c..53a7bfae5df 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -27,6 +27,7 @@ from tensordict.utils import NestedKey from torch import distributions as d +from torchrl._utils import _replace_last from torchrl.objectives.common import LossModule from torchrl.objectives.utils import ( @@ -1267,3 +1268,30 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: def reset(self) -> None: self.beta = self._beta_init + + +def _make_lp_get_error(tensor_keys, log_prob, err): + result = ( + f"The sample log probability key (tensor_keys.sample_log_prob={tensor_keys.sample_log_prob}) does " + f"not appear in the log-prob tensordict with keys {list(log_prob.keys(True, True))}. " + ) + # now check if we can substitute the actions with action_log_prob and retrieve the log-probs + action_keys = tensor_keys.action + if isinstance(action_keys, list): + has_all_log_probs = True + log_prob_keys = [] + for action_key in action_keys: + log_prob_key = _replace_last(action_key, "action_log_prob") + log_prob_keys.append(log_prob_key) + if log_prob_key not in log_prob: + has_all_log_probs = False + break + if has_all_log_probs: + result += ( + f"The action keys are {action_keys} and all log_prob keys {log_prob_keys} are present in the " + f"log-prob tensordict. Calling `loss.set_keys(sample_log_prob={log_prob_keys})` should resolve " + f"this error." + ) + return KeyError(result) + result += "This is usually due to a missing call to loss.set_keys(sample_log_prob=)." + return KeyError(result)