Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] TD3-bc compatibility with compile #2657

Merged
merged 14 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sota-implementations/td3_bc/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,8 @@ logger:
eval_steps: 1000
eval_envs: 1
video: False

compile:
compile: False
compile_mode:
cudagraphs: False
101 changes: 66 additions & 35 deletions sota-implementations/td3_bc/td3_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
"""
from __future__ import annotations

import time
import warnings

import hydra
import numpy as np
import torch
import tqdm
from torchrl._utils import logger as torchrl_logger
from tensordict import TensorDict
from tensordict.nn import CudaGraphModule

from torchrl._utils import compile_with_warmup, timeit

from torchrl.envs import set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
Expand Down Expand Up @@ -72,7 +75,16 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create replay buffer
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer, device=device)

compile_mode = None
if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"

# Create agent
model, _ = make_td3_agent(cfg, eval_env, device)
Expand All @@ -83,67 +95,86 @@ def main(cfg: "DictConfig"): # noqa: F821
# Create optimizer
optimizer_actor, optimizer_critic = make_optimizer(cfg.optim, loss_module)

gradient_steps = cfg.optim.gradient_steps
evaluation_interval = cfg.logger.eval_iter
eval_steps = cfg.logger.eval_steps
delayed_updates = cfg.optim.policy_update_delay
update_counter = 0
pbar = tqdm.tqdm(range(gradient_steps))
# Training loop
start_time = time.time()
for i in pbar:
pbar.update(1)
# Update actor every delayed_updates
update_counter += 1
update_actor = update_counter % delayed_updates == 0

# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()
if sampled_tensordict.device != device:
sampled_tensordict = sampled_tensordict.to(device)
else:
sampled_tensordict = sampled_tensordict.clone()

def update(sampled_tensordict, update_actor):
# Compute loss
q_loss, *_ = loss_module.qvalue_loss(sampled_tensordict)

# Update critic
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()
q_loss.item()

to_log = {"q_loss": q_loss.item()}
optimizer_critic.zero_grad(set_to_none=True)

# Update actor
if update_actor:
actor_loss, actorloss_metadata = loss_module.actor_loss(sampled_tensordict)
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()
optimizer_actor.zero_grad(set_to_none=True)

# Update target params
target_net_updater.step()
else:
actorloss_metadata = {}
actor_loss = q_loss.new_zeros(())
metadata = TensorDict(actorloss_metadata)
metadata.set("q_loss", q_loss.detach())
metadata.set("actor_loss", actor_loss.detach())
return metadata

if cfg.compile.compile:
update = compile_with_warmup(update, mode=compile_mode, warmup=1)

if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, in_keys=[], out_keys=[], warmup=5)

gradient_steps = cfg.optim.gradient_steps
evaluation_interval = cfg.logger.eval_iter
eval_steps = cfg.logger.eval_steps
delayed_updates = cfg.optim.policy_update_delay
pbar = tqdm.tqdm(range(gradient_steps))
# Training loop
for update_counter in pbar:
timeit.printevery(num_prints=1000, total_count=gradient_steps, erase=True)

to_log["actor_loss"] = actor_loss.item()
to_log.update(actorloss_metadata)
# Update actor every delayed_updates
update_actor = update_counter % delayed_updates == 0

with timeit("rb - sample"):
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample()

with timeit("update"):
torch.compiler.cudagraph_mark_step_begin()
metadata = update(sampled_tensordict, update_actor).clone()

to_log = {}
if update_actor:
to_log.update(metadata.to_dict())
else:
to_log.update(metadata.exclude("actor_loss").to_dict())

# evaluation
if i % evaluation_interval == 0:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
if update_counter % evaluation_interval == 0:
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad(), timeit("eval"):
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
eval_env.apply(dump_video)
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
to_log["evaluation_reward"] = eval_reward
if logger is not None:
log_metrics(logger, to_log, i)
to_log.update(timeit.todict(prefix="time"))
log_metrics(logger, to_log, update_counter)

if not eval_env.is_closed:
eval_env.close()
pbar.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")


if __name__ == "__main__":
Expand Down
54 changes: 21 additions & 33 deletions sota-implementations/td3_bc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import functools

import torch
from tensordict.nn import TensorDictSequential
from tensordict.nn import TensorDictModule, TensorDictSequential

from torch import nn, optim
from torchrl.data.datasets.d4rl import D4RLExperienceReplay
Expand All @@ -26,14 +26,7 @@
)
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import (
AdditiveGaussianModule,
MLP,
SafeModule,
SafeSequential,
TanhModule,
ValueOperator,
)
from torchrl.modules import AdditiveGaussianModule, MLP, TanhModule, ValueOperator

from torchrl.objectives import SoftUpdate
from torchrl.objectives.td3_bc import TD3BCLoss
Expand Down Expand Up @@ -98,17 +91,19 @@ def make_environment(cfg, logger=None):
# ---------------------------


def make_offline_replay_buffer(rb_cfg):
def make_offline_replay_buffer(rb_cfg, device):
data = D4RLExperienceReplay(
dataset_id=rb_cfg.dataset,
split_trajs=False,
batch_size=rb_cfg.batch_size,
sampler=SamplerWithoutReplacement(drop_last=False),
# drop_last for compile
sampler=SamplerWithoutReplacement(drop_last=True),
prefetch=4,
direct_download=True,
)

data.append_transform(DoubleToFloat())
data.append_transform(lambda td: td.to(device))

return data

Expand All @@ -122,26 +117,22 @@ def make_td3_agent(cfg, train_env, device):
"""Make TD3 agent."""
# Define Actor Network
in_keys = ["observation"]
action_spec = train_env.action_spec
if train_env.batch_size:
action_spec = action_spec[(0,) * len(train_env.batch_size)]
actor_net_kwargs = {
"num_cells": cfg.network.hidden_sizes,
"out_features": action_spec.shape[-1],
"activation_class": get_activation(cfg),
}
action_spec = train_env.action_spec_unbatched.to(device)

actor_net = MLP(**actor_net_kwargs)
actor_net = MLP(
num_cells=cfg.network.hidden_sizes,
out_features=action_spec.shape[-1],
activation_class=get_activation(cfg),
device=device,
)

in_keys_actor = in_keys
actor_module = SafeModule(
actor_module = TensorDictModule(
actor_net,
in_keys=in_keys_actor,
out_keys=[
"param",
],
out_keys=["param"],
)
actor = SafeSequential(
actor = TensorDictSequential(
actor_module,
TanhModule(
in_keys=["param"],
Expand All @@ -151,22 +142,19 @@ def make_td3_agent(cfg, train_env, device):
)

# Define Critic Network
qvalue_net_kwargs = {
"num_cells": cfg.network.hidden_sizes,
"out_features": 1,
"activation_class": get_activation(cfg),
}

qvalue_net = MLP(
**qvalue_net_kwargs,
num_cells=cfg.network.hidden_sizes,
out_features=1,
activation_class=get_activation(cfg),
device=device,
)

qvalue = ValueOperator(
in_keys=["action"] + in_keys,
module=qvalue_net,
)

model = nn.ModuleList([actor, qvalue]).to(device)
model = nn.ModuleList([actor, qvalue])

# init nets
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
Expand Down
Loading