diff --git a/README.md b/README.md index 4285080d7..dee7a1331 100644 --- a/README.md +++ b/README.md @@ -186,7 +186,7 @@ Additionally, we also have a [Quickstart notebook][quickstart] that can be used ## Advanced Usage 👽 -Mava can be used in a wide array of advanced systems. As an example, we demonstrate recording experience data from one of our PPO systems into a [Flashbax](https://github.com/instadeepai/flashbax) `Vault`. This vault can then easily be integrated into offline MARL systems, such as those found in [OG-MARL](https://github.com/instadeepai/og-marl). See the [Advanced README](./mava/advanced_usage/) for more information. +Mava can be used in a wide array of advanced systems. As an example, we demonstrate recording experience data from one of our PPO systems into a [Flashbax](https://github.com/instadeepai/flashbax) `Vault`. This vault can then easily be integrated into offline MARL systems, such as those found in [OG-MARL](https://github.com/instadeepai/og-marl). See the [Advanced README](./examples/advanced_usage/README.md) for more information. ## Contributing 🤝 diff --git a/examples/Quickstart.ipynb b/examples/Quickstart.ipynb index 7febf6140..baf119cda 100644 --- a/examples/Quickstart.ipynb +++ b/examples/Quickstart.ipynb @@ -413,8 +413,6 @@ " )\n", "\n", " # Compute the parallel mean (pmean) over the batch.\n", - " # This calculation is inspired by the Anakin architecture demo notebook.\n", - " # available at https://tinyurl.com/26tdzs5x\n", " # This pmean could be a regular mean as the batch axis is on the same device.\n", " actor_grads, actor_loss_info = jax.lax.pmean(\n", " (actor_grads, actor_loss_info), axis_name=\"batch\"\n", diff --git a/mava/advanced_usage/README.md b/examples/advanced_usage/README.md similarity index 100% rename from mava/advanced_usage/README.md rename to examples/advanced_usage/README.md diff --git a/mava/advanced_usage/ff_ippo_store_experience.py b/examples/advanced_usage/ff_ippo_store_experience.py similarity index 99% rename from mava/advanced_usage/ff_ippo_store_experience.py rename to examples/advanced_usage/ff_ippo_store_experience.py index 0c33f33f4..58d71c795 100644 --- a/mava/advanced_usage/ff_ippo_store_experience.py +++ b/examples/advanced_usage/ff_ippo_store_experience.py @@ -1,3 +1,4 @@ +# type: ignore # Copyright 2022 InstaDeep Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -225,8 +226,6 @@ def _critic_loss_fn( ) # Compute the parallel mean (pmean) over the batch. - # This calculation is inspired by the Anakin architecture demo notebook. - # available at https://tinyurl.com/26tdzs5x # This pmean could be a regular mean as the batch axis is on the same device. actor_grads, actor_loss_info = jax.lax.pmean( (actor_grads, actor_loss_info), axis_name="batch" diff --git a/mava/evaluator.py b/mava/evaluator.py index 6b2fda203..f157b42d0 100644 --- a/mava/evaluator.py +++ b/mava/evaluator.py @@ -293,7 +293,7 @@ def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]: # find the first instance of done to get the metrics at that timestep, we don't # care about subsequent steps because we only the results from the first episode done_idx = np.argmax(timesteps.last(), axis=0) - metrics = jax.tree_map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics) + metrics = tree.map(lambda m: m[done_idx, np.arange(n_parallel_envs)], metrics) del metrics["is_terminal_step"] # uneeded for logging return key, metrics @@ -307,7 +307,7 @@ def _episode(key: PRNGKey) -> Tuple[PRNGKey, Metrics]: metrics_array.append(metric) # flatten metrics - metrics: Metrics = jax.tree_map(lambda *x: np.array(x).reshape(-1), *metrics_array) + metrics: Metrics = tree.map(lambda *x: np.array(x).reshape(-1), *metrics_array) return metrics def timed_eval_fn(params: FrozenDict, key: PRNGKey, init_act_state: ActorState) -> Metrics: diff --git a/mava/systems/mat/anakin/mat.py b/mava/systems/mat/anakin/mat.py index 2973fe567..1db141d19 100644 --- a/mava/systems/mat/anakin/mat.py +++ b/mava/systems/mat/anakin/mat.py @@ -37,16 +37,13 @@ ExperimentOutput, LearnerFn, MarlEnv, + Metrics, TimeStep, ) from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer from mava.utils.config import check_total_timesteps -from mava.utils.jax_utils import ( - merge_leading_dims, - unreplicate_batch_dim, - unreplicate_n_dims, -) +from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger from mava.utils.network_utils import get_action_head from mava.utils.training import make_learning_rate @@ -83,51 +80,35 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup _ (Any): The current metrics info. """ - def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]: + def _env_step( + learner_state: LearnerState, _: Any + ) -> Tuple[LearnerState, Tuple[PPOTransition, Metrics]]: """Step the environment.""" params, opt_state, key, env_state, last_timestep = learner_state - # SELECT ACTION + # Select action key, policy_key = jax.random.split(key) action, log_prob, value = actor_action_select_fn( # type: ignore params, last_timestep.observation, policy_key, ) - # STEP ENVIRONMENT + # Step environment env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) - # LOG EPISODE METRICS - # Repeat along the agent dimension. This is needed to handle the - # shuffling along the agent dimension during training. - info = tree.map( - lambda x: jnp.repeat(x[..., jnp.newaxis], config.system.num_agents, axis=-1), - timestep.extras["episode_metrics"], - ) - - # SET TRANSITION - done = tree.map( - lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), - timestep.last(), - ) + done = timestep.last().repeat(env.num_agents).reshape(config.arch.num_envs, -1) transition = PPOTransition( - done, - action, - value, - timestep.reward, - log_prob, - last_timestep.observation, - info, + done, action, value, timestep.reward, log_prob, last_timestep.observation ) learner_state = LearnerState(params, opt_state, key, env_state, timestep) - return learner_state, transition + return learner_state, (transition, timestep.extras["episode_metrics"]) - # STEP ENVIRONMENT FOR ROLLOUT LENGTH - learner_state, traj_batch = jax.lax.scan( + # Step environment for rollout length + learner_state, (traj_batch, episode_metrics) = jax.lax.scan( _env_step, learner_state, None, config.system.rollout_length ) - # CALCULATE ADVANTAGE + # Calculate advantage params, opt_state, key, env_state, last_timestep = learner_state key, last_val_key = jax.random.split(key) @@ -171,8 +152,6 @@ def _update_epoch(update_state: Tuple, _: Any) -> Tuple: def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: """Update the network for a single minibatch.""" - - # UNPACK TRAIN STATE AND BATCH INFO params, opt_state, key = train_state traj_batch, advantages, targets = batch_info @@ -184,8 +163,7 @@ def _loss_fn( entropy_key: chex.PRNGKey, ) -> Tuple: """Calculate the actor loss.""" - # RERUN NETWORK - + # Rerun network log_prob, value, entropy = actor_apply_fn( # type: ignore params, traj_batch.obs, @@ -193,14 +171,12 @@ def _loss_fn( entropy_key, ) - # CALCULATE ACTOR LOSS + # Calculate actor loss ratio = jnp.exp(log_prob - traj_batch.log_prob) - # Nomalise advantage at minibatch level gae = (gae - gae.mean()) / (gae.std() + 1e-8) - - loss_actor1 = ratio * gae - loss_actor2 = ( + actor_loss1 = ratio * gae + actor_loss2 = ( jnp.clip( ratio, 1.0 - config.system.clip_eps, @@ -208,28 +184,26 @@ def _loss_fn( ) * gae ) - loss_actor = -jnp.minimum(loss_actor1, loss_actor2) - loss_actor = loss_actor.mean() + actor_loss = -jnp.minimum(actor_loss1, actor_loss2) + actor_loss = actor_loss.mean() entropy = entropy.mean() - # CALCULATE VALUE LOSS + # Clipped MSE loss value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( -config.system.clip_eps, config.system.clip_eps ) - - # MSE LOSS value_losses = jnp.square(value - value_targets) value_losses_clipped = jnp.square(value_pred_clipped - value_targets) value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() total_loss = ( - loss_actor + actor_loss - config.system.ent_coef * entropy + config.system.vf_coef * value_loss ) - return total_loss, (loss_actor, entropy, value_loss) + return total_loss, (actor_loss, entropy, value_loss) - # CALCULATE ACTOR LOSS + # Calculate loss key, entropy_key = jax.random.split(key) actor_grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) actor_loss_info, actor_grads = actor_grad_fn( @@ -248,15 +222,11 @@ def _loss_fn( (actor_grads, actor_loss_info), axis_name="device" ) - # UPDATE ACTOR PARAMS AND OPTIMISER STATE + # Update params and optimiser state actor_updates, new_opt_state = actor_update_fn(actor_grads, opt_state) new_params = optax.apply_updates(params, actor_updates) - # PACK LOSS INFO - total_loss = actor_loss_info[0] - value_loss = actor_loss_info[1][2] - actor_loss = actor_loss_info[1][0] - entropy = actor_loss_info[1][1] + total_loss, (actor_loss, entropy, value_loss) = actor_loss_info loss_info = { "total_loss": total_loss, "value_loss": value_loss, @@ -269,7 +239,7 @@ def _loss_fn( params, opt_state, traj_batch, advantages, targets, key = update_state key, batch_shuffle_key, agent_shuffle_key, entropy_key = jax.random.split(key, 4) - # SHUFFLE MINIBATCHES + # Shuffle minibatches batch_size = config.system.rollout_length * config.arch.num_envs permutation = jax.random.permutation(batch_shuffle_key, batch_size) @@ -286,7 +256,7 @@ def _loss_fn( shuffled_batch, ) - # UPDATE MINIBATCHES + # Update minibatches (params, opt_state, entropy_key), loss_info = jax.lax.scan( _update_minibatch, (params, opt_state, entropy_key), minibatches ) @@ -296,7 +266,7 @@ def _loss_fn( update_state = params, opt_state, traj_batch, advantages, targets, key - # UPDATE EPOCHS + # Update epochs update_state, loss_info = jax.lax.scan( _update_epoch, update_state, None, config.system.ppo_epochs ) @@ -304,9 +274,7 @@ def _loss_fn( params, opt_state, traj_batch, advantages, targets, key = update_state learner_state = LearnerState(params, opt_state, key, env_state, last_timestep) - metric = traj_batch.info - - return learner_state, (metric, loss_info) + return learner_state, (episode_metrics, loss_info) def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: """Learner function. @@ -351,7 +319,7 @@ def learner_setup( # PRNG keys. key, actor_net_key = keys - # Initialise observation: Obs for all agents. + # Get mock inputs to initialise network. init_x = env.observation_spec().generate_value() init_x = tree.map(lambda x: x[None, ...], init_x) diff --git a/mava/systems/ppo/anakin/ff_ippo.py b/mava/systems/ppo/anakin/ff_ippo.py index fd97b280d..ed2943916 100644 --- a/mava/systems/ppo/anakin/ff_ippo.py +++ b/mava/systems/ppo/anakin/ff_ippo.py @@ -26,22 +26,17 @@ from flax.core.frozen_dict import FrozenDict from jax import tree from omegaconf import DictConfig, OmegaConf -from optax._src.base import OptState from rich.pretty import pprint from mava.evaluator import get_eval_fn, make_ff_eval_act_fn from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition -from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv +from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv, Metrics from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer from mava.utils.config import check_total_timesteps -from mava.utils.jax_utils import ( - merge_leading_dims, - unreplicate_batch_dim, - unreplicate_n_dims, -) +from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger from mava.utils.network_utils import get_action_head from mava.utils.training import make_learning_rate @@ -79,11 +74,13 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup """ - def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]: + def _env_step( + learner_state: LearnerState, _: Any + ) -> Tuple[LearnerState, Tuple[PPOTransition, Metrics]]: """Step the environment.""" params, opt_states, key, env_state, last_timestep = learner_state - # SELECT ACTION + # Select action key, policy_key = jax.random.split(key) actor_policy = actor_apply_fn(params.actor_params, last_timestep.observation) value = critic_apply_fn(params.critic_params, last_timestep.observation) @@ -91,34 +88,22 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra action = actor_policy.sample(seed=policy_key) log_prob = actor_policy.log_prob(action) - # STEP ENVIRONMENT + # Step environment env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) - # LOG EPISODE METRICS - done = tree.map( - lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), - timestep.last(), - ) - info = timestep.extras["episode_metrics"] - + done = timestep.last().repeat(env.num_agents).reshape(config.arch.num_envs, -1) transition = PPOTransition( - done, - action, - value, - timestep.reward, - log_prob, - last_timestep.observation, - info, + done, action, value, timestep.reward, log_prob, last_timestep.observation ) learner_state = LearnerState(params, opt_states, key, env_state, timestep) - return learner_state, transition + return learner_state, (transition, timestep.extras["episode_metrics"]) - # STEP ENVIRONMENT FOR ROLLOUT LENGTH - learner_state, traj_batch = jax.lax.scan( + # Step environment for rollout length + learner_state, (traj_batch, episode_metrics) = jax.lax.scan( _env_step, learner_state, None, config.system.rollout_length ) - # CALCULATE ADVANTAGE + # Calculate advantage params, opt_states, key, env_state, last_timestep = learner_state last_val = critic_apply_fn(params.critic_params, last_timestep.observation) @@ -156,27 +141,26 @@ def _update_epoch(update_state: Tuple, _: Any) -> Tuple: def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: """Update the network for a single minibatch.""" - # UNPACK TRAIN STATE AND BATCH INFO params, opt_states, key = train_state traj_batch, advantages, targets = batch_info def _actor_loss_fn( actor_params: FrozenDict, - actor_opt_state: OptState, traj_batch: PPOTransition, gae: chex.Array, key: chex.PRNGKey, ) -> Tuple: """Calculate the actor loss.""" - # RERUN NETWORK + # Rerun network actor_policy = actor_apply_fn(actor_params, traj_batch.obs) log_prob = actor_policy.log_prob(traj_batch.action) - # CALCULATE ACTOR LOSS + # Calculate actor loss ratio = jnp.exp(log_prob - traj_batch.log_prob) + # Nomalise advantage at minibatch level gae = (gae - gae.mean()) / (gae.std() + 1e-8) - loss_actor1 = ratio * gae - loss_actor2 = ( + actor_loss1 = ratio * gae + actor_loss2 = ( jnp.clip( ratio, 1.0 - config.system.clip_eps, @@ -184,25 +168,24 @@ def _actor_loss_fn( ) * gae ) - loss_actor = -jnp.minimum(loss_actor1, loss_actor2) - loss_actor = loss_actor.mean() + actor_loss = -jnp.minimum(actor_loss1, actor_loss2) + actor_loss = actor_loss.mean() # The seed will be used in the TanhTransformedDistribution: entropy = actor_policy.entropy(seed=key).mean() - total_loss_actor = loss_actor - config.system.ent_coef * entropy - return total_loss_actor, (loss_actor, entropy) + total_actor_loss = actor_loss - config.system.ent_coef * entropy + return total_actor_loss, (actor_loss, entropy) def _critic_loss_fn( critic_params: FrozenDict, - critic_opt_state: OptState, traj_batch: PPOTransition, targets: chex.Array, ) -> Tuple: """Calculate the critic loss.""" - # RERUN NETWORK + # Rerun network value = critic_apply_fn(critic_params, traj_batch.obs) - # CALCULATE VALUE LOSS + # Clipped MSE loss value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( -config.system.clip_eps, config.system.clip_eps ) @@ -210,32 +193,23 @@ def _critic_loss_fn( value_losses_clipped = jnp.square(value_pred_clipped - targets) value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() - critic_total_loss = config.system.vf_coef * value_loss - return critic_total_loss, (value_loss) + total_value_loss = config.system.vf_coef * value_loss + return total_value_loss, value_loss - # CALCULATE ACTOR LOSS + # Calculate actor loss key, entropy_key = jax.random.split(key) actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) actor_loss_info, actor_grads = actor_grad_fn( - params.actor_params, - opt_states.actor_opt_state, - traj_batch, - advantages, - entropy_key, + params.actor_params, traj_batch, advantages, entropy_key ) - # CALCULATE CRITIC LOSS + # Calculate critic loss critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) - critic_loss_info, critic_grads = critic_grad_fn( - params.critic_params, - opt_states.critic_opt_state, - traj_batch, - targets, + value_loss_info, critic_grads = critic_grad_fn( + params.critic_params, traj_batch, targets ) # Compute the parallel mean (pmean) over the batch. - # This calculation is inspired by the Anakin architecture demo notebook. - # available at https://tinyurl.com/26tdzs5x # This pmean could be a regular mean as the batch axis is on the same device. actor_grads, actor_loss_info = jax.lax.pmean( (actor_grads, actor_loss_info), axis_name="batch" @@ -245,38 +219,35 @@ def _critic_loss_fn( (actor_grads, actor_loss_info), axis_name="device" ) - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="batch" + critic_grads, value_loss_info = jax.lax.pmean( + (critic_grads, value_loss_info), axis_name="batch" ) # pmean over devices. - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="device" + critic_grads, value_loss_info = jax.lax.pmean( + (critic_grads, value_loss_info), axis_name="device" ) - # UPDATE ACTOR PARAMS AND OPTIMISER STATE + # Update params and optimiser state actor_updates, actor_new_opt_state = actor_update_fn( actor_grads, opt_states.actor_opt_state ) actor_new_params = optax.apply_updates(params.actor_params, actor_updates) - # UPDATE CRITIC PARAMS AND OPTIMISER STATE critic_updates, critic_new_opt_state = critic_update_fn( critic_grads, opt_states.critic_opt_state ) critic_new_params = optax.apply_updates(params.critic_params, critic_updates) - # PACK NEW PARAMS AND OPTIMISER STATE new_params = Params(actor_new_params, critic_new_params) new_opt_state = OptStates(actor_new_opt_state, critic_new_opt_state) - # PACK LOSS INFO - total_loss = actor_loss_info[0] + critic_loss_info[0] - value_loss = critic_loss_info[1] - actor_loss = actor_loss_info[1][0] - entropy = actor_loss_info[1][1] + actor_loss, (_, entropy) = actor_loss_info + value_loss, unscaled_value_loss = value_loss_info + + total_loss = actor_loss + value_loss loss_info = { "total_loss": total_loss, - "value_loss": value_loss, + "value_loss": unscaled_value_loss, "actor_loss": actor_loss, "entropy": entropy, } @@ -285,7 +256,7 @@ def _critic_loss_fn( params, opt_states, traj_batch, advantages, targets, key = update_state key, shuffle_key, entropy_key = jax.random.split(key, 3) - # SHUFFLE MINIBATCHES + # Shuffle data and create minibatches batch_size = config.system.rollout_length * config.arch.num_envs permutation = jax.random.permutation(shuffle_key, batch_size) batch = (traj_batch, advantages, targets) @@ -296,7 +267,7 @@ def _critic_loss_fn( shuffled_batch, ) - # UPDATE MINIBATCHES + # Update minibatches (params, opt_states, entropy_key), loss_info = jax.lax.scan( _update_minibatch, (params, opt_states, entropy_key), minibatches ) @@ -306,15 +277,14 @@ def _critic_loss_fn( update_state = (params, opt_states, traj_batch, advantages, targets, key) - # UPDATE EPOCHS + # Update epochs update_state, loss_info = jax.lax.scan( _update_epoch, update_state, None, config.system.ppo_epochs ) params, opt_states, traj_batch, advantages, targets, key = update_state learner_state = LearnerState(params, opt_states, key, env_state, last_timestep) - metric = traj_batch.info - return learner_state, (metric, loss_info) + return learner_state, (episode_metrics, loss_info) def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: """Learner function. @@ -381,7 +351,7 @@ def learner_setup( optax.adam(critic_lr, eps=1e-5), ) - # Initialise observation with obs of all agents. + # Get mock inputs to initialise network. obs = env.observation_spec().generate_value() init_x = tree.map(lambda x: x[jnp.newaxis, ...], obs) diff --git a/mava/systems/ppo/anakin/ff_mappo.py b/mava/systems/ppo/anakin/ff_mappo.py index 26b3e17b6..5e5a0006f 100644 --- a/mava/systems/ppo/anakin/ff_mappo.py +++ b/mava/systems/ppo/anakin/ff_mappo.py @@ -25,14 +25,13 @@ from flax.core.frozen_dict import FrozenDict from jax import tree from omegaconf import DictConfig, OmegaConf -from optax._src.base import OptState from rich.pretty import pprint from mava.evaluator import get_eval_fn, make_ff_eval_act_fn from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition -from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv +from mava.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn, MarlEnv, Metrics from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer from mava.utils.config import check_total_timesteps @@ -74,39 +73,36 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup """ - def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]: + def _env_step( + learner_state: LearnerState, _: Any + ) -> Tuple[LearnerState, Tuple[PPOTransition, Metrics]]: """Step the environment.""" params, opt_states, key, env_state, last_timestep = learner_state - # SELECT ACTION + # Select action key, policy_key = jax.random.split(key) actor_policy = actor_apply_fn(params.actor_params, last_timestep.observation) value = critic_apply_fn(params.critic_params, last_timestep.observation) action = actor_policy.sample(seed=policy_key) log_prob = actor_policy.log_prob(action) - # STEP ENVIRONMENT + # Step environment env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) - # LOG EPISODE METRICS - done = tree.map( - lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), - timestep.last(), - ) - info = timestep.extras["episode_metrics"] + done = timestep.last().repeat(env.num_agents).reshape(config.arch.num_envs, -1) transition = PPOTransition( - done, action, value, timestep.reward, log_prob, last_timestep.observation, info + done, action, value, timestep.reward, log_prob, last_timestep.observation ) learner_state = LearnerState(params, opt_states, key, env_state, timestep) - return learner_state, transition + return learner_state, (transition, timestep.extras["episode_metrics"]) - # STEP ENVIRONMENT FOR ROLLOUT LENGTH - learner_state, traj_batch = jax.lax.scan( + # Step environment for rollout length + learner_state, (traj_batch, episode_metrics) = jax.lax.scan( _env_step, learner_state, None, config.system.rollout_length ) - # CALCULATE ADVANTAGE + # Calculate advantage params, opt_states, key, env_state, last_timestep = learner_state last_val = critic_apply_fn(params.critic_params, last_timestep.observation) @@ -144,27 +140,26 @@ def _update_epoch(update_state: Tuple, _: Any) -> Tuple: def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: """Update the network for a single minibatch.""" - # UNPACK TRAIN STATE AND BATCH INFO params, opt_states, key = train_state traj_batch, advantages, targets = batch_info def _actor_loss_fn( actor_params: FrozenDict, - actor_opt_state: OptState, traj_batch: PPOTransition, gae: chex.Array, key: chex.PRNGKey, ) -> Tuple: """Calculate the actor loss.""" - # RERUN NETWORK + # Rerun network actor_policy = actor_apply_fn(actor_params, traj_batch.obs) log_prob = actor_policy.log_prob(traj_batch.action) - # CALCULATE ACTOR LOSS + # Calculate actor loss ratio = jnp.exp(log_prob - traj_batch.log_prob) + # Nomalise advantage at minibatch level gae = (gae - gae.mean()) / (gae.std() + 1e-8) - loss_actor1 = ratio * gae - loss_actor2 = ( + actor_loss1 = ratio * gae + actor_loss2 = ( jnp.clip( ratio, 1.0 - config.system.clip_eps, @@ -172,25 +167,24 @@ def _actor_loss_fn( ) * gae ) - loss_actor = -jnp.minimum(loss_actor1, loss_actor2) - loss_actor = loss_actor.mean() + actor_loss = -jnp.minimum(actor_loss1, actor_loss2) + actor_loss = actor_loss.mean() # The seed will be used in the TanhTransformedDistribution: entropy = actor_policy.entropy(seed=key).mean() - total_loss_actor = loss_actor - config.system.ent_coef * entropy - return total_loss_actor, (loss_actor, entropy) + total_actor_loss = actor_loss - config.system.ent_coef * entropy + return total_actor_loss, (actor_loss, entropy) def _critic_loss_fn( critic_params: FrozenDict, - critic_opt_state: OptState, traj_batch: PPOTransition, targets: chex.Array, ) -> Tuple: """Calculate the critic loss.""" - # RERUN NETWORK + # Rerun network value = critic_apply_fn(critic_params, traj_batch.obs) - # CALCULATE VALUE LOSS + # Clipped MSE loss value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( -config.system.clip_eps, config.system.clip_eps ) @@ -198,29 +192,26 @@ def _critic_loss_fn( value_losses_clipped = jnp.square(value_pred_clipped - targets) value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() - critic_total_loss = config.system.vf_coef * value_loss - return critic_total_loss, (value_loss) + total_value_loss = config.system.vf_coef * value_loss + return total_value_loss, value_loss - # CALCULATE ACTOR LOSS + # Calculate actor loss key, entropy_key = jax.random.split(key) actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) actor_loss_info, actor_grads = actor_grad_fn( params.actor_params, - opt_states.actor_opt_state, traj_batch, advantages, entropy_key, ) - # CALCULATE CRITIC LOSS + # Calculate critic loss critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) - critic_loss_info, critic_grads = critic_grad_fn( - params.critic_params, opt_states.critic_opt_state, traj_batch, targets + value_loss_info, critic_grads = critic_grad_fn( + params.critic_params, traj_batch, targets ) # Compute the parallel mean (pmean) over the batch. - # This calculation is inspired by the Anakin architecture demo notebook. - # available at https://tinyurl.com/26tdzs5x # This pmean could be a regular mean as the batch axis is on the same device. actor_grads, actor_loss_info = jax.lax.pmean( (actor_grads, actor_loss_info), axis_name="batch" @@ -230,21 +221,20 @@ def _critic_loss_fn( (actor_grads, actor_loss_info), axis_name="device" ) - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="batch" + critic_grads, value_loss_info = jax.lax.pmean( + (critic_grads, value_loss_info), axis_name="batch" ) # pmean over devices. - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="device" + critic_grads, value_loss_info = jax.lax.pmean( + (critic_grads, value_loss_info), axis_name="device" ) - # UPDATE ACTOR PARAMS AND OPTIMISER STATE + # Update params and optimiser state actor_updates, actor_new_opt_state = actor_update_fn( actor_grads, opt_states.actor_opt_state ) actor_new_params = optax.apply_updates(params.actor_params, actor_updates) - # UPDATE CRITIC PARAMS AND OPTIMISER STATE critic_updates, critic_new_opt_state = critic_update_fn( critic_grads, opt_states.critic_opt_state ) @@ -253,14 +243,13 @@ def _critic_loss_fn( new_params = Params(actor_new_params, critic_new_params) new_opt_state = OptStates(actor_new_opt_state, critic_new_opt_state) - # PACK LOSS INFO - total_loss = actor_loss_info[0] + critic_loss_info[0] - value_loss = critic_loss_info[1] - actor_loss = actor_loss_info[1][0] - entropy = actor_loss_info[1][1] + actor_loss, (_, entropy) = actor_loss_info + value_loss, unscaled_value_loss = value_loss_info + + total_loss = actor_loss + value_loss loss_info = { "total_loss": total_loss, - "value_loss": value_loss, + "value_loss": unscaled_value_loss, "actor_loss": actor_loss, "entropy": entropy, } @@ -269,7 +258,7 @@ def _critic_loss_fn( params, opt_states, traj_batch, advantages, targets, key = update_state key, shuffle_key, entropy_key = jax.random.split(key, 3) - # SHUFFLE MINIBATCHES + # Shuffle minibatches batch_size = config.system.rollout_length * config.arch.num_envs permutation = jax.random.permutation(shuffle_key, batch_size) batch = (traj_batch, advantages, targets) @@ -280,7 +269,7 @@ def _critic_loss_fn( shuffled_batch, ) - # UPDATE MINIBATCHES + # Update minibatches (params, opt_states, entropy_key), loss_info = jax.lax.scan( _update_minibatch, (params, opt_states, entropy_key), minibatches ) @@ -290,15 +279,14 @@ def _critic_loss_fn( update_state = (params, opt_states, traj_batch, advantages, targets, key) - # UPDATE EPOCHS + # Update epochs update_state, loss_info = jax.lax.scan( _update_epoch, update_state, None, config.system.ppo_epochs ) params, opt_states, traj_batch, advantages, targets, key = update_state learner_state = LearnerState(params, opt_states, key, env_state, last_timestep) - metric = traj_batch.info - return learner_state, (metric, loss_info) + return learner_state, (episode_metrics, loss_info) def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: """Learner function. @@ -365,7 +353,7 @@ def learner_setup( optax.adam(critic_lr, eps=1e-5), ) - # Initialise observation with obs of all agents. + # Get mock inputs to initialise network. obs = env.observation_spec().generate_value() init_x = tree.map(lambda x: x[jnp.newaxis, ...], obs) diff --git a/mava/systems/ppo/anakin/rec_ippo.py b/mava/systems/ppo/anakin/rec_ippo.py index 212246080..279238382 100644 --- a/mava/systems/ppo/anakin/rec_ippo.py +++ b/mava/systems/ppo/anakin/rec_ippo.py @@ -26,7 +26,6 @@ from flax.core.frozen_dict import FrozenDict from jax import tree from omegaconf import DictConfig, OmegaConf -from optax._src.base import OptState from rich.pretty import pprint from mava.evaluator import get_eval_fn, get_num_eval_envs, make_rec_eval_act_fn @@ -44,6 +43,7 @@ ExperimentOutput, LearnerFn, MarlEnv, + Metrics, RecActorApply, RecCriticApply, ) @@ -91,7 +91,7 @@ def _update_step(learner_state: RNNLearnerState, _: Any) -> Tuple[RNNLearnerStat def _env_step( learner_state: RNNLearnerState, _: Any - ) -> Tuple[RNNLearnerState, RNNPPOTransition]: + ) -> Tuple[RNNLearnerState, Tuple[RNNPPOTransition, Metrics]]: """Step the environment.""" ( params, @@ -107,10 +107,7 @@ def _env_step( # Add a batch dimension to the observation. batched_observation = tree.map(lambda x: x[jnp.newaxis, :], last_timestep.observation) - ac_in = ( - batched_observation, - last_done[jnp.newaxis, :], - ) + ac_in = (batched_observation, last_done[jnp.newaxis, :]) # Run the network. policy_hidden_state, actor_policy = actor_apply_fn( @@ -128,13 +125,7 @@ def _env_step( # Step the environment. env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) - # log episode return and length - done = tree.map( - lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), - timestep.last(), - ) - info = timestep.extras["episode_metrics"] - + done = timestep.last().repeat(env.num_agents).reshape(config.arch.num_envs, -1) hstates = HiddenStates(policy_hidden_state, critic_hidden_state) transition = RNNPPOTransition( last_done, @@ -144,35 +135,23 @@ def _env_step( log_prob, last_timestep.observation, last_hstates, - info, ) learner_state = RNNLearnerState( params, opt_states, key, env_state, timestep, done, hstates ) - return learner_state, transition + return learner_state, (transition, timestep.extras["episode_metrics"]) - # STEP ENVIRONMENT FOR ROLLOUT LENGTH - learner_state, traj_batch = jax.lax.scan( + # Step environment for rollout length + learner_state, (traj_batch, episode_metrics) = jax.lax.scan( _env_step, learner_state, None, config.system.rollout_length ) - # CALCULATE ADVANTAGE - ( - params, - opt_states, - key, - env_state, - last_timestep, - last_done, - hstates, - ) = learner_state + # Calculate advantage + params, opt_states, key, env_state, last_timestep, last_done, hstates = learner_state # Add a batch dimension to the observation. batched_last_observation = tree.map(lambda x: x[jnp.newaxis, :], last_timestep.observation) - ac_in = ( - batched_last_observation, - last_done[jnp.newaxis, :], - ) + ac_in = (batched_last_observation, last_done[jnp.newaxis, :]) # Run the network. _, last_val = critic_apply_fn(params.critic_params, hstates.critic_hidden_state, ac_in) @@ -208,30 +187,29 @@ def _update_epoch(update_state: Tuple, _: Any) -> Tuple: def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: """Update the network for a single minibatch.""" - # UNPACK TRAIN STATE AND BATCH INFO params, opt_states, key = train_state traj_batch, advantages, targets = batch_info def _actor_loss_fn( actor_params: FrozenDict, - actor_opt_state: OptState, traj_batch: RNNPPOTransition, gae: chex.Array, key: chex.PRNGKey, ) -> Tuple: """Calculate the actor loss.""" - # RERUN NETWORK - + # Rerun network obs_and_done = (traj_batch.obs, traj_batch.done) _, actor_policy = actor_apply_fn( actor_params, traj_batch.hstates.policy_hidden_state[0], obs_and_done ) log_prob = actor_policy.log_prob(traj_batch.action) + # Calculate actor loss ratio = jnp.exp(log_prob - traj_batch.log_prob) + # Nomalise advantage at minibatch level gae = (gae - gae.mean()) / (gae.std() + 1e-8) - loss_actor1 = ratio * gae - loss_actor2 = ( + actor_loss1 = ratio * gae + actor_loss2 = ( jnp.clip( ratio, 1.0 - config.system.clip_eps, @@ -239,28 +217,27 @@ def _actor_loss_fn( ) * gae ) - loss_actor = -jnp.minimum(loss_actor1, loss_actor2) - loss_actor = loss_actor.mean() + actor_loss = -jnp.minimum(actor_loss1, actor_loss2) + actor_loss = actor_loss.mean() # The seed will be used in the TanhTransformedDistribution: entropy = actor_policy.entropy(seed=key).mean() - total_loss = loss_actor - config.system.ent_coef * entropy - return total_loss, (loss_actor, entropy) + total_loss = actor_loss - config.system.ent_coef * entropy + return total_loss, (actor_loss, entropy) def _critic_loss_fn( critic_params: FrozenDict, - critic_opt_state: OptState, traj_batch: RNNPPOTransition, targets: chex.Array, ) -> Tuple: """Calculate the critic loss.""" - # RERUN NETWORK + # Rerun network obs_and_done = (traj_batch.obs, traj_batch.done) _, value = critic_apply_fn( critic_params, traj_batch.hstates.critic_hidden_state[0], obs_and_done ) - # CALCULATE VALUE LOSS + # Clipped MSE loss value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( -config.system.clip_eps, config.system.clip_eps ) @@ -269,28 +246,25 @@ def _critic_loss_fn( value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() total_loss = config.system.vf_coef * value_loss - return total_loss, (value_loss) + return total_loss, value_loss - # CALCULATE ACTOR LOSS + # Calculate actor loss key, entropy_key = jax.random.split(key) actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) actor_loss_info, actor_grads = actor_grad_fn( params.actor_params, - opt_states.actor_opt_state, traj_batch, advantages, entropy_key, ) - # CALCULATE CRITIC LOSS + # Calculate critic loss critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) - critic_loss_info, critic_grads = critic_grad_fn( - params.critic_params, opt_states.critic_opt_state, traj_batch, targets + value_loss_info, critic_grads = critic_grad_fn( + params.critic_params, traj_batch, targets ) # Compute the parallel mean (pmean) over the batch. - # This calculation is inspired by the Anakin architecture demo notebook. - # available at https://tinyurl.com/26tdzs5x # This pmean could be a regular mean as the batch axis is on the same device. actor_grads, actor_loss_info = jax.lax.pmean( (actor_grads, actor_loss_info), axis_name="batch" @@ -300,21 +274,20 @@ def _critic_loss_fn( (actor_grads, actor_loss_info), axis_name="device" ) - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="batch" + critic_grads, value_loss_info = jax.lax.pmean( + (critic_grads, value_loss_info), axis_name="batch" ) # pmean over devices. - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="device" + critic_grads, value_loss_info = jax.lax.pmean( + (critic_grads, value_loss_info), axis_name="device" ) - # UPDATE ACTOR PARAMS AND OPTIMISER STATE + # Update params and optimiser state actor_updates, actor_new_opt_state = actor_update_fn( actor_grads, opt_states.actor_opt_state ) actor_new_params = optax.apply_updates(params.actor_params, actor_updates) - # UPDATE CRITIC PARAMS AND OPTIMISER STATE critic_updates, critic_new_opt_state = critic_update_fn( critic_grads, opt_states.critic_opt_state ) @@ -323,14 +296,13 @@ def _critic_loss_fn( new_params = Params(actor_new_params, critic_new_params) new_opt_state = OptStates(actor_new_opt_state, critic_new_opt_state) - # PACK LOSS INFO - total_loss = actor_loss_info[0] + critic_loss_info[0] - value_loss = critic_loss_info[1] - actor_loss = actor_loss_info[1][0] - entropy = actor_loss_info[1][1] + actor_loss, (_, entropy) = actor_loss_info + value_loss, unscaled_value_loss = value_loss_info + + total_loss = actor_loss + value_loss loss_info = { "total_loss": total_loss, - "value_loss": value_loss, + "value_loss": unscaled_value_loss, "actor_loss": actor_loss, "entropy": entropy, } @@ -340,7 +312,7 @@ def _critic_loss_fn( params, opt_states, traj_batch, advantages, targets, key = update_state key, shuffle_key, entropy_key = jax.random.split(key, 3) - # SHUFFLE MINIBATCHES + # Shuffle minibatches batch = (traj_batch, advantages, targets) num_recurrent_chunks = ( config.system.rollout_length // config.system.recurrent_chunk_size @@ -365,7 +337,7 @@ def _critic_loss_fn( ) minibatches = tree.map(lambda x: jnp.swapaxes(x, 1, 0), reshaped_batch) - # UPDATE MINIBATCHES + # Update minibatches (params, opt_states, entropy_key), loss_info = jax.lax.scan( _update_minibatch, (params, opt_states, entropy_key), minibatches ) @@ -389,7 +361,7 @@ def _critic_loss_fn( key, ) - # UPDATE EPOCHS + # Update epochs update_state, loss_info = jax.lax.scan( _update_epoch, update_state, None, config.system.ppo_epochs ) @@ -404,8 +376,7 @@ def _critic_loss_fn( last_done, hstates, ) - metric = traj_batch.info - return learner_state, (metric, loss_info) + return learner_state, (episode_metrics, loss_info) def learner_fn(learner_state: RNNLearnerState) -> ExperimentOutput[RNNLearnerState]: """Learner function. @@ -486,7 +457,7 @@ def learner_setup( optax.adam(critic_lr, eps=1e-5), ) - # Initialise observation with obs of all agents. + # Get mock inputs to initialise network. init_obs = env.observation_spec().generate_value() init_obs = tree.map( lambda x: jnp.repeat(x[jnp.newaxis, ...], config.arch.num_envs, axis=0), diff --git a/mava/systems/ppo/anakin/rec_mappo.py b/mava/systems/ppo/anakin/rec_mappo.py index b995ac167..96d7d74ac 100644 --- a/mava/systems/ppo/anakin/rec_mappo.py +++ b/mava/systems/ppo/anakin/rec_mappo.py @@ -26,7 +26,6 @@ from flax.core.frozen_dict import FrozenDict from jax import tree from omegaconf import DictConfig, OmegaConf -from optax._src.base import OptState from rich.pretty import pprint from mava.evaluator import get_eval_fn, get_num_eval_envs, make_rec_eval_act_fn @@ -44,6 +43,7 @@ ExperimentOutput, LearnerFn, MarlEnv, + Metrics, RecActorApply, RecCriticApply, ) @@ -91,7 +91,7 @@ def _update_step(learner_state: RNNLearnerState, _: Any) -> Tuple[RNNLearnerStat def _env_step( learner_state: RNNLearnerState, _: Any - ) -> Tuple[RNNLearnerState, RNNPPOTransition]: + ) -> Tuple[RNNLearnerState, Tuple[RNNPPOTransition, Metrics]]: """Step the environment.""" ( params, @@ -126,13 +126,7 @@ def _env_step( # Step the environment. env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) - # log episode return and length - done = tree.map( - lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), - timestep.last(), - ) - info = timestep.extras["episode_metrics"] - + done = timestep.last().repeat(env.num_agents).reshape(config.arch.num_envs, -1) hstates = HiddenStates(policy_hidden_state, critic_hidden_state) transition = RNNPPOTransition( last_done, @@ -142,28 +136,19 @@ def _env_step( log_prob, last_timestep.observation, last_hstates, - info, ) learner_state = RNNLearnerState( params, opt_states, key, env_state, timestep, done, hstates ) - return learner_state, transition + return learner_state, (transition, timestep.extras["episode_metrics"]) - # STEP ENVIRONMENT FOR ROLLOUT LENGTH - learner_state, traj_batch = jax.lax.scan( + # Step environment for rollout length + learner_state, (traj_batch, episode_metrics) = jax.lax.scan( _env_step, learner_state, None, config.system.rollout_length ) - # CALCULATE ADVANTAGE - ( - params, - opt_states, - key, - env_state, - last_timestep, - last_done, - hstates, - ) = learner_state + # Calculate advantage + params, opt_states, key, env_state, last_timestep, last_done, hstates = learner_state # Add a batch dimension to the observation. batched_last_observation = tree.map(lambda x: x[jnp.newaxis, :], last_timestep.observation) @@ -204,29 +189,29 @@ def _update_epoch(update_state: Tuple, _: Any) -> Tuple: def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: """Update the network for a single minibatch.""" - # UNPACK TRAIN STATE AND BATCH INFO params, opt_states, key = train_state traj_batch, advantages, targets = batch_info def _actor_loss_fn( actor_params: FrozenDict, - actor_opt_state: OptState, traj_batch: RNNPPOTransition, gae: chex.Array, key: chex.PRNGKey, ) -> Tuple: """Calculate the actor loss.""" - # RERUN NETWORK + # Rerun network obs_and_done = (traj_batch.obs, traj_batch.done) _, actor_policy = actor_apply_fn( actor_params, traj_batch.hstates.policy_hidden_state[0], obs_and_done ) log_prob = actor_policy.log_prob(traj_batch.action) + # Calculate actor loss ratio = jnp.exp(log_prob - traj_batch.log_prob) + # Nomalise advantage at minibatch level gae = (gae - gae.mean()) / (gae.std() + 1e-8) - loss_actor1 = ratio * gae - loss_actor2 = ( + actor_loss1 = ratio * gae + actor_loss2 = ( jnp.clip( ratio, 1.0 - config.system.clip_eps, @@ -234,28 +219,27 @@ def _actor_loss_fn( ) * gae ) - loss_actor = -jnp.minimum(loss_actor1, loss_actor2) - loss_actor = loss_actor.mean() + actor_loss = -jnp.minimum(actor_loss1, actor_loss2) + actor_loss = actor_loss.mean() # The seed will be used in the TanhTransformedDistribution: entropy = actor_policy.entropy(seed=key).mean() - total_loss = loss_actor - config.system.ent_coef * entropy - return total_loss, (loss_actor, entropy) + total_loss = actor_loss - config.system.ent_coef * entropy + return total_loss, (actor_loss, entropy) def _critic_loss_fn( critic_params: FrozenDict, - critic_opt_state: OptState, traj_batch: RNNPPOTransition, targets: chex.Array, ) -> Tuple: """Calculate the critic loss.""" - # RERUN NETWORK + # Rerun network obs_and_done = (traj_batch.obs, traj_batch.done) _, value = critic_apply_fn( critic_params, traj_batch.hstates.critic_hidden_state[0], obs_and_done ) - # CALCULATE VALUE LOSS + # Clipped MSE loss value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( -config.system.clip_eps, config.system.clip_eps ) @@ -264,28 +248,25 @@ def _critic_loss_fn( value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() total_loss = config.system.vf_coef * value_loss - return total_loss, (value_loss) + return total_loss, value_loss - # CALCULATE ACTOR LOSS + # Calculate actor loss key, entropy_key = jax.random.split(key) actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) actor_loss_info, actor_grads = actor_grad_fn( params.actor_params, - opt_states.actor_opt_state, traj_batch, advantages, entropy_key, ) - # CALCULATE CRITIC LOSS + # Calculate critic loss critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) - critic_loss_info, critic_grads = critic_grad_fn( - params.critic_params, opt_states.critic_opt_state, traj_batch, targets + value_loss_info, critic_grads = critic_grad_fn( + params.critic_params, traj_batch, targets ) # Compute the parallel mean (pmean) over the batch. - # This calculation is inspired by the Anakin architecture demo notebook. - # available at https://tinyurl.com/26tdzs5x # This pmean could be a regular mean as the batch axis is on the same device. actor_grads, actor_loss_info = jax.lax.pmean( (actor_grads, actor_loss_info), axis_name="batch" @@ -295,21 +276,20 @@ def _critic_loss_fn( (actor_grads, actor_loss_info), axis_name="device" ) - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="batch" + critic_grads, value_loss_info = jax.lax.pmean( + (critic_grads, value_loss_info), axis_name="batch" ) # pmean over devices. - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="device" + critic_grads, value_loss_info = jax.lax.pmean( + (critic_grads, value_loss_info), axis_name="device" ) - # UPDATE ACTOR PARAMS AND OPTIMISER STATE + # Update params and optimiser state actor_updates, actor_new_opt_state = actor_update_fn( actor_grads, opt_states.actor_opt_state ) actor_new_params = optax.apply_updates(params.actor_params, actor_updates) - # UPDATE CRITIC PARAMS AND OPTIMISER STATE critic_updates, critic_new_opt_state = critic_update_fn( critic_grads, opt_states.critic_opt_state ) @@ -318,14 +298,13 @@ def _critic_loss_fn( new_params = Params(actor_new_params, critic_new_params) new_opt_state = OptStates(actor_new_opt_state, critic_new_opt_state) - # PACK LOSS INFO - total_loss = actor_loss_info[0] + critic_loss_info[0] - value_loss = critic_loss_info[1] - actor_loss = actor_loss_info[1][0] - entropy = actor_loss_info[1][1] + actor_loss, (_, entropy) = actor_loss_info + value_loss, unscaled_value_loss = value_loss_info + + total_loss = actor_loss + value_loss loss_info = { "total_loss": total_loss, - "value_loss": value_loss, + "value_loss": unscaled_value_loss, "actor_loss": actor_loss, "entropy": entropy, } @@ -335,7 +314,7 @@ def _critic_loss_fn( params, opt_states, traj_batch, advantages, targets, key = update_state key, shuffle_key, entropy_key = jax.random.split(key, 3) - # SHUFFLE MINIBATCHES + # Shuffle minibatches batch = (traj_batch, advantages, targets) num_recurrent_chunks = ( config.system.rollout_length // config.system.recurrent_chunk_size @@ -360,7 +339,7 @@ def _critic_loss_fn( ) minibatches = tree.map(lambda x: jnp.swapaxes(x, 1, 0), reshaped_batch) - # UPDATE MINIBATCHES + # Update minibatches (params, opt_states, entropy_key), loss_info = jax.lax.scan( _update_minibatch, (params, opt_states, entropy_key), minibatches ) @@ -384,7 +363,7 @@ def _critic_loss_fn( key, ) - # UPDATE EPOCHS + # Update epochs update_state, loss_info = jax.lax.scan( _update_epoch, update_state, None, config.system.ppo_epochs ) @@ -399,8 +378,7 @@ def _critic_loss_fn( last_done, hstates, ) - metric = traj_batch.info - return learner_state, (metric, loss_info) + return learner_state, (episode_metrics, loss_info) def learner_fn(learner_state: RNNLearnerState) -> ExperimentOutput[RNNLearnerState]: """Learner function. @@ -482,7 +460,7 @@ def learner_setup( optax.adam(critic_lr, eps=1e-5), ) - # Initialise observation with obs of all agents. + # Get mock inputs to initialise network. init_obs = env.observation_spec().generate_value() init_obs = tree.map( lambda x: jnp.repeat(x[jnp.newaxis, ...], config.arch.num_envs, axis=0), diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 11bfd4b26..996821cd0 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -45,7 +45,7 @@ from mava.types import ( ActorApply, CriticApply, - ExperimentOutput, + Metrics, Observation, SebulbaLearnerFn, ) @@ -113,6 +113,7 @@ def act_fn( while not thread_lifetime.should_stop(): # Rollout traj: List[PPOTransition] = [] + episode_metrics: List[Dict] = [] actor_timings: Dict[str, List[float]] = defaultdict(list) with RecordTimeTo(actor_timings["rollout_time"]): for _ in range(config.system.rollout_length): @@ -142,14 +143,14 @@ def act_fn( timestep.reward, log_prob, obs_tpu, - timestep.extras["episode_metrics"], ) ) + episode_metrics.append(timestep.extras["episode_metrics"]) # send trajectories to learner with RecordTimeTo(actor_timings["rollout_put_time"]): try: - rollout_queue.put(traj, timestep, actor_timings) + rollout_queue.put(traj, timestep, (actor_timings, episode_metrics)) except queue.Full: err = "Waited too long to add to the rollout queue, killing the actor thread" warnings.warn(err, stacklevel=2) @@ -175,7 +176,7 @@ def get_learner_step_fn( def _update_step( learner_state: LearnerState, traj_batch: PPOTransition, - ) -> Tuple[LearnerState, Tuple]: + ) -> Tuple[LearnerState, Metrics]: """A single update of the network. This function calculates advantages and targets based on the trajectories @@ -216,7 +217,7 @@ def _get_advantages(gae_and_next_value: Tuple, transition: PPOTransition) -> Tup last_val = critic_apply_fn(params.critic_params, final_timestep.observation) advantages, targets = _calculate_gae(traj_batch, last_val) - def _update_epoch(update_state: Tuple, _: Any) -> Tuple: + def _update_epoch(update_state: Tuple, _: Any) -> Tuple[Tuple, Metrics]: """Update the network for a single epoch.""" def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: @@ -240,8 +241,8 @@ def _actor_loss_fn( # Calculate actor loss ratio = jnp.exp(log_prob - traj_batch.log_prob) gae = (gae - gae.mean()) / (gae.std() + 1e-8) - loss_actor1 = ratio * gae - loss_actor2 = ( + actor_loss1 = ratio * gae + actor_loss2 = ( jnp.clip( ratio, 1.0 - config.system.clip_eps, @@ -249,13 +250,13 @@ def _actor_loss_fn( ) * gae ) - loss_actor = -jnp.minimum(loss_actor1, loss_actor2) - loss_actor = loss_actor.mean() + actor_loss = -jnp.minimum(actor_loss1, actor_loss2) + actor_loss = actor_loss.mean() # The seed will be used in the TanhTransformedDistribution: entropy = actor_policy.entropy(seed=key).mean() - total_loss_actor = loss_actor - config.system.ent_coef * entropy - return total_loss_actor, (loss_actor, entropy) + total_actor_loss = actor_loss - config.system.ent_coef * entropy + return total_actor_loss, (actor_loss, entropy) def _critic_loss_fn( critic_params: FrozenDict, traj_batch: PPOTransition, targets: chex.Array @@ -272,8 +273,8 @@ def _critic_loss_fn( value_losses_clipped = jnp.square(value_pred_clipped - targets) value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() - critic_total_loss = config.system.vf_coef * value_loss - return critic_total_loss, (value_loss) + total_value_loss = config.system.vf_coef * value_loss + return total_value_loss, value_loss # Calculate actor loss key, entropy_key = jax.random.split(key) @@ -284,7 +285,7 @@ def _critic_loss_fn( # Calculate critic loss critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) - critic_loss_info, critic_grads = critic_grad_fn( + value_loss_info, critic_grads = critic_grad_fn( params.critic_params, traj_batch, targets ) @@ -298,8 +299,8 @@ def _critic_loss_fn( ) # pmean over learner devices. - critic_grads, critic_loss_info = jax.lax.pmean( - (critic_grads, critic_loss_info), axis_name="learner_devices" + critic_grads, value_loss_info = jax.lax.pmean( + (critic_grads, value_loss_info), axis_name="learner_devices" ) # Update actor params and optimiser state @@ -318,12 +319,12 @@ def _critic_loss_fn( new_params = Params(actor_new_params, critic_new_params) new_opt_state = OptStates(actor_new_opt_state, critic_new_opt_state) # Pack loss info - actor_total_loss, (actor_loss, entropy) = actor_loss_info - critic_total_loss, (value_loss) = critic_loss_info - total_loss = critic_total_loss + actor_total_loss + actor_loss, (_, entropy) = actor_loss_info + value_loss, unscaled_value_loss = value_loss_info + total_loss = actor_loss + value_loss loss_info = { "total_loss": total_loss, - "value_loss": value_loss, + "value_loss": unscaled_value_loss, "actor_loss": actor_loss, "entropy": entropy, } @@ -359,12 +360,11 @@ def _critic_loss_fn( params, opt_states, traj_batch, advantages, targets, key = update_state learner_state = LearnerState(params, opt_states, key, None, learner_state.timestep) - metric = traj_batch.info - return learner_state, (metric, loss_info) + return learner_state, loss_info def learner_fn( learner_state: LearnerState, traj_batch: PPOTransition - ) -> ExperimentOutput[LearnerState]: + ) -> Tuple[LearnerState, Metrics]: """Learner function. This function represents the learner, it updates the network parameters @@ -382,13 +382,9 @@ def learner_fn( # This function is shard mapped on the batch axis, but `_update_step` needs # the first axis to be time traj_batch = tree.map(switch_leading_axes, traj_batch) - learner_state, (episode_info, loss_info) = _update_step(learner_state, traj_batch) + learner_state, loss_info = _update_step(learner_state, traj_batch) - return ExperimentOutput( - learner_state=learner_state, - episode_metrics=episode_info, - train_metrics=loss_info, - ) + return learner_state, loss_info return learner_fn @@ -412,7 +408,7 @@ def learner_thread( # Get the trajectory batch from the pipeline # This is blocking so it will wait until the pipeline has data. with RecordTimeTo(learn_times["rollout_get_time"]): - traj_batch, timestep, rollout_time = pipeline.get(block=True) + traj_batch, timestep, rollout_time, ep_metrics = pipeline.get(block=True) # Replace the timestep in the learner state with the latest timestep # This means the learner has access to the entire trajectory as well as @@ -420,7 +416,7 @@ def learner_thread( learner_state = learner_state._replace(timestep=timestep) # Update the networks with RecordTimeTo(learn_times["learning_time"]): - learner_state, ep_metrics, train_metrics = learn_fn(learner_state, traj_batch) + learner_state, train_metrics = learn_fn(learner_state, traj_batch) metrics.append((ep_metrics, train_metrics)) rollout_times_array.append(rollout_time) @@ -515,7 +511,7 @@ def learner_setup( learn, mesh=mesh, in_specs=(learn_state_spec, data_spec), - out_specs=ExperimentOutput(learn_state_spec, data_spec, data_spec), + out_specs=(learn_state_spec, data_spec), ) ) diff --git a/mava/systems/ppo/types.py b/mava/systems/ppo/types.py index c8145b1a7..9e56e17f8 100644 --- a/mava/systems/ppo/types.py +++ b/mava/systems/ppo/types.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict import chex from flax.core.frozen_dict import FrozenDict @@ -75,7 +74,6 @@ class PPOTransition(NamedTuple): reward: chex.Array log_prob: chex.Array obs: Observation - info: Dict class RNNPPOTransition(NamedTuple): @@ -88,4 +86,3 @@ class RNNPPOTransition(NamedTuple): log_prob: chex.Array obs: chex.Array hstates: HiddenStates - info: Dict diff --git a/mava/systems/sable/anakin/ff_sable.py b/mava/systems/sable/anakin/ff_sable.py index 24951079e..33547523c 100644 --- a/mava/systems/sable/anakin/ff_sable.py +++ b/mava/systems/sable/anakin/ff_sable.py @@ -26,7 +26,6 @@ from colorama import Fore, Style from flax.core.frozen_dict import FrozenDict as Params from jax import tree -from jumanji.env import Environment from jumanji.types import TimeStep from omegaconf import DictConfig, OmegaConf from rich.pretty import pprint @@ -34,13 +33,13 @@ from mava.evaluator import ActorState, EvalActFn, get_eval_fn, get_num_eval_envs from mava.networks import SableNetwork from mava.networks.utils.sable import get_init_hidden_state +from mava.systems.ppo.types import PPOTransition as Transition from mava.systems.sable.types import ( ActorApply, LearnerApply, - Transition, ) from mava.systems.sable.types import FFLearnerState as LearnerState -from mava.types import Action, ExperimentOutput, LearnerFn, MarlEnv +from mava.types import Action, ExperimentOutput, LearnerFn, MarlEnv, Metrics from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer from mava.utils.config import check_total_timesteps @@ -52,7 +51,7 @@ def get_learner_fn( - env: Environment, + env: MarlEnv, apply_fns: Tuple[ActorApply, LearnerApply], update_fn: optax.TransformUpdateFn, config: DictConfig, @@ -82,11 +81,13 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup """ - def _env_step(learner_state: LearnerState, _: int) -> Tuple[LearnerState, Transition]: + def _env_step( + learner_state: LearnerState, _: int + ) -> Tuple[LearnerState, Tuple[Transition, Metrics]]: """Step the environment.""" params, opt_states, key, env_state, last_timestep = learner_state - # SELECT ACTION + # Select action key, policy_key = jax.random.split(key) # Apply the actor network to get the action, log_prob, value and updated hstates. @@ -97,20 +98,10 @@ def _env_step(learner_state: LearnerState, _: int) -> Tuple[LearnerState, Transi key=policy_key, ) - # STEP ENVIRONMENT + # Step environment env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) - # LOG EPISODE METRICS - info = tree.map( - lambda x: jnp.repeat(x[..., jnp.newaxis], config.system.num_agents, axis=-1), - timestep.extras["episode_metrics"], - ) - - # SET TRANSITION - done = tree.map( - lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), - timestep.last(), - ) + done = timestep.last().repeat(env.num_agents).reshape(config.arch.num_envs, -1) transition = Transition( done, action, @@ -118,23 +109,19 @@ def _env_step(learner_state: LearnerState, _: int) -> Tuple[LearnerState, Transi timestep.reward, log_prob, last_timestep.observation, - info, ) learner_state = LearnerState(params, opt_states, key, env_state, timestep) - return learner_state, transition - - # STEP ENVIRONMENT FOR ROLLOUT LENGTH - learner_state, traj_batch = jax.lax.scan( - _env_step, - learner_state, - jnp.arange(config.system.rollout_length), - config.system.rollout_length, + return learner_state, (transition, timestep.extras["episode_metrics"]) + + # Step environment for rollout length + learner_state, (traj_batch, episode_metrics) = jax.lax.scan( + _env_step, learner_state, length=config.system.rollout_length ) - # CALCULATE ADVANTAGE + # Calculate advantage params, opt_states, key, env_state, last_timestep = learner_state key, last_val_key = jax.random.split(key) - _, _, current_val, _ = sable_action_select_fn( # type: ignore + _, _, last_val, _ = sable_action_select_fn( # type: ignore params, observation=last_timestep.observation, key=last_val_key, @@ -170,14 +157,13 @@ def _get_advantages( ) return advantages, advantages + traj_batch.value - advantages, targets = _calculate_gae(traj_batch, current_val) + advantages, targets = _calculate_gae(traj_batch, last_val) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: """Update the network for a single epoch.""" def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: """Update the network for a single minibatch.""" - # UNPACK TRAIN STATE AND BATCH INFO params, opt_state, key = train_state traj_batch, advantages, targets = batch_info @@ -189,7 +175,7 @@ def _loss_fn( rng_key: chex.PRNGKey, ) -> Tuple: """Calculate Sable loss.""" - # RERUN NETWORK + # Rerun network value, log_prob, entropy = sable_apply_fn( # type: ignore params, observation=traj_batch.obs, @@ -198,11 +184,12 @@ def _loss_fn( rng_key=rng_key, ) - # CALCULATE ACTOR LOSS + # Calculate actor loss ratio = jnp.exp(log_prob - traj_batch.log_prob) + # Nomalise advantage at minibatch level gae = (gae - gae.mean()) / (gae.std() + 1e-8) - loss_actor1 = ratio * gae - loss_actor2 = ( + actor_loss1 = ratio * gae + actor_loss2 = ( jnp.clip( ratio, 1.0 - config.system.clip_eps, @@ -210,50 +197,41 @@ def _loss_fn( ) * gae ) - loss_actor = -jnp.minimum(loss_actor1, loss_actor2) - loss_actor = loss_actor.mean() + actor_loss = -jnp.minimum(actor_loss1, actor_loss2) + actor_loss = actor_loss.mean() entropy = entropy.mean() - # CALCULATE VALUE LOSS + # Clipped MSE loss value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( -config.system.clip_eps, config.system.clip_eps ) - - # MSE LOSS value_losses = jnp.square(value - value_targets) value_losses_clipped = jnp.square(value_pred_clipped - value_targets) value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() - # TOTAL LOSS total_loss = ( - loss_actor + actor_loss - config.system.ent_coef * entropy + config.system.vf_coef * value_loss ) - return total_loss, (loss_actor, entropy, value_loss) + return total_loss, (actor_loss, entropy, value_loss) - # CALCULATE ACTOR LOSS + # Calculate loss key, entropy_key = jax.random.split(key) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) loss_info, grads = grad_fn(params, traj_batch, advantages, targets, entropy_key) # Compute the parallel mean (pmean) over the batch. - # This calculation is inspired by the Anakin architecture demo notebook. - # available at https://tinyurl.com/26tdzs5x # This pmean could be a regular mean as the batch axis is on the same device. grads, loss_info = jax.lax.pmean((grads, loss_info), axis_name="batch") # pmean over devices. grads, loss_info = jax.lax.pmean((grads, loss_info), axis_name="device") - # UPDATE PARAMS AND OPTIMISER STATE + # Update params and optimiser state updates, new_opt_state = update_fn(grads, opt_state) new_params = optax.apply_updates(params, updates) - # PACK LOSS INFO - total_loss = loss_info[0] - actor_loss = loss_info[1][0] - entropy = loss_info[1][1] - value_loss = loss_info[1][2] + total_loss, (actor_loss, entropy, value_loss) = loss_info loss_info = { "total_loss": total_loss, "value_loss": value_loss, @@ -263,16 +241,9 @@ def _loss_fn( return (new_params, new_opt_state, key), loss_info - ( - params, - opt_states, - traj_batch, - advantages, - targets, - key, - ) = update_state + (params, opt_states, traj_batch, advantages, targets, key) = update_state - # SHUFFLE MINIBATCHES + # Shuffle minibatches key, batch_shuffle_key, agent_shuffle_key, entropy_key = jax.random.split(key, 4) # Shuffle batch @@ -286,39 +257,25 @@ def _loss_fn( agent_perm = jax.random.permutation(agent_shuffle_key, config.system.num_agents) shuffled_batch = tree.map(lambda x: jnp.take(x, agent_perm, axis=1), shuffled_batch) - # SPLIT INTO MINIBATCHES + # Split into minibatches minibatches = tree.map( lambda x: jnp.reshape(x, (config.system.num_minibatches, -1, *x.shape[1:])), shuffled_batch, ) - # UPDATE MINIBATCHES + # Update minibatches (params, opt_states, entropy_key), loss_info = jax.lax.scan( _update_minibatch, (params, opt_states, entropy_key), minibatches, ) - update_state = ( - params, - opt_states, - traj_batch, - advantages, - targets, - key, - ) + update_state = (params, opt_states, traj_batch, advantages, targets, key) return update_state, loss_info - update_state = ( - params, - opt_states, - traj_batch, - advantages, - targets, - key, - ) + update_state = (params, opt_states, traj_batch, advantages, targets, key) - # UPDATE EPOCHS + # Update epochs update_state, loss_info = jax.lax.scan( _update_epoch, update_state, None, config.system.ppo_epochs ) @@ -331,8 +288,7 @@ def _loss_fn( env_state, last_timestep, ) - metric = traj_batch.info - return learner_state, (metric, loss_info) + return learner_state, (episode_metrics, loss_info) def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: """Learner function. diff --git a/mava/systems/sable/anakin/rec_sable.py b/mava/systems/sable/anakin/rec_sable.py index 34c78fc7c..b1eda03f0 100644 --- a/mava/systems/sable/anakin/rec_sable.py +++ b/mava/systems/sable/anakin/rec_sable.py @@ -26,7 +26,6 @@ from colorama import Fore, Style from flax.core.frozen_dict import FrozenDict as Params from jax import tree -from jumanji.env import Environment from jumanji.types import TimeStep from omegaconf import DictConfig, OmegaConf from rich.pretty import pprint @@ -34,14 +33,14 @@ from mava.evaluator import ActorState, EvalActFn, get_eval_fn, get_num_eval_envs from mava.networks import SableNetwork from mava.networks.utils.sable import get_init_hidden_state +from mava.systems.ppo.types import PPOTransition as Transition from mava.systems.sable.types import ( ActorApply, HiddenStates, LearnerApply, - Transition, ) from mava.systems.sable.types import RecLearnerState as LearnerState -from mava.types import Action, ExperimentOutput, LearnerFn, MarlEnv +from mava.types import Action, ExperimentOutput, LearnerFn, MarlEnv, Metrics from mava.utils import make_env as environments from mava.utils.checkpointing import Checkpointer from mava.utils.config import check_total_timesteps @@ -53,7 +52,7 @@ def get_learner_fn( - env: Environment, + env: MarlEnv, apply_fns: Tuple[ActorApply, LearnerApply], update_fn: optax.TransformUpdateFn, config: DictConfig, @@ -62,6 +61,7 @@ def get_learner_fn( # Get apply functions for executing and training the network. sable_action_select_fn, sable_apply_fn = apply_fns + num_envs = config.arch.num_envs def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tuple]: """A single update of the network. @@ -84,11 +84,13 @@ def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tup """ - def _env_step(learner_state: LearnerState, _: int) -> Tuple[LearnerState, Transition]: + def _env_step( + learner_state: LearnerState, _: Any + ) -> Tuple[LearnerState, Tuple[Transition, Metrics]]: """Step the environment.""" params, opt_states, key, env_state, last_timestep, hstates = learner_state - # SELECT ACTION + # Select action key, policy_key = jax.random.split(key) # Apply the actor network to get the action, log_prob, value and updated hstates. @@ -100,58 +102,36 @@ def _env_step(learner_state: LearnerState, _: int) -> Tuple[LearnerState, Transi policy_key, ) - # STEP ENVIRONMENT + # Step environment env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) - # LOG EPISODE METRICS - info = tree.map( - lambda x: jnp.repeat(x[..., jnp.newaxis], config.system.num_agents, axis=-1), - timestep.extras["episode_metrics"], - ) - # Reset hidden state if done. done = timestep.last() done = jnp.expand_dims(done, (1, 2, 3, 4)) hstates = tree.map(lambda hs: jnp.where(done, jnp.zeros_like(hs), hs), hstates) - # SET TRANSITION - prev_done = tree.map( - lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), - last_timestep.last(), - ) + prev_done = last_timestep.last().repeat(env.num_agents).reshape(num_envs, -1) transition = Transition( - prev_done, - action, - value, - timestep.reward, - log_prob, - last_timestep.observation, - info, + prev_done, action, value, timestep.reward, log_prob, last_timestep.observation ) learner_state = LearnerState(params, opt_states, key, env_state, timestep, hstates) - return learner_state, transition + return learner_state, (transition, timestep.extras["episode_metrics"]) - # COPY OLD HIDDEN STATES: TO BE USED IN THE TRAINING LOOP + # Copy old hidden states: to be used in the training loop prev_hstates = tree.map(lambda x: jnp.copy(x), learner_state.hstates) - # STEP ENVIRONMENT FOR ROLLOUT LENGTH - learner_state, traj_batch = jax.lax.scan( - _env_step, - learner_state, - jnp.arange(config.system.rollout_length), - config.system.rollout_length, + # Step environment for rollout length + learner_state, (traj_batch, episode_metrics) = jax.lax.scan( + _env_step, learner_state, length=config.system.rollout_length ) - # CALCULATE ADVANTAGE + # Calculate advantage params, opt_states, key, env_state, last_timestep, updated_hstates = learner_state key, last_val_key = jax.random.split(key) - _, _, current_val, _ = sable_action_select_fn( # type: ignore + _, _, last_val, _ = sable_action_select_fn( # type: ignore params, last_timestep.observation, updated_hstates, last_val_key ) - current_done = tree.map( - lambda x: jnp.repeat(x, config.system.num_agents).reshape(config.arch.num_envs, -1), - last_timestep.last(), - ) + last_done = last_timestep.last().repeat(env.num_agents).reshape(num_envs, -1) def _calculate_gae( traj_batch: Transition, @@ -184,14 +164,13 @@ def _get_advantages( ) return advantages, advantages + traj_batch.value - advantages, targets = _calculate_gae(traj_batch, current_val, current_done) + advantages, targets = _calculate_gae(traj_batch, last_val, last_done) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: """Update the network for a single epoch.""" def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: """Update the network for a single minibatch.""" - # UNPACK TRAIN STATE AND BATCH INFO params, opt_state, key = train_state traj_batch, advantages, targets, prev_hstates = batch_info @@ -204,7 +183,7 @@ def _loss_fn( rng_key: chex.PRNGKey, ) -> Tuple: """Calculate Sable loss.""" - # RERUN NETWORK + # Rerun network value, log_prob, entropy = sable_apply_fn( # type: ignore params, traj_batch.obs, @@ -214,11 +193,12 @@ def _loss_fn( rng_key, ) - # CALCULATE ACTOR LOSS + # Calculate actor loss ratio = jnp.exp(log_prob - traj_batch.log_prob) + # Nomalise advantage at minibatch level gae = (gae - gae.mean()) / (gae.std() + 1e-8) - loss_actor1 = ratio * gae - loss_actor2 = ( + actor_loss1 = ratio * gae + actor_loss2 = ( jnp.clip( ratio, 1.0 - config.system.clip_eps, @@ -226,29 +206,26 @@ def _loss_fn( ) * gae ) - loss_actor = -jnp.minimum(loss_actor1, loss_actor2) - loss_actor = loss_actor.mean() + actor_loss = -jnp.minimum(actor_loss1, actor_loss2) + actor_loss = actor_loss.mean() entropy = entropy.mean() - # CALCULATE VALUE LOSS + # Clipped MSE loss value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( -config.system.clip_eps, config.system.clip_eps ) - - # MSE LOSS value_losses = jnp.square(value - value_targets) value_losses_clipped = jnp.square(value_pred_clipped - value_targets) value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() - # TOTAL LOSS total_loss = ( - loss_actor + actor_loss - config.system.ent_coef * entropy + config.system.vf_coef * value_loss ) - return total_loss, (loss_actor, entropy, value_loss) + return total_loss, (actor_loss, entropy, value_loss) - # CALCULATE ACTOR LOSS + # Calculate loss key, entropy_key = jax.random.split(key) grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) loss_info, grads = grad_fn( @@ -261,22 +238,16 @@ def _loss_fn( ) # Compute the parallel mean (pmean) over the batch. - # This calculation is inspired by the Anakin architecture demo notebook. - # available at https://tinyurl.com/26tdzs5x # This pmean could be a regular mean as the batch axis is on the same device. grads, loss_info = jax.lax.pmean((grads, loss_info), axis_name="batch") # pmean over devices. grads, loss_info = jax.lax.pmean((grads, loss_info), axis_name="device") - # UPDATE PARAMS AND OPTIMISER STATE + # Update params and optimiser state updates, new_opt_state = update_fn(grads, opt_state) new_params = optax.apply_updates(params, updates) - # PACK LOSS INFO - total_loss = loss_info[0] - actor_loss = loss_info[1][0] - entropy = loss_info[1][1] - value_loss = loss_info[1][2] + total_loss, (actor_loss, entropy, value_loss) = loss_info loss_info = { "total_loss": total_loss, "value_loss": value_loss, @@ -286,17 +257,9 @@ def _loss_fn( return (new_params, new_opt_state, key), loss_info - ( - params, - opt_states, - traj_batch, - advantages, - targets, - key, - prev_hstates, - ) = update_state + (params, opt_states, traj_batch, advantages, targets, key, prev_hstates) = update_state - # SHUFFLE MINIBATCHES + # Shuffle minibatches key, batch_shuffle_key, agent_shuffle_key, entropy_key = jax.random.split(key, 4) # Shuffle batch @@ -312,10 +275,10 @@ def _loss_fn( agent_perm = jax.random.permutation(agent_shuffle_key, config.system.num_agents) batch = tree.map(lambda x: jnp.take(x, agent_perm, axis=2), batch) - # CONCATENATE TIME AND AGENTS + # Concatenate time and agents batch = tree.map(concat_time_and_agents, batch) - # SPLIT INTO MINIBATCHES + # Split into minibatches minibatches = tree.map( lambda x: jnp.reshape(x, (config.system.num_minibatches, -1, *x.shape[1:])), batch, @@ -332,28 +295,12 @@ def _loss_fn( (*minibatches, prev_hs_minibatch), ) - update_state = ( - params, - opt_states, - traj_batch, - advantages, - targets, - key, - prev_hstates, - ) + update_state = (params, opt_states, traj_batch, advantages, targets, key, prev_hstates) return update_state, loss_info - update_state = ( - params, - opt_states, - traj_batch, - advantages, - targets, - key, - prev_hstates, - ) + update_state = (params, opt_states, traj_batch, advantages, targets, key, prev_hstates) - # UPDATE EPOCHS + # Update epochs update_state, loss_info = jax.lax.scan( _update_epoch, update_state, None, config.system.ppo_epochs ) @@ -367,8 +314,7 @@ def _loss_fn( last_timestep, updated_hstates, ) - metric = traj_batch.info - return learner_state, (metric, loss_info) + return learner_state, (episode_metrics, loss_info) def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: """Learner function. diff --git a/mava/systems/sable/types.py b/mava/systems/sable/types.py index c93d3bf48..7021d1bef 100644 --- a/mava/systems/sable/types.py +++ b/mava/systems/sable/types.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Tuple +from typing import Callable, Tuple from chex import Array, PRNGKey from flax.core.frozen_dict import FrozenDict @@ -58,18 +58,6 @@ class FFLearnerState(NamedTuple): timestep: TimeStep -class Transition(NamedTuple): - """Transition tuple.""" - - done: Array - action: Array - value: Array - reward: Array - log_prob: Array - obs: Array - info: Dict - - ActorApply = Callable[ [FrozenDict, Array, Array, HiddenStates, PRNGKey], Tuple[Array, Array, Array, Array, HiddenStates], diff --git a/mava/types.py b/mava/types.py index 4072629dc..d60175f50 100644 --- a/mava/types.py +++ b/mava/types.py @@ -152,7 +152,7 @@ class ExperimentOutput(NamedTuple, Generic[MavaState]): LearnerFn = Callable[[MavaState], ExperimentOutput[MavaState]] -SebulbaLearnerFn = Callable[[MavaState, MavaTransition], ExperimentOutput[MavaState]] +SebulbaLearnerFn = Callable[[MavaState, MavaTransition], Tuple[MavaState, Metrics]] ActorApply = Callable[[FrozenDict, Observation], Distribution] CriticApply = Callable[[FrozenDict, Observation], Value] RecActorApply = Callable[ diff --git a/mava/utils/sebulba.py b/mava/utils/sebulba.py index dc51140f5..1fe441e2a 100644 --- a/mava/utils/sebulba.py +++ b/mava/utils/sebulba.py @@ -27,6 +27,7 @@ # todo: remove the ppo dependencies when we make sebulba for other systems from mava.systems.ppo.types import Params, PPOTransition +from mava.types import Metrics QUEUE_PUT_TIMEOUT = 100 @@ -90,7 +91,9 @@ def run(self) -> None: except queue.Empty: continue - def put(self, traj: Sequence[PPOTransition], timestep: TimeStep, time_dict: Dict) -> None: + def put( + self, traj: Sequence[PPOTransition], timestep: TimeStep, metrics: Tuple[Dict, List[Dict]] + ) -> None: """Put a trajectory on the queue to be consumed by the learner.""" start_condition, end_condition = (threading.Condition(), threading.Condition()) with start_condition: @@ -101,6 +104,10 @@ def put(self, traj: Sequence[PPOTransition], timestep: TimeStep, time_dict: Dict traj = _stack_trajectory(traj) traj, timestep = jax.device_put((traj, timestep), device=self.sharding) + time_dict, episode_metrics = metrics + # [{'metric1' : value1, ...} * rollout_len -> {'metric1' : [value1, value2, ...], ...} + episode_metrics = _stack_trajectory(episode_metrics) + # We block on the `put` to ensure that actors wait for the learners to catch up. # This ensures two things: # The actors don't get too far ahead of the learners, which could lead to off-policy data. @@ -110,7 +117,7 @@ def put(self, traj: Sequence[PPOTransition], timestep: TimeStep, time_dict: Dict # We use a try-finally so the lock is released even if an exception is raised. try: self._queue.put( - (traj, timestep, time_dict), + (traj, timestep, time_dict, episode_metrics), block=True, timeout=QUEUE_PUT_TIMEOUT, ) @@ -129,7 +136,7 @@ def qsize(self) -> int: def get( self, block: bool = True, timeout: Union[float, None] = None - ) -> Tuple[PPOTransition, TimeStep, Dict]: + ) -> Tuple[PPOTransition, TimeStep, Dict, Metrics]: """Get a trajectory from the pipeline.""" return self._queue.get(block, timeout) # type: ignore