Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 16, 2024
1 parent 03549ba commit 2b39661
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions sota-implementations/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,18 @@ def update(sampled_tensordict, update_actor):
q_losses,
) = ([], [])
for _ in range(num_updates):

# Update actor every delayed_updates
update_counter += 1
update_actor = update_counter % delayed_updates == 0

# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
with timeit("rb - sample"):
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()

with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
q_loss, actor_loss = update(sampled_tensordict, update_actor)

q_loss, actor_loss = update(sampled_tensordict, update_actor)
q_losses.append(q_loss)
if update_actor:
actor_losses.append(actor_loss)
Expand All @@ -218,13 +221,15 @@ def update(sampled_tensordict, update_actor):
)

if collected_frames >= init_random_frames:
metrics_to_log["train/q_loss"] = np.mean(q_losses)
metrics_to_log["train/q_loss"] = q_losses.mean()
if update_actor:
metrics_to_log["train/a_loss"] = np.mean(actor_losses)
metrics_to_log["train/a_loss"] = actor_losses.mean()

# Evaluation
if abs(collected_frames % eval_iter) < frames_per_batch:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad(), timeit("eval"):
eval_rollout = eval_env.rollout(
eval_rollout_steps,
exploration_policy,
Expand Down

0 comments on commit 2b39661

Please sign in to comment.