Skip to content

Commit

Permalink
[Feature] GAIL compatibility with compile
Browse files Browse the repository at this point in the history
ghstack-source-id: c658dcd52733273880f6974001bed662b7e2e55d
Pull Request resolved: #2573
  • Loading branch information
vmoens committed Dec 14, 2024
1 parent f2af2b3 commit 22121c9
Show file tree
Hide file tree
Showing 13 changed files with 152 additions and 87 deletions.
2 changes: 1 addition & 1 deletion sota-implementations/cql/discrete_cql_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ collector:
multi_step: 0
init_random_frames: 1000
env_per_collector: 1
device: cpu
device:
max_frames_per_traj: 200
annealing_frames: 10000
eps_start: 1.0
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/cql/online_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ collector:
multi_step: 0
init_random_frames: 5_000
env_per_collector: 1
device: cpu
device:
max_frames_per_traj: 1000


Expand Down
8 changes: 7 additions & 1 deletion sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,20 @@ def make_collector(
cudagraph=False,
):
"""Make collector."""
device = cfg.collector.device
if device in ("", None):
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
collector = SyncDataCollector(
train_env,
actor_model_explore,
init_random_frames=cfg.collector.init_random_frames,
frames_per_batch=cfg.collector.frames_per_batch,
max_frames_per_traj=cfg.collector.max_frames_per_traj,
total_frames=cfg.collector.total_frames,
device=cfg.collector.device,
device=device,
compile_policy={"mode": compile_mode} if compile else False,
cudagraph_policy=cudagraph,
)
Expand Down
5 changes: 5 additions & 0 deletions sota-implementations/gail/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ gail:
gp_lambda: 10.0
device: null

compile:
compile: False
compile_mode:
cudagraphs: False

