diff --git a/mava/systems/ppo/anakin/ff_ippo.py b/mava/systems/ppo/anakin/ff_ippo.py index 07ef5d39d..e55c9076c 100644 --- a/mava/systems/ppo/anakin/ff_ippo.py +++ b/mava/systems/ppo/anakin/ff_ippo.py @@ -38,6 +38,7 @@ from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger +from mava.utils.multistep import calculate_gae from mava.utils.network_utils import get_action_head from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -78,7 +79,7 @@ def _env_step( learner_state: LearnerState, _: Any ) -> Tuple[LearnerState, Tuple[PPOTransition, Metrics]]: """Step the environment.""" - params, opt_states, key, env_state, last_timestep = learner_state + params, opt_states, key, env_state, last_timestep, last_done = learner_state # Select action key, policy_key = jax.random.split(key) @@ -93,9 +94,9 @@ def _env_step( done = timestep.last().repeat(env.num_agents).reshape(config.arch.num_envs, -1) transition = PPOTransition( - done, action, value, timestep.reward, log_prob, last_timestep.observation + last_done, action, value, timestep.reward, log_prob, last_timestep.observation ) - learner_state = LearnerState(params, opt_states, key, env_state, timestep) + learner_state = LearnerState(params, opt_states, key, env_state, timestep, done) return learner_state, (transition, timestep.extras["episode_metrics"]) # Step environment for rollout length @@ -104,37 +105,12 @@ def _env_step( ) # Calculate advantage - params, opt_states, key, env_state, last_timestep = learner_state + params, opt_states, key, env_state, last_timestep, last_done = learner_state last_val = critic_apply_fn(params.critic_params, last_timestep.observation) - def _calculate_gae( - traj_batch: PPOTransition, last_val: chex.Array - ) -> Tuple[chex.Array, chex.Array]: - """Calculate the GAE.""" - - def _get_advantages(gae_and_next_value: Tuple, transition: PPOTransition) -> Tuple: - """Calculate the GAE for a single transition.""" - gae, next_value = gae_and_next_value - done, value, reward = ( - transition.done, - transition.value, - transition.reward, - ) - gamma = config.system.gamma - delta = reward + gamma * next_value * (1 - done) - value - gae = delta + gamma * config.system.gae_lambda * (1 - done) * gae - return (gae, value), gae - - _, advantages = jax.lax.scan( - _get_advantages, - (jnp.zeros_like(last_val), last_val), - traj_batch, - reverse=True, - unroll=16, - ) - return advantages, advantages + traj_batch.value - - advantages, targets = _calculate_gae(traj_batch, last_val) + advantages, targets = calculate_gae( + traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda + ) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: """Update the network for a single epoch.""" @@ -283,7 +259,7 @@ def _critic_loss_fn( ) params, opt_states, traj_batch, advantages, targets, key = update_state - learner_state = LearnerState(params, opt_states, key, env_state, last_timestep) + learner_state = LearnerState(params, opt_states, key, env_state, last_timestep, last_done) return learner_state, (episode_metrics, loss_info) def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: @@ -400,9 +376,13 @@ def learner_setup( params = restored_params # Define params to be replicated across devices and batches. + dones = jnp.zeros( + (config.arch.num_envs, config.system.num_agents), + dtype=bool, + ) key, step_keys = jax.random.split(key) opt_states = OptStates(actor_opt_state, critic_opt_state) - replicate_learner = (params, opt_states, step_keys) + replicate_learner = (params, opt_states, step_keys, dones) # Duplicate learner for update_batch_size. broadcast = lambda x: jnp.broadcast_to(x, (config.system.update_batch_size, *x.shape)) @@ -412,8 +392,8 @@ def learner_setup( replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices()) # Initialise learner state. - params, opt_states, step_keys = replicate_learner - init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps) + params, opt_states, step_keys, dones = replicate_learner + init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps, dones) return learn, actor_network, init_learner_state diff --git a/mava/systems/ppo/anakin/ff_mappo.py b/mava/systems/ppo/anakin/ff_mappo.py index f511f9de7..552789cc7 100644 --- a/mava/systems/ppo/anakin/ff_mappo.py +++ b/mava/systems/ppo/anakin/ff_mappo.py @@ -37,6 +37,7 @@ from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger +from mava.utils.multistep import calculate_gae from mava.utils.network_utils import get_action_head from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -77,7 +78,7 @@ def _env_step( learner_state: LearnerState, _: Any ) -> Tuple[LearnerState, Tuple[PPOTransition, Metrics]]: """Step the environment.""" - params, opt_states, key, env_state, last_timestep = learner_state + params, opt_states, key, env_state, last_timestep, last_done = learner_state # Select action key, policy_key = jax.random.split(key) @@ -92,9 +93,9 @@ def _env_step( done = timestep.last().repeat(env.num_agents).reshape(config.arch.num_envs, -1) transition = PPOTransition( - done, action, value, timestep.reward, log_prob, last_timestep.observation + last_done, action, value, timestep.reward, log_prob, last_timestep.observation ) - learner_state = LearnerState(params, opt_states, key, env_state, timestep) + learner_state = LearnerState(params, opt_states, key, env_state, timestep, done) return learner_state, (transition, timestep.extras["episode_metrics"]) # Step environment for rollout length @@ -103,37 +104,12 @@ def _env_step( ) # Calculate advantage - params, opt_states, key, env_state, last_timestep = learner_state + params, opt_states, key, env_state, last_timestep, last_done = learner_state last_val = critic_apply_fn(params.critic_params, last_timestep.observation) - def _calculate_gae( - traj_batch: PPOTransition, last_val: chex.Array - ) -> Tuple[chex.Array, chex.Array]: - """Calculate the GAE.""" - - def _get_advantages(gae_and_next_value: Tuple, transition: PPOTransition) -> Tuple: - """Calculate the GAE for a single transition.""" - gae, next_value = gae_and_next_value - done, value, reward = ( - transition.done, - transition.value, - transition.reward, - ) - gamma = config.system.gamma - delta = reward + gamma * next_value * (1 - done) - value - gae = delta + gamma * config.system.gae_lambda * (1 - done) * gae - return (gae, value), gae - - _, advantages = jax.lax.scan( - _get_advantages, - (jnp.zeros_like(last_val), last_val), - traj_batch, - reverse=True, - unroll=16, - ) - return advantages, advantages + traj_batch.value - - advantages, targets = _calculate_gae(traj_batch, last_val) + advantages, targets = calculate_gae( + traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda + ) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: """Update the network for a single epoch.""" @@ -285,7 +261,7 @@ def _critic_loss_fn( ) params, opt_states, traj_batch, advantages, targets, key = update_state - learner_state = LearnerState(params, opt_states, key, env_state, last_timestep) + learner_state = LearnerState(params, opt_states, key, env_state, last_timestep, last_done) return learner_state, (episode_metrics, loss_info) def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: @@ -402,9 +378,13 @@ def learner_setup( params = restored_params # Define params to be replicated across devices and batches. + dones = jnp.zeros( + (config.arch.num_envs, config.system.num_agents), + dtype=bool, + ) key, step_keys = jax.random.split(key) opt_states = OptStates(actor_opt_state, critic_opt_state) - replicate_learner = (params, opt_states, step_keys) + replicate_learner = (params, opt_states, step_keys, dones) # Duplicate learner for update_batch_size. broadcast = lambda x: jnp.broadcast_to(x, (config.system.update_batch_size, *x.shape)) @@ -414,8 +394,8 @@ def learner_setup( replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices()) # Initialise learner state. - params, opt_states, step_keys = replicate_learner - init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps) + params, opt_states, step_keys, dones = replicate_learner + init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps, dones) return learn, actor_network, init_learner_state diff --git a/mava/systems/ppo/anakin/rec_ippo.py b/mava/systems/ppo/anakin/rec_ippo.py index 90310fa9e..3bbefeddf 100644 --- a/mava/systems/ppo/anakin/rec_ippo.py +++ b/mava/systems/ppo/anakin/rec_ippo.py @@ -52,6 +52,7 @@ from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger +from mava.utils.multistep import calculate_gae from mava.utils.network_utils import get_action_head from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -158,29 +159,9 @@ def _env_step( # Squeeze out the batch dimension and mask out the value of terminal states. last_val = last_val.squeeze(0) - def _calculate_gae( - traj_batch: RNNPPOTransition, last_val: chex.Array, last_done: chex.Array - ) -> Tuple[chex.Array, chex.Array]: - def _get_advantages( - carry: Tuple[chex.Array, chex.Array, chex.Array], transition: RNNPPOTransition - ) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]: - gae, next_value, next_done = carry - done, value, reward = transition.done, transition.value, transition.reward - gamma = config.system.gamma - delta = reward + gamma * next_value * (1 - next_done) - value - gae = delta + gamma * config.system.gae_lambda * (1 - next_done) * gae - return (gae, value, done), gae - - _, advantages = jax.lax.scan( - _get_advantages, - (jnp.zeros_like(last_val), last_val, last_done), - traj_batch, - reverse=True, - unroll=16, - ) - return advantages, advantages + traj_batch.value - - advantages, targets = _calculate_gae(traj_batch, last_val, last_done) + advantages, targets = calculate_gae( + traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda + ) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: """Update the network for a single epoch.""" diff --git a/mava/systems/ppo/anakin/rec_mappo.py b/mava/systems/ppo/anakin/rec_mappo.py index e7466142a..b0e34a886 100644 --- a/mava/systems/ppo/anakin/rec_mappo.py +++ b/mava/systems/ppo/anakin/rec_mappo.py @@ -52,6 +52,7 @@ from mava.utils.config import check_total_timesteps from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims from mava.utils.logger import LogEvent, MavaLogger +from mava.utils.multistep import calculate_gae from mava.utils.network_utils import get_action_head from mava.utils.training import make_learning_rate from mava.wrappers.episode_metrics import get_final_step_metrics @@ -160,29 +161,9 @@ def _env_step( # Squeeze out the batch dimension and mask out the value of terminal states. last_val = last_val.squeeze(0) - def _calculate_gae( - traj_batch: RNNPPOTransition, last_val: chex.Array, last_done: chex.Array - ) -> Tuple[chex.Array, chex.Array]: - def _get_advantages( - carry: Tuple[chex.Array, chex.Array, chex.Array], transition: RNNPPOTransition - ) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]: - gae, next_value, next_done = carry - done, value, reward = transition.done, transition.value, transition.reward - gamma = config.system.gamma - delta = reward + gamma * next_value * (1 - next_done) - value - gae = delta + gamma * config.system.gae_lambda * (1 - next_done) * gae - return (gae, value, done), gae - - _, advantages = jax.lax.scan( - _get_advantages, - (jnp.zeros_like(last_val), last_val, last_done), - traj_batch, - reverse=True, - unroll=16, - ) - return advantages, advantages + traj_batch.value - - advantages, targets = _calculate_gae(traj_batch, last_val, last_done) + advantages, targets = calculate_gae( + traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda + ) def _update_epoch(update_state: Tuple, _: Any) -> Tuple: """Update the network for a single epoch.""" diff --git a/mava/systems/ppo/sebulba/ff_ippo.py b/mava/systems/ppo/sebulba/ff_ippo.py index 996821cd0..c79b654d2 100644 --- a/mava/systems/ppo/sebulba/ff_ippo.py +++ b/mava/systems/ppo/sebulba/ff_ippo.py @@ -41,7 +41,7 @@ from mava.evaluator import make_ff_eval_act_fn from mava.networks import FeedForwardActor as Actor from mava.networks import FeedForwardValueNet as Critic -from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition +from mava.systems.ppo.types import OptStates, Params, PPOTransition, SebulbaLearnerState from mava.types import ( ActorApply, CriticApply, @@ -163,7 +163,7 @@ def get_learner_step_fn( apply_fns: Tuple[ActorApply, CriticApply], update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], config: DictConfig, -) -> SebulbaLearnerFn[LearnerState, PPOTransition]: +) -> SebulbaLearnerFn[SebulbaLearnerState, PPOTransition]: """Get the learner function.""" num_envs = config.arch.num_envs @@ -174,16 +174,16 @@ def get_learner_step_fn( actor_update_fn, critic_update_fn = update_fns def _update_step( - learner_state: LearnerState, + learner_state: SebulbaLearnerState, traj_batch: PPOTransition, - ) -> Tuple[LearnerState, Metrics]: + ) -> Tuple[SebulbaLearnerState, Metrics]: """A single update of the network. This function calculates advantages and targets based on the trajectories from the actor and updates the actor and critic networks based on the losses. Args: - learner_state (LearnerState): contains all the items needed for learning. + learner_state (SebulbaLearnerState): contains all the items needed for learning. traj_batch (PPOTransition): the batch of data to learn with. """ @@ -359,12 +359,12 @@ def _critic_loss_fn( ) params, opt_states, traj_batch, advantages, targets, key = update_state - learner_state = LearnerState(params, opt_states, key, None, learner_state.timestep) + learner_state = SebulbaLearnerState(params, opt_states, key, None, learner_state.timestep) return learner_state, loss_info def learner_fn( - learner_state: LearnerState, traj_batch: PPOTransition - ) -> Tuple[LearnerState, Metrics]: + learner_state: SebulbaLearnerState, traj_batch: PPOTransition + ) -> Tuple[SebulbaLearnerState, Metrics]: """Learner function. This function represents the learner, it updates the network parameters @@ -390,8 +390,8 @@ def learner_fn( def learner_thread( - learn_fn: SebulbaLearnerFn[LearnerState, PPOTransition], - learner_state: LearnerState, + learn_fn: SebulbaLearnerFn[SebulbaLearnerState, PPOTransition], + learner_state: SebulbaLearnerState, config: DictConfig, eval_queue: Queue, pipeline: Pipeline, @@ -438,9 +438,9 @@ def learner_thread( def learner_setup( key: chex.PRNGKey, config: DictConfig, learner_devices: List ) -> Tuple[ - SebulbaLearnerFn[LearnerState, PPOTransition], + SebulbaLearnerFn[SebulbaLearnerState, PPOTransition], Tuple[ActorApply, CriticApply], - LearnerState, + SebulbaLearnerState, Sharding, ]: """Initialise learner_fn, network and learner state.""" @@ -504,7 +504,7 @@ def learner_setup( update_fns = (actor_optim.update, critic_optim.update) # defines how the learner state is sharded: params, opt and key = sharded, timestep = sharded - learn_state_spec = LearnerState(model_spec, model_spec, data_spec, None, data_spec) + learn_state_spec = SebulbaLearnerState(model_spec, model_spec, data_spec, None, data_spec) learn = get_learner_step_fn(apply_fns, update_fns, config) learn = jax.jit( shard_map( @@ -537,7 +537,7 @@ def learner_setup( ) # Initialise learner state. - init_learner_state = LearnerState(params, opt_states, step_keys, None, None) # type: ignore + init_learner_state = SebulbaLearnerState(params, opt_states, step_keys, None, None) # type: ignore env.close() return learn, apply_fns, init_learner_state, learner_sharding # type: ignore diff --git a/mava/systems/ppo/types.py b/mava/systems/ppo/types.py index 9e56e17f8..a94f6b7a3 100644 --- a/mava/systems/ppo/types.py +++ b/mava/systems/ppo/types.py @@ -51,6 +51,17 @@ class LearnerState(NamedTuple): key: chex.PRNGKey env_state: State timestep: TimeStep + dones: Done + + +class SebulbaLearnerState(NamedTuple): + """State of the learner.""" + + params: Params + opt_states: OptStates + key: chex.PRNGKey + env_state: State + timestep: TimeStep class RNNLearnerState(NamedTuple): diff --git a/mava/utils/multistep.py b/mava/utils/multistep.py new file mode 100644 index 000000000..955f0ccd0 --- /dev/null +++ b/mava/utils/multistep.py @@ -0,0 +1,68 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union + +import chex +import jax +import jax.numpy as jnp + +from mava.systems.ppo.types import PPOTransition, RNNPPOTransition + + +def calculate_gae( + traj_batch: Union[PPOTransition, RNNPPOTransition], + last_val: chex.Array, + last_done: chex.Array, + gamma: float, + gae_lambda: float, + unroll: int = 16, +) -> Tuple[chex.Array, chex.Array]: + """Computes truncated generalized advantage estimates. + + The advantages are computed in a backwards fashion according to the equation: + Âₜ = δₜ + (γλ) * δₜ₊₁ + ... + ... + (γλ)ᵏ⁻ᵗ⁺¹ * δₖ₋₁ + where δₜ = rₜ₊₁ + γₜ₊₁ * v(sₜ₊₁) - v(sₜ). + See Proximal Policy Optimization Algorithms, Schulman et al.: + https://arxiv.org/abs/1707.06347 + + Args: + traj_batch (B, T, N, ...): a batch of trajectories. + last_val (B, N): value of the final timestep. + last_done (B, N): whether the last timestep was a terminated or truncated. + gamma (float): discount factor. + gae_lambda (float): GAE mixing parameter. + unroll (int): how much XLA should unroll the scan used to calculate GAE. + + Returns Tuple[(B, T, N), (B, T, N)]: advantages and target values. + """ + + def _get_advantages( + carry: Tuple[chex.Array, chex.Array, chex.Array], transition: RNNPPOTransition + ) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]: + gae, next_value, next_done = carry + done, value, reward = transition.done, transition.value, transition.reward + + delta = reward + gamma * next_value * (1 - next_done) - value + gae = delta + gamma * gae_lambda * (1 - next_done) * gae + return (gae, value, done), gae + + _, advantages = jax.lax.scan( + _get_advantages, + (jnp.zeros_like(last_val), last_val, last_done), + traj_batch, + reverse=True, + unroll=unroll, + ) + return advantages, advantages + traj_batch.value