Skip to content

Commit

Permalink
Feat: unified gae (#1129)
Browse files Browse the repository at this point in the history
* unified gae

* fixes

* pre-commit

* requested changes

* Replace A with N (number of agents)

Co-authored-by: Sasha Abramowitz <[email protected]>

* same

Co-authored-by: Sasha Abramowitz <[email protected]>

* cleaning

* undo mistake

* more cleaning

* merge conflicts

* uncapitalize advantage comments

---------

Co-authored-by: Sasha Abramowitz <[email protected]>
Co-authored-by: Omayma Mahjoub <[email protected]>
  • Loading branch information
3 people authored Jan 7, 2025
1 parent 4076c7c commit c92fe18
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 132 deletions.
52 changes: 16 additions & 36 deletions mava/systems/ppo/anakin/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.multistep import calculate_gae
from mava.utils.network_utils import get_action_head
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics
Expand Down Expand Up @@ -78,7 +79,7 @@ def _env_step(
learner_state: LearnerState, _: Any
) -> Tuple[LearnerState, Tuple[PPOTransition, Metrics]]:
"""Step the environment."""
params, opt_states, key, env_state, last_timestep = learner_state
params, opt_states, key, env_state, last_timestep, last_done = learner_state

# Select action
key, policy_key = jax.random.split(key)
Expand All @@ -93,9 +94,9 @@ def _env_step(

done = timestep.last().repeat(env.num_agents).reshape(config.arch.num_envs, -1)
transition = PPOTransition(
done, action, value, timestep.reward, log_prob, last_timestep.observation
last_done, action, value, timestep.reward, log_prob, last_timestep.observation
)
learner_state = LearnerState(params, opt_states, key, env_state, timestep)
learner_state = LearnerState(params, opt_states, key, env_state, timestep, done)
return learner_state, (transition, timestep.extras["episode_metrics"])

# Step environment for rollout length
Expand All @@ -104,37 +105,12 @@ def _env_step(
)

# Calculate advantage
params, opt_states, key, env_state, last_timestep = learner_state
params, opt_states, key, env_state, last_timestep, last_done = learner_state
last_val = critic_apply_fn(params.critic_params, last_timestep.observation)

def _calculate_gae(
traj_batch: PPOTransition, last_val: chex.Array
) -> Tuple[chex.Array, chex.Array]:
"""Calculate the GAE."""

def _get_advantages(gae_and_next_value: Tuple, transition: PPOTransition) -> Tuple:
"""Calculate the GAE for a single transition."""
gae, next_value = gae_and_next_value
done, value, reward = (
transition.done,
transition.value,
transition.reward,
)
gamma = config.system.gamma
delta = reward + gamma * next_value * (1 - done) - value
gae = delta + gamma * config.system.gae_lambda * (1 - done) * gae
return (gae, value), gae

_, advantages = jax.lax.scan(
_get_advantages,
(jnp.zeros_like(last_val), last_val),
traj_batch,
reverse=True,
unroll=16,
)
return advantages, advantages + traj_batch.value

advantages, targets = _calculate_gae(traj_batch, last_val)
advantages, targets = calculate_gae(
traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda
)

def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
"""Update the network for a single epoch."""
Expand Down Expand Up @@ -283,7 +259,7 @@ def _critic_loss_fn(
)

params, opt_states, traj_batch, advantages, targets, key = update_state
learner_state = LearnerState(params, opt_states, key, env_state, last_timestep)
learner_state = LearnerState(params, opt_states, key, env_state, last_timestep, last_done)
return learner_state, (episode_metrics, loss_info)

def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]:
Expand Down Expand Up @@ -400,9 +376,13 @@ def learner_setup(
params = restored_params

# Define params to be replicated across devices and batches.
dones = jnp.zeros(
(config.arch.num_envs, config.system.num_agents),
dtype=bool,
)
key, step_keys = jax.random.split(key)
opt_states = OptStates(actor_opt_state, critic_opt_state)
replicate_learner = (params, opt_states, step_keys)
replicate_learner = (params, opt_states, step_keys, dones)

# Duplicate learner for update_batch_size.
broadcast = lambda x: jnp.broadcast_to(x, (config.system.update_batch_size, *x.shape))
Expand All @@ -412,8 +392,8 @@ def learner_setup(
replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices())

# Initialise learner state.
params, opt_states, step_keys = replicate_learner
init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps)
params, opt_states, step_keys, dones = replicate_learner
init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps, dones)

