From 2b39661349ea37e3c42e6dbbc8eefab4eedb3829 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 18:08:29 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- sota-implementations/td3/td3.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py index a5b11d55e74..277cbd8e514 100644 --- a/sota-implementations/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -184,15 +184,18 @@ def update(sampled_tensordict, update_actor): q_losses, ) = ([], []) for _ in range(num_updates): - # Update actor every delayed_updates update_counter += 1 update_actor = update_counter % delayed_updates == 0 - # Sample from replay buffer - sampled_tensordict = replay_buffer.sample() + with timeit("rb - sample"): + # Sample from replay buffer + sampled_tensordict = replay_buffer.sample() + + with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() + q_loss, actor_loss = update(sampled_tensordict, update_actor) - q_loss, actor_loss = update(sampled_tensordict, update_actor) q_losses.append(q_loss) if update_actor: actor_losses.append(actor_loss) @@ -218,13 +221,15 @@ def update(sampled_tensordict, update_actor): ) if collected_frames >= init_random_frames: - metrics_to_log["train/q_loss"] = np.mean(q_losses) + metrics_to_log["train/q_loss"] = q_losses.mean() if update_actor: - metrics_to_log["train/a_loss"] = np.mean(actor_losses) + metrics_to_log["train/a_loss"] = actor_losses.mean() # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(), timeit("eval"): eval_rollout = eval_env.rollout( eval_rollout_steps, exploration_policy,