From 1f32a1551e863c3ae43d3e3c0bddf45d2d885aca Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Tue, 3 Dec 2024 11:39:44 +0200 Subject: [PATCH] Merge pull request #1124 from instadeepai/chore/ppo-system-cleanup - chore: remove jax.tree_map - Merge branch develop into chore/ppo-system-cleanup - fix: update ret_output shape (#1147) - Merge branch develop into chore/ppo-system-cleanup - chore: minor comment changes - fix: integration test workflow (#1145) - Merge branch develop into chore/ppo-system-cleanup - Merge branch develop into chore/ppo-system-cleanup - chore: renaming loss_X -> X_loss - fix: updated sebulba - Merge branch develop into chore/ppo-system-cleanup - chore: remove expired link to anakin notebook - chore: add back advanced usage in examples - Merge branch develop into chore/ppo-system-cleanup - feat: sable cleanup - chore: merge dev - chore: unify value/critic loss naming - refactor: loss_actor -> actor_loss - Merge branch develop into chore/ppo-system-cleanup - chore: pre-commit - feat: mat system clean up - chore: remove advanced usage - Merge branch develop into chore/ppo-system-cleanup - feat: ff mappo system clean up - feat: ff ippo system clean up - feat: rec mappo system clean up - feat: rec ippo system clean up Co-authored-by: Louay-Ben-nessir Co-authored-by: Arnol Fokam Co-authored-by: Ruan de Kock --- .github/workflows/integration_tests.yaml | 4 + README.md | 2 +- examples/Quickstart.ipynb | 2 - examples/advanced_usage/README.md | 114 +++ .../ff_ippo_store_experience.py | 685 ++++++++++++++++++ mava/evaluator.py | 4 +- mava/networks/retention.py | 8 +- mava/systems/mat/anakin/mat.py | 92 +-- mava/systems/ppo/anakin/ff_ippo.py | 122 ++-- mava/systems/ppo/anakin/ff_mappo.py | 100 ++- mava/systems/ppo/anakin/rec_ippo.py | 111 ++- mava/systems/ppo/anakin/rec_mappo.py | 100 +-- mava/systems/ppo/sebulba/ff_ippo.py | 60 +- mava/systems/ppo/types.py | 3 - mava/systems/sable/anakin/ff_sable.py | 120 +-- mava/systems/sable/anakin/rec_sable.py | 136 ++-- mava/systems/sable/types.py | 14 +- mava/types.py | 2 +- mava/utils/sebulba.py | 13 +- 19 files changed, 1129 insertions(+), 563 deletions(-) create mode 100644 examples/advanced_usage/README.md create mode 100644 examples/advanced_usage/ff_ippo_store_experience.py diff --git a/.github/workflows/integration_tests.yaml b/.github/workflows/integration_tests.yaml index 492593a319..bfb5cbb80a 100644 --- a/.github/workflows/integration_tests.yaml +++ b/.github/workflows/integration_tests.yaml @@ -11,6 +11,10 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 20 + strategy: + matrix: + python-version: ["3.12", "3.11"] + steps: - name: Checkout mava uses: actions/checkout@v4 diff --git a/README.md b/README.md index 4285080d70..dee7a13311 100644 --- a/README.md +++ b/README.md @@ -186,7 +186,7 @@ Additionally, we also have a [Quickstart notebook][quickstart] that can be used ## Advanced Usage 👽 -Mava can be used in a wide array of advanced systems. As an example, we demonstrate recording experience data from one of our PPO systems into a [Flashbax](https://github.com/instadeepai/flashbax) `Vault`. This vault can then easily be integrated into offline MARL systems, such as those found in [OG-MARL](https://github.com/instadeepai/og-marl). See the [Advanced README](./mava/advanced_usage/) for more information. +Mava can be used in a wide array of advanced systems. As an example, we demonstrate recording experience data from one of our PPO systems into a [Flashbax](https://github.com/instadeepai/flashbax) `Vault`. This vault can then easily be integrated into offline MARL systems, such as those found in [OG-MARL](https://github.com/instadeepai/og-marl). See the [Advanced README](./examples/advanced_usage/README.md) for more information. ## Contributing 🤝 diff --git a/examples/Quickstart.ipynb b/examples/Quickstart.ipynb index 7febf61401..baf119cda6 100644 --- a/examples/Quickstart.ipynb +++ b/examples/Quickstart.ipynb @@ -413,8 +413,6 @@ " )\n", "\n", " # Compute the parallel mean (pmean) over the batch.\n", - " # This calculation is inspired by the Anakin architecture demo notebook.\n", - " # available at https://tinyurl.com/26tdzs5x\n", " # This pmean could be a regular mean as the batch axis is on the same device.\n", " actor_grads, actor_loss_info = jax.lax.pmean(\n", " (actor_grads, actor_loss_info), axis_name=\"batch\"\n", diff --git a/examples/advanced_usage/README.md b/examples/advanced_usage/README.md new file mode 100644 index 0000000000..e5da8511c7 --- /dev/null +++ b/examples/advanced_usage/README.md @@ -0,0 +1,114 @@ +# Advanced Mava usage 👽 +## Data recording from a PPO system 🔴 +We include here an example of an advanced use case with Mava: recording experience data from a PPO system, which can then be used for offline MARL—e.g. using the [OG-MARL](https://github.com/instadeepai/og-marl) framework. This functionality is demonstrated in [ff_ippo_store_experience.py](./ff_ippo_store_experience.py), and uses [Flashbax](https://github.com/instadeepai/flashbax)'s `Vault` feature. Vault enables efficient storage of experience data recorded in JAX-based systems, and integrates tightly with Mava and the rest of InstaDeep's MARL ecosystem. + +Firstly, a vault must be created using the structure of an experience buffer. Here, we create a dummy structure of the data we want to store: +```py +# Transition structure +dummy_flashbax_transition = { + "done": jnp.zeros((config.system.num_agents,), dtype=bool), + "action": jnp.zeros((config.system.num_agents,), dtype=jnp.int32), + "reward": jnp.zeros((config.system.num_agents,), dtype=jnp.float32), + "observation": jnp.zeros( + ( + config.system.num_agents, + env.observation_spec().agents_view.shape[1], + ), + dtype=jnp.float32, + ), + "legal_action_mask": jnp.zeros( + ( + config.system.num_agents, + config.system.num_actions, + ), + dtype=bool, + ), +} + +# Flashbax buffer +buffer = fbx.make_flat_buffer( + max_length=int(5e5), + min_length=int(1), + sample_batch_size=1, + add_sequences=True, + add_batch_size=( + n_devices + * config["system"]["num_updates_per_eval"] + * config["system"]["update_batch_size"] + * config["arch"]["num_envs"] + ), +) + +# Buffer state +buffer_state = buffer.init( + dummy_flashbax_transition, +) +``` + +We can now create a `Vault` for our data: +```py +v = Vault( + vault_name="our_system_name", + experience_structure=buffer_state.experience, + vault_uid="unique_vault_id", +) +``` + +We modify our `learn` function to additionally record our agents' trajectories, such that we can access experience data: +```py +learner_output, experience_to_store = learn(learner_state) +``` + +Because of the Anakin architecture set-up, our trajectories are stored in the incorrect dimensions for our use case. Hence, we transform the data, and then store it in a flashbax buffer: +```py +# Shape legend: +# D: Number of devices +# NU: Number of updates per evaluation +# UB: Update batch size +# T: Time steps per rollout +# NE: Number of environments + +@jax.jit +def _reshape_experience(experience: Dict[str, chex.Array]) -> Dict[str, chex.Array]: + """Reshape experience to match buffer.""" + + # Swap the T and NE axes (D, NU, UB, T, NE, ...) -> (D, NU, UB, NE, T, ...) + experience: Dict[str, chex.Array] = jax.tree.map(lambda x: x.swapaxes(3, 4), experience) + # Merge 4 leading dimensions into 1. (D, NU, UB, NE, T ...) -> (D * NU * UB * NE, T, ...) + experience: Dict[str, chex.Array] = jax.tree.map( + lambda x: x.reshape(-1, *x.shape[4:]), experience + ) + return experience + +flashbax_transition = _reshape_experience( + { + # (D, NU, UB, T, NE, ...) + "done": experience_to_store.done, + "action": experience_to_store.action, + "reward": experience_to_store.reward, + "observation": experience_to_store.obs.agents_view, + "legal_action_mask": experience_to_store.obs.action_mask, + } +) +# Add to fbx buffer +buffer_state = buffer.add(buffer_state, flashbax_transition) +``` + +Then, periodically, we can write this buffer state into the vault, which is stored on disk: +```py +v.write(buffer_state) +``` + +If we now want to use the recorded data, we can easily restore the vault in another context: +```py +v = Vault( + vault_name="our_system_name", + vault_uid="unique_vault_id", +) +buffer_state = v.read() +``` + +For a demonstration of offline MARL training, see some examples [here](https://github.com/instadeepai/og-marl/tree/feat/vault). + +--- +⚠️ Note: this functionality is highly experimental! The current API likely to change. diff --git a/examples/advanced_usage/ff_ippo_store_experience.py b/examples/advanced_usage/ff_ippo_store_experience.py new file mode 100644 index 0000000000..58d71c795c --- /dev/null +++ b/examples/advanced_usage/ff_ippo_store_experience.py @@ -0,0 +1,685 @@ +# type: ignore +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import time +from typing import Any, Callable, Dict, Tuple + +import chex +import flashbax as fbx +import hydra +import jax +import jax.numpy as jnp +import optax +from colorama import Fore, Style +from flashbax.vault import Vault +from flax.core.frozen_dict import FrozenDict +from jax import tree +from omegaconf import DictConfig, OmegaConf +from optax._src.base import OptState +from rich.pretty import pprint + +from mava.evaluator import get_eval_fn, make_ff_eval_act_fn +from mava.networks import FeedForwardActor as Actor +from mava.networks import FeedForwardValueNet as Critic +from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition +from mava.types import ActorApply, CriticApply, ExperimentOutput, MarlEnv, MavaState +from mava.utils.checkpointing import Checkpointer +from mava.utils.jax_utils import ( + merge_leading_dims, + unreplicate_batch_dim, + unreplicate_n_dims, +) +from mava.utils.logger import LogEvent, MavaLogger +from mava.utils.make_env import make +from mava.utils.network_utils import get_action_head +from mava.wrappers.episode_metrics import get_final_step_metrics + +StoreExpLearnerFn = Callable[[MavaState], Tuple[ExperimentOutput[MavaState], PPOTransition]] + +# Experimental config +SAVE_VAULT = True +VAULT_NAME = "ff_ippo_rware" +VAULT_UID = None # None => timestamp +VAULT_SAVE_INTERVAL = 5 + + +def get_learner_fn( + env: MarlEnv, + apply_fns: Tuple[ActorApply, CriticApply], + update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], + config: DictConfig, +) -> StoreExpLearnerFn[LearnerState]: + """Get the learner function.""" + # Get apply and update functions for actor and critic networks. + actor_apply_fn, critic_apply_fn = apply_fns + actor_update_fn, critic_update_fn = update_fns + + def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tuple]: + """A single update of the network. + + This function steps the environment and records the trajectory batch for + training. It then calculates advantages and targets based on the recorded + trajectory and updates the actor and critic networks based on the calculated + losses. + + Args: + ---- + learner_state (NamedTuple): + - params (Params): The current model parameters. + - opt_states (OptStates): The current optimizer states. + - key (PRNGKey): The random number generator state. + - env_state (State): The environment state. + - last_timestep (TimeStep): The last timestep in the current trajectory. + _ (Any): The current metrics info. + + """ + + def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]: + """Step the environment.""" + params, opt_states, key, env_state, last_timestep = learner_state + + # SELECT ACTION + key, policy_key = jax.random.split(key) + actor_policy = actor_apply_fn(params.actor_params, last_timestep.observation) + value = critic_apply_fn(params.critic_params, last_timestep.observation) + action = actor_policy.sample(seed=policy_key) + log_prob = actor_policy.log_prob(action) + + # STEP ENVIRONMENT + env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) + + # LOG EPISODE METRICS + done = tree.map( + lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), + timestep.last(), + ) + info = timestep.extras["episode_metrics"] + + transition = PPOTransition( + done, action, value, timestep.reward, log_prob, last_timestep.observation, info + ) + + learner_state = LearnerState(params, opt_states, key, env_state, timestep) + return learner_state, transition + + # STEP ENVIRONMENT FOR ROLLOUT LENGTH + learner_state, traj_batch = jax.lax.scan( + _env_step, learner_state, None, config.system.rollout_length + ) + + # CALCULATE ADVANTAGE + params, opt_states, key, env_state, last_timestep = learner_state + last_val = critic_apply_fn(params.critic_params, last_timestep.observation) + + def _calculate_gae( + traj_batch: PPOTransition, last_val: chex.Array + ) -> Tuple[chex.Array, chex.Array]: + """Calculate the GAE.""" + + def _get_advantages(gae_and_next_value: Tuple, transition: PPOTransition) -> Tuple: + """Calculate the GAE for a single transition.""" + gae, next_value = gae_and_next_value + done, value, reward = ( + transition.done, + transition.value, + transition.reward, + ) + gamma = config.system.gamma + delta = reward + gamma * next_value * (1 - done) - value + gae = delta + gamma * config.system.gae_lambda * (1 - done) * gae + return (gae, value), gae + + _, advantages = jax.lax.scan( + _get_advantages, + (jnp.zeros_like(last_val), last_val), + traj_batch, + reverse=True, + unroll=16, + ) + return advantages, advantages + traj_batch.value + + advantages, targets = _calculate_gae(traj_batch, last_val) + + def _update_epoch(update_state: Tuple, _: Any) -> Tuple: + """Update the network for a single epoch.""" + + def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: + """Update the network for a single minibatch.""" + # UNPACK TRAIN STATE AND BATCH INFO + params, opt_states = train_state + traj_batch, advantages, targets = batch_info + + def _actor_loss_fn( + actor_params: FrozenDict, + actor_opt_state: OptState, + traj_batch: PPOTransition, + gae: chex.Array, + ) -> Tuple: + """Calculate the actor loss.""" + # RERUN NETWORK + actor_policy = actor_apply_fn(actor_params, traj_batch.obs) + log_prob = actor_policy.log_prob(traj_batch.action) + + # CALCULATE ACTOR LOSS + ratio = jnp.exp(log_prob - traj_batch.log_prob) + gae = (gae - gae.mean()) / (gae.std() + 1e-8) + loss_actor1 = ratio * gae + loss_actor2 = ( + jnp.clip( + ratio, + 1.0 - config.system.clip_eps, + 1.0 + config.system.clip_eps, + ) + * gae + ) + loss_actor = -jnp.minimum(loss_actor1, loss_actor2) + loss_actor = loss_actor.mean() + entropy = actor_policy.entropy().mean() + + total_loss_actor = loss_actor - config.system.ent_coef * entropy + return total_loss_actor, (loss_actor, entropy) + + def _critic_loss_fn( + critic_params: FrozenDict, + critic_opt_state: OptState, + traj_batch: PPOTransition, + targets: chex.Array, + ) -> Tuple: + """Calculate the critic loss.""" + # RERUN NETWORK + value = critic_apply_fn(critic_params, traj_batch.obs) + + # CALCULATE VALUE LOSS + value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( + -config.system.clip_eps, config.system.clip_eps + ) + value_losses = jnp.square(value - targets) + value_losses_clipped = jnp.square(value_pred_clipped - targets) + value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() + + critic_total_loss = config.system.vf_coef * value_loss + return critic_total_loss, (value_loss) + + # CALCULATE ACTOR LOSS + actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) + actor_loss_info, actor_grads = actor_grad_fn( + params.actor_params, opt_states.actor_opt_state, traj_batch, advantages + ) + + # CALCULATE CRITIC LOSS + critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) + critic_loss_info, critic_grads = critic_grad_fn( + params.critic_params, opt_states.critic_opt_state, traj_batch, targets + ) + + # Compute the parallel mean (pmean) over the batch. + # This pmean could be a regular mean as the batch axis is on the same device. + actor_grads, actor_loss_info = jax.lax.pmean( + (actor_grads, actor_loss_info), axis_name="batch" + ) + # pmean over devices. + actor_grads, actor_loss_info = jax.lax.pmean( + (actor_grads, actor_loss_info), axis_name="device" + ) + + critic_grads, critic_loss_info = jax.lax.pmean( + (critic_grads, critic_loss_info), axis_name="batch" + ) + # pmean over devices. + critic_grads, critic_loss_info = jax.lax.pmean( + (critic_grads, critic_loss_info), axis_name="device" + ) + + # UPDATE ACTOR PARAMS AND OPTIMISER STATE + actor_updates, actor_new_opt_state = actor_update_fn( + actor_grads, opt_states.actor_opt_state + ) + actor_new_params = optax.apply_updates(params.actor_params, actor_updates) + + # UPDATE CRITIC PARAMS AND OPTIMISER STATE + critic_updates, critic_new_opt_state = critic_update_fn( + critic_grads, opt_states.critic_opt_state + ) + critic_new_params = optax.apply_updates(params.critic_params, critic_updates) + + # PACK NEW PARAMS AND OPTIMISER STATE + new_params = Params(actor_new_params, critic_new_params) + new_opt_state = OptStates(actor_new_opt_state, critic_new_opt_state) + + # PACK LOSS INFO + total_loss = actor_loss_info[0] + critic_loss_info[0] + value_loss = critic_loss_info[1] + actor_loss = actor_loss_info[1][0] + entropy = actor_loss_info[1][1] + loss_info = { + "total_loss": total_loss, + "value_loss": value_loss, + "actor_loss": actor_loss, + "entropy": entropy, + } + + return (new_params, new_opt_state), loss_info + + params, opt_states, traj_batch, advantages, targets, key = update_state + key, shuffle_key = jax.random.split(key) + + # SHUFFLE MINIBATCHES + batch_size = config.system.rollout_length * config.arch.num_envs + permutation = jax.random.permutation(shuffle_key, batch_size) + batch = (traj_batch, advantages, targets) + batch = tree.map(lambda x: merge_leading_dims(x, 2), batch) + shuffled_batch = tree.map(lambda x: jnp.take(x, permutation, axis=0), batch) + minibatches = tree.map( + lambda x: jnp.reshape(x, (config.system.num_minibatches, -1, *x.shape[1:])), + shuffled_batch, + ) + + # UPDATE MINIBATCHES + (params, opt_states), loss_info = jax.lax.scan( + _update_minibatch, (params, opt_states), minibatches + ) + + update_state = (params, opt_states, traj_batch, advantages, targets, key) + return update_state, loss_info + + update_state = (params, opt_states, traj_batch, advantages, targets, key) + + # UPDATE EPOCHS + update_state, loss_info = jax.lax.scan( + _update_epoch, update_state, None, config.system.ppo_epochs + ) + + params, opt_states, traj_batch, advantages, targets, key = update_state + learner_state = LearnerState(params, opt_states, key, env_state, last_timestep) + metric = traj_batch.info + return learner_state, (metric, loss_info, traj_batch) + + def learner_fn( + learner_state: LearnerState, + ) -> Tuple[ExperimentOutput[LearnerState], PPOTransition]: + """Learner function. + + This function represents the learner, it updates the network parameters + by iteratively applying the `_update_step` function for a fixed number of + updates. The `_update_step` function is vectorized over a batch of inputs. + + Args: + ---- + learner_state (NamedTuple): + - params (Params): The initial model parameters. + - opt_states (OptStates): The initial optimizer state. + - key (chex.PRNGKey): The random number generator state. + - env_state (LogEnvState): The environment state. + - timesteps (TimeStep): The initial timestep in the initial trajectory. + + """ + batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch") + + learner_state, (episode_info, loss_info, traj_batch) = jax.lax.scan( + batched_update_step, learner_state, None, config.system.num_updates_per_eval + ) + return ( + ExperimentOutput( + learner_state=learner_state, + episode_metrics=episode_info, + train_metrics=loss_info, + ), + traj_batch, + ) + + return learner_fn + + +def learner_setup( + env: MarlEnv, keys: chex.Array, config: DictConfig +) -> Tuple[StoreExpLearnerFn[LearnerState], Actor, LearnerState]: + """Initialise learner_fn, network, optimiser, environment and states.""" + # Get available TPU cores. + n_devices = len(jax.devices()) + + # Get number of actions and agents. + num_actions = env.action_dim + config.system.num_agents = env.num_agents + config.system.num_actions = num_actions + + # PRNG keys. + key, key_p = keys + + # Define network and optimiser. + actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) + action_head, _ = get_action_head(env.action_spec()) + actor_action_head = hydra.utils.instantiate(action_head, action_dim=env.action_dim) + critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) + + actor_network = Actor(torso=actor_torso, action_head=actor_action_head) + critic_network = Critic(torso=critic_torso) + + actor_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(config.system.actor_lr, eps=1e-5), + ) + critic_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(config.system.critic_lr, eps=1e-5), + ) + + # Initialise observation with obs of all agents. + obs = env.observation_spec().generate_value() + init_x = tree.map(lambda x: x[jnp.newaxis, ...], obs) + + # Initialise actor params and optimiser state. + actor_params = actor_network.init(key_p, init_x) + actor_opt_state = actor_optim.init(actor_params) + + # Initialise critic params and optimiser state. + critic_params = critic_network.init(key_p, init_x) + critic_opt_state = critic_optim.init(critic_params) + + # Load model from checkpoint if specified. + if config.logger.checkpointing.load_model: + loaded_checkpoint = Checkpointer( + model_name=config.logger.system_name, + **config.logger.checkpointing.load_args, # Other checkpoint args + ) + # Restore the learner state from the checkpoint + restored_params, _ = loaded_checkpoint.restore_params( + input_params=Params(actor_params, critic_params) + ) + # Update the params + actor_params, critic_params = restored_params.actor_params, restored_params.critic_params + + # Pack apply and update functions. + apply_fns = (actor_network.apply, critic_network.apply) + update_fns = (actor_optim.update, critic_optim.update) + + # Get batched iterated update and replicate it to pmap it over cores. + learn = get_learner_fn(env, apply_fns, update_fns, config) + learn = jax.pmap(learn, axis_name="device") + + # Broadcast params and optimiser state to cores and batch. + broadcast = lambda x: jnp.broadcast_to( + x, (n_devices, config.system.update_batch_size, *x.shape) + ) + + actor_params = tree.map(broadcast, actor_params) + actor_opt_state = tree.map(broadcast, actor_opt_state) + critic_params = tree.map(broadcast, critic_params) + critic_opt_state = tree.map(broadcast, critic_opt_state) + + # Initialise environment states and timesteps. + key, *env_keys = jax.random.split( + key, n_devices * config.system.update_batch_size * config.arch.num_envs + 1 + ) + env_states, timesteps = jax.vmap(env.reset, in_axes=(0))( + jnp.stack(env_keys), + ) + + # Split keys for each core. + key, *step_keys = jax.random.split(key, n_devices * config.system.update_batch_size + 1) + + # Add dimension to pmap over. + reshape_step_keys = lambda x: x.reshape( + (n_devices, config.system.update_batch_size) + x.shape[1:] + ) + step_keys = reshape_step_keys(jnp.stack(step_keys)) + reshape_states = lambda x: x.reshape( + (n_devices, config.system.update_batch_size, config.arch.num_envs) + x.shape[1:] + ) + env_states = tree.map(reshape_states, env_states) + timesteps = tree.map(reshape_states, timesteps) + + params = Params(actor_params, critic_params) + opt_states = OptStates(actor_opt_state, critic_opt_state) + + init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps) + return learn, actor_network, init_learner_state + + +def run_experiment(_config: DictConfig) -> None: + """Runs experiment.""" + _config.logger.system_name = "ff_ippo" + # Logger setup + config = copy.deepcopy(_config) + logger = MavaLogger(config) + + n_devices = len(jax.devices()) + + # Create the enviroments for train and eval. + env, eval_env = make(config=config) + + # PRNG keys. + key, key_e, key_p = jax.random.split(jax.random.PRNGKey(config.system.seed), num=3) + + # Setup learner. + learn, actor_network, learner_state = learner_setup(env, (key, key_p), config) + + # Setup evaluator. + eval_keys = jax.random.split(key_e, n_devices) + eval_act_fn = make_ff_eval_act_fn(actor_network, config) + evaluator = get_eval_fn(eval_env, eval_act_fn, config, config.arch.num_eval_episodes) + + config.system.num_updates_per_eval = config.system.num_updates // config.arch.num_evaluation + steps_per_rollout = ( + n_devices + * config.system.num_updates_per_eval + * config.system.rollout_length + * config.system.update_batch_size + * config.arch.num_envs + ) + # Get total_timesteps + config.system.total_timesteps = ( + n_devices + * config.system.num_updates + * config.system.rollout_length + * config.system.update_batch_size + * config.arch.num_envs + ) + cfg: Dict = OmegaConf.to_container(config, resolve=True) + cfg["arch"]["devices"] = jax.devices() + pprint(cfg) + + # Set up checkpointer + save_checkpoint = config.logger.checkpointing.save_model + if save_checkpoint: + checkpointer = Checkpointer( + metadata=config, # Save all config as metadata in the checkpoint + model_name=config.logger.system_name, + **config.logger.checkpointing.save_args, # Checkpoint args + ) + + dummy_flashbax_transition = { + "done": jnp.zeros((config.system.num_agents,), dtype=bool), + "action": jnp.zeros((config.system.num_agents,), dtype=jnp.int32), + "reward": jnp.zeros((config.system.num_agents,), dtype=jnp.float32), + "observation": jnp.zeros( + ( + config.system.num_agents, + env.observation_spec().agents_view.shape[1], + ), + dtype=jnp.float32, + ), + "legal_action_mask": jnp.zeros( + ( + config.system.num_agents, + config.system.num_actions, + ), + dtype=bool, + ), + } + + buffer = fbx.make_flat_buffer( + max_length=int(5e5), # Max number of transitions to store + min_length=int(1), + sample_batch_size=1, + add_sequences=True, + add_batch_size=( + n_devices + * config.system.num_updates_per_eval + * config.system.update_batch_size + * config.arch.num_envs + ), + ) + buffer_state = buffer.init( + dummy_flashbax_transition, + ) + buffer_add = jax.jit(buffer.add, donate_argnums=(0)) + + # Shape legend: + # D: Number of devices + # NU: Number of updates per evaluation + # UB: Update batch size + # T: Time steps per rollout + # NE: Number of environments + + @jax.jit + def _reshape_experience(experience: Dict[str, chex.Array]) -> Dict[str, chex.Array]: + """Reshape experience to match buffer.""" + # Swap the T and NE axes (D, NU, UB, T, NE, ...) -> (D, NU, UB, NE, T, ...) + experience = tree.map(lambda x: x.swapaxes(3, 4), experience) + # Merge 4 leading dimensions into 1. (D, NU, UB, NE, T ...) -> (D * NU * UB * NE, T, ...) + experience = tree.map(lambda x: x.reshape(-1, *x.shape[4:]), experience) + return experience + + # Use vault to record experience + if SAVE_VAULT: + vault = Vault( + vault_name=VAULT_NAME, + experience_structure=buffer_state.experience, + vault_uid=VAULT_UID, + # Metadata must be a python dictionary + metadata=OmegaConf.to_container(config, resolve=True), + ) + + # Run experiment for a total number of evaluations. + max_episode_return = -jnp.inf + best_params = None + for eval_step in range(config.arch.num_evaluation): + # Train. + start_time = time.time() + + learner_output, experience_to_store = learn(learner_state) + + # Record data into the vault + if SAVE_VAULT: + # Pack transition + flashbax_transition = _reshape_experience( + { + # (D, NU, UB, T, NE, ...) + "done": experience_to_store.done, + "action": experience_to_store.action, + "reward": experience_to_store.reward, + "observation": experience_to_store.obs.agents_view, + "legal_action_mask": experience_to_store.obs.action_mask, + } + ) + # Add to fbx buffer + buffer_state = buffer_add(buffer_state, flashbax_transition) + + # Save buffer into vault + if eval_step % VAULT_SAVE_INTERVAL == 0: + write_length = vault.write(buffer_state) + print(f"(Wrote {write_length}) Vault index = {vault.vault_index}") + + jax.block_until_ready(learner_output) + + # Log the results of the training. + elapsed_time = time.time() - start_time + t = int(steps_per_rollout * (eval_step + 1)) + episode_metrics, ep_completed = get_final_step_metrics(learner_output.episode_metrics) + episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + + # Separately log timesteps, actoring metrics and training metrics. + logger.log({"timestep": t}, t, eval_step, LogEvent.MISC) + if ep_completed: + logger.log(learner_output.episode_metrics, t, eval_step, LogEvent.ACT) + logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN) + + # Prepare for evaluation. + start_time = time.time() + + trained_params = unreplicate_batch_dim(learner_state.params.actor_params) + key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) + eval_keys = jnp.stack(eval_keys) + eval_keys = eval_keys.reshape(n_devices, -1) + + # Evaluate. + eval_metrics = evaluator(trained_params, eval_keys, {}) + jax.block_until_ready(eval_metrics) + + # Log the results of the evaluation. + elapsed_time = time.time() - start_time + episode_return = jnp.mean(eval_metrics["episode_return"]) + + steps_per_eval = int(jnp.sum(eval_metrics["episode_length"])) + eval_metrics["steps_per_second"] = steps_per_eval / elapsed_time + logger.log(eval_metrics, t, eval_step, LogEvent.EVAL) + + if save_checkpoint: + # Save checkpoint of learner state + checkpointer.save( + timestep=steps_per_rollout * (eval_step + 1), + unreplicated_learner_state=unreplicate_n_dims(learner_output.learner_state), + episode_return=episode_return, + ) + + if config.arch.absolute_metric and max_episode_return <= episode_return: + best_params = copy.deepcopy(trained_params) + max_episode_return = episode_return + + # Update runner state to continue training. + learner_state = learner_output.learner_state + + # Final write to vault for any remaining data + vault.write(buffer_state) + + # Measure absolute metric. + if config.arch.absolute_metric: + start_time = time.time() + + eval_episodes = config.arch.num_absolute_metric_eval_episodes + abs_metric_evaluator = get_eval_fn(eval_env, eval_act_fn, config, eval_episodes) + + key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) + eval_keys = jnp.stack(eval_keys) + eval_keys = eval_keys.reshape(n_devices, -1) + + eval_metrics = abs_metric_evaluator(best_params, eval_keys, {}) + jax.block_until_ready(eval_metrics) + + elapsed_time = time.time() - start_time + steps_per_eval = int(jnp.sum(eval_metrics["episode_length"])) + t = int(steps_per_rollout * (eval_step + 1)) + eval_metrics["steps_per_second"] = steps_per_eval / elapsed_time + logger.log(eval_metrics, t, eval_step, LogEvent.ABSOLUTE) + + # Stop logger + logger.stop() + + +@hydra.main(config_path="../configs/default", config_name="ff_ippo.yaml", version_base="1.2") +def hydra_entry_point(cfg: DictConfig) -> None: + """Experiment entry point.""" + # Allow dynamic attributes. + OmegaConf.set_struct(cfg, False) + + # Run experiment. + run_experiment(cfg) + + print(f"{Fore.CYAN}{Style.BRIGHT}IPPO experiment completed{Style.RESET_ALL}") + + +if __name__ == "__main__": + hydra_entry_point() diff --git a/mava/evaluator.py b/mava/evaluator.py index 6b2fda2038..f157b42d0d 100644 --- a/mava/evaluator.py +++ b/mava/evaluator.py @@ -293,7 +293,7 @@ def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]: # find the first instance of done to get the metrics at that timestep, we don't # care about subsequent steps because we only the results from the first episode done_idx = np.argmax(timesteps.last(), axis=0) - metrics = jax.tree_map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics) + metrics = tree.map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics) del metrics["is_terminal_step"] # uneeded for logging return key, metrics @@ -307,7 +307,7 @@ def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]: metrics_array.append(metric) # flatten metrics - metrics: Metrics = jax.tree_map(lambda *x: np.array(x).reshape(-1), *metrics_array) + metrics: Metrics = tree.map(lambda *x: np.array(x).reshape(-1), *metrics_array) return metrics def timed_eval_fn(params: FrozenDict, key: PRNGKey, init_act_state: ActorState) -> Metrics: diff --git a/mava/networks/retention.py b/mava/networks/retention.py index a041abf339..29f159f26d 100644 --- a/mava/networks/retention.py +++ b/mava/networks/retention.py @@ -237,12 +237,12 @@ def setup(self) -> None: self.w_g = self.param( "w_g", nn.initializers.normal(stddev=1 / self.embed_dim), - (self.embed_dim, self.head_size), + (self.embed_dim, self.embed_dim), ) self.w_o = self.param( "w_o", nn.initializers.normal(stddev=1 / self.embed_dim), - (self.head_size, self.embed_dim), + (self.embed_dim, self.embed_dim), ) self.group_norm = nn.GroupNorm(num_groups=self.n_head) @@ -278,7 +278,7 @@ def __call__( if self.memory_config.timestep_positional_encoding: key, query, value = self.pe(key, query, value, step_count) - ret_output = jnp.zeros((B, C, self.head_size), dtype=value.dtype) + ret_output = jnp.zeros((B, C, self.embed_dim), dtype=value.dtype) for head in range(self.n_head): y, new_hs = self.retention_heads[head](key, query, value, hstate[:, head], dones) ret_output = ret_output.at[ @@ -304,7 +304,7 @@ def recurrent( if self.memory_config.timestep_positional_encoding: key_n, query_n, value_n = self.pe(key_n, query_n, value_n, step_count) - ret_output = jnp.zeros((B, S, self.head_size), dtype=value_n.dtype) + ret_output = jnp.zeros((B, S, self.embed_dim), dtype=value_n.dtype) for head in range(self.n_head): y, new_hs = self.retention_heads[head].recurrent( key_n, query_n, value_n, hstate[:, head] diff --git a/mava/systems/mat/anakin/mat.py b/mava/systems/mat/anakin/mat.py index 2973fe5678..1db141d198 100644 --- a/mava/systems/mat/anakin/mat.py +++ b/mava/systems/mat/anakin/mat.py @@ -37,16 +37,13 @@ ExperimentOutput, LearnerFn, MarlEnv, + Metrics, TimeStep, ) from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer from mava.utils.config import check_total_timesteps -from mava.utils.jax_utils import ( - merge_leading_dims, - unreplicate_batch_dim, - unreplicate_n_dims, -) +from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger from mava.utils.network_utils import get_action_head from mava.utils.training import make_learning_rate @@ -83,51 +80,35 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup _ (Any): The current metrics info. """ - def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]: + def _env_step( + learner_state: LearnerState, _: Any + ) -> Tuple[LearnerState, Tuple[PPOTransition, Metrics]]: """Step the environment.""" params, opt_state, key, env_state, last_timestep = learner_state - # SELECT ACTION + # Select action key, policy_key = jax.random.split(key) action, log_prob, value = actor_action_select_fn( # type: ignore params, last_timestep.observation, policy_key, ) - # STEP ENVIRONMENT + # Step environment env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) - # LOG EPISODE METRICS - # Repeat along the agent dimension. This is needed to handle the - # shuffling along the agent dimension during training. - info = tree.map( - lambda x: jnp.repeat(x[..., jnp.newaxis], config.system.num_agents, axis=-1), - timestep.extras["episode_metrics"], - ) - - # SET TRANSITION - done = tree.map( - lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), - timestep.last(), - ) + done = timestep.last().repeat(env.num_agents).reshape(config.arch.num_envs, -1) transition = PPOTransition( - done, - action, - value, - timestep.reward, - log_prob, - last_timestep.observation, - info, + done, action, value, timestep.reward, log_prob, last_timestep.observation ) learner_state = LearnerState(params, opt_state, key, env_state, timestep) - return learner_state, transition + return learner_state, (transition, timestep.extras["episode_metrics"]) - # STEP ENVIRONMENT FOR ROLLOUT LENGTH - learner_state, traj_batch = jax.lax.scan( + # Step environment for rollout length + learner_state, (traj_batch, episode_metrics) = jax.lax.scan( _env_step, learner_state, None, config.system.rollout_length ) - # CALCULATE ADVANTAGE + # Calculate advantage params, opt_state, key, env_state, last_timestep = learner_state key, last_val_key = jax.random.split(key) @@ -171,8 +152,6 @@ def _update_epoch(update_state: Tuple, _: Any) -> Tuple: def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: """Update the network for a single minibatch.""" - - # UNPACK TRAIN STATE AND BATCH INFO params, opt_state, key = train_state traj_batch, advantages, targets = batch_info @@ -184,8 +163,7 @@ def _loss_fn( entropy_key: chex.PRNGKey, ) -> Tuple: """Calculate the actor loss.""" - # RERUN NETWORK - + # Rerun network log_prob, value, entropy = actor_apply_fn( # type: ignore params, traj_batch.obs, @@ -193,14 +171,12 @@ def _loss_fn( entropy_key, ) - # CALCULATE ACTOR LOSS + # Calculate actor loss ratio = jnp.exp(log_prob - traj_batch.log_prob) - # Nomalise advantage at minibatch level gae = (gae - gae.mean()) / (gae.std() + 1e-8) - - loss_actor1 = ratio * gae - loss_actor2 = ( + actor_loss1 = ratio * gae + actor_loss2 = ( jnp.clip( ratio, 1.0 - config.system.clip_eps, @@ -208,28 +184,26 @@ def _loss_fn( ) * gae ) - loss_actor = -jnp.minimum(loss_actor1, loss_actor2) - loss_actor = loss_actor.mean() + actor_loss = -jnp.minimum(actor_loss1, actor_loss2) + actor_loss = actor_loss.mean() entropy = entropy.mean() - # CALCULATE VALUE LOSS + # Clipped MSE loss value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( -config.system.clip_eps, config.system.clip_eps ) - - # MSE LOSS value_losses = jnp.square(value - value_targets) value_losses_clipped = jnp.square(value_pred_clipped - value_targets) value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() total_loss = ( - loss_actor + actor_loss - config.system.ent_coef * entropy + config.system.vf_coef * value_loss ) - return total_loss, (loss_actor, entropy, value_loss) + return total_loss, (actor_loss, entropy, value_loss) - # CALCULATE ACTOR LOSS + # Calculate loss key, entropy_key = jax.random.split(key) actor_grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) actor_loss_info, actor_grads = actor_grad_fn( @@ -248,15 +222,11 @@ def _loss_fn( (actor_grads, actor_loss_info), axis_name="device" ) - # UPDATE ACTOR PARAMS AND OPTIMISER STATE + # Update params and optimiser state actor_updates, new_opt_state = actor_update_fn(actor_grads, opt_state) new_params = optax.apply_updates(params, actor_updates) - # PACK LOSS INFO - total_loss = actor_loss_info[0] - value_loss = actor_loss_info[1][2] - actor_loss = actor_loss_info[1][0] - entropy = actor_loss_info[1][1] + total_loss, (actor_loss, entropy, value_loss) = actor_loss_info loss_info = { "total_loss": total_loss, "value_loss": value_loss, @@ -269,7 +239,7 @@ def _loss_fn( params, opt_state, traj_batch, advantages, targets, key = update_state key, batch_shuffle_key, agent_shuffle_key, entropy_key = jax.random.split(key, 4) - # SHUFFLE MINIBATCHES + # Shuffle minibatches batch_size = config.system.rollout_length * config.arch.num_envs permutation = jax.random.permutation(batch_shuffle_key, batch_size) @@ -286,7 +256,7 @@ def _loss_fn( shuffled_batch, ) - # UPDATE MINIBATCHES + # Update minibatches (params, opt_state, entropy_key), loss_info = jax.lax.scan( _update_minibatch, (params, opt_state, entropy_key), minibatches ) @@ -296,7 +266,7 @@ def _loss_fn( update_state = params, opt_state, traj_batch, advantages, targets, key - # UPDATE EPOCHS + # Update epochs update_state, loss_info = jax.lax.scan( _update_epoch, update_state, None, config.system.ppo_epochs ) @@ -304,9 +274,7 @@ def _loss_fn( params, opt_state, traj_batch, advantages, targets, key = update_state learner_state = LearnerState(params, opt_state, key, env_state, last_timestep) - metric = traj_batch.info - - return learner_state, (metric, loss_info) + return learner_state, (episode_metrics, loss_info) def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: """Learner function. @@ -351,7 +319,7 @@ def learner_setup( # PRNG keys. key, actor_net_key = keys - # Initialise observation: Obs for all agents. + # Get mock inputs to initialise network. init_x = env.observation_spec().generate_value() init_x = tree.map(lambda x: x[None, ...], init_x) diff --git a/mava/systems/ppo/anakin/ff_ippo.py b/mava/systems/ppo/anakin/ff_ippo.py index fd97b280d8..ed29439168 100644 --- a/mava/systems/ppo/anakin/ff_ippo.py +++ b/mava/systems/ppo/anakin/ff_ippo.py @@ -26,22 +26,17 @@ from flax.core.frozen_dict import FrozenDict from jax import tree from omegaconf import DictConfig, OmegaConf -from optax._src.base import OptState from rich.pretty import pprint from mava.evaluator import get_eval_fn, make_ff_eval_act_fn from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition -from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv +from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv, Metrics from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer from mava.utils.config import check_total_timesteps -from mava.utils.jax_utils import ( - merge_leading_dims, - unreplicate_batch_dim, - unreplicate_n_dims, -) +from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger from mava.utils.network_utils import get_action_head from mava.utils.training import make_learning_rate @@ -79,11 +74,13 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup """ - def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]: + def _env_step( + learner_state: LearnerState, _: Any + ) -> Tuple[LearnerState, Tuple[PPOTransition, Metrics]]: """Step the environment.""" params, opt_states, key, env_state, last_timestep = learner_state - # SELECT ACTION + # Select action key, policy_key = jax.random.split(key) actor_policy = actor_apply_fn(params.actor_params, last_timestep.observation) value = critic_apply_fn(params.critic_params, last_timestep.observation) @@ -91,34 +88,22 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra action = actor_policy.sample(seed=policy_key) log_prob = actor_policy.log_prob(action) - # STEP ENVIRONMENT + # Step environment env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) - # LOG EPISODE METRICS - done = tree.map( - lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), - timestep.last(), - ) - info = timestep.extras["episode_metrics"] - + done = timestep.last().repeat(env.num_agents).reshape(config.arch.num_envs, -1) transition = PPOTransition( - done, - action, - value, - timestep.reward, - log_prob, - last_timestep.observation, - info, + done, action, value, timestep.reward, log_prob, last_timestep.observation ) learner_state = LearnerState(params, opt_states, key, env_state, timestep) - return learner_state, transition + return learner_state, (transition, timestep.extras["episode_metrics"]) - # STEP ENVIRONMENT FOR ROLLOUT LENGTH - learner_state, traj_batch = jax.lax.scan( + # Step environment for rollout length + learner_state, (traj_batch, episode_metrics) = jax.lax.scan( _env_step, learner_state, None, config.system.rollout_length ) - # CALCULATE ADVANTAGE + # Calculate advantage params, opt_states, key, env_state, last_timestep = learner_state last_val = critic_apply_fn(params.critic_params, last_timestep.observation) @@ -156,27 +141,26 @@ def _update_epoch(update_state: Tuple, _: Any) -> Tuple: def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: """Update the network for a single minibatch.""" - # UNPACK TRAIN STATE AND BATCH INFO params, opt_states, key = train_state traj_batch, advantages, targets = batch_info def _actor_loss_fn( actor_params: FrozenDict, - actor_opt_state: OptState, traj_batch: PPOTransition, gae: chex.Array, key: chex.PRNGKey, ) -> Tuple: """Calculate the actor loss.""" - # RERUN NETWORK + # Rerun network actor_policy = actor_apply_fn(actor_params, traj_batch.obs) log_prob = actor_policy.log_prob(traj_batch.action) - # CALCULATE ACTOR LOSS + # Calculate actor loss ratio = jnp.exp(log_prob - traj_batch.log_prob) + # Nomalise advantage at minibatch level gae = (gae - gae.mean()) / (gae.std() + 1e-8) - loss_actor1 = ratio * gae - loss_actor2 = ( + actor_loss1 = ratio * gae + actor_loss2 = ( jnp.clip( ratio, 1.0 - config.system.clip_eps, @@ -184,25 +168,24 @@ def _actor_loss_fn( ) * gae ) - loss_actor = -jnp.minimum(loss_actor1, loss_actor2) - loss_actor = loss_actor.mean() + actor_loss = -jnp.minimum(actor_loss1, actor_loss2) + actor_loss = actor_loss.mean() # The seed will be used in the TanhTransformedDistribution: entropy = actor_policy.entropy(seed=key).mean() - total_loss_actor = loss_actor - config.system.ent_coef * entropy - return total_loss_actor, (loss_actor, entropy) + total_actor_loss = actor_loss - config.system.ent_coef * entropy + return total_actor_loss, (actor_loss, entropy) def _critic_loss_fn( critic_params: FrozenDict, - critic_opt_state: OptState, traj_batch: PPOTransition, targets: chex.Array, ) -> Tuple: """Calculate the critic loss.""" - # RERUN NETWORK + # Rerun network value = critic_apply_fn(critic_params, traj_batch.obs) - # CALCULATE VALUE LOSS + # Clipped MSE loss value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( -config.system.clip_eps, config.system.clip_eps ) @@ -210,32 +193,23 @@ def _critic_loss_fn( value_losses_clipped = jnp.square(value_pred_clipped - targets) value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() - critic_total_loss = config.system.vf_coef * value_loss - return critic_total_loss, (value_loss) + total_value_loss = config.system.vf_coef * value_loss + return total_value_loss, value_loss - # CALCULATE ACTOR LOSS + # Calculate actor loss key, entropy_key = jax.random.split(key) actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) actor_loss_info, actor_grads = actor_grad_fn( - params.actor_params, - opt_states.actor_opt_state, - traj_batch, - advantages, - entropy_key, + params.actor_params, traj_batch, advantages, entropy_key ) - # CALCULATE CRITIC LOSS + # Calculate critic loss critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) - critic_loss_info, critic_grads = critic_grad_fn( - params.critic_params, - opt_states.critic_opt_state, - traj_batch, - targets, + value_loss_info, critic_grads = critic_grad_fn( + params.critic_params, traj_batch, targets ) # Compute the parallel mean (pmean) over the batch. - # This calculation is inspired by the Anakin architecture demo notebook. - # available at https://tinyurl.com/26tdzs5x # This pmean could be a regular mean as the batch axis is on the same device. actor_grads, actor_loss_info = jax.lax.pmean( (actor_grads, actor_loss_info), axis_name="batch" @@ -245,38 +219,35 @@ def _critic_loss_fn( (actor_grads, actor_loss_info), axis_name="device" ) - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="batch" + critic_grads, value_loss_info = jax.lax.pmean( + (critic_grads, value_loss_info), axis_name="batch" ) # pmean over devices. - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="device" + critic_grads, value_loss_info = jax.lax.pmean( + (critic_grads, value_loss_info), axis_name="device" ) - # UPDATE ACTOR PARAMS AND OPTIMISER STATE + # Update params and optimiser state actor_updates, actor_new_opt_state = actor_update_fn( actor_grads, opt_states.actor_opt_state ) actor_new_params = optax.apply_updates(params.actor_params, actor_updates) - # UPDATE CRITIC PARAMS AND OPTIMISER STATE critic_updates, critic_new_opt_state = critic_update_fn( critic_grads, opt_states.critic_opt_state ) critic_new_params = optax.apply_updates(params.critic_params, critic_updates) - # PACK NEW PARAMS AND OPTIMISER STATE new_params = Params(actor_new_params, critic_new_params) new_opt_state = OptStates(actor_new_opt_state, critic_new_opt_state) - # PACK LOSS INFO - total_loss = actor_loss_info[0] + critic_loss_info[0] - value_loss = critic_loss_info[1] - actor_loss = actor_loss_info[1][0] - entropy = actor_loss_info[1][1] + actor_loss, (_, entropy) = actor_loss_info + value_loss, unscaled_value_loss = value_loss_info + + total_loss = actor_loss + value_loss loss_info = { "total_loss": total_loss, - "value_loss": value_loss, + "value_loss": unscaled_value_loss, "actor_loss": actor_loss, "entropy": entropy, } @@ -285,7 +256,7 @@ def _critic_loss_fn( params, opt_states, traj_batch, advantages, targets, key = update_state key, shuffle_key, entropy_key = jax.random.split(key, 3) - # SHUFFLE MINIBATCHES + # Shuffle data and create minibatches batch_size = config.system.rollout_length * config.arch.num_envs permutation = jax.random.permutation(shuffle_key, batch_size) batch = (traj_batch, advantages, targets) @@ -296,7 +267,7 @@ def _critic_loss_fn( shuffled_batch, ) - # UPDATE MINIBATCHES + # Update minibatches (params, opt_states, entropy_key), loss_info = jax.lax.scan( _update_minibatch, (params, opt_states, entropy_key), minibatches ) @@ -306,15 +277,14 @@ def _critic_loss_fn( update_state = (params, opt_states, traj_batch, advantages, targets, key) - # UPDATE EPOCHS + # Update epochs update_state, loss_info = jax.lax.scan( _update_epoch, update_state, None, config.system.ppo_epochs ) params, opt_states, traj_batch, advantages, targets, key = update_state learner_state = LearnerState(params, opt_states, key, env_state, last_timestep) - metric = traj_batch.info - return learner_state, (metric, loss_info) + return learner_state, (episode_metrics, loss_info) def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: """Learner function. @@ -381,7 +351,7 @@ def learner_setup( optax.adam(critic_lr, eps=1e-5), ) - # Initialise observation with obs of all agents. + # Get mock inputs to initialise network. obs = env.observation_spec().generate_value() init_x = tree.map(lambda x: x[jnp.newaxis, ...], obs) diff --git a/mava/systems/ppo/anakin/ff_mappo.py b/mava/systems/ppo/anakin/ff_mappo.py index 26b3e17b67..5e5a0006f2 100644 --- a/mava/systems/ppo/anakin/ff_mappo.py +++ b/mava/systems/ppo/anakin/ff_mappo.py @@ -25,14 +25,13 @@ from flax.core.frozen_dict import FrozenDict from jax import tree from omegaconf import DictConfig, OmegaConf -from optax._src.base import OptState from rich.pretty import pprint from mava.evaluator import get_eval_fn, make_ff_eval_act_fn from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition -from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv +from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv, Metrics from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer from mava.utils.config import check_total_timesteps @@ -74,39 +73,36 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup """ - def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]: + def _env_step( + learner_state: LearnerState, _: Any + ) -> Tuple[LearnerState, Tuple[PPOTransition, Metrics]]: """Step the environment.""" params, opt_states, key, env_state, last_timestep = learner_state - # SELECT ACTION + # Select action key, policy_key = jax.random.split(key) actor_policy = actor_apply_fn(params.actor_params, last_timestep.observation) value = critic_apply_fn(params.critic_params, last_timestep.observation) action = actor_policy.sample(seed=policy_key) log_prob = actor_policy.log_prob(action) - # STEP ENVIRONMENT + # Step environment env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) - # LOG EPISODE METRICS - done = tree.map( - lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), - timestep.last(), - ) - info = timestep.extras["episode_metrics"] + done = timestep.last().repeat(env.num_agents).reshape(config.arch.num_envs, -1) transition = PPOTransition( - done, action, value, timestep.reward, log_prob, last_timestep.observation, info + done, action, value, timestep.reward, log_prob, last_timestep.observation ) learner_state = LearnerState(params, opt_states, key, env_state, timestep) - return learner_state, transition + return learner_state, (transition, timestep.extras["episode_metrics"]) - # STEP ENVIRONMENT FOR ROLLOUT LENGTH - learner_state, traj_batch = jax.lax.scan( + # Step environment for rollout length + learner_state, (traj_batch, episode_metrics) = jax.lax.scan( _env_step, learner_state, None, config.system.rollout_length ) - # CALCULATE ADVANTAGE + # Calculate advantage params, opt_states, key, env_state, last_timestep = learner_state last_val = critic_apply_fn(params.critic_params, last_timestep.observation) @@ -144,27 +140,26 @@ def _update_epoch(update_state: Tuple, _: Any) -> Tuple: def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: """Update the network for a single minibatch.""" - # UNPACK TRAIN STATE AND BATCH INFO params, opt_states, key = train_state traj_batch, advantages, targets = batch_info def _actor_loss_fn( actor_params: FrozenDict, - actor_opt_state: OptState, traj_batch: PPOTransition, gae: chex.Array, key: chex.PRNGKey, ) -> Tuple: """Calculate the actor loss.""" - # RERUN NETWORK + # Rerun network actor_policy = actor_apply_fn(actor_params, traj_batch.obs) log_prob = actor_policy.log_prob(traj_batch.action) - # CALCULATE ACTOR LOSS + # Calculate actor loss ratio = jnp.exp(log_prob - traj_batch.log_prob) + # Nomalise advantage at minibatch level gae = (gae - gae.mean()) / (gae.std() + 1e-8) - loss_actor1 = ratio * gae - loss_actor2 = ( + actor_loss1 = ratio * gae + actor_loss2 = ( jnp.clip( ratio, 1.0 - config.system.clip_eps, @@ -172,25 +167,24 @@ def _actor_loss_fn( ) * gae ) - loss_actor = -jnp.minimum(loss_actor1, loss_actor2) - loss_actor = loss_actor.mean() + actor_loss = -jnp.minimum(actor_loss1, actor_loss2) + actor_loss = actor_loss.mean() # The seed will be used in the TanhTransformedDistribution: entropy = actor_policy.entropy(seed=key).mean() - total_loss_actor = loss_actor - config.system.ent_coef * entropy - return total_loss_actor, (loss_actor, entropy) + total_actor_loss = actor_loss - config.system.ent_coef * entropy + return total_actor_loss, (actor_loss, entropy) def _critic_loss_fn( critic_params: FrozenDict, - critic_opt_state: OptState, traj_batch: PPOTransition, targets: chex.Array, ) -> Tuple: """Calculate the critic loss.""" - # RERUN NETWORK + # Rerun network value = critic_apply_fn(critic_params, traj_batch.obs) - # CALCULATE VALUE LOSS + # Clipped MSE loss value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( -config.system.clip_eps, config.system.clip_eps ) @@ -198,29 +192,26 @@ def _critic_loss_fn( value_losses_clipped = jnp.square(value_pred_clipped - targets) value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() - critic_total_loss = config.system.vf_coef * value_loss - return critic_total_loss, (value_loss) + total_value_loss = config.system.vf_coef * value_loss + return total_value_loss, value_loss - # CALCULATE ACTOR LOSS + # Calculate actor loss key, entropy_key = jax.random.split(key) actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) actor_loss_info, actor_grads = actor_grad_fn( params.actor_params, - opt_states.actor_opt_state, traj_batch, advantages, entropy_key, ) - # CALCULATE CRITIC LOSS + # Calculate critic loss critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) - critic_loss_info, critic_grads = critic_grad_fn( - params.critic_params, opt_states.critic_opt_state, traj_batch, targets + value_loss_info, critic_grads = critic_grad_fn( + params.critic_params, traj_batch, targets ) # Compute the parallel mean (pmean) over the batch. - # This calculation is inspired by the Anakin architecture demo notebook. - # available at https://tinyurl.com/26tdzs5x # This pmean could be a regular mean as the batch axis is on the same device. actor_grads, actor_loss_info = jax.lax.pmean( (actor_grads, actor_loss_info), axis_name="batch" @@ -230,21 +221,20 @@ def _critic_loss_fn( (actor_grads, actor_loss_info), axis_name="device" ) - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="batch" + critic_grads, value_loss_info = jax.lax.pmean( + (critic_grads, value_loss_info), axis_name="batch" ) # pmean over devices. - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="device" + critic_grads, value_loss_info = jax.lax.pmean( + (critic_grads, value_loss_info), axis_name="device" ) - # UPDATE ACTOR PARAMS AND OPTIMISER STATE + # Update params and optimiser state actor_updates, actor_new_opt_state = actor_update_fn( actor_grads, opt_states.actor_opt_state ) actor_new_params = optax.apply_updates(params.actor_params, actor_updates) - # UPDATE CRITIC PARAMS AND OPTIMISER STATE critic_updates, critic_new_opt_state = critic_update_fn( critic_grads, opt_states.critic_opt_state ) @@ -253,14 +243,13 @@ def _critic_loss_fn( new_params = Params(actor_new_params, critic_new_params) new_opt_state = OptStates(actor_new_opt_state, critic_new_opt_state) - # PACK LOSS INFO - total_loss = actor_loss_info[0] + critic_loss_info[0] - value_loss = critic_loss_info[1] - actor_loss = actor_loss_info[1][0] - entropy = actor_loss_info[1][1] + actor_loss, (_, entropy) = actor_loss_info + value_loss, unscaled_value_loss = value_loss_info + + total_loss = actor_loss + value_loss loss_info = { "total_loss": total_loss, - "value_loss": value_loss, + "value_loss": unscaled_value_loss, "actor_loss": actor_loss, "entropy": entropy, } @@ -269,7 +258,7 @@ def _critic_loss_fn( params, opt_states, traj_batch, advantages, targets, key = update_state key, shuffle_key, entropy_key = jax.random.split(key, 3) - # SHUFFLE MINIBATCHES + # Shuffle minibatches batch_size = config.system.rollout_length * config.arch.num_envs permutation = jax.random.permutation(shuffle_key, batch_size) batch = (traj_batch, advantages, targets) @@ -280,7 +269,7 @@ def _critic_loss_fn( shuffled_batch, ) - # UPDATE MINIBATCHES + # Update minibatches (params, opt_states, entropy_key), loss_info = jax.lax.scan( _update_minibatch, (params, opt_states, entropy_key), minibatches ) @@ -290,15 +279,14 @@ def _critic_loss_fn( update_state = (params, opt_states, traj_batch, advantages, targets, key) - # UPDATE EPOCHS + # Update epochs update_state, loss_info = jax.lax.scan( _update_epoch, update_state, None, config.system.ppo_epochs ) params, opt_states, traj_batch, advantages, targets, key = update_state learner_state = LearnerState(params, opt_states, key, env_state, last_timestep) - metric = traj_batch.info - return learner_state, (metric, loss_info) + return learner_state, (episode_metrics, loss_info) def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: """Learner function. @@ -365,7 +353,7 @@ def learner_setup( optax.adam(critic_lr, eps=1e-5), ) - # Initialise observation with obs of all agents. + # Get mock inputs to initialise network. obs = env.observation_spec().generate_value() init_x = tree.map(lambda x: x[jnp.newaxis, ...], obs) diff --git a/mava/systems/ppo/anakin/rec_ippo.py b/mava/systems/ppo/anakin/rec_ippo.py index 2122460807..2792383825 100644 --- a/mava/systems/ppo/anakin/rec_ippo.py +++ b/mava/systems/ppo/anakin/rec_ippo.py @@ -26,7 +26,6 @@ from flax.core.frozen_dict import FrozenDict from jax import tree from omegaconf import DictConfig, OmegaConf -from optax._src.base import OptState from rich.pretty import pprint from mava.evaluator import get_eval_fn, get_num_eval_envs, make_rec_eval_act_fn @@ -44,6 +43,7 @@ ExperimentOutput, LearnerFn, MarlEnv, + Metrics, RecActorApply, RecCriticApply, ) @@ -91,7 +91,7 @@ def _update_step(learner_state: RNNLearnerState, _: Any) -> Tuple[RNNLearnerStat def _env_step( learner_state: RNNLearnerState, _: Any - ) -> Tuple[RNNLearnerState, RNNPPOTransition]: + ) -> Tuple[RNNLearnerState, Tuple[RNNPPOTransition, Metrics]]: """Step the environment.""" ( params, @@ -107,10 +107,7 @@ def _env_step( # Add a batch dimension to the observation. batched_observation = tree.map(lambda x: x[jnp.newaxis, :], last_timestep.observation) - ac_in = ( - batched_observation, - last_done[jnp.newaxis, :], - ) + ac_in = (batched_observation, last_done[jnp.newaxis, :]) # Run the network. policy_hidden_state, actor_policy = actor_apply_fn( @@ -128,13 +125,7 @@ def _env_step( # Step the environment. env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) - # log episode return and length - done = tree.map( - lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), - timestep.last(), - ) - info = timestep.extras["episode_metrics"] - + done = timestep.last().repeat(env.num_agents).reshape(config.arch.num_envs, -1) hstates = HiddenStates(policy_hidden_state, critic_hidden_state) transition = RNNPPOTransition( last_done, @@ -144,35 +135,23 @@ def _env_step( log_prob, last_timestep.observation, last_hstates, - info, ) learner_state = RNNLearnerState( params, opt_states, key, env_state, timestep, done, hstates ) - return learner_state, transition + return learner_state, (transition, timestep.extras["episode_metrics"]) - # STEP ENVIRONMENT FOR ROLLOUT LENGTH - learner_state, traj_batch = jax.lax.scan( + # Step environment for rollout length + learner_state, (traj_batch, episode_metrics) = jax.lax.scan( _env_step, learner_state, None, config.system.rollout_length ) - # CALCULATE ADVANTAGE - ( - params, - opt_states, - key, - env_state, - last_timestep, - last_done, - hstates, - ) = learner_state + # Calculate advantage + params, opt_states, key, env_state, last_timestep, last_done, hstates = learner_state # Add a batch dimension to the observation. batched_last_observation = tree.map(lambda x: x[jnp.newaxis, :], last_timestep.observation) - ac_in = ( - batched_last_observation, - last_done[jnp.newaxis, :], - ) + ac_in = (batched_last_observation, last_done[jnp.newaxis, :]) # Run the network. _, last_val = critic_apply_fn(params.critic_params, hstates.critic_hidden_state, ac_in) @@ -208,30 +187,29 @@ def _update_epoch(update_state: Tuple, _: Any) -> Tuple: def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: """Update the network for a single minibatch.""" - # UNPACK TRAIN STATE AND BATCH INFO params, opt_states, key = train_state traj_batch, advantages, targets = batch_info def _actor_loss_fn( actor_params: FrozenDict, - actor_opt_state: OptState, traj_batch: RNNPPOTransition, gae: chex.Array, key: chex.PRNGKey, ) -> Tuple: """Calculate the actor loss.""" - # RERUN NETWORK - + # Rerun network obs_and_done = (traj_batch.obs, traj_batch.done) _, actor_policy = actor_apply_fn( actor_params, traj_batch.hstates.policy_hidden_state[0], obs_and_done ) log_prob = actor_policy.log_prob(traj_batch.action) + # Calculate actor loss ratio = jnp.exp(log_prob - traj_batch.log_prob) + # Nomalise advantage at minibatch level gae = (gae - gae.mean()) / (gae.std() + 1e-8) - loss_actor1 = ratio * gae - loss_actor2 = ( + actor_loss1 = ratio * gae + actor_loss2 = ( jnp.clip( ratio, 1.0 - config.system.clip_eps, @@ -239,28 +217,27 @@ def _actor_loss_fn( ) * gae ) - loss_actor = -jnp.minimum(loss_actor1, loss_actor2) - loss_actor = loss_actor.mean() + actor_loss = -jnp.minimum(actor_loss1, actor_loss2) + actor_loss = actor_loss.mean() # The seed will be used in the TanhTransformedDistribution: entropy = actor_policy.entropy(seed=key).mean() - total_loss = loss_actor - config.system.ent_coef * entropy - return total_loss, (loss_actor, entropy) + total_loss = actor_loss - config.system.ent_coef * entropy + return total_loss, (actor_loss, entropy) def _critic_loss_fn( critic_params: FrozenDict, - critic_opt_state: OptState, traj_batch: RNNPPOTransition, targets: chex.Array, ) -> Tuple: """Calculate the critic loss.""" - # RERUN NETWORK + # Rerun network obs_and_done = (traj_batch.obs, traj_batch.done) _, value = critic_apply_fn( critic_params, traj_batch.hstates.critic_hidden_state[0], obs_and_done ) - # CALCULATE VALUE LOSS + # Clipped MSE loss value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( -config.system.clip_eps, config.system.clip_eps ) @@ -269,28 +246,25 @@ def _critic_loss_fn( value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() total_loss = config.system.vf_coef * value_loss - return total_loss, (value_loss) + return total_loss, value_loss - # CALCULATE ACTOR LOSS + # Calculate actor loss key, entropy_key = jax.random.split(key) actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) actor_loss_info, actor_grads = actor_grad_fn( params.actor_params, - opt_states.actor_opt_state, traj_batch, advantages, entropy_key, ) - # CALCULATE CRITIC LOSS + # Calculate critic loss critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) - critic_loss_info, critic_grads = critic_grad_fn( - params.critic_params, opt_states.critic_opt_state, traj_batch, targets + value_loss_info, critic_grads = critic_grad_fn( + params.critic_params, traj_batch, targets ) # Compute the parallel mean (pmean) over the batch. - # This calculation is inspired by the Anakin architecture demo notebook. - # available at https://tinyurl.com/26tdzs5x # This pmean could be a regular mean as the batch axis is on the same device. actor_grads, actor_loss_info = jax.lax.pmean( (actor_grads, actor_loss_info), axis_name="batch" @@ -300,21 +274,20 @@ def _critic_loss_fn( (actor_grads, actor_loss_info), axis_name="device" ) - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="batch" + critic_grads, value_loss_info = jax.lax.pmean( + (critic_grads, value_loss_info), axis_name="batch" ) # pmean over devices. - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="device" + critic_grads, value_loss_info = jax.lax.pmean( + (critic_grads, value_loss_info), axis_name="device" ) - # UPDATE ACTOR PARAMS AND OPTIMISER STATE + # Update params and optimiser state actor_updates, actor_new_opt_state = actor_update_fn( actor_grads, opt_states.actor_opt_state ) actor_new_params = optax.apply_updates(params.actor_params, actor_updates) - # UPDATE CRITIC PARAMS AND OPTIMISER STATE critic_updates, critic_new_opt_state = critic_update_fn( critic_grads, opt_states.critic_opt_state ) @@ -323,14 +296,13 @@ def _critic_loss_fn( new_params = Params(actor_new_params, critic_new_params) new_opt_state = OptStates(actor_new_opt_state, critic_new_opt_state) - # PACK LOSS INFO - total_loss = actor_loss_info[0] + critic_loss_info[0] - value_loss = critic_loss_info[1] - actor_loss = actor_loss_info[1][0] - entropy = actor_loss_info[1][1] + actor_loss, (_, entropy) = actor_loss_info + value_loss, unscaled_value_loss = value_loss_info + + total_loss = actor_loss + value_loss loss_info = { "total_loss": total_loss, - "value_loss": value_loss, + "value_loss": unscaled_value_loss, "actor_loss": actor_loss, "entropy": entropy, } @@ -340,7 +312,7 @@ def _critic_loss_fn( params, opt_states, traj_batch, advantages, targets, key = update_state key, shuffle_key, entropy_key = jax.random.split(key, 3) - # SHUFFLE MINIBATCHES + # Shuffle minibatches batch = (traj_batch, advantages, targets) num_recurrent_chunks = ( config.system.rollout_length // config.system.recurrent_chunk_size @@ -365,7 +337,7 @@ def _critic_loss_fn( ) minibatches = tree.map(lambda x: jnp.swapaxes(x, 1, 0), reshaped_batch) - # UPDATE MINIBATCHES + # Update minibatches (params, opt_states, entropy_key), loss_info = jax.lax.scan( _update_minibatch, (params, opt_states, entropy_key), minibatches ) @@ -389,7 +361,7 @@ def _critic_loss_fn( key, ) - # UPDATE EPOCHS + # Update epochs update_state, loss_info = jax.lax.scan( _update_epoch, update_state, None, config.system.ppo_epochs ) @@ -404,8 +376,7 @@ def _critic_loss_fn( last_done, hstates, ) - metric = traj_batch.info - return learner_state, (metric, loss_info) + return learner_state, (episode_metrics, loss_info) def learner_fn(learner_state: RNNLearnerState) -> ExperimentOutput[RNNLearnerState]: """Learner function. @@ -486,7 +457,7 @@ def learner_setup( optax.adam(critic_lr, eps=1e-5), ) - # Initialise observation with obs of all agents. + # Get mock inputs to initialise network. init_obs = env.observation_spec().generate_value() init_obs = tree.map( lambda x: jnp.repeat(x[jnp.newaxis, ...], config.arch.num_envs, axis=0), diff --git a/mava/systems/ppo/anakin/rec_mappo.py b/mava/systems/ppo/anakin/rec_mappo.py index b995ac1678..96d7d74acf 100644 --- a/mava/systems/ppo/anakin/rec_mappo.py +++ b/mava/systems/ppo/anakin/rec_mappo.py @@ -26,7 +26,6 @@ from flax.core.frozen_dict import FrozenDict from jax import tree from omegaconf import DictConfig, OmegaConf -from optax._src.base import OptState from rich.pretty import pprint from mava.evaluator import get_eval_fn, get_num_eval_envs, make_rec_eval_act_fn @@ -44,6 +43,7 @@ ExperimentOutput, LearnerFn, MarlEnv, + Metrics, RecActorApply, RecCriticApply, ) @@ -91,7 +91,7 @@ def _update_step(learner_state: RNNLearnerState, _: Any) -> Tuple[RNNLearnerStat def _env_step( learner_state: RNNLearnerState, _: Any - ) -> Tuple[RNNLearnerState, RNNPPOTransition]: + ) -> Tuple[RNNLearnerState, Tuple[RNNPPOTransition, Metrics]]: """Step the environment.""" ( params, @@ -126,13 +126,7 @@ def _env_step( # Step the environment. env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) - # log episode return and length - done = tree.map( - lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), - timestep.last(), - ) - info = timestep.extras["episode_metrics"] - + done = timestep.last().repeat(env.num_agents).reshape(config.arch.num_envs, -1) hstates = HiddenStates(policy_hidden_state, critic_hidden_state) transition = RNNPPOTransition( last_done, @@ -142,28 +136,19 @@ def _env_step( log_prob, last_timestep.observation, last_hstates, - info, ) learner_state = RNNLearnerState( params, opt_states, key, env_state, timestep, done, hstates ) - return learner_state, transition + return learner_state, (transition, timestep.extras["episode_metrics"]) - # STEP ENVIRONMENT FOR ROLLOUT LENGTH - learner_state, traj_batch = jax.lax.scan( + # Step environment for rollout length + learner_state, (traj_batch, episode_metrics) = jax.lax.scan( _env_step, learner_state, None, config.system.rollout_length ) - # CALCULATE ADVANTAGE - ( - params, - opt_states, - key, - env_state, - last_timestep, - last_done, - hstates, - ) = learner_state + # Calculate advantage + params, opt_states, key, env_state, last_timestep, last_done, hstates = learner_state # Add a batch dimension to the observation. batched_last_observation = tree.map(lambda x: x[jnp.newaxis, :], last_timestep.observation) @@ -204,29 +189,29 @@ def _update_epoch(update_state: Tuple, _: Any) -> Tuple: def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: """Update the network for a single minibatch.""" - # UNPACK TRAIN STATE AND BATCH INFO params, opt_states, key = train_state traj_batch, advantages, targets = batch_info def _actor_loss_fn( actor_params: FrozenDict, - actor_opt_state: OptState, traj_batch: RNNPPOTransition, gae: chex.Array, key: chex.PRNGKey, ) -> Tuple: """Calculate the actor loss.""" - # RERUN NETWORK + # Rerun network obs_and_done = (traj_batch.obs, traj_batch.done) _, actor_policy = actor_apply_fn( actor_params, traj_batch.hstates.policy_hidden_state[0], obs_and_done ) log_prob = actor_policy.log_prob(traj_batch.action) + # Calculate actor loss ratio = jnp.exp(log_prob - traj_batch.log_prob) + # Nomalise advantage at minibatch level gae = (gae - gae.mean()) / (gae.std() + 1e-8) - loss_actor1 = ratio * gae - loss_actor2 = ( + actor_loss1 = ratio * gae + actor_loss2 = ( jnp.clip( ratio, 1.0 - config.system.clip_eps, @@ -234,28 +219,27 @@ def _actor_loss_fn( ) * gae ) - loss_actor = -jnp.minimum(loss_actor1, loss_actor2) - loss_actor = loss_actor.mean() + actor_loss = -jnp.minimum(actor_loss1, actor_loss2) + actor_loss = actor_loss.mean() # The seed will be used in the TanhTransformedDistribution: entropy = actor_policy.entropy(seed=key).mean() - total_loss = loss_actor - config.system.ent_coef * entropy - return total_loss, (loss_actor, entropy) + total_loss = actor_loss - config.system.ent_coef * entropy + return total_loss, (actor_loss, entropy) def _critic_loss_fn( critic_params: FrozenDict, - critic_opt_state: OptState, traj_batch: RNNPPOTransition, targets: chex.Array, ) -> Tuple: """Calculate the critic loss.""" - # RERUN NETWORK + # Rerun network obs_and_done = (traj_batch.obs, traj_batch.done) _, value = critic_apply_fn( critic_params, traj_batch.hstates.critic_hidden_state[0], obs_and_done ) - # CALCULATE VALUE LOSS + # Clipped MSE loss value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( -config.system.clip_eps, config.system.clip_eps ) @@ -264,28 +248,25 @@ def _critic_loss_fn( value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() total_loss = config.system.vf_coef * value_loss - return total_loss, (value_loss) + return total_loss, value_loss - # CALCULATE ACTOR LOSS + # Calculate actor loss key, entropy_key = jax.random.split(key) actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) actor_loss_info, actor_grads = actor_grad_fn( params.actor_params, - opt_states.actor_opt_state, traj_batch, advantages, entropy_key, ) - # CALCULATE CRITIC LOSS + # Calculate critic loss critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) - critic_loss_info, critic_grads = critic_grad_fn( - params.critic_params, opt_states.critic_opt_state, traj_batch, targets + value_loss_info, critic_grads = critic_grad_fn( + params.critic_params, traj_batch, targets ) # Compute the parallel mean (pmean) over the batch. - # This calculation is inspired by the Anakin architecture demo notebook. - # available at https://tinyurl.com/26tdzs5x # This pmean could be a regular mean as the batch axis is on the same device. actor_grads, actor_loss_info = jax.lax.pmean( (actor_grads, actor_loss_info), axis_name="batch" @@ -295,21 +276,20 @@ def _critic_loss_fn( (actor_grads, actor_loss_info), axis_name="device" ) - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="batch" + critic_grads, value_loss_info = jax.lax.pmean( + (critic_grads, value_loss_info), axis_name="batch" ) # pmean over devices. - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="device" + critic_grads, value_loss_info = jax.lax.pmean( + (critic_grads, value_loss_info), axis_name="device" ) - # UPDATE ACTOR PARAMS AND OPTIMISER STATE + # Update params and optimiser state actor_updates, actor_new_opt_state = actor_update_fn( actor_grads, opt_states.actor_opt_state ) actor_new_params = optax.apply_updates(params.actor_params, actor_updates) - # UPDATE CRITIC PARAMS AND OPTIMISER STATE critic_updates, critic_new_opt_state = critic_update_fn( critic_grads, opt_states.critic_opt_state ) @@ -318,14 +298,13 @@ def _critic_loss_fn( new_params = Params(actor_new_params, critic_new_params) new_opt_state = OptStates(actor_new_opt_state, critic_new_opt_state) - # PACK LOSS INFO - total_loss = actor_loss_info[0] + critic_loss_info[0] - value_loss = critic_loss_info[1] - actor_loss = actor_loss_info[1][0] - entropy = actor_loss_info[1][1] + actor_loss, (_, entropy) = actor_loss_info + value_loss, unscaled_value_loss = value_loss_info + + total_loss = actor_loss + value_loss loss_info = { "total_loss": total_loss, - "value_loss": value_loss, + "value_loss": unscaled_value_loss, "actor_loss": actor_loss, "entropy": entropy, } @@ -335,7 +314,7 @@ def _critic_loss_fn( params, opt_states, traj_batch, advantages, targets, key = update_state key, shuffle_key, entropy_key = jax.random.split(key, 3) - # SHUFFLE MINIBATCHES + # Shuffle minibatches batch = (traj_batch, advantages, targets) num_recurrent_chunks = ( config.system.rollout_length // config.system.recurrent_chunk_size @@ -360,7 +339,7 @@ def _critic_loss_fn( ) minibatches = tree.map(lambda x: jnp.swapaxes(x, 1, 0), reshaped_batch) - # UPDATE MINIBATCHES + # Update minibatches (params, opt_states, entropy_key), loss_info = jax.lax.scan( _update_minibatch, (params, opt_states, entropy_key), minibatches ) @@ -384,7 +363,7 @@ def _critic_loss_fn( key, ) - # UPDATE EPOCHS + # Update epochs update_state, loss_info = jax.lax.scan( _update_epoch, update_state, None, config.system.ppo_epochs ) @@ -399,8 +378,7 @@ def _critic_loss_fn( last_done, hstates, ) - metric = traj_batch.info - return learner_state, (metric, loss_info) + return learner_state, (episode_metrics, loss_info) def learner_fn(learner_state: RNNLearnerState) -> ExperimentOutput[RNNLearnerState]: """Learner function. @@ -482,7 +460,7 @@ def learner_setup( optax.adam(critic_lr, eps=1e-5), ) - # Initialise observation with obs of all agents. + # Get mock inputs to initialise network. init_obs = env.observation_spec().generate_value() init_obs = tree.map( lambda x: jnp.repeat(x[jnp.newaxis, ...], config.arch.num_envs, axis=0), diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 11bfd4b261..996821cd02 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -45,7 +45,7 @@ from mava.types import ( ActorApply, CriticApply, - ExperimentOutput, + Metrics, Observation, SebulbaLearnerFn, ) @@ -113,6 +113,7 @@ def act_fn( while not thread_lifetime.should_stop(): # Rollout traj: List[PPOTransition] = [] + episode_metrics: List[Dict] = [] actor_timings: Dict[str, List[float]] = defaultdict(list) with RecordTimeTo(actor_timings["rollout_time"]): for _ in range(config.system.rollout_length): @@ -142,14 +143,14 @@ def act_fn( timestep.reward, log_prob, obs_tpu, - timestep.extras["episode_metrics"], ) ) + episode_metrics.append(timestep.extras["episode_metrics"]) # send trajectories to learner with RecordTimeTo(actor_timings["rollout_put_time"]): try: - rollout_queue.put(traj, timestep, actor_timings) + rollout_queue.put(traj, timestep, (actor_timings, episode_metrics)) except queue.Full: err = "Waited too long to add to the rollout queue, killing the actor thread" warnings.warn(err, stacklevel=2) @@ -175,7 +176,7 @@ def get_learner_step_fn( def _update_step( learner_state: LearnerState, traj_batch: PPOTransition, - ) -> Tuple[LearnerState, Tuple]: + ) -> Tuple[LearnerState, Metrics]: """A single update of the network. This function calculates advantages and targets based on the trajectories @@ -216,7 +217,7 @@ def _get_advantages(gae_and_next_value: Tuple, transition: PPOTransition) -> Tup last_val = critic_apply_fn(params.critic_params, final_timestep.observation) advantages, targets = _calculate_gae(traj_batch, last_val) - def _update_epoch(update_state: Tuple, _: Any) -> Tuple: + def _update_epoch(update_state: Tuple, _: Any) -> Tuple[Tuple, Metrics]: """Update the network for a single epoch.""" def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: @@ -240,8 +241,8 @@ def _actor_loss_fn( # Calculate actor loss ratio = jnp.exp(log_prob - traj_batch.log_prob) gae = (gae - gae.mean()) / (gae.std() + 1e-8) - loss_actor1 = ratio * gae - loss_actor2 = ( + actor_loss1 = ratio * gae + actor_loss2 = ( jnp.clip( ratio, 1.0 - config.system.clip_eps, @@ -249,13 +250,13 @@ def _actor_loss_fn( ) * gae ) - loss_actor = -jnp.minimum(loss_actor1, loss_actor2) - loss_actor = loss_actor.mean() + actor_loss = -jnp.minimum(actor_loss1, actor_loss2) + actor_loss = actor_loss.mean() # The seed will be used in the TanhTransformedDistribution: entropy = actor_policy.entropy(seed=key).mean() - total_loss_actor = loss_actor - config.system.ent_coef * entropy - return total_loss_actor, (loss_actor, entropy) + total_actor_loss = actor_loss - config.system.ent_coef * entropy + return total_actor_loss, (actor_loss, entropy) def _critic_loss_fn( critic_params: FrozenDict, traj_batch: PPOTransition, targets: chex.Array @@ -272,8 +273,8 @@ def _critic_loss_fn( value_losses_clipped = jnp.square(value_pred_clipped - targets) value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() - critic_total_loss = config.system.vf_coef * value_loss - return critic_total_loss, (value_loss) + total_value_loss = config.system.vf_coef * value_loss + return total_value_loss, value_loss # Calculate actor loss key, entropy_key = jax.random.split(key) @@ -284,7 +285,7 @@ def _critic_loss_fn( # Calculate critic loss critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) - critic_loss_info, critic_grads = critic_grad_fn( + value_loss_info, critic_grads = critic_grad_fn( params.critic_params, traj_batch, targets ) @@ -298,8 +299,8 @@ def _critic_loss_fn( ) # pmean over learner devices. - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="learner_devices" + critic_grads, value_loss_info = jax.lax.pmean( + (critic_grads, value_loss_info), axis_name="learner_devices" ) # Update actor params and optimiser state @@ -318,12 +319,12 @@ def _critic_loss_fn( new_params = Params(actor_new_params, critic_new_params) new_opt_state = OptStates(actor_new_opt_state, critic_new_opt_state) # Pack loss info - actor_total_loss, (actor_loss, entropy) = actor_loss_info - critic_total_loss, (value_loss) = critic_loss_info - total_loss = critic_total_loss + actor_total_loss + actor_loss, (_, entropy) = actor_loss_info + value_loss, unscaled_value_loss = value_loss_info + total_loss = actor_loss + value_loss loss_info = { "total_loss": total_loss, - "value_loss": value_loss, + "value_loss": unscaled_value_loss, "actor_loss": actor_loss, "entropy": entropy, } @@ -359,12 +360,11 @@ def _critic_loss_fn( params, opt_states, traj_batch, advantages, targets, key = update_state learner_state = LearnerState(params, opt_states, key, None, learner_state.timestep) - metric = traj_batch.info - return learner_state, (metric, loss_info) + return learner_state, loss_info def learner_fn( learner_state: LearnerState, traj_batch: PPOTransition - ) -> ExperimentOutput[LearnerState]: + ) -> Tuple[LearnerState, Metrics]: """Learner function. This function represents the learner, it updates the network parameters @@ -382,13 +382,9 @@ def learner_fn( # This function is shard mapped on the batch axis, but `_update_step` needs # the first axis to be time traj_batch = tree.map(switch_leading_axes, traj_batch) - learner_state, (episode_info, loss_info) = _update_step(learner_state, traj_batch) + learner_state, loss_info = _update_step(learner_state, traj_batch) - return ExperimentOutput( - learner_state=learner_state, - episode_metrics=episode_info, - train_metrics=loss_info, - ) + return learner_state, loss_info return learner_fn @@ -412,7 +408,7 @@ def learner_thread( # Get the trajectory batch from the pipeline # This is blocking so it will wait until the pipeline has data. with RecordTimeTo(learn_times["rollout_get_time"]): - traj_batch, timestep, rollout_time = pipeline.get(block=True) + traj_batch, timestep, rollout_time, ep_metrics = pipeline.get(block=True) # Replace the timestep in the learner state with the latest timestep # This means the learner has access to the entire trajectory as well as @@ -420,7 +416,7 @@ def learner_thread( learner_state = learner_state._replace(timestep=timestep) # Update the networks with RecordTimeTo(learn_times["learning_time"]): - learner_state, ep_metrics, train_metrics = learn_fn(learner_state, traj_batch) + learner_state, train_metrics = learn_fn(learner_state, traj_batch) metrics.append((ep_metrics, train_metrics)) rollout_times_array.append(rollout_time) @@ -515,7 +511,7 @@ def learner_setup( learn, mesh=mesh, in_specs=(learn_state_spec, data_spec), - out_specs=ExperimentOutput(learn_state_spec, data_spec, data_spec), + out_specs=(learn_state_spec, data_spec), ) ) diff --git a/mava/systems/ppo/types.py b/mava/systems/ppo/types.py index c8145b1a7c..9e56e17f80 100644 --- a/mava/systems/ppo/types.py +++ b/mava/systems/ppo/types.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict import chex from flax.core.frozen_dict import FrozenDict @@ -75,7 +74,6 @@ class PPOTransition(NamedTuple): reward: chex.Array log_prob: chex.Array obs: Observation - info: Dict class RNNPPOTransition(NamedTuple): @@ -88,4 +86,3 @@ class RNNPPOTransition(NamedTuple): log_prob: chex.Array obs: chex.Array hstates: HiddenStates - info: Dict diff --git a/mava/systems/sable/anakin/ff_sable.py b/mava/systems/sable/anakin/ff_sable.py index 24951079e2..33547523c4 100644 --- a/mava/systems/sable/anakin/ff_sable.py +++ b/mava/systems/sable/anakin/ff_sable.py @@ -26,7 +26,6 @@ from colorama import Fore, Style from flax.core.frozen_dict import FrozenDict as Params from jax import tree -from jumanji.env import Environment from jumanji.types import TimeStep from omegaconf import DictConfig, OmegaConf from rich.pretty import pprint @@ -34,13 +33,13 @@ from mava.evaluator import ActorState, EvalActFn, get_eval_fn, get_num_eval_envs from mava.networks import SableNetwork from mava.networks.utils.sable import get_init_hidden_state +from mava.systems.ppo.types import PPOTransition as Transition from mava.systems.sable.types import ( ActorApply, LearnerApply, - Transition, ) from mava.systems.sable.types import FFLearnerState as LearnerState -from mava.types import Action, ExperimentOutput, LearnerFn, MarlEnv +from mava.types import Action, ExperimentOutput, LearnerFn, MarlEnv, Metrics from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer from mava.utils.config import check_total_timesteps @@ -52,7 +51,7 @@ def get_learner_fn( - env: Environment, + env: MarlEnv, apply_fns: Tuple[ActorApply, LearnerApply], update_fn: optax.TransformUpdateFn, config: DictConfig, @@ -82,11 +81,13 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup """ - def _env_step(learner_state: LearnerState, _: int) -> Tuple[LearnerState, Transition]: + def _env_step( + learner_state: LearnerState, _: int + ) -> Tuple[LearnerState, Tuple[Transition, Metrics]]: """Step the environment.""" params, opt_states, key, env_state, last_timestep = learner_state - # SELECT ACTION + # Select action key, policy_key = jax.random.split(key) # Apply the actor network to get the action, log_prob, value and updated hstates. @@ -97,20 +98,10 @@ def _env_step(learner_state: LearnerState, _: int) -> Tuple[LearnerState, Transi key=policy_key, ) - # STEP ENVIRONMENT + # Step environment env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) - # LOG EPISODE METRICS - info = tree.map( - lambda x: jnp.repeat(x[..., jnp.newaxis], config.system.num_agents, axis=-1), - timestep.extras["episode_metrics"], - ) - - # SET TRANSITION - done = tree.map( - lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), - timestep.last(), - ) + done = timestep.last().repeat(env.num_agents).reshape(config.arch.num_envs, -1) transition = Transition( done, action, @@ -118,23 +109,19 @@ def _env_step(learner_state: LearnerState, _: int) -> Tuple[LearnerState, Transi timestep.reward, log_prob, last_timestep.observation, - info, ) learner_state = LearnerState(params, opt_states, key, env_state, timestep) - return learner_state, transition - - # STEP ENVIRONMENT FOR ROLLOUT LENGTH - learner_state, traj_batch = jax.lax.scan( - _env_step, - learner_state, - jnp.arange(config.system.rollout_length), - config.system.rollout_length, + return learner_state, (transition, timestep.extras["episode_metrics"]) + + # Step environment for rollout length + learner_state, (traj_batch, episode_metrics) = jax.lax.scan( + _env_step, learner_state, length=config.system.rollout_length ) - # CALCULATE ADVANTAGE + # Calculate advantage params, opt_states, key, env_state, last_timestep = learner_state key, last_val_key = jax.random.split(key) - _, _, current_val, _ = sable_action_select_fn( # type: ignore + _, _, last_val, _ = sable_action_select_fn( # type: ignore params, observation=last_timestep.observation, key=last_val_key, @@ -170,14 +157,13 @@ def _get_advantages( ) return advantages, advantages + traj_batch.value - advantages, targets = _calculate_gae(traj_batch, current_val) + advantages, targets = _calculate_gae(traj_batch, last_val) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: """Update the network for a single epoch.""" def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: """Update the network for a single minibatch.""" - # UNPACK TRAIN STATE AND BATCH INFO params, opt_state, key = train_state traj_batch, advantages, targets = batch_info @@ -189,7 +175,7 @@ def _loss_fn( rng_key: chex.PRNGKey, ) -> Tuple: """Calculate Sable loss.""" - # RERUN NETWORK + # Rerun network value, log_prob, entropy = sable_apply_fn( # type: ignore params, observation=traj_batch.obs, @@ -198,11 +184,12 @@ def _loss_fn( rng_key=rng_key, ) - # CALCULATE ACTOR LOSS + # Calculate actor loss ratio = jnp.exp(log_prob - traj_batch.log_prob) + # Nomalise advantage at minibatch level gae = (gae - gae.mean()) / (gae.std() + 1e-8) - loss_actor1 = ratio * gae - loss_actor2 = ( + actor_loss1 = ratio * gae + actor_loss2 = ( jnp.clip( ratio, 1.0 - config.system.clip_eps, @@ -210,50 +197,41 @@ def _loss_fn( ) * gae ) - loss_actor = -jnp.minimum(loss_actor1, loss_actor2) - loss_actor = loss_actor.mean() + actor_loss = -jnp.minimum(actor_loss1, actor_loss2) + actor_loss = actor_loss.mean() entropy = entropy.mean() - # CALCULATE VALUE LOSS + # Clipped MSE loss value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( -config.system.clip_eps, config.system.clip_eps ) - - # MSE LOSS value_losses = jnp.square(value - value_targets) value_losses_clipped = jnp.square(value_pred_clipped - value_targets) value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() - # TOTAL LOSS total_loss = ( - loss_actor + actor_loss - config.system.ent_coef * entropy + config.system.vf_coef * value_loss ) - return total_loss, (loss_actor, entropy, value_loss) + return total_loss, (actor_loss, entropy, value_loss) - # CALCULATE ACTOR LOSS + # Calculate loss key, entropy_key = jax.random.split(key) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) loss_info, grads = grad_fn(params, traj_batch, advantages, targets, entropy_key) # Compute the parallel mean (pmean) over the batch. - # This calculation is inspired by the Anakin architecture demo notebook. - # available at https://tinyurl.com/26tdzs5x # This pmean could be a regular mean as the batch axis is on the same device. grads, loss_info = jax.lax.pmean((grads, loss_info), axis_name="batch") # pmean over devices. grads, loss_info = jax.lax.pmean((grads, loss_info), axis_name="device") - # UPDATE PARAMS AND OPTIMISER STATE + # Update params and optimiser state updates, new_opt_state = update_fn(grads, opt_state) new_params = optax.apply_updates(params, updates) - # PACK LOSS INFO - total_loss = loss_info[0] - actor_loss = loss_info[1][0] - entropy = loss_info[1][1] - value_loss = loss_info[1][2] + total_loss, (actor_loss, entropy, value_loss) = loss_info loss_info = { "total_loss": total_loss, "value_loss": value_loss, @@ -263,16 +241,9 @@ def _loss_fn( return (new_params, new_opt_state, key), loss_info - ( - params, - opt_states, - traj_batch, - advantages, - targets, - key, - ) = update_state + (params, opt_states, traj_batch, advantages, targets, key) = update_state - # SHUFFLE MINIBATCHES + # Shuffle minibatches key, batch_shuffle_key, agent_shuffle_key, entropy_key = jax.random.split(key, 4) # Shuffle batch @@ -286,39 +257,25 @@ def _loss_fn( agent_perm = jax.random.permutation(agent_shuffle_key, config.system.num_agents) shuffled_batch = tree.map(lambda x: jnp.take(x, agent_perm, axis=1), shuffled_batch) - # SPLIT INTO MINIBATCHES + # Split into minibatches minibatches = tree.map( lambda x: jnp.reshape(x, (config.system.num_minibatches, -1, *x.shape[1:])), shuffled_batch, ) - # UPDATE MINIBATCHES + # Update minibatches (params, opt_states, entropy_key), loss_info = jax.lax.scan( _update_minibatch, (params, opt_states, entropy_key), minibatches, ) - update_state = ( - params, - opt_states, - traj_batch, - advantages, - targets, - key, - ) + update_state = (params, opt_states, traj_batch, advantages, targets, key) return update_state, loss_info - update_state = ( - params, - opt_states, - traj_batch, - advantages, - targets, - key, - ) + update_state = (params, opt_states, traj_batch, advantages, targets, key) - # UPDATE EPOCHS + # Update epochs update_state, loss_info = jax.lax.scan( _update_epoch, update_state, None, config.system.ppo_epochs ) @@ -331,8 +288,7 @@ def _loss_fn( env_state, last_timestep, ) - metric = traj_batch.info - return learner_state, (metric, loss_info) + return learner_state, (episode_metrics, loss_info) def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: """Learner function. diff --git a/mava/systems/sable/anakin/rec_sable.py b/mava/systems/sable/anakin/rec_sable.py index 34c78fc7c2..b1eda03f00 100644 --- a/mava/systems/sable/anakin/rec_sable.py +++ b/mava/systems/sable/anakin/rec_sable.py @@ -26,7 +26,6 @@ from colorama import Fore, Style from flax.core.frozen_dict import FrozenDict as Params from jax import tree -from jumanji.env import Environment from jumanji.types import TimeStep from omegaconf import DictConfig, OmegaConf from rich.pretty import pprint @@ -34,14 +33,14 @@ from mava.evaluator import ActorState, EvalActFn, get_eval_fn, get_num_eval_envs from mava.networks import SableNetwork from mava.networks.utils.sable import get_init_hidden_state +from mava.systems.ppo.types import PPOTransition as Transition from mava.systems.sable.types import ( ActorApply, HiddenStates, LearnerApply, - Transition, ) from mava.systems.sable.types import RecLearnerState as LearnerState -from mava.types import Action, ExperimentOutput, LearnerFn, MarlEnv +from mava.types import Action, ExperimentOutput, LearnerFn, MarlEnv, Metrics from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer from mava.utils.config import check_total_timesteps @@ -53,7 +52,7 @@ def get_learner_fn( - env: Environment, + env: MarlEnv, apply_fns: Tuple[ActorApply, LearnerApply], update_fn: optax.TransformUpdateFn, config: DictConfig, @@ -62,6 +61,7 @@ def get_learner_fn( # Get apply functions for executing and training the network. sable_action_select_fn, sable_apply_fn = apply_fns + num_envs = config.arch.num_envs def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tuple]: """A single update of the network. @@ -84,11 +84,13 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup """ - def _env_step(learner_state: LearnerState, _: int) -> Tuple[LearnerState, Transition]: + def _env_step( + learner_state: LearnerState, _: Any + ) -> Tuple[LearnerState, Tuple[Transition, Metrics]]: """Step the environment.""" params, opt_states, key, env_state, last_timestep, hstates = learner_state - # SELECT ACTION + # Select action key, policy_key = jax.random.split(key) # Apply the actor network to get the action, log_prob, value and updated hstates. @@ -100,58 +102,36 @@ def _env_step(learner_state: LearnerState, _: int) -> Tuple[LearnerState, Transi policy_key, ) - # STEP ENVIRONMENT + # Step environment env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) - # LOG EPISODE METRICS - info = tree.map( - lambda x: jnp.repeat(x[..., jnp.newaxis], config.system.num_agents, axis=-1), - timestep.extras["episode_metrics"], - ) - # Reset hidden state if done. done = timestep.last() done = jnp.expand_dims(done, (1, 2, 3, 4)) hstates = tree.map(lambda hs: jnp.where(done, jnp.zeros_like(hs), hs), hstates) - # SET TRANSITION - prev_done = tree.map( - lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), - last_timestep.last(), - ) + prev_done = last_timestep.last().repeat(env.num_agents).reshape(num_envs, -1) transition = Transition( - prev_done, - action, - value, - timestep.reward, - log_prob, - last_timestep.observation, - info, + prev_done, action, value, timestep.reward, log_prob, last_timestep.observation ) learner_state = LearnerState(params, opt_states, key, env_state, timestep, hstates) - return learner_state, transition + return learner_state, (transition, timestep.extras["episode_metrics"]) - # COPY OLD HIDDEN STATES: TO BE USED IN THE TRAINING LOOP + # Copy old hidden states: to be used in the training loop prev_hstates = tree.map(lambda x: jnp.copy(x), learner_state.hstates) - # STEP ENVIRONMENT FOR ROLLOUT LENGTH - learner_state, traj_batch = jax.lax.scan( - _env_step, - learner_state, - jnp.arange(config.system.rollout_length), - config.system.rollout_length, + # Step environment for rollout length + learner_state, (traj_batch, episode_metrics) = jax.lax.scan( + _env_step, learner_state, length=config.system.rollout_length ) - # CALCULATE ADVANTAGE + # Calculate advantage params, opt_states, key, env_state, last_timestep, updated_hstates = learner_state key, last_val_key = jax.random.split(key) - _, _, current_val, _ = sable_action_select_fn( # type: ignore + _, _, last_val, _ = sable_action_select_fn( # type: ignore params, last_timestep.observation, updated_hstates, last_val_key ) - current_done = tree.map( - lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), - last_timestep.last(), - ) + last_done = last_timestep.last().repeat(env.num_agents).reshape(num_envs, -1) def _calculate_gae( traj_batch: Transition, @@ -184,14 +164,13 @@ def _get_advantages( ) return advantages, advantages + traj_batch.value - advantages, targets = _calculate_gae(traj_batch, current_val, current_done) + advantages, targets = _calculate_gae(traj_batch, last_val, last_done) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: """Update the network for a single epoch.""" def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: """Update the network for a single minibatch.""" - # UNPACK TRAIN STATE AND BATCH INFO params, opt_state, key = train_state traj_batch, advantages, targets, prev_hstates = batch_info @@ -204,7 +183,7 @@ def _loss_fn( rng_key: chex.PRNGKey, ) -> Tuple: """Calculate Sable loss.""" - # RERUN NETWORK + # Rerun network value, log_prob, entropy = sable_apply_fn( # type: ignore params, traj_batch.obs, @@ -214,11 +193,12 @@ def _loss_fn( rng_key, ) - # CALCULATE ACTOR LOSS + # Calculate actor loss ratio = jnp.exp(log_prob - traj_batch.log_prob) + # Nomalise advantage at minibatch level gae = (gae - gae.mean()) / (gae.std() + 1e-8) - loss_actor1 = ratio * gae - loss_actor2 = ( + actor_loss1 = ratio * gae + actor_loss2 = ( jnp.clip( ratio, 1.0 - config.system.clip_eps, @@ -226,29 +206,26 @@ def _loss_fn( ) * gae ) - loss_actor = -jnp.minimum(loss_actor1, loss_actor2) - loss_actor = loss_actor.mean() + actor_loss = -jnp.minimum(actor_loss1, actor_loss2) + actor_loss = actor_loss.mean() entropy = entropy.mean() - # CALCULATE VALUE LOSS + # Clipped MSE loss value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( -config.system.clip_eps, config.system.clip_eps ) - - # MSE LOSS value_losses = jnp.square(value - value_targets) value_losses_clipped = jnp.square(value_pred_clipped - value_targets) value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() - # TOTAL LOSS total_loss = ( - loss_actor + actor_loss - config.system.ent_coef * entropy + config.system.vf_coef * value_loss ) - return total_loss, (loss_actor, entropy, value_loss) + return total_loss, (actor_loss, entropy, value_loss) - # CALCULATE ACTOR LOSS + # Calculate loss key, entropy_key = jax.random.split(key) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) loss_info, grads = grad_fn( @@ -261,22 +238,16 @@ def _loss_fn( ) # Compute the parallel mean (pmean) over the batch. - # This calculation is inspired by the Anakin architecture demo notebook. - # available at https://tinyurl.com/26tdzs5x # This pmean could be a regular mean as the batch axis is on the same device. grads, loss_info = jax.lax.pmean((grads, loss_info), axis_name="batch") # pmean over devices. grads, loss_info = jax.lax.pmean((grads, loss_info), axis_name="device") - # UPDATE PARAMS AND OPTIMISER STATE + # Update params and optimiser state updates, new_opt_state = update_fn(grads, opt_state) new_params = optax.apply_updates(params, updates) - # PACK LOSS INFO - total_loss = loss_info[0] - actor_loss = loss_info[1][0] - entropy = loss_info[1][1] - value_loss = loss_info[1][2] + total_loss, (actor_loss, entropy, value_loss) = loss_info loss_info = { "total_loss": total_loss, "value_loss": value_loss, @@ -286,17 +257,9 @@ def _loss_fn( return (new_params, new_opt_state, key), loss_info - ( - params, - opt_states, - traj_batch, - advantages, - targets, - key, - prev_hstates, - ) = update_state + (params, opt_states, traj_batch, advantages, targets, key, prev_hstates) = update_state - # SHUFFLE MINIBATCHES + # Shuffle minibatches key, batch_shuffle_key, agent_shuffle_key, entropy_key = jax.random.split(key, 4) # Shuffle batch @@ -312,10 +275,10 @@ def _loss_fn( agent_perm = jax.random.permutation(agent_shuffle_key, config.system.num_agents) batch = tree.map(lambda x: jnp.take(x, agent_perm, axis=2), batch) - # CONCATENATE TIME AND AGENTS + # Concatenate time and agents batch = tree.map(concat_time_and_agents, batch) - # SPLIT INTO MINIBATCHES + # Split into minibatches minibatches = tree.map( lambda x: jnp.reshape(x, (config.system.num_minibatches, -1, *x.shape[1:])), batch, @@ -332,28 +295,12 @@ def _loss_fn( (*minibatches, prev_hs_minibatch), ) - update_state = ( - params, - opt_states, - traj_batch, - advantages, - targets, - key, - prev_hstates, - ) + update_state = (params, opt_states, traj_batch, advantages, targets, key, prev_hstates) return update_state, loss_info - update_state = ( - params, - opt_states, - traj_batch, - advantages, - targets, - key, - prev_hstates, - ) + update_state = (params, opt_states, traj_batch, advantages, targets, key, prev_hstates) - # UPDATE EPOCHS + # Update epochs update_state, loss_info = jax.lax.scan( _update_epoch, update_state, None, config.system.ppo_epochs ) @@ -367,8 +314,7 @@ def _loss_fn( last_timestep, updated_hstates, ) - metric = traj_batch.info - return learner_state, (metric, loss_info) + return learner_state, (episode_metrics, loss_info) def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: """Learner function. diff --git a/mava/systems/sable/types.py b/mava/systems/sable/types.py index c93d3bf482..7021d1bef6 100644 --- a/mava/systems/sable/types.py +++ b/mava/systems/sable/types.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Tuple +from typing import Callable, Tuple from chex import Array, PRNGKey from flax.core.frozen_dict import FrozenDict @@ -58,18 +58,6 @@ class FFLearnerState(NamedTuple): timestep: TimeStep -class Transition(NamedTuple): - """Transition tuple.""" - - done: Array - action: Array - value: Array - reward: Array - log_prob: Array - obs: Array - info: Dict - - ActorApply = Callable[ [FrozenDict, Array, Array, HiddenStates, PRNGKey], Tuple[Array, Array, Array, Array, HiddenStates], diff --git a/mava/types.py b/mava/types.py index 4072629dcd..d60175f502 100644 --- a/mava/types.py +++ b/mava/types.py @@ -152,7 +152,7 @@ class ExperimentOutput(NamedTuple, Generic[MavaState]): LearnerFn = Callable[[MavaState], ExperimentOutput[MavaState]] -SebulbaLearnerFn = Callable[[MavaState, MavaTransition], ExperimentOutput[MavaState]] +SebulbaLearnerFn = Callable[[MavaState, MavaTransition], Tuple[MavaState, Metrics]] ActorApply = Callable[[FrozenDict, Observation], Distribution] CriticApply = Callable[[FrozenDict, Observation], Value] RecActorApply = Callable[ diff --git a/mava/utils/sebulba.py b/mava/utils/sebulba.py index dc51140f5e..1fe441e2a8 100644 --- a/mava/utils/sebulba.py +++ b/mava/utils/sebulba.py @@ -27,6 +27,7 @@ # todo: remove the ppo dependencies when we make sebulba for other systems from mava.systems.ppo.types import Params, PPOTransition +from mava.types import Metrics QUEUE_PUT_TIMEOUT = 100 @@ -90,7 +91,9 @@ def run(self) -> None: except queue.Empty: continue - def put(self, traj: Sequence[PPOTransition], timestep: TimeStep, time_dict: Dict) -> None: + def put( + self, traj: Sequence[PPOTransition], timestep: TimeStep, metrics: Tuple[Dict, List[Dict]] + ) -> None: """Put a trajectory on the queue to be consumed by the learner.""" start_condition, end_condition = (threading.Condition(), threading.Condition()) with start_condition: @@ -101,6 +104,10 @@ def put(self, traj: Sequence[PPOTransition], timestep: TimeStep, time_dict: Dict traj = _stack_trajectory(traj) traj, timestep = jax.device_put((traj, timestep), device=self.sharding) + time_dict, episode_metrics = metrics + # [{'metric1' : value1, ...} * rollout_len -> {'metric1' : [value1, value2, ...], ...} + episode_metrics = _stack_trajectory(episode_metrics) + # We block on the `put` to ensure that actors wait for the learners to catch up. # This ensures two things: # The actors don't get too far ahead of the learners, which could lead to off-policy data. @@ -110,7 +117,7 @@ def put(self, traj: Sequence[PPOTransition], timestep: TimeStep, time_dict: Dict # We use a try-finally so the lock is released even if an exception is raised. try: self._queue.put( - (traj, timestep, time_dict), + (traj, timestep, time_dict, episode_metrics), block=True, timeout=QUEUE_PUT_TIMEOUT, ) @@ -129,7 +136,7 @@ def qsize(self) -> int: def get( self, block: bool = True, timeout: Union[float, None] = None - ) -> Tuple[PPOTransition, TimeStep, Dict]: + ) -> Tuple[PPOTransition, TimeStep, Dict, Metrics]: """Get a trajectory from the pipeline.""" return self._queue.get(block, timeout) # type: ignore