return learn, actor_network, init_learner_state

Expand Down
52 changes: 16 additions & 36 deletions mava/systems/ppo/anakin/ff_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import merge_leading_dims, unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.multistep import calculate_gae
from mava.utils.network_utils import get_action_head
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics
Expand Down Expand Up @@ -77,7 +78,7 @@ def _env_step(
learner_state: LearnerState, _: Any
) -> Tuple[LearnerState, Tuple[PPOTransition, Metrics]]:
"""Step the environment."""
params, opt_states, key, env_state, last_timestep = learner_state
params, opt_states, key, env_state, last_timestep, last_done = learner_state

# Select action
key, policy_key = jax.random.split(key)
Expand All @@ -92,9 +93,9 @@ def _env_step(
done = timestep.last().repeat(env.num_agents).reshape(config.arch.num_envs, -1)

transition = PPOTransition(
done, action, value, timestep.reward, log_prob, last_timestep.observation
last_done, action, value, timestep.reward, log_prob, last_timestep.observation
)
learner_state = LearnerState(params, opt_states, key, env_state, timestep)
learner_state = LearnerState(params, opt_states, key, env_state, timestep, done)
return learner_state, (transition, timestep.extras["episode_metrics"])

# Step environment for rollout length
Expand All @@ -103,37 +104,12 @@ def _env_step(
)

# Calculate advantage
params, opt_states, key, env_state, last_timestep = learner_state
params, opt_states, key, env_state, last_timestep, last_done = learner_state
last_val = critic_apply_fn(params.critic_params, last_timestep.observation)

def _calculate_gae(
traj_batch: PPOTransition, last_val: chex.Array
) -> Tuple[chex.Array, chex.Array]:
"""Calculate the GAE."""

def _get_advantages(gae_and_next_value: Tuple, transition: PPOTransition) -> Tuple:
"""Calculate the GAE for a single transition."""
gae, next_value = gae_and_next_value
done, value, reward = (
transition.done,
transition.value,
transition.reward,
)
gamma = config.system.gamma
delta = reward + gamma * next_value * (1 - done) - value
gae = delta + gamma * config.system.gae_lambda * (1 - done) * gae
return (gae, value), gae

_, advantages = jax.lax.scan(
_get_advantages,
(jnp.zeros_like(last_val), last_val),
traj_batch,
reverse=True,
unroll=16,
)
return advantages, advantages + traj_batch.value

advantages, targets = _calculate_gae(traj_batch, last_val)
advantages, targets = calculate_gae(
traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda
)

def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
"""Update the network for a single epoch."""
Expand Down Expand Up @@ -285,7 +261,7 @@ def _critic_loss_fn(
)

params, opt_states, traj_batch, advantages, targets, key = update_state
learner_state = LearnerState(params, opt_states, key, env_state, last_timestep)
learner_state = LearnerState(params, opt_states, key, env_state, last_timestep, last_done)
return learner_state, (episode_metrics, loss_info)

def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]:
Expand Down Expand Up @@ -402,9 +378,13 @@ def learner_setup(
params = restored_params

# Define params to be replicated across devices and batches.
dones = jnp.zeros(
(config.arch.num_envs, config.system.num_agents),
dtype=bool,
)
key, step_keys = jax.random.split(key)
opt_states = OptStates(actor_opt_state, critic_opt_state)
replicate_learner = (params, opt_states, step_keys)
replicate_learner = (params, opt_states, step_keys, dones)

# Duplicate learner for update_batch_size.
broadcast = lambda x: jnp.broadcast_to(x, (config.system.update_batch_size, *x.shape))
Expand All @@ -414,8 +394,8 @@ def learner_setup(
replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices())

# Initialise learner state.
params, opt_states, step_keys = replicate_learner
init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps)
params, opt_states, step_keys, dones = replicate_learner
init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps, dones)

