Skip to content

Commit

Permalink
[Feature] CQL compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: 6bfb32c1e9647bd82cf72424602431da898fd81a
Pull Request resolved: #2553
  • Loading branch information
vmoens committed Nov 12, 2024
1 parent e0ae747 commit 108a21f
Show file tree
Hide file tree
Showing 14 changed files with 328 additions and 219 deletions.
9 changes: 5 additions & 4 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -308,11 +308,12 @@ Utils
:toctree: generated/
:template: rl_template_noinherit.rst

HardUpdate
SoftUpdate
ValueEstimators
default_value_kwargs
distance_loss
group_optimizers
hold_out_net
hold_out_params
next_state_value
SoftUpdate
HardUpdate
ValueEstimators
default_value_kwargs
137 changes: 83 additions & 54 deletions sota-implementations/cql/cql_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
import numpy as np
import torch
import tqdm
from torchrl._utils import logger as torchrl_logger
from tensordict.nn import CudaGraphModule

from torchrl._utils import logger as torchrl_logger, timeit
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import group_optimizers
from torchrl.record.loggers import generate_exp_name, get_logger

from utils import (
Expand Down Expand Up @@ -69,6 +72,9 @@ def main(cfg: "DictConfig"): # noqa: F821
# Create agent
model = make_cql_model(cfg, train_env, eval_env, device)
del train_env
if hasattr(eval_env, "start"):
# To set the number of threads to the definitive value
eval_env.start()

# Create loss
loss_module, target_net_updater = make_continuous_loss(cfg.loss, model)
Expand All @@ -81,81 +87,104 @@ def main(cfg: "DictConfig"): # noqa: F821
alpha_prime_optim,
) = make_continuous_cql_optimizer(cfg, loss_module)

pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)
# Group optimizers
optimizer = group_optimizers(
policy_optim, critic_optim, alpha_optim, alpha_prime_optim
)

gradient_steps = cfg.optim.gradient_steps
policy_eval_start = cfg.optim.policy_eval_start
evaluation_interval = cfg.logger.eval_iter
eval_steps = cfg.logger.eval_steps

# Training loop
start_time = time.time()
for i in range(gradient_steps):
pbar.update(1)
# sample data
data = replay_buffer.sample()
# compute loss
loss_vals = loss_module(data.clone().to(device))
def update(data, policy_eval_start, iteration):
loss_vals = loss_module(data.to(device))

# official cql implementation uses behavior cloning loss for first few updating steps as it helps for some tasks
if i >= policy_eval_start:
actor_loss = loss_vals["loss_actor"]
else:
actor_loss = loss_vals["loss_actor_bc"]
actor_loss = torch.where(
iteration >= policy_eval_start,
loss_vals["loss_actor"],
loss_vals["loss_actor_bc"],
)
q_loss = loss_vals["loss_qvalue"]
cql_loss = loss_vals["loss_cql"]

q_loss = q_loss + cql_loss
loss_vals["q_loss"] = q_loss

# update model
alpha_loss = loss_vals["loss_alpha"]
alpha_prime_loss = loss_vals["loss_alpha_prime"]
if alpha_prime_loss is None:
alpha_prime_loss = 0

alpha_optim.zero_grad()
alpha_loss.backward()
alpha_optim.step()
loss = actor_loss + q_loss + alpha_loss + alpha_prime_loss

policy_optim.zero_grad()
actor_loss.backward()
policy_optim.step()
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)

if alpha_prime_optim is not None:
alpha_prime_optim.zero_grad()
alpha_prime_loss.backward(retain_graph=True)
alpha_prime_optim.step()
# update qnet_target params
target_net_updater.step()

critic_optim.zero_grad()
# TODO: we have the option to compute losses independently retain is not needed?
q_loss.backward(retain_graph=False)
critic_optim.step()
return loss.detach(), loss_vals.detach()

loss = actor_loss + q_loss + alpha_loss + alpha_prime_loss
compile_mode = None
if cfg.loss.compile:
if cfg.loss.compile_mode not in (None, ""):
compile_mode = cfg.loss.compile_mode
elif cfg.loss.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"
update = torch.compile(update, mode=compile_mode)
if cfg.loss.cudagraphs:
update = CudaGraphModule(update, warmup=50)

pbar = tqdm.tqdm(total=cfg.optim.gradient_steps)

gradient_steps = cfg.optim.gradient_steps
policy_eval_start = cfg.optim.policy_eval_start
evaluation_interval = cfg.logger.eval_iter
eval_steps = cfg.logger.eval_steps

# Training loop
start_time = time.time()
policy_eval_start = torch.tensor(policy_eval_start, device=device)
for i in range(gradient_steps):
pbar.update(1)
# sample data
with timeit("sample"):
data = replay_buffer.sample()

with timeit("update"):
# compute loss
i_device = torch.tensor(i, device=device)
loss, loss_vals = update(
data.to(device), policy_eval_start=policy_eval_start, iteration=i_device
)

# log metrics
to_log = {
"loss": loss.item(),
"loss_actor_bc": loss_vals["loss_actor_bc"].item(),
"loss_actor": loss_vals["loss_actor"].item(),
"loss_qvalue": q_loss.item(),
"loss_cql": cql_loss.item(),
"loss_alpha": alpha_loss.item(),
"loss_alpha_prime": alpha_prime_loss.item(),
"loss": loss.cpu(),
**loss_vals.cpu(),
}

# update qnet_target params
target_net_updater.step()

# evaluation
if i % evaluation_interval == 0:
with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
eval_env.apply(dump_video)
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
to_log["evaluation_reward"] = eval_reward

log_metrics(logger, to_log, i)
with timeit("log/eval"):
if i % evaluation_interval == 0:
with set_exploration_type(
ExplorationType.DETERMINISTIC
), torch.no_grad():
eval_td = eval_env.rollout(
max_steps=eval_steps, policy=model[0], auto_cast_to_device=True
)
eval_env.apply(dump_video)
eval_reward = eval_td["next", "reward"].sum(1).mean().item()
to_log["evaluation_reward"] = eval_reward

with timeit("log"):
if i % 200 == 0:
to_log.update(timeit.todict(prefix="time"))
log_metrics(logger, to_log, i)
if i % 200 == 0:
timeit.print()
timeit.erase()

pbar.close()
torchrl_logger.info(f"Training time: {time.time() - start_time}")
Expand Down
Loading

0 comments on commit 108a21f

Please sign in to comment.