From 40728c69f2506b67541e2eb2a2f983b8d06895c9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 15:18:38 -0800 Subject: [PATCH] [Feature] IQL compatibility with compile ghstack-source-id: c319d4ae38bc10f4821ad91a31b53e311a2df5d1 Pull Request resolved: https://github.com/pytorch/rl/pull/2649 --- sota-implementations/iql/discrete_iql.py | 169 +++++++++++++---------- sota-implementations/iql/iql_offline.py | 84 ++++++----- sota-implementations/iql/iql_online.py | 133 ++++++++---------- sota-implementations/iql/utils.py | 4 +- 4 files changed, 205 insertions(+), 185 deletions(-) diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py index 79cf2114d40..203b916c4a2 100644 --- a/sota-implementations/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -13,16 +13,20 @@ """ from __future__ import annotations -import time +import warnings import hydra import numpy as np import torch import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict import TensorDict +from tensordict.nn import CudaGraphModule + +from torchrl._utils import timeit from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( @@ -87,8 +91,19 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create model model = make_discrete_iql_model(cfg, train_env, eval_env, device) + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + # Create collector - collector = make_collector(cfg, train_env, actor_model_explore=model[0]) + collector = make_collector( + cfg, train_env, actor_model_explore=model[0], compile_mode=compile_mode + ) # Create loss loss_module, target_net_updater = make_discrete_loss(cfg.loss, model) @@ -97,6 +112,34 @@ def main(cfg: "DictConfig"): # noqa: F821 optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer( cfg.optim, loss_module ) + optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_value) + del optimizer_actor, optimizer_critic, optimizer_value + + def update(sampled_tensordict): + optimizer.zero_grad(set_to_none=True) + # compute losses + actor_loss, _ = loss_module.actor_loss(sampled_tensordict) + value_loss, _ = loss_module.value_loss(sampled_tensordict) + q_loss, metadata = loss_module.qvalue_loss(sampled_tensordict) + (actor_loss + value_loss + q_loss).backward() + optimizer.step() + + # update qnet_target params + target_net_updater.step() + return TensorDict( + metadata.update( + {"actor_loss": actor_loss, "value_loss": value_loss, "q_loss": q_loss} + ) + ).detach() + + if cfg.compile.compile: + update = torch.compile(update, mode=compile_mode) + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) # Main loop collected_frames = 0 @@ -112,84 +155,52 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_iter = cfg.logger.eval_iter frames_per_batch = cfg.collector.frames_per_batch eval_rollout_steps = cfg.collector.max_frames_per_traj - sampling_start = start_time = time.time() - for tensordict in collector: - sampling_time = time.time() - sampling_start - pbar.update(tensordict.numel()) + + collector_iter = iter(collector) + for _ in range(len(collector)): + with timeit("collection"): + tensordict = next(collector_iter) + current_frames = tensordict.numel() + pbar.update(current_frames) + # update weights of the inference policy collector.update_policy_weights_() - tensordict = tensordict.reshape(-1) - current_frames = tensordict.numel() - # add to replay buffer - replay_buffer.extend(tensordict.cpu()) + with timeit("buffer - extend"): + tensordict = tensordict.reshape(-1) + + # add to replay buffer + replay_buffer.extend(tensordict) collected_frames += current_frames # optimization steps - training_start = time.time() - if collected_frames >= init_random_frames: - for _ in range(num_updates): - # sample from replay buffer - sampled_tensordict = replay_buffer.sample().clone() - if sampled_tensordict.device != device: - sampled_tensordict = sampled_tensordict.to( - device, non_blocking=True - ) - else: - sampled_tensordict = sampled_tensordict - # compute losses - actor_loss, _ = loss_module.actor_loss(sampled_tensordict) - optimizer_actor.zero_grad() - actor_loss.backward() - optimizer_actor.step() - - value_loss, _ = loss_module.value_loss(sampled_tensordict) - optimizer_value.zero_grad() - value_loss.backward() - optimizer_value.step() - - q_loss, metadata = loss_module.qvalue_loss(sampled_tensordict) - optimizer_critic.zero_grad() - q_loss.backward() - optimizer_critic.step() - - # update qnet_target params - target_net_updater.step() - - # update priority - if prb: - sampled_tensordict.set( - loss_module.tensor_keys.priority, - metadata.pop("td_error").detach().max(0).values, - ) - replay_buffer.update_priority(sampled_tensordict) - - training_time = time.time() - training_start + with timeit("training"): + if collected_frames >= init_random_frames: + for _ in range(num_updates): + # sample from replay buffer + with timeit("buffer - sample"): + sampled_tensordict = replay_buffer.sample().to(device) + + with timeit("training - update"): + metadata = update(sampled_tensordict) + # update priority + if prb: + sampled_tensordict.set( + loss_module.tensor_keys.priority, + metadata.pop("td_error").detach().max(0).values, + ) + replay_buffer.update_priority(sampled_tensordict) + episode_rewards = tensordict["next", "episode_reward"][ tensordict["next", "done"] ] - # Logging metrics_to_log = {} - if len(episode_rewards) > 0: - episode_length = tensordict["next", "step_count"][ - tensordict["next", "done"] - ] - metrics_to_log["train/reward"] = episode_rewards.mean().item() - metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( - episode_length - ) - if collected_frames >= init_random_frames: - metrics_to_log["train/q_loss"] = q_loss.detach() - metrics_to_log["train/actor_loss"] = actor_loss.detach() - metrics_to_log["train/value_loss"] = value_loss.detach() - metrics_to_log["train/sampling_time"] = sampling_time - metrics_to_log["train/training_time"] = training_time - # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): - eval_start = time.time() + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(), timeit("eval"): eval_rollout = eval_env.rollout( eval_rollout_steps, model[0], @@ -197,18 +208,28 @@ def main(cfg: "DictConfig"): # noqa: F821 break_when_any_done=True, ) eval_env.apply(dump_video) - eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward - metrics_to_log["eval/time"] = eval_time + + # Logging + if len(episode_rewards) > 0: + episode_length = tensordict["next", "step_count"][ + tensordict["next", "done"] + ] + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + episode_length + ) + if collected_frames >= init_random_frames: + metrics_to_log["train/q_loss"] = metadata["q_loss"] + metrics_to_log["train/actor_loss"] = metadata["actor_loss"] + metrics_to_log["train/value_loss"] = metadata["value_loss"] + metrics_to_log.update(timeit.todict(prefix="time")) if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) - sampling_start = time.time() + timeit.erase() collector.shutdown() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/sota-implementations/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py index 09cf9954b86..1853bc79a5f 100644 --- a/sota-implementations/iql/iql_offline.py +++ b/sota-implementations/iql/iql_offline.py @@ -11,16 +11,19 @@ """ from __future__ import annotations -import time +import warnings import hydra import numpy as np import torch import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule + +from torchrl._utils import timeit from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( @@ -85,54 +88,62 @@ def main(cfg: "DictConfig"): # noqa: F821 optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer( cfg.optim, loss_module ) + optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_value) - pbar = tqdm.tqdm(total=cfg.optim.gradient_steps) - - gradient_steps = cfg.optim.gradient_steps - evaluation_interval = cfg.logger.eval_iter - eval_steps = cfg.logger.eval_steps - - # Training loop - start_time = time.time() - for i in range(gradient_steps): - pbar.update(1) - # sample data - data = replay_buffer.sample() - - if data.device != device: - data = data.to(device, non_blocking=True) - + def update(data): + optimizer.zero_grad(set_to_none=True) # compute losses loss_info = loss_module(data) actor_loss = loss_info["loss_actor"] value_loss = loss_info["loss_value"] q_loss = loss_info["loss_qvalue"] - optimizer_actor.zero_grad() - actor_loss.backward() - optimizer_actor.step() - - optimizer_value.zero_grad() - value_loss.backward() - optimizer_value.step() - - optimizer_critic.zero_grad() - q_loss.backward() - optimizer_critic.step() + (actor_loss + value_loss + q_loss).backward() + optimizer.step() # update qnet_target params target_net_updater.step() + return loss_info.detach() + + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + + if cfg.compile.compile: + update = torch.compile(update, mode=compile_mode) + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) + + pbar = tqdm.tqdm(range(cfg.optim.gradient_steps)) + + evaluation_interval = cfg.logger.eval_iter + eval_steps = cfg.logger.eval_steps + + # Training loop + for i in pbar: + # sample data + with timeit("sample"): + data = replay_buffer.sample() + data = data.to(device) - # log metrics - to_log = { - "loss_actor": actor_loss.item(), - "loss_qvalue": q_loss.item(), - "loss_value": value_loss.item(), - } + with timeit("update"): + loss_info = update(data) # evaluation + to_log = loss_info.to_dict() if i % evaluation_interval == 0: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(), timeit("eval"): eval_td = eval_env.rollout( max_steps=eval_steps, policy=model[0], auto_cast_to_device=True ) @@ -147,7 +158,6 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_env.close() if not train_env.is_closed: train_env.close() - torchrl_logger.info(f"Training time: {time.time() - start_time}") if __name__ == "__main__": diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index 8497d24f106..de32678e6cd 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -13,16 +13,15 @@ """ from __future__ import annotations -import time - import hydra import numpy as np import torch import tqdm -from torchrl._utils import logger as torchrl_logger +from torchrl._utils import timeit from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( @@ -97,10 +96,11 @@ def main(cfg: "DictConfig"): # noqa: F821 optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer( cfg.optim, loss_module ) + optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_value) + del optimizer_actor, optimizer_critic, optimizer_value # Main loop collected_frames = 0 - pbar = tqdm.tqdm(total=cfg.collector.total_frames) init_random_frames = cfg.collector.init_random_frames num_updates = int( @@ -112,82 +112,61 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_iter = cfg.logger.eval_iter frames_per_batch = cfg.collector.frames_per_batch eval_rollout_steps = cfg.collector.max_frames_per_traj - sampling_start = start_time = time.time() - for tensordict in collector: - sampling_time = time.time() - sampling_start - pbar.update(tensordict.numel()) + collector_iter = iter(collector) + pbar = tqdm.tqdm(range(collector.total_frames)) + for _ in range(len(collector)): + with timeit("collection"): + tensordict = next(collector_iter) + current_frames = tensordict.numel() + pbar.update(current_frames) # update weights of the inference policy collector.update_policy_weights_() - tensordict = tensordict.view(-1) - current_frames = tensordict.numel() - # add to replay buffer - replay_buffer.extend(tensordict.cpu()) + with timeit("rb - extend"): + # add to replay buffer + tensordict = tensordict.rehsape(-1) + replay_buffer.extend(tensordict.cpu()) collected_frames += current_frames # optimization steps - training_start = time.time() - if collected_frames >= init_random_frames: - for _ in range(num_updates): - # sample from replay buffer - sampled_tensordict = replay_buffer.sample().clone() - if sampled_tensordict.device != device: - sampled_tensordict = sampled_tensordict.to( - device, non_blocking=True - ) - else: - sampled_tensordict = sampled_tensordict - # compute losses - loss_info = loss_module(sampled_tensordict) - actor_loss = loss_info["loss_actor"] - value_loss = loss_info["loss_value"] - q_loss = loss_info["loss_qvalue"] - - optimizer_actor.zero_grad() - actor_loss.backward() - optimizer_actor.step() - - optimizer_value.zero_grad() - value_loss.backward() - optimizer_value.step() - - optimizer_critic.zero_grad() - q_loss.backward() - optimizer_critic.step() - - # update qnet_target params - target_net_updater.step() - - # update priority - if prb: - replay_buffer.update_priority(sampled_tensordict) - training_time = time.time() - training_start + with timeit("training"): + if collected_frames >= init_random_frames: + for _ in range(num_updates): + with timeit("rb - sampling"): + # sample from replay buffer + sampled_tensordict = replay_buffer.sample().to(device) + + def update(sampled_tensordict): + optimizer.zero_grad() + # compute losses + loss_info = loss_module(sampled_tensordict) + actor_loss = loss_info["loss_actor"] + value_loss = loss_info["loss_value"] + q_loss = loss_info["loss_qvalue"] + + (actor_loss + value_loss + q_loss).backward() + optimizer.step() + + # update qnet_target params + target_net_updater.step() + return loss_info.detach() + + with timeit("update"): + loss_info = update(sampled_tensordict) + # update priority + if prb: + replay_buffer.update_priority(sampled_tensordict) episode_rewards = tensordict["next", "episode_reward"][ tensordict["next", "done"] ] # Logging metrics_to_log = {} - if len(episode_rewards) > 0: - episode_length = tensordict["next", "step_count"][ - tensordict["next", "done"] - ] - metrics_to_log["train/reward"] = episode_rewards.mean().item() - metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( - episode_length - ) - if collected_frames >= init_random_frames: - metrics_to_log["train/q_loss"] = q_loss.detach() - metrics_to_log["train/actor_loss"] = actor_loss.detach() - metrics_to_log["train/value_loss"] = value_loss.detach() - metrics_to_log["train/entropy"] = loss_info.get("entropy").detach() - metrics_to_log["train/sampling_time"] = sampling_time - metrics_to_log["train/training_time"] = training_time - # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): - eval_start = time.time() + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(), timeit("evaluating"): eval_rollout = eval_env.rollout( eval_rollout_steps, model[0], @@ -195,25 +174,33 @@ def main(cfg: "DictConfig"): # noqa: F821 break_when_any_done=True, ) eval_env.apply(dump_video) - eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward - metrics_to_log["eval/time"] = eval_time + if len(episode_rewards) > 0: + episode_length = tensordict["next", "step_count"][ + tensordict["next", "done"] + ] + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + episode_length + ) + if collected_frames >= init_random_frames: + metrics_to_log["train/q_loss"] = loss_info["loss_qvalue"].detach() + metrics_to_log["train/actor_loss"] = loss_info["loss_actor"].detach() + metrics_to_log["train/value_loss"] = loss_info["loss_value"].detach() + metrics_to_log["train/entropy"] = loss_info.get("entropy").detach() + metrics_to_log.update(timeit.todict(prefix="time")) + if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) - sampling_start = time.time() collector.shutdown() - end_time = time.time() - execution_time = end_time - start_time if not eval_env.is_closed: eval_env.close() if not train_env.is_closed: train_env.close() - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") - if __name__ == "__main__": main() diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index 261cb912de0..b74eb8e8f79 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -118,7 +118,7 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1, logger=None): # --------------------------- -def make_collector(cfg, train_env, actor_model_explore): +def make_collector(cfg, train_env, actor_model_explore, compile_mode): """Make collector.""" device = cfg.collector.device if device in ("", None): @@ -134,6 +134,8 @@ def make_collector(cfg, train_env, actor_model_explore): max_frames_per_traj=cfg.collector.max_frames_per_traj, total_frames=cfg.collector.total_frames, device=device, + compile_policy={"mode": compile_mode} if compile_mode else False, + cuda=cfg.compile.cudagraphs, ) collector.set_seed(cfg.env.seed) return collector