From a47b32c073ebb74878b8e7329ef2e79997c3e783 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 20 Nov 2024 11:49:44 +0000 Subject: [PATCH] [BugFix] make buffers zero-dim in exploration modules ghstack-source-id: fd2705eb9132169da4871b27b354f7895c644061 Pull Request resolved: https://github.com/pytorch/rl/pull/2591 --- test/_utils_internal.py | 4 +- .../modules/tensordict_module/exploration.py | 38 +++++++++---------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/test/_utils_internal.py b/test/_utils_internal.py index dea0d136844..a3476b31110 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -167,8 +167,8 @@ def get_available_devices(): def get_default_devices(): num_cuda = torch.cuda.device_count() if num_cuda == 0: - if torch.mps.is_available(): - return [torch.device("mps:0")] + # if torch.mps.is_available(): + # return [torch.device("mps:0")] return [torch.device("cpu")] elif num_cuda == 1: return [torch.device("cuda:0")] diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 9acfae1aa21..a1879519271 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -112,10 +112,10 @@ def __init__( super().__init__() - self.register_buffer("eps_init", torch.as_tensor([eps_init])) - self.register_buffer("eps_end", torch.as_tensor([eps_end])) + self.register_buffer("eps_init", torch.as_tensor(eps_init)) + self.register_buffer("eps_end", torch.as_tensor(eps_end)) self.annealing_num_steps = annealing_num_steps - self.register_buffer("eps", torch.as_tensor([eps_init], dtype=torch.float32)) + self.register_buffer("eps", torch.as_tensor(eps_init, dtype=torch.float32)) if spec is not None: if not isinstance(spec, Composite) and len(self.out_keys) >= 1: @@ -275,13 +275,13 @@ def __init__( super().__init__(policy) if sigma_end > sigma_init: raise RuntimeError("sigma should decrease over time or be constant") - self.register_buffer("sigma_init", torch.tensor([sigma_init], device=device)) - self.register_buffer("sigma_end", torch.tensor([sigma_end], device=device)) + self.register_buffer("sigma_init", torch.tensor(sigma_init, device=device)) + self.register_buffer("sigma_end", torch.tensor(sigma_end, device=device)) self.annealing_num_steps = annealing_num_steps - self.register_buffer("mean", torch.tensor([mean], device=device)) - self.register_buffer("std", torch.tensor([std], device=device)) + self.register_buffer("mean", torch.tensor(mean, device=device)) + self.register_buffer("std", torch.tensor(std, device=device)) self.register_buffer( - "sigma", torch.tensor([sigma_init], dtype=torch.float32, device=device) + "sigma", torch.tensor(sigma_init, dtype=torch.float32, device=device) ) self.action_key = action_key self.out_keys = list(self.td_module.out_keys) @@ -423,13 +423,13 @@ def __init__( super().__init__() - self.register_buffer("sigma_init", torch.tensor([sigma_init], device=device)) - self.register_buffer("sigma_end", torch.tensor([sigma_end], device=device)) + self.register_buffer("sigma_init", torch.tensor(sigma_init, device=device)) + self.register_buffer("sigma_end", torch.tensor(sigma_end, device=device)) self.annealing_num_steps = annealing_num_steps - self.register_buffer("mean", torch.tensor([mean], device=device)) - self.register_buffer("std", torch.tensor([std], device=device)) + self.register_buffer("mean", torch.tensor(mean, device=device)) + self.register_buffer("std", torch.tensor(std, device=device)) self.register_buffer( - "sigma", torch.tensor([sigma_init], dtype=torch.float32, device=device) + "sigma", torch.tensor(sigma_init, dtype=torch.float32, device=device) ) if spec is not None: @@ -628,8 +628,8 @@ def __init__( key=action_key, device=device, ) - self.register_buffer("eps_init", torch.tensor([eps_init], device=device)) - self.register_buffer("eps_end", torch.tensor([eps_end], device=device)) + self.register_buffer("eps_init", torch.tensor(eps_init, device=device)) + self.register_buffer("eps_end", torch.tensor(eps_end, device=device)) if self.eps_end > self.eps_init: raise ValueError( "eps should decrease over time or be constant, " @@ -637,7 +637,7 @@ def __init__( ) self.annealing_num_steps = annealing_num_steps self.register_buffer( - "eps", torch.tensor([eps_init], dtype=torch.float32, device=device) + "eps", torch.tensor(eps_init, dtype=torch.float32, device=device) ) self.out_keys = list(self.td_module.out_keys) + self.ou.out_keys self.is_init_key = is_init_key @@ -840,8 +840,8 @@ def __init__( device=device, ) - self.register_buffer("eps_init", torch.tensor([eps_init], device=device)) - self.register_buffer("eps_end", torch.tensor([eps_end], device=device)) + self.register_buffer("eps_init", torch.tensor(eps_init, device=device)) + self.register_buffer("eps_end", torch.tensor(eps_end, device=device)) if self.eps_end > self.eps_init: raise ValueError( "eps should decrease over time or be constant, " @@ -849,7 +849,7 @@ def __init__( ) self.annealing_num_steps = annealing_num_steps self.register_buffer( - "eps", torch.tensor([eps_init], dtype=torch.float32, device=device) + "eps", torch.tensor(eps_init, dtype=torch.float32, device=device) ) self.in_keys = [self.ou.key]