Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: PPO system cleanup #1124

Merged
merged 25 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
bad03c7
feat: rec ippo system clean up
sash-a Nov 1, 2024
5bf7188
feat: rec mappo system clean up
sash-a Nov 1, 2024
f77b782
feat: ff ippo system clean up
sash-a Nov 1, 2024
7c4da12
feat: ff mappo system clean up
sash-a Nov 1, 2024
d88191a
Merge branch 'develop' into chore/ppo-system-cleanup
sash-a Nov 4, 2024
c299eb4
chore: remove advanced usage
sash-a Nov 5, 2024
bc7236e
feat: mat system clean up
sash-a Nov 6, 2024
b3d6006
chore: pre-commit
sash-a Nov 6, 2024
a9f2050
Merge branch 'develop' into chore/ppo-system-cleanup
sash-a Nov 6, 2024
a3e5842
refactor: loss_actor -> actor_loss
sash-a Nov 7, 2024
1a6ad17
chore: unify value/critic loss naming
sash-a Nov 7, 2024
b16c44f
chore: merge dev
sash-a Nov 7, 2024
a11bc1a
feat: sable cleanup
sash-a Nov 8, 2024
394a301
Merge branch 'develop' into chore/ppo-system-cleanup
sash-a Nov 11, 2024
6491522
chore: add back advanced usage in examples
sash-a Nov 13, 2024
16c828e
chore: remove expired link to anakin notebook
sash-a Nov 13, 2024
27bdc2f
Merge branch 'develop' into chore/ppo-system-cleanup
sash-a Nov 13, 2024
e8a0c07
fix: updated sebulba
Louay-Ben-nessir Nov 14, 2024
aed2313
chore: renaming loss_X -> X_loss
Louay-Ben-nessir Nov 14, 2024
e99e4ac
Merge branch 'develop' into chore/ppo-system-cleanup
sash-a Nov 21, 2024
b3e1c01
Merge branch 'develop' into chore/ppo-system-cleanup
RuanJohn Nov 22, 2024
5f3a09a
chore: minor comment changes
RuanJohn Nov 26, 2024
b74059a
Merge branch 'develop' into chore/ppo-system-cleanup
RuanJohn Nov 26, 2024
900eca8
Merge branch 'develop' into chore/ppo-system-cleanup
sash-a Dec 3, 2024
81c108d
chore: remove jax.tree_map
sash-a Dec 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ Additionally, we also have a [Quickstart notebook][quickstart] that can be used

## Advanced Usage 👽
sash-a marked this conversation as resolved.
Show resolved Hide resolved

Mava can be used in a wide array of advanced systems. As an example, we demonstrate recording experience data from one of our PPO systems into a [Flashbax](https://github.com/instadeepai/flashbax) `Vault`. This vault can then easily be integrated into offline MARL systems, such as those found in [OG-MARL](https://github.com/instadeepai/og-marl). See the [Advanced README](./mava/advanced_usage/) for more information.
Mava can be used in a wide array of advanced systems. As an example, we demonstrate recording experience data from one of our PPO systems into a [Flashbax](https://github.com/instadeepai/flashbax) `Vault`. This vault can then easily be integrated into offline MARL systems, such as those found in [OG-MARL](https://github.com/instadeepai/og-marl). See the [Advanced README](./examples/advanced_usage/README.md) for more information.

## Contributing 🤝

Expand Down
2 changes: 0 additions & 2 deletions examples/Quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,6 @@
" )\n",
"\n",
" # Compute the parallel mean (pmean) over the batch.\n",
" # This calculation is inspired by the Anakin architecture demo notebook.\n",
" # available at https://tinyurl.com/26tdzs5x\n",
" # This pmean could be a regular mean as the batch axis is on the same device.\n",
" actor_grads, actor_loss_info = jax.lax.pmean(\n",
" (actor_grads, actor_loss_info), axis_name=\"batch\"\n",
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# type: ignore
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -225,8 +226,6 @@ def _critic_loss_fn(
)

# Compute the parallel mean (pmean) over the batch.
# This calculation is inspired by the Anakin architecture demo notebook.
# available at https://tinyurl.com/26tdzs5x
# This pmean could be a regular mean as the batch axis is on the same device.
actor_grads, actor_loss_info = jax.lax.pmean(
(actor_grads, actor_loss_info), axis_name="batch"
Expand Down
4 changes: 2 additions & 2 deletions mava/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]:
# find the first instance of done to get the metrics at that timestep, we don't
# care about subsequent steps because we only the results from the first episode
done_idx = np.argmax(timesteps.last(), axis=0)
metrics = jax.tree_map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics)
metrics = tree.map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics)
del metrics["is_terminal_step"] # uneeded for logging

