From 1ee71e3458c0f3a4f14501855c1188c7c4e21181 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 13 Jan 2025 14:42:49 +0000 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- test/test_cost.py | 53 ++++++++++++++++++++++++++------------- torchrl/objectives/ppo.py | 3 +-- 2 files changed, 36 insertions(+), 20 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 7ee72543ecf..a538a8d3418 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -7918,14 +7918,13 @@ def _create_mock_actor( action_spec = Bounded( -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) ) - if composite_action_dist: - action_spec = Composite({action_key: {"action1": action_spec}}) net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor()) if composite_action_dist: if action_key is None: action_key = ("action", "action1") else: action_key = (action_key, "action1") + action_spec = Composite({action_key: {"action1": action_spec}}) distribution_class = functools.partial( CompositeDistribution, distribution_map={ @@ -8380,7 +8379,10 @@ def test_ppo_composite_no_aggregate( loss_critic_type="l2", functional=functional, ) - loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")]) + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) if advantage is not None: advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")]) advantage(td) @@ -8495,7 +8497,10 @@ def test_ppo_shared(self, loss_class, device, advantage, composite_action_dist): advantage(td) if composite_action_dist: - loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")]) + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) loss = loss_fn(td) loss_critic = loss["loss_critic"] @@ -8607,8 +8612,14 @@ def test_ppo_shared_seq( advantage(td) if composite_action_dist: - loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")]) - loss_fn2.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")]) + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) + loss_fn2.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) loss = loss_fn(td).exclude("entropy") @@ -8701,7 +8712,10 @@ def zero_param(p): advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")]) advantage(td) if composite_action_dist: - loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")]) + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) loss = loss_fn(td) loss_critic = loss["loss_critic"] @@ -8791,10 +8805,7 @@ def test_ppo_tensordict_keys(self, loss_class, td_est, composite_action_dist): @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) - @pytest.mark.parametrize("composite_action_dist", [True, False]) - def test_ppo_tensordict_keys_run( - self, loss_class, advantage, td_est, composite_action_dist - ): + def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): """Test PPO loss module with non-default tensordict keys.""" torch.manual_seed(self.seed) gradient_mode = True @@ -8802,18 +8813,16 @@ def test_ppo_tensordict_keys_run( "advantage": "advantage_test", "value_target": "value_target_test", "value": "state_value_test", - "sample_log_prob": ('action_test', 'action1_log_prob') if composite_action_dist else "sample_log_prob_test", - "action": ("action_test", "action") if composite_action_dist else "action_test", + "sample_log_prob": "sample_log_prob_test", + "action": "action_test", } td = self._create_seq_mock_data_ppo( sample_log_prob_key=tensor_keys["sample_log_prob"], action_key=tensor_keys["action"], - composite_action_dist=composite_action_dist, ) actor = self._create_mock_actor( sample_log_prob_key=tensor_keys["sample_log_prob"], - composite_action_dist=composite_action_dist, action_key=tensor_keys["action"], ) value = self._create_mock_value(out_keys=[tensor_keys["value"]]) @@ -8851,8 +8860,6 @@ def test_ppo_tensordict_keys_run( raise NotImplementedError loss_fn = loss_class(actor, value, loss_critic_type="l2") - if composite_action_dist: - tensor_keys["sample_log_prob"] = [tensor_keys["sample_log_prob"]] loss_fn.set_keys(**tensor_keys) if advantage is not None: # collect tensordict key names for the advantage module @@ -9030,11 +9037,16 @@ def test_ppo_reduction(self, reduction, loss_class, composite_action_dist): reduction=reduction, ) advantage(td) + if composite_action_dist: + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) loss = loss_fn(td) if reduction == "none": for key in loss.keys(): if key.startswith("loss_"): - assert loss[key].shape == td.shape + assert loss[key].shape == td.shape, key else: for key in loss.keys(): if not key.startswith("loss_"): @@ -9082,6 +9094,11 @@ def test_ppo_value_clipping( clip_value=clip_value, ) advantage(td) + if composite_action_dist: + loss_fn.set_keys( + action=("action", "action1"), + sample_log_prob=[("action", "action1_log_prob")], + ) value = td.pop(loss_fn.tensor_keys.value) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 18d95bcdd7a..2412ea62180 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -569,7 +569,6 @@ def _log_weight( log_prob = _sum_td_features(log_prob) log_prob.view_as(prev_log_prob) - print(log_prob , prev_log_prob) log_weight = (log_prob - prev_log_prob).unsqueeze(-1) kl_approx = (prev_log_prob - log_prob).unsqueeze(-1) if is_tensor_collection(kl_approx): @@ -946,7 +945,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ratio = log_weight_clip.exp() gain2 = ratio * advantage - gain = torch.stack([gain1, gain2], -1).min(dim=-1)[0] + gain = torch.stack([gain1, gain2], -1).min(dim=-1).values if is_tensor_collection(gain): gain = _sum_td_features(gain) td_out = TensorDict({"loss_objective": -gain}, batch_size=[])