replay_buffer:
dataset: halfcheetah-expert-v2
batch_size: 256
172 changes: 102 additions & 70 deletions sota-implementations/gail/gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,24 @@
"""
from __future__ import annotations

import warnings

import hydra
import numpy as np
import torch
import tqdm

from gail_utils import log_metrics, make_gail_discriminator, make_offline_replay_buffer
from ppo_utils import eval_model, make_env, make_ppo_models
from tensordict import TensorDict
from tensordict.nn import CudaGraphModule
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement

from torchrl.envs import set_gym_backend
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import ClipPPOLoss, GAILLoss
from torchrl.objectives import ClipPPOLoss, GAILLoss, group_optimizers
from torchrl.objectives.value.advantages import GAE
from torchrl.record import VideoRecorder
from torchrl.record.loggers import generate_exp_name, get_logger
Expand Down Expand Up @@ -71,18 +75,8 @@ def main(cfg: "DictConfig"): # noqa: F821
np.random.seed(cfg.env.seed)

# Create models (check utils_mujoco.py)
actor, critic = make_ppo_models(cfg.env.env_name)
actor, critic = actor.to(device), critic.to(device)

# Create collector
collector = SyncDataCollector(
create_env_fn=make_env(cfg.env.env_name, device),
policy=actor,
frames_per_batch=cfg.ppo.collector.frames_per_batch,
total_frames=cfg.ppo.collector.total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
actor, critic = make_ppo_models(
cfg.env.env_name, compile=cfg.compile.compile, device=device
)

# Create data buffer
Expand Down Expand Up @@ -111,8 +105,36 @@ def main(cfg: "DictConfig"): # noqa: F821
)

# Create optimizers
actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5)
critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.ppo.optim.lr, eps=1e-5)
actor_optim = torch.optim.Adam(
actor.parameters(), lr=torch.tensor(cfg.ppo.optim.lr, device=device), eps=1e-5
)
critic_optim = torch.optim.Adam(
critic.parameters(), lr=torch.tensor(cfg.ppo.optim.lr, device=device), eps=1e-5
)
optim = group_optimizers(actor_optim, critic_optim)
del actor_optim, critic_optim

compile_mode = None
if cfg.compile.compile:
compile_mode = cfg.compile.compile_mode
if compile_mode in ("", None):
if cfg.compile.cudagraphs:
compile_mode = "default"
else:
compile_mode = "reduce-overhead"

# Create collector
collector = SyncDataCollector(
create_env_fn=make_env(cfg.env.env_name, device),
policy=actor,
frames_per_batch=cfg.ppo.collector.frames_per_batch,
total_frames=cfg.ppo.collector.total_frames,
device=device,
storing_device=device,
max_frames_per_traj=-1,
compile_policy={"mode": compile_mode} if compile_mode is not None else False,
cudagraph_policy=cfg.compile.cudagraphs,
)

# Create replay buffer
replay_buffer = make_offline_replay_buffer(cfg.replay_buffer)
Expand Down Expand Up @@ -140,32 +162,9 @@ def main(cfg: "DictConfig"): # noqa: F821
VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"])
)
test_env.eval()
num_network_updates = torch.zeros((), dtype=torch.int64, device=device)

# Training loop
collected_frames = 0
num_network_updates = 0
pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames)

# extract cfg variables
cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs
cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr
cfg_optim_lr = cfg.ppo.optim.lr
cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon
cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon
cfg_logger_test_interval = cfg.logger.test_interval
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes

for i, data in enumerate(collector):

log_info = {}
frames_in_batch = data.numel()
collected_frames += frames_in_batch
pbar.update(data.numel())

# Update discriminator
# Get expert data
expert_data = replay_buffer.sample()
expert_data = expert_data.to(device)
def update(data, expert_data, num_network_updates=num_network_updates):
# Add collector data to expert data
expert_data.set(
discriminator_loss.tensor_keys.collector_action,
Expand All @@ -178,9 +177,9 @@ def main(cfg: "DictConfig"): # noqa: F821
d_loss = discriminator_loss(expert_data)

# Backward pass
discriminator_optim.zero_grad()
d_loss.get("loss").backward()
discriminator_optim.step()
discriminator_optim.zero_grad(set_to_none=True)

# Compute discriminator reward
with torch.no_grad():
Expand All @@ -190,40 +189,25 @@ def main(cfg: "DictConfig"): # noqa: F821
# Set discriminator rewards to tensordict
data.set(("next", "reward"), d_rewards)

# Get training rewards and episode lengths
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
episode_length = data["next", "step_count"][data["next", "done"]]
log_info.update(
{
"train/reward": episode_rewards.mean().item(),
"train/episode_length": episode_length.sum().item()
/ len(episode_length),
}
)
# Update PPO
for _ in range(cfg_loss_ppo_epochs):

# Compute GAE
with torch.no_grad():
data = adv_module(data)
data_reshape = data.reshape(-1)

# Update the data buffer
data_buffer.empty()
data_buffer.extend(data_reshape)

for _, batch in enumerate(data_buffer):

# Get a data batch
batch = batch.to(device)
for batch in data_buffer:
optim.zero_grad(set_to_none=True)

# Linearly decrease the learning rate and clip epsilon
alpha = 1.0
alpha = torch.ones((), device=device)
if cfg_optim_anneal_lr:
alpha = 1 - (num_network_updates / total_network_updates)
for group in actor_optim.param_groups:
group["lr"] = cfg_optim_lr * alpha
for group in critic_optim.param_groups:
for group in optim.param_groups:
group["lr"] = cfg_optim_lr * alpha
if cfg_loss_anneal_clip_eps:
loss_module.clip_epsilon.copy_(cfg_loss_clip_epsilon * alpha)
Expand All @@ -235,20 +219,68 @@ def main(cfg: "DictConfig"): # noqa: F821
actor_loss = loss["loss_objective"] + loss["loss_entropy"]

# Backward pass
actor_loss.backward()
critic_loss.backward()
(actor_loss + critic_loss).backward()

# Update the networks
actor_optim.step()
critic_optim.step()
actor_optim.zero_grad()
critic_optim.zero_grad()
optim.step()
return TensorDict(dloss=d_loss, alpha=alpha).detach()

if cfg.compile.compile:
update = torch.compile(update, mode=compile_mode)
if cfg.compile.cudagraphs:
warnings.warn(
"CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.",
category=UserWarning,
)
update = CudaGraphModule(update, warmup=50)

# Training loop
collected_frames = 0
pbar = tqdm.tqdm(total=cfg.ppo.collector.total_frames)

# extract cfg variables
cfg_loss_ppo_epochs = cfg.ppo.loss.ppo_epochs
cfg_optim_anneal_lr = cfg.ppo.optim.anneal_lr
cfg_optim_lr = cfg.ppo.optim.lr
cfg_loss_anneal_clip_eps = cfg.ppo.loss.anneal_clip_epsilon
cfg_loss_clip_epsilon = cfg.ppo.loss.clip_epsilon
cfg_logger_test_interval = cfg.logger.test_interval
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes

for i, data in enumerate(collector):

log_info = {}
frames_in_batch = data.numel()
collected_frames += frames_in_batch
pbar.update(data.numel())

# Update discriminator
# Get expert data
expert_data = replay_buffer.sample()
expert_data = expert_data.to(device)

metadata = update(data, expert_data)
d_loss = metadata["d_loss"]
alpha = metadata["alpha"]

# Get training rewards and episode lengths
episode_rewards = data["next", "episode_reward"][data["next", "done"]]
if len(episode_rewards) > 0:
episode_length = data["next", "step_count"][data["next", "done"]]

log_info.update(
{
"train/reward": episode_rewards.mean().item(),
"train/episode_length": episode_length.sum().item()
/ len(episode_length),
}
)

log_info.update(
{
"train/actor_loss": actor_loss.item(),
"train/critic_loss": critic_loss.item(),
"train/discriminator_loss": d_loss["loss"].item(),
# "train/actor_loss": actor_loss.item(),
# "train/critic_loss": critic_loss.item(),
"train/discriminator_loss": d_loss["loss"],
"train/lr": alpha * cfg_optim_lr,
"train/clip_epsilon": (
alpha * cfg_loss_clip_epsilon
Expand Down
18 changes: 11 additions & 7 deletions sota-implementations/gail/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def make_env(env_name="HalfCheetah-v4", device="cpu", from_pixels: bool = False)
# --------------------------------------------------------------------


def make_ppo_models_state(proof_environment):
def make_ppo_models_state(proof_environment, compile, device):

# Define input shape
input_shape = proof_environment.observation_spec["observation"].shape
Expand All @@ -52,9 +52,10 @@ def make_ppo_models_state(proof_environment):
num_outputs = proof_environment.action_spec_unbatched.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"low": proof_environment.action_spec_unbatched.space.low,
"high": proof_environment.action_spec_unbatched.space.high,
"low": proof_environment.action_spec_unbatched.space.low.to(device),
"high": proof_environment.action_spec_unbatched.space.high.to(device),
"tanh_loc": False,
"safe_tanh": not compile,
}

# Define policy architecture
Expand All @@ -63,6 +64,7 @@ def make_ppo_models_state(proof_environment):
activation_class=torch.nn.Tanh,
out_features=num_outputs, # predict only loc
num_cells=[64, 64],
device=device,
)

# Initialize policy weights
Expand All @@ -87,7 +89,7 @@ def make_ppo_models_state(proof_environment):
out_keys=["loc", "scale"],
),
in_keys=["loc", "scale"],
spec=proof_environment.single_full_action_spec,
spec=proof_environment.full_action_spec_unbatched.to(device),
distribution_class=distribution_class,
distribution_kwargs=distribution_kwargs,
return_log_prob=True,
Expand Down Expand Up @@ -117,9 +119,11 @@ def make_ppo_models_state(proof_environment):
return policy_module, value_module


def make_ppo_models(env_name):
proof_environment = make_env(env_name, device="cpu")
actor, critic = make_ppo_models_state(proof_environment)
def make_ppo_models(env_name, compile, device):
proof_environment = make_env(env_name, device=device)
actor, critic = make_ppo_models_state(
proof_environment, compile=compile, device=device
)
return actor, critic


Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/iql/discrete_iql.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ collector:
total_frames: 20000
init_random_frames: 1000
env_per_collector: 1
device: cpu
device:
max_frames_per_traj: 200

# logger
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/iql/online_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ collector:
multi_step: 0
init_random_frames: 5000
env_per_collector: 1
device: cpu
device:
max_frames_per_traj: 200

# logger
Expand Down
Loading

0 comments on commit 22121c9

Please sign in to comment.