return learn, actor_network, init_learner_state

Expand Down
27 changes: 4 additions & 23 deletions mava/systems/ppo/anakin/rec_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.multistep import calculate_gae
from mava.utils.network_utils import get_action_head
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics
Expand Down Expand Up @@ -158,29 +159,9 @@ def _env_step(
# Squeeze out the batch dimension and mask out the value of terminal states.
last_val = last_val.squeeze(0)

def _calculate_gae(
traj_batch: RNNPPOTransition, last_val: chex.Array, last_done: chex.Array
) -> Tuple[chex.Array, chex.Array]:
def _get_advantages(
carry: Tuple[chex.Array, chex.Array, chex.Array], transition: RNNPPOTransition
) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]:
gae, next_value, next_done = carry
done, value, reward = transition.done, transition.value, transition.reward
gamma = config.system.gamma
delta = reward + gamma * next_value * (1 - next_done) - value
gae = delta + gamma * config.system.gae_lambda * (1 - next_done) * gae
return (gae, value, done), gae

_, advantages = jax.lax.scan(
_get_advantages,
(jnp.zeros_like(last_val), last_val, last_done),
traj_batch,
reverse=True,
unroll=16,
)
return advantages, advantages + traj_batch.value

advantages, targets = _calculate_gae(traj_batch, last_val, last_done)
advantages, targets = calculate_gae(
traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda
)

def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
"""Update the network for a single epoch."""
Expand Down
27 changes: 4 additions & 23 deletions mava/systems/ppo/anakin/rec_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from mava.utils.config import check_total_timesteps
from mava.utils.jax_utils import unreplicate_batch_dim, unreplicate_n_dims
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.multistep import calculate_gae
from mava.utils.network_utils import get_action_head
from mava.utils.training import make_learning_rate
from mava.wrappers.episode_metrics import get_final_step_metrics
Expand Down Expand Up @@ -160,29 +161,9 @@ def _env_step(
# Squeeze out the batch dimension and mask out the value of terminal states.
last_val = last_val.squeeze(0)

def _calculate_gae(
traj_batch: RNNPPOTransition, last_val: chex.Array, last_done: chex.Array
) -> Tuple[chex.Array, chex.Array]:
def _get_advantages(
carry: Tuple[chex.Array, chex.Array, chex.Array], transition: RNNPPOTransition
) -> Tuple[Tuple[chex.Array, chex.Array, chex.Array], chex.Array]:
gae, next_value, next_done = carry
done, value, reward = transition.done, transition.value, transition.reward
gamma = config.system.gamma
delta = reward + gamma * next_value * (1 - next_done) - value
gae = delta + gamma * config.system.gae_lambda * (1 - next_done) * gae
return (gae, value, done), gae

_, advantages = jax.lax.scan(
_get_advantages,
(jnp.zeros_like(last_val), last_val, last_done),
traj_batch,
reverse=True,
unroll=16,
)
return advantages, advantages + traj_batch.value

advantages, targets = _calculate_gae(traj_batch, last_val, last_done)
advantages, targets = calculate_gae(
traj_batch, last_val, last_done, config.system.gamma, config.system.gae_lambda
)

def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
"""Update the network for a single epoch."""
Expand Down
28 changes: 14 additions & 14 deletions mava/systems/ppo/sebulba/ff_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from mava.evaluator import make_ff_eval_act_fn
from mava.networks import FeedForwardActor as Actor
from mava.networks import FeedForwardValueNet as Critic
from mava.systems.ppo.types import LearnerState, OptStates, Params, PPOTransition
from mava.systems.ppo.types import OptStates, Params, PPOTransition, SebulbaLearnerState
from mava.types import (
ActorApply,
CriticApply,
Expand Down Expand Up @@ -163,7 +163,7 @@ def get_learner_step_fn(
apply_fns: Tuple[ActorApply, CriticApply],
update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn],
config: DictConfig,
) -> SebulbaLearnerFn[LearnerState, PPOTransition]:
) -> SebulbaLearnerFn[SebulbaLearnerState, PPOTransition]:
"""Get the learner function."""

