From 098677a2bc5b55d3ad653c3fed6c343ac2576536 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 25 Oct 2024 18:51:22 -0700 Subject: [PATCH] [Feature] MCTS Scoring functions ghstack-source-id: 5fdfbeab44f579aa01e333f6900cf0c9297d58aa Pull Request resolved: https://github.com/pytorch/rl/pull/2358 --- torchrl/modules/mcts/__init__.py | 5 ++ torchrl/modules/mcts/scores.py | 100 +++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+) create mode 100644 torchrl/modules/mcts/__init__.py create mode 100644 torchrl/modules/mcts/scores.py diff --git a/torchrl/modules/mcts/__init__.py b/torchrl/modules/mcts/__init__.py new file mode 100644 index 00000000000..b983d492454 --- /dev/null +++ b/torchrl/modules/mcts/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from .scores import PUCTScore, UCBScore diff --git a/torchrl/modules/mcts/scores.py b/torchrl/modules/mcts/scores.py new file mode 100644 index 00000000000..99b8772fc14 --- /dev/null +++ b/torchrl/modules/mcts/scores.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import functools +import math +from abc import abstractmethod +from enum import Enum + +from tensordict import NestedKey, TensorDictBase +from tensordict.nn import TensorDictModuleBase +from torch import nn + + +class MCTSScore(TensorDictModuleBase): + @abstractmethod + def forward(self, node): + pass + + +class PUCTScore(MCTSScore): + c: float + + def __init__( + self, + *, + c: float, + win_count_key: NestedKey = "win_count", + visits_key: NestedKey = "visits", + total_visits_key: NestedKey = "total_visits", + prior_prob_key: NestedKey = "prior_prob", + score_key: NestedKey = "score", + ): + super().__init__() + self.c = c + self.win_count_key = win_count_key + self.visits_key = visits_key + self.total_visits_key = total_visits_key + self.prior_prob_key = prior_prob_key + self.score_key = score_key + self.in_keys = [ + self.win_count_key, + self.prior_prob_key, + self.total_visits_key, + self.visits_key, + ] + self.out_keys = [self.score_key] + + def forward(self, node: TensorDictBase) -> TensorDictBase: + win_count = node.get(self.win_count_key) + visits = node.get(self.visits_key) + n_total = node.get(self.total_visits_key) + prior_prob = node.get(self.prior_prob_key) + node.set( + self.score_key, + (win_count / visits) + self.c * prior_prob * n_total.sqrt() / (1 + visits), + ) + return node + + +class UCBScore(MCTSScore): + c: float + + def __init__( + self, + *, + c: float, + win_count_key: NestedKey = "win_count", + visits_key: NestedKey = "visits", + total_visits_key: NestedKey = "total_visits", + score_key: NestedKey = "score", + ): + super().__init__() + self.c = c + self.win_count_key = win_count_key + self.visits_key = visits_key + self.total_visits_key = total_visits_key + self.score_key = score_key + self.in_keys = [self.win_count_key, self.total_visits_key, self.visits_key] + self.out_keys = [self.score_key] + + def forward(self, node: TensorDictBase) -> TensorDictBase: + win_count = node.get(self.win_count_key) + visits = node.get(self.visits_key) + n_total = node.get(self.total_visits_key) + node.set( + self.score_key, + (win_count / visits) + self.c * n_total.sqrt() / (1 + visits), + ) + return node + + +class MCTSScores(Enum): + PUCT = functools.partial(PUCTScore, c=5) # AlphaGo default value + UCB = functools.partial(UCBScore, c=math.sqrt(2)) # default from Auer et al. 2002 + UCB1_TUNED = "UCB1-Tuned" + EXP3 = "EXP3" + PUCT_VARIANT = "PUCT-Variant"