Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 10, 2025
1 parent cf0f6cf commit 96b4fb3
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
28 changes: 28 additions & 0 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,16 @@ def __init__(
if self.return_log_prob and self.log_prob_key not in self.out_keys:
self.out_keys.append(self.log_prob_key)

@property
def dist_params_keys(self) -> List[NestedKey]:
"""Returns all the keys pointing at the distribution params."""
return list(self.in_keys)

@property
def dist_sample_keys(self) -> List[NestedKey]:
"""Returns all the keys pointing at the distribution samples."""
return [key for key in self.out_keys if key is not self.log_prob_key]

def get_dist(self, tensordict: TensorDictBase) -> D.Distribution:
"""Creates a :class:`torch.distribution.Distribution` instance with the parameters provided in the input tensordict.
Expand Down Expand Up @@ -389,6 +399,8 @@ def get_dist(self, tensordict: TensorDictBase) -> D.Distribution:
raise err
return dist

build_dist_from_params = get_dist

def log_prob(
self,
tensordict,
Expand Down Expand Up @@ -944,6 +956,22 @@ def get_dist_params(
with set_interaction_type(type):
return tds(tensordict, tensordict_out, **kwargs)

@property
def dist_params_keys(self) -> List[NestedKey]:
"""Returns all the keys pointing at the distribution params."""
result = []
for m in reversed(list(self._module_iter())):
result.extend(getattr(m, "dist_params_keys", []))
return result

@property
def dist_sample_keys(self) -> List[NestedKey]:
"""Returns all the keys pointing at the distribution samples."""
result = []
for m in reversed(list(self._module_iter())):
result.extend(getattr(m, "dist_sample_keys", []))
return result

@property
def num_samples(self):
num_samples = ()
Expand Down
18 changes: 18 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1992,6 +1992,8 @@ def test_probabilistic_n_samples(self, return_log_prob):
assert td["sample"].shape == (2, 3, 4)
if return_log_prob:
assert "sample_log_prob" in td
assert prob.dist_sample_keys == ["sample"]
assert prob.dist_params_keys == ["loc"]

@pytest.mark.parametrize("return_log_prob", [True, False])
@pytest.mark.parametrize("inplace", [True, False])
Expand Down Expand Up @@ -2042,6 +2044,8 @@ def test_probabilistic_seq_n_samples(
assert isinstance(log_prob, torch.Tensor)
else:
assert isinstance(log_prob, TensorDict)
assert seq.dist_sample_keys == ["sample"]
assert seq.dist_params_keys == ["loc"]

@pytest.mark.parametrize("return_log_prob", [True, False])
@pytest.mark.parametrize("inplace", [True, False])
Expand Down Expand Up @@ -2093,6 +2097,8 @@ def test_intermediate_probabilistic_seq_n_samples(
assert isinstance(log_prob, torch.Tensor)
else:
assert isinstance(log_prob, TensorDict)
assert seq.dist_sample_keys == ["sample"]
assert seq.dist_params_keys == ["loc"]

@pytest.mark.parametrize(
"log_prob_key",
Expand All @@ -2118,6 +2124,8 @@ def test_nested_keys_probabilistic_delta(self, log_prob_key):
return_log_prob=True,
log_prob_key=log_prob_key,
)
assert module.dist_sample_keys == [("data", "action")]
assert module.dist_params_keys == [("data", "param")]
td_out = module(policy_module(td))
assert td_out["data", "action"].shape == (3, 4, 1)
if log_prob_key:
Expand All @@ -2132,6 +2140,8 @@ def test_nested_keys_probabilistic_delta(self, log_prob_key):
return_log_prob=True,
log_prob_key=log_prob_key,
)
assert module.dist_sample_keys == [("data", "action")]
assert module.dist_params_keys == [("data", "param")]
td_out = module(policy_module(td))
assert td_out["data", "action"].shape == (3, 4, 1)
if log_prob_key:
Expand Down Expand Up @@ -2912,6 +2922,8 @@ def test_prob_module(self, interaction, return_log_prob, map_names):
default_interaction_type=interaction,
return_log_prob=return_log_prob,
)
assert module.dist_sample_keys == out_keys
assert module.dist_params_keys == in_keys
if not return_log_prob:
assert module.out_keys[-2:] == out_keys
else:
Expand Down Expand Up @@ -3010,6 +3022,8 @@ def test_prob_module_nested(self, interaction, map_names):
)
# loosely checks that the log-prob keys have been added
assert module.out_keys[-2:] != out_keys
assert module.dist_sample_keys == out_keys
assert module.dist_params_keys == in_keys

sample = module(params)
key_logprob0 = (
Expand Down Expand Up @@ -3077,6 +3091,8 @@ def test_prob_module_seq(self, interaction, return_log_prob, ordereddict):
)
]
module = ProbabilisticTensorDictSequential(*args)
assert module.dist_sample_keys == out_keys
assert module.dist_params_keys == in_keys
sample = module(params)
if return_log_prob:
assert "cont_log_prob" in sample.keys()
Expand Down Expand Up @@ -3156,6 +3172,8 @@ def test_prob_module_seq_nested(self, interaction):
log_prob_key=log_prob_key,
),
)
assert module.dist_sample_keys == out_keys
assert module.dist_params_keys == in_keys
sample = module(params)
assert "cont_log_prob" in sample.keys()
assert ("nested", "cont_log_prob") in sample.keys(True)
Expand Down

0 comments on commit 96b4fb3

Please sign in to comment.