num_envs = config.arch.num_envs
Expand All @@ -174,16 +174,16 @@ def get_learner_step_fn(
actor_update_fn, critic_update_fn = update_fns

def _update_step(
learner_state: LearnerState,
learner_state: SebulbaLearnerState,
traj_batch: PPOTransition,
) -> Tuple[LearnerState, Metrics]:
) -> Tuple[SebulbaLearnerState, Metrics]:
"""A single update of the network.
This function calculates advantages and targets based on the trajectories
from the actor and updates the actor and critic networks based on the losses.
Args:
learner_state (LearnerState): contains all the items needed for learning.
learner_state (SebulbaLearnerState): contains all the items needed for learning.
traj_batch (PPOTransition): the batch of data to learn with.
"""

Expand Down Expand Up @@ -359,12 +359,12 @@ def _critic_loss_fn(
)

params, opt_states, traj_batch, advantages, targets, key = update_state
learner_state = LearnerState(params, opt_states, key, None, learner_state.timestep)
learner_state = SebulbaLearnerState(params, opt_states, key, None, learner_state.timestep)
return learner_state, loss_info

def learner_fn(
learner_state: LearnerState, traj_batch: PPOTransition
) -> Tuple[LearnerState, Metrics]:
learner_state: SebulbaLearnerState, traj_batch: PPOTransition
) -> Tuple[SebulbaLearnerState, Metrics]:
"""Learner function.
This function represents the learner, it updates the network parameters
Expand All @@ -390,8 +390,8 @@ def learner_fn(


def learner_thread(
learn_fn: SebulbaLearnerFn[LearnerState, PPOTransition],
learner_state: LearnerState,
learn_fn: SebulbaLearnerFn[SebulbaLearnerState, PPOTransition],
learner_state: SebulbaLearnerState,
config: DictConfig,
eval_queue: Queue,
pipeline: Pipeline,
Expand Down Expand Up @@ -438,9 +438,9 @@ def learner_thread(
def learner_setup(
key: chex.PRNGKey, config: DictConfig, learner_devices: List
) -> Tuple[
SebulbaLearnerFn[LearnerState, PPOTransition],
SebulbaLearnerFn[SebulbaLearnerState, PPOTransition],
Tuple[ActorApply, CriticApply],
LearnerState,
SebulbaLearnerState,
Sharding,
]:
"""Initialise learner_fn, network and learner state."""
Expand Down Expand Up @@ -504,7 +504,7 @@ def learner_setup(
update_fns = (actor_optim.update, critic_optim.update)

# defines how the learner state is sharded: params, opt and key = sharded, timestep = sharded
learn_state_spec = LearnerState(model_spec, model_spec, data_spec, None, data_spec)
learn_state_spec = SebulbaLearnerState(model_spec, model_spec, data_spec, None, data_spec)
learn = get_learner_step_fn(apply_fns, update_fns, config)
learn = jax.jit(
shard_map(
Expand Down Expand Up @@ -537,7 +537,7 @@ def learner_setup(
)

# Initialise learner state.
init_learner_state = LearnerState(params, opt_states, step_keys, None, None) # type: ignore
init_learner_state = SebulbaLearnerState(params, opt_states, step_keys, None, None) # type: ignore
env.close()

return learn, apply_fns, init_learner_state, learner_sharding # type: ignore
Expand Down
Loading

0 comments on commit c92fe18

Please sign in to comment.