return key, metrics
Expand All @@ -307,7 +307,7 @@ def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]:
metrics_array.append(metric)

# flatten metrics
metrics: Metrics = jax.tree_map(lambda *x: np.array(x).reshape(-1), *metrics_array)
metrics: Metrics = tree.map(lambda *x: np.array(x).reshape(-1), *metrics_array)
return metrics

def timed_eval_fn(params: FrozenDict, key: PRNGKey, init_act_state: ActorState) -> Metrics:
Expand Down
92 changes: 30 additions & 62 deletions mava/systems/mat/anakin/mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,13 @@
ExperimentOutput,
LearnerFn,
MarlEnv,
Metrics,
TimeStep,
)
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import (
merge_leading_dims,
unreplicate_batch_dim,
unreplicate_n_dims,
)
from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.network_utils import get_action_head
from mava.utils.training import make_learning_rate
Expand Down Expand Up @@ -83,51 +80,35 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup
_ (Any): The current metrics info.
"""

def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]:
def _env_step(
learner_state: LearnerState, _: Any
) -> Tuple[LearnerState, Tuple[PPOTransition, Metrics]]:
"""Step the environment."""
params, opt_state, key, env_state, last_timestep = learner_state

# SELECT ACTION
# Select action
key, policy_key = jax.random.split(key)
action, log_prob, value = actor_action_select_fn( # type: ignore
params,
last_timestep.observation,
policy_key,
)
# STEP ENVIRONMENT
# Step environment
env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)

# LOG EPISODE METRICS
# Repeat along the agent dimension. This is needed to handle the
# shuffling along the agent dimension during training.
info = tree.map(
lambda x: jnp.repeat(x[..., jnp.newaxis], config.system.num_agents, axis=-1),
timestep.extras["episode_metrics"],
)

# SET TRANSITION
done = tree.map(
lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1),
timestep.last(),
)
done = timestep.last().repeat(env.num_agents).reshape(config.arch.num_envs, -1)
transition = PPOTransition(
done,
action,
value,
timestep.reward,
log_prob,
last_timestep.observation,
info,
sash-a marked this conversation as resolved.
Show resolved Hide resolved
done, action, value, timestep.reward, log_prob, last_timestep.observation
)
learner_state = LearnerState(params, opt_state, key, env_state, timestep)
return learner_state, transition
return learner_state, (transition, timestep.extras["episode_metrics"])

# STEP ENVIRONMENT FOR ROLLOUT LENGTH
learner_state, traj_batch = jax.lax.scan(
# Step environment for rollout length
learner_state, (traj_batch, episode_metrics) = jax.lax.scan(
_env_step, learner_state, None, config.system.rollout_length
)

# CALCULATE ADVANTAGE
# Calculate advantage
params, opt_state, key, env_state, last_timestep = learner_state

key, last_val_key = jax.random.split(key)
Expand Down Expand Up @@ -171,8 +152,6 @@ def _update_epoch(update_state: Tuple, _: Any) -> Tuple:

def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple:
"""Update the network for a single minibatch."""

# UNPACK TRAIN STATE AND BATCH INFO
params, opt_state, key = train_state
traj_batch, advantages, targets = batch_info

Expand All @@ -184,52 +163,47 @@ def _loss_fn(
entropy_key: chex.PRNGKey,
) -> Tuple:
"""Calculate the actor loss."""
# RERUN NETWORK

# Rerun network
log_prob, value, entropy = actor_apply_fn( # type: ignore
params,
traj_batch.obs,
traj_batch.action,
entropy_key,
)

# CALCULATE ACTOR LOSS
# Calculate actor loss
ratio = jnp.exp(log_prob - traj_batch.log_prob)

# Nomalise advantage at minibatch level
gae = (gae - gae.mean()) / (gae.std() + 1e-8)

loss_actor1 = ratio * gae
loss_actor2 = (
actor_loss1 = ratio * gae
actor_loss2 = (
jnp.clip(
ratio,
1.0 - config.system.clip_eps,
1.0 + config.system.clip_eps,
)
* gae
)
loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
loss_actor = loss_actor.mean()
actor_loss = -jnp.minimum(actor_loss1, actor_loss2)
actor_loss = actor_loss.mean()
entropy = entropy.mean()

# CALCULATE VALUE LOSS
# Clipped MSE loss
value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip(
-config.system.clip_eps, config.system.clip_eps
)

# MSE LOSS
value_losses = jnp.square(value - value_targets)
value_losses_clipped = jnp.square(value_pred_clipped - value_targets)
value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()

total_loss = (
loss_actor
actor_loss
- config.system.ent_coef * entropy
+ config.system.vf_coef * value_loss
)
return total_loss, (loss_actor, entropy, value_loss)
return total_loss, (actor_loss, entropy, value_loss)

# CALCULATE ACTOR LOSS
# Calculate loss
key, entropy_key = jax.random.split(key)
actor_grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
actor_loss_info, actor_grads = actor_grad_fn(
Expand All @@ -248,15 +222,11 @@ def _loss_fn(
(actor_grads, actor_loss_info), axis_name="device"
)

# UPDATE ACTOR PARAMS AND OPTIMISER STATE
# Update params and optimiser state
actor_updates, new_opt_state = actor_update_fn(actor_grads, opt_state)
new_params = optax.apply_updates(params, actor_updates)

# PACK LOSS INFO
total_loss = actor_loss_info[0]
value_loss = actor_loss_info[1][2]
actor_loss = actor_loss_info[1][0]
entropy = actor_loss_info[1][1]
sash-a marked this conversation as resolved.
Show resolved Hide resolved
total_loss, (actor_loss, entropy, value_loss) = actor_loss_info
loss_info = {
"total_loss": total_loss,
"value_loss": value_loss,
Expand All @@ -269,7 +239,7 @@ def _loss_fn(
params, opt_state, traj_batch, advantages, targets, key = update_state
key, batch_shuffle_key, agent_shuffle_key, entropy_key = jax.random.split(key, 4)

# SHUFFLE MINIBATCHES
# Shuffle minibatches
batch_size = config.system.rollout_length * config.arch.num_envs
permutation = jax.random.permutation(batch_shuffle_key, batch_size)

Expand All @@ -286,7 +256,7 @@ def _loss_fn(
shuffled_batch,
)

# UPDATE MINIBATCHES
# Update minibatches
(params, opt_state, entropy_key), loss_info = jax.lax.scan(
_update_minibatch, (params, opt_state, entropy_key), minibatches
)
Expand All @@ -296,17 +266,15 @@ def _loss_fn(

update_state = params, opt_state, traj_batch, advantages, targets, key

# UPDATE EPOCHS
# Update epochs
update_state, loss_info = jax.lax.scan(
_update_epoch, update_state, None, config.system.ppo_epochs
)

params, opt_state, traj_batch, advantages, targets, key = update_state
learner_state = LearnerState(params, opt_state, key, env_state, last_timestep)

metric = traj_batch.info

return learner_state, (metric, loss_info)
return learner_state, (episode_metrics, loss_info)

def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]:
"""Learner function.
Expand Down Expand Up @@ -351,7 +319,7 @@ def learner_setup(
# PRNG keys.
key, actor_net_key = keys

# Initialise observation: Obs for all agents.
# Get mock inputs to initialise network.
init_x = env.observation_spec().generate_value()
init_x = tree.map(lambda x: x[None, ...], init_x)

Expand Down
Loading
Loading