From b0b0ca9ed7664eef6b2376c49d78740e187538cf Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 14:01:49 -0800 Subject: [PATCH] [Quality] IMPALA auto-device ghstack-source-id: a439301cf4cac58474cbb910de71a1f190c316ed Pull Request resolved: https://github.com/pytorch/rl/pull/2654 --- sota-implementations/impala/config_multi_node_ray.yaml | 2 +- .../impala/config_multi_node_submitit.yaml | 2 +- sota-implementations/impala/config_single_node.yaml | 2 +- sota-implementations/impala/impala_multi_node_ray.py | 6 +++++- sota-implementations/impala/impala_multi_node_submitit.py | 6 +++++- sota-implementations/impala/impala_single_node.py | 7 +++++-- 6 files changed, 18 insertions(+), 7 deletions(-) diff --git a/sota-implementations/impala/config_multi_node_ray.yaml b/sota-implementations/impala/config_multi_node_ray.yaml index c67b5ed52da..549428a4725 100644 --- a/sota-implementations/impala/config_multi_node_ray.yaml +++ b/sota-implementations/impala/config_multi_node_ray.yaml @@ -24,7 +24,7 @@ ray_init_config: storage: null # Device for the forward and backward passes -local_device: "cuda:0" +local_device: # Resources assigned to each IMPALA rollout collection worker remote_worker_resources: diff --git a/sota-implementations/impala/config_multi_node_submitit.yaml b/sota-implementations/impala/config_multi_node_submitit.yaml index 59973e46b40..4d4332722aa 100644 --- a/sota-implementations/impala/config_multi_node_submitit.yaml +++ b/sota-implementations/impala/config_multi_node_submitit.yaml @@ -3,7 +3,7 @@ env: env_name: PongNoFrameskip-v4 # Device for the forward and backward passes -local_device: "cuda:0" +local_device: # SLURM config slurm_config: diff --git a/sota-implementations/impala/config_single_node.yaml b/sota-implementations/impala/config_single_node.yaml index b93c3802a33..655edaddc4e 100644 --- a/sota-implementations/impala/config_single_node.yaml +++ b/sota-implementations/impala/config_single_node.yaml @@ -3,7 +3,7 @@ env: env_name: PongNoFrameskip-v4 # Device for the forward and backward passes -device: "cuda:0" +device: # collector collector: diff --git a/sota-implementations/impala/impala_multi_node_ray.py b/sota-implementations/impala/impala_multi_node_ray.py index ba40de1acde..b2b724f6a6d 100644 --- a/sota-implementations/impala/impala_multi_node_ray.py +++ b/sota-implementations/impala/impala_multi_node_ray.py @@ -32,7 +32,11 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.record.loggers import generate_exp_name, get_logger from utils import eval_model, make_env, make_ppo_models - device = torch.device(cfg.local_device) + device = cfg.local_device + if not device: + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0") + else: + device = torch.device(device) # Correct for frame_skip frame_skip = 4 diff --git a/sota-implementations/impala/impala_multi_node_submitit.py b/sota-implementations/impala/impala_multi_node_submitit.py index 5f77008a12b..07d38604391 100644 --- a/sota-implementations/impala/impala_multi_node_submitit.py +++ b/sota-implementations/impala/impala_multi_node_submitit.py @@ -34,7 +34,11 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.record.loggers import generate_exp_name, get_logger from utils import eval_model, make_env, make_ppo_models - device = torch.device(cfg.local_device) + device = cfg.local_device + if not device: + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0") + else: + device = torch.device(device) # Correct for frame_skip frame_skip = 4 diff --git a/sota-implementations/impala/impala_single_node.py b/sota-implementations/impala/impala_single_node.py index 130d0d30dd7..cd11ae467c3 100644 --- a/sota-implementations/impala/impala_single_node.py +++ b/sota-implementations/impala/impala_single_node.py @@ -31,7 +31,11 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.record.loggers import generate_exp_name, get_logger from utils import eval_model, make_env, make_ppo_models - device = torch.device(cfg.device) + device = cfg.device + if not device: + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0") + else: + device = torch.device(device) # Correct for frame_skip frame_skip = 4 @@ -55,7 +59,6 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create models (check utils.py) actor, critic = make_ppo_models(cfg.env.env_name) - actor, critic = actor.to(device), critic.to(device) # Create collector collector = MultiaSyncDataCollector(