Skip to content

Commit

Permalink
[Feature] Change default interaction types to DETERMINISTIC (#825)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 21, 2024
1 parent d14db1c commit ab1abac
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ class set_interaction_type(_DecoratorContextManager):
type (InteractionType): sampling type to use when the policy is being called.
"""

def __init__(self, type: InteractionType | None = InteractionType.MODE) -> None:
def __init__(
self, type: InteractionType | None = InteractionType.DETERMINISTIC
) -> None:
super().__init__()
self.type = type

Expand Down Expand Up @@ -309,7 +311,7 @@ def __init__(
out_keys: NestedKey | List[NestedKey] | None = None,
*,
default_interaction_mode: str | None = None,
default_interaction_type: InteractionType = InteractionType.MODE,
default_interaction_type: InteractionType = InteractionType.DETERMINISTIC,
distribution_class: type = Delta,
distribution_kwargs: dict | None = None,
return_log_prob: bool = False,
Expand Down
1 change: 1 addition & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class TestInteractionType:
"str_and_expected_type",
[
("mode", InteractionType.MODE),
("deterministic", InteractionType.DETERMINISTIC),
("MEDIAN", InteractionType.MEDIAN),
("Mean", InteractionType.MEAN),
("RanDom", InteractionType.RANDOM),
Expand Down

1 comment on commit ab1abac

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'GPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: ab1abac Previous: d14db1c Ratio
benchmarks/common/common_ops_test.py::test_values 544704.8480417925 iter/sec (stddev: 1.2670919363653362e-7) 1196234.6917385357 iter/sec (stddev: 1.5348675389859357e-7) 2.20
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_last 101270.99512509268 iter/sec (stddev: 6.125165425177699e-7) 317599.90866157325 iter/sec (stddev: 3.1063694304065627e-7) 3.14
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_leaf_last 101374.4118852215 iter/sec (stddev: 6.398978788556923e-7) 321125.7562701977 iter/sec (stddev: 2.8982279738182617e-7) 3.17

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.