diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 621388f2c..a1a24543d 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -354,6 +354,16 @@ def __init__( if self.return_log_prob and self.log_prob_key not in self.out_keys: self.out_keys.append(self.log_prob_key) + @property + def dist_params_keys(self) -> List[NestedKey]: + """Returns all the keys pointing at the distribution params.""" + return list(self.in_keys) + + @property + def dist_sample_keys(self) -> List[NestedKey]: + """Returns all the keys pointing at the distribution samples.""" + return [key for key in self.out_keys if key is not self.log_prob_key] + def get_dist(self, tensordict: TensorDictBase) -> D.Distribution: """Creates a :class:`torch.distribution.Distribution` instance with the parameters provided in the input tensordict. @@ -389,6 +399,8 @@ def get_dist(self, tensordict: TensorDictBase) -> D.Distribution: raise err return dist + build_dist_from_params = get_dist + def log_prob( self, tensordict, @@ -944,6 +956,22 @@ def get_dist_params( with set_interaction_type(type): return tds(tensordict, tensordict_out, **kwargs) + @property + def dist_params_keys(self) -> List[NestedKey]: + """Returns all the keys pointing at the distribution params.""" + result = [] + for m in reversed(list(self._module_iter())): + result.extend(getattr(m, "dist_params_keys", [])) + return result + + @property + def dist_sample_keys(self) -> List[NestedKey]: + """Returns all the keys pointing at the distribution samples.""" + result = [] + for m in reversed(list(self._module_iter())): + result.extend(getattr(m, "dist_sample_keys", [])) + return result + @property def num_samples(self): num_samples = () diff --git a/test/test_nn.py b/test/test_nn.py index af087463b..404271ba0 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1992,6 +1992,8 @@ def test_probabilistic_n_samples(self, return_log_prob): assert td["sample"].shape == (2, 3, 4) if return_log_prob: assert "sample_log_prob" in td + assert prob.dist_sample_keys == ["sample"] + assert prob.dist_params_keys == ["loc"] @pytest.mark.parametrize("return_log_prob", [True, False]) @pytest.mark.parametrize("inplace", [True, False]) @@ -2042,6 +2044,8 @@ def test_probabilistic_seq_n_samples( assert isinstance(log_prob, torch.Tensor) else: assert isinstance(log_prob, TensorDict) + assert seq.dist_sample_keys == ["sample"] + assert seq.dist_params_keys == ["loc"] @pytest.mark.parametrize("return_log_prob", [True, False]) @pytest.mark.parametrize("inplace", [True, False]) @@ -2093,6 +2097,8 @@ def test_intermediate_probabilistic_seq_n_samples( assert isinstance(log_prob, torch.Tensor) else: assert isinstance(log_prob, TensorDict) + assert seq.dist_sample_keys == ["sample"] + assert seq.dist_params_keys == ["loc"] @pytest.mark.parametrize( "log_prob_key", @@ -2118,6 +2124,8 @@ def test_nested_keys_probabilistic_delta(self, log_prob_key): return_log_prob=True, log_prob_key=log_prob_key, ) + assert module.dist_sample_keys == [("data", "action")] + assert module.dist_params_keys == [("data", "param")] td_out = module(policy_module(td)) assert td_out["data", "action"].shape == (3, 4, 1) if log_prob_key: @@ -2132,6 +2140,8 @@ def test_nested_keys_probabilistic_delta(self, log_prob_key): return_log_prob=True, log_prob_key=log_prob_key, ) + assert module.dist_sample_keys == [("data", "action")] + assert module.dist_params_keys == [("data", "param")] td_out = module(policy_module(td)) assert td_out["data", "action"].shape == (3, 4, 1) if log_prob_key: @@ -2912,6 +2922,8 @@ def test_prob_module(self, interaction, return_log_prob, map_names): default_interaction_type=interaction, return_log_prob=return_log_prob, ) + assert module.dist_sample_keys == out_keys + assert module.dist_params_keys == in_keys if not return_log_prob: assert module.out_keys[-2:] == out_keys else: @@ -3010,6 +3022,8 @@ def test_prob_module_nested(self, interaction, map_names): ) # loosely checks that the log-prob keys have been added assert module.out_keys[-2:] != out_keys + assert module.dist_sample_keys == out_keys + assert module.dist_params_keys == in_keys sample = module(params) key_logprob0 = ( @@ -3077,6 +3091,8 @@ def test_prob_module_seq(self, interaction, return_log_prob, ordereddict): ) ] module = ProbabilisticTensorDictSequential(*args) + assert module.dist_sample_keys == out_keys + assert module.dist_params_keys == in_keys sample = module(params) if return_log_prob: assert "cont_log_prob" in sample.keys() @@ -3156,6 +3172,8 @@ def test_prob_module_seq_nested(self, interaction): log_prob_key=log_prob_key, ), ) + assert module.dist_sample_keys == out_keys + assert module.dist_params_keys == in_keys sample = module(params) assert "cont_log_prob" in sample.keys() assert ("nested", "cont_log_prob") in sample.keys(True)