Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 13, 2025
2 parents 399b618 + 1ee71e3 commit 6dc021b
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 20 deletions.
53 changes: 35 additions & 18 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -7918,14 +7918,13 @@ def _create_mock_actor(
action_spec = Bounded(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
if composite_action_dist:
action_spec = Composite({action_key: {"action1": action_spec}})
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
if composite_action_dist:
if action_key is None:
action_key = ("action", "action1")
else:
action_key = (action_key, "action1")
action_spec = Composite({action_key: {"action1": action_spec}})
distribution_class = functools.partial(
CompositeDistribution,
distribution_map={
Expand Down Expand Up @@ -8380,7 +8379,10 @@ def test_ppo_composite_no_aggregate(
loss_critic_type="l2",
functional=functional,
)
loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
loss_fn.set_keys(
action=("action", "action1"),
sample_log_prob=[("action", "action1_log_prob")],
)
if advantage is not None:
advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
advantage(td)
Expand Down Expand Up @@ -8495,7 +8497,10 @@ def test_ppo_shared(self, loss_class, device, advantage, composite_action_dist):
advantage(td)

if composite_action_dist:
loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
loss_fn.set_keys(
action=("action", "action1"),
sample_log_prob=[("action", "action1_log_prob")],
)
loss = loss_fn(td)

loss_critic = loss["loss_critic"]
Expand Down Expand Up @@ -8607,8 +8612,14 @@ def test_ppo_shared_seq(
advantage(td)

if composite_action_dist:
loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
loss_fn2.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
loss_fn.set_keys(
action=("action", "action1"),
sample_log_prob=[("action", "action1_log_prob")],
)
loss_fn2.set_keys(
action=("action", "action1"),
sample_log_prob=[("action", "action1_log_prob")],
)

loss = loss_fn(td).exclude("entropy")

Expand Down Expand Up @@ -8701,7 +8712,10 @@ def zero_param(p):
advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
advantage(td)
if composite_action_dist:
loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
loss_fn.set_keys(
action=("action", "action1"),
sample_log_prob=[("action", "action1_log_prob")],
)
loss = loss_fn(td)

loss_critic = loss["loss_critic"]
Expand Down Expand Up @@ -8791,29 +8805,24 @@ def test_ppo_tensordict_keys(self, loss_class, td_est, composite_action_dist):
@pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss))
@pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None))
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
@pytest.mark.parametrize("composite_action_dist", [True, False])
def test_ppo_tensordict_keys_run(
self, loss_class, advantage, td_est, composite_action_dist
):
def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est):
"""Test PPO loss module with non-default tensordict keys."""
torch.manual_seed(self.seed)
gradient_mode = True
tensor_keys = {
"advantage": "advantage_test",
"value_target": "value_target_test",
"value": "state_value_test",
"sample_log_prob": ('action_test', 'action1_log_prob') if composite_action_dist else "sample_log_prob_test",
"action": ("action_test", "action") if composite_action_dist else "action_test",
"sample_log_prob": "sample_log_prob_test",
"action": "action_test",
}

td = self._create_seq_mock_data_ppo(
sample_log_prob_key=tensor_keys["sample_log_prob"],
action_key=tensor_keys["action"],
composite_action_dist=composite_action_dist,
)
actor = self._create_mock_actor(
sample_log_prob_key=tensor_keys["sample_log_prob"],
composite_action_dist=composite_action_dist,
action_key=tensor_keys["action"],
)
value = self._create_mock_value(out_keys=[tensor_keys["value"]])
Expand Down Expand Up @@ -8851,8 +8860,6 @@ def test_ppo_tensordict_keys_run(
raise NotImplementedError

loss_fn = loss_class(actor, value, loss_critic_type="l2")
if composite_action_dist:
tensor_keys["sample_log_prob"] = [tensor_keys["sample_log_prob"]]
loss_fn.set_keys(**tensor_keys)
if advantage is not None:
# collect tensordict key names for the advantage module
Expand Down Expand Up @@ -9030,11 +9037,16 @@ def test_ppo_reduction(self, reduction, loss_class, composite_action_dist):
reduction=reduction,
)
advantage(td)
if composite_action_dist:
loss_fn.set_keys(
action=("action", "action1"),
sample_log_prob=[("action", "action1_log_prob")],
)
loss = loss_fn(td)
if reduction == "none":
for key in loss.keys():
if key.startswith("loss_"):
assert loss[key].shape == td.shape
assert loss[key].shape == td.shape, key
else:
for key in loss.keys():
if not key.startswith("loss_"):
Expand Down Expand Up @@ -9082,6 +9094,11 @@ def test_ppo_value_clipping(
clip_value=clip_value,
)
advantage(td)
if composite_action_dist:
loss_fn.set_keys(
action=("action", "action1"),
sample_log_prob=[("action", "action1_log_prob")],
)

value = td.pop(loss_fn.tensor_keys.value)

Expand Down
3 changes: 1 addition & 2 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,6 @@ def _log_weight(
log_prob = _sum_td_features(log_prob)
log_prob.view_as(prev_log_prob)

print(log_prob , prev_log_prob)
log_weight = (log_prob - prev_log_prob).unsqueeze(-1)
kl_approx = (prev_log_prob - log_prob).unsqueeze(-1)
if is_tensor_collection(kl_approx):
Expand Down Expand Up @@ -946,7 +945,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
ratio = log_weight_clip.exp()
gain2 = ratio * advantage

gain = torch.stack([gain1, gain2], -1).min(dim=-1)[0]
gain = torch.stack([gain1, gain2], -1).min(dim=-1).values
if is_tensor_collection(gain):
gain = _sum_td_features(gain)
td_out = TensorDict({"loss_objective": -gain}, batch_size=[])
Expand Down

0 comments on commit 6dc021b

Please sign in to comment.