-
Notifications
You must be signed in to change notification settings - Fork 95
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1124 from instadeepai/chore/ppo-system-cleanup
- chore: remove jax.tree_map - Merge branch develop into chore/ppo-system-cleanup - fix: update ret_output shape (#1147) - Merge branch develop into chore/ppo-system-cleanup - chore: minor comment changes - fix: integration test workflow (#1145) - Merge branch develop into chore/ppo-system-cleanup - Merge branch develop into chore/ppo-system-cleanup - chore: renaming loss_X -> X_loss - fix: updated sebulba - Merge branch develop into chore/ppo-system-cleanup - chore: remove expired link to anakin notebook - chore: add back advanced usage in examples - Merge branch develop into chore/ppo-system-cleanup - feat: sable cleanup - chore: merge dev - chore: unify value/critic loss naming - refactor: loss_actor -> actor_loss - Merge branch develop into chore/ppo-system-cleanup - chore: pre-commit - feat: mat system clean up - chore: remove advanced usage - Merge branch develop into chore/ppo-system-cleanup - feat: ff mappo system clean up - feat: ff ippo system clean up - feat: rec mappo system clean up - feat: rec ippo system clean up Co-authored-by: Louay-Ben-nessir <[email protected]> Co-authored-by: Arnol Fokam <[email protected]> Co-authored-by: Ruan de Kock <[email protected]>
- Loading branch information
1 parent
cd928d5
commit 1f32a15
Showing
19 changed files
with
1,129 additions
and
563 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Advanced Mava usage 👽 | ||
## Data recording from a PPO system 🔴 | ||
We include here an example of an advanced use case with Mava: recording experience data from a PPO system, which can then be used for offline MARL—e.g. using the [OG-MARL](https://github.com/instadeepai/og-marl) framework. This functionality is demonstrated in [ff_ippo_store_experience.py](./ff_ippo_store_experience.py), and uses [Flashbax](https://github.com/instadeepai/flashbax)'s `Vault` feature. Vault enables efficient storage of experience data recorded in JAX-based systems, and integrates tightly with Mava and the rest of InstaDeep's MARL ecosystem. | ||
|
||
Firstly, a vault must be created using the structure of an experience buffer. Here, we create a dummy structure of the data we want to store: | ||
```py | ||
# Transition structure | ||
dummy_flashbax_transition = { | ||
"done": jnp.zeros((config.system.num_agents,), dtype=bool), | ||
"action": jnp.zeros((config.system.num_agents,), dtype=jnp.int32), | ||
"reward": jnp.zeros((config.system.num_agents,), dtype=jnp.float32), | ||
"observation": jnp.zeros( | ||
( | ||
config.system.num_agents, | ||
env.observation_spec().agents_view.shape[1], | ||
), | ||
dtype=jnp.float32, | ||
), | ||
"legal_action_mask": jnp.zeros( | ||
( | ||
config.system.num_agents, | ||
config.system.num_actions, | ||
), | ||
dtype=bool, | ||
), | ||
} | ||
|
||
# Flashbax buffer | ||
buffer = fbx.make_flat_buffer( | ||
max_length=int(5e5), | ||
min_length=int(1), | ||
sample_batch_size=1, | ||
add_sequences=True, | ||
add_batch_size=( | ||
n_devices | ||
* config["system"]["num_updates_per_eval"] | ||
* config["system"]["update_batch_size"] | ||
* config["arch"]["num_envs"] | ||
), | ||
) | ||
|
||
# Buffer state | ||
buffer_state = buffer.init( | ||
dummy_flashbax_transition, | ||
) | ||
``` | ||
|
||
We can now create a `Vault` for our data: | ||
```py | ||
v = Vault( | ||
vault_name="our_system_name", | ||
experience_structure=buffer_state.experience, | ||
vault_uid="unique_vault_id", | ||
) | ||
``` | ||
|
||
We modify our `learn` function to additionally record our agents' trajectories, such that we can access experience data: | ||
```py | ||
learner_output, experience_to_store = learn(learner_state) | ||
``` | ||
|
||
Because of the Anakin architecture set-up, our trajectories are stored in the incorrect dimensions for our use case. Hence, we transform the data, and then store it in a flashbax buffer: | ||
```py | ||
# Shape legend: | ||
# D: Number of devices | ||
# NU: Number of updates per evaluation | ||
# UB: Update batch size | ||
# T: Time steps per rollout | ||
# NE: Number of environments | ||
|
||
@jax.jit | ||
def _reshape_experience(experience: Dict[str, chex.Array]) -> Dict[str, chex.Array]: | ||
"""Reshape experience to match buffer.""" | ||
|
||
# Swap the T and NE axes (D, NU, UB, T, NE, ...) -> (D, NU, UB, NE, T, ...) | ||
experience: Dict[str, chex.Array] = jax.tree.map(lambda x: x.swapaxes(3, 4), experience) | ||
# Merge 4 leading dimensions into 1. (D, NU, UB, NE, T ...) -> (D * NU * UB * NE, T, ...) | ||
experience: Dict[str, chex.Array] = jax.tree.map( | ||
lambda x: x.reshape(-1, *x.shape[4:]), experience | ||
) | ||
return experience | ||
|
||
flashbax_transition = _reshape_experience( | ||
{ | ||
# (D, NU, UB, T, NE, ...) | ||
"done": experience_to_store.done, | ||
"action": experience_to_store.action, | ||
"reward": experience_to_store.reward, | ||
"observation": experience_to_store.obs.agents_view, | ||
"legal_action_mask": experience_to_store.obs.action_mask, | ||
} | ||
) | ||
# Add to fbx buffer | ||
buffer_state = buffer.add(buffer_state, flashbax_transition) | ||
``` | ||
|
||
Then, periodically, we can write this buffer state into the vault, which is stored on disk: | ||
```py | ||
v.write(buffer_state) | ||
``` | ||
|
||
If we now want to use the recorded data, we can easily restore the vault in another context: | ||
```py | ||
v = Vault( | ||
vault_name="our_system_name", | ||
vault_uid="unique_vault_id", | ||
) | ||
buffer_state = v.read() | ||
``` | ||
|
||
For a demonstration of offline MARL training, see some examples [here](https://github.com/instadeepai/og-marl/tree/feat/vault). | ||
|
||
--- | ||
⚠️ Note: this functionality is highly experimental! The current API likely to change. |
Oops, something went wrong.