From 9e1a0884a5ae7f36dee659caca1e72e576f7a895 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 13:50:04 -0800 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- sota-implementations/ppo/ppo_atari.py | 2 ++ sota-implementations/ppo/ppo_mujoco.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index 2b0b7ec5e98..426d8012437 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -237,6 +237,8 @@ def update(batch, num_network_updates): with torch.no_grad(), timeit("adv"): torch.compiler.cudagraph_mark_step_begin() data = adv_module(data) + if compile_mode: + data = data.clone() with timeit("rb - extend"): # Update the data buffer data_reshape = data.reshape(-1) diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index 1f284fc7634..334a486e7e2 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -226,6 +226,9 @@ def update(batch, num_network_updates): with torch.no_grad(), timeit("adv"): torch.compiler.cudagraph_mark_step_begin() data = adv_module(data) + if compile_mode: + data = data.clone() + with timeit("rb - extend"): # Update the data buffer data_reshape = data.reshape(-1)