Skip to content

Commit

Permalink
Merge pull request #1124 from instadeepai/chore/ppo-system-cleanup
Browse files Browse the repository at this point in the history
- chore: remove jax.tree_map
- Merge branch develop into chore/ppo-system-cleanup
- fix: update ret_output shape (#1147)
- Merge branch develop into chore/ppo-system-cleanup
- chore: minor comment changes
- fix: integration test workflow (#1145)
- Merge branch develop into chore/ppo-system-cleanup
- Merge branch develop into chore/ppo-system-cleanup
- chore: renaming loss_X -> X_loss
- fix: updated sebulba
- Merge branch develop into chore/ppo-system-cleanup
- chore: remove expired link to anakin notebook
- chore: add back advanced usage in examples
- Merge branch develop into chore/ppo-system-cleanup
- feat: sable cleanup
- chore: merge dev
- chore: unify value/critic loss naming
- refactor: loss_actor -> actor_loss
- Merge branch develop into chore/ppo-system-cleanup
- chore: pre-commit
- feat: mat system clean up
- chore: remove advanced usage
- Merge branch develop into chore/ppo-system-cleanup
- feat: ff mappo system clean up
- feat: ff ippo system clean up
- feat: rec mappo system clean up
- feat: rec ippo system clean up

Co-authored-by: Louay-Ben-nessir <[email protected]>
Co-authored-by: Arnol Fokam <[email protected]>
Co-authored-by: Ruan de Kock <[email protected]>
  • Loading branch information
4 people committed Dec 3, 2024
1 parent cd928d5 commit 1f32a15
Show file tree
Hide file tree
Showing 19 changed files with 1,129 additions and 563 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/integration_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ jobs:
runs-on: ubuntu-latest
timeout-minutes: 20

strategy:
matrix:
python-version: ["3.12", "3.11"]

steps:
- name: Checkout mava
uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ Additionally, we also have a [Quickstart notebook][quickstart] that can be used

## Advanced Usage 👽

Mava can be used in a wide array of advanced systems. As an example, we demonstrate recording experience data from one of our PPO systems into a [Flashbax](https://github.com/instadeepai/flashbax) `Vault`. This vault can then easily be integrated into offline MARL systems, such as those found in [OG-MARL](https://github.com/instadeepai/og-marl). See the [Advanced README](./mava/advanced_usage/) for more information.
Mava can be used in a wide array of advanced systems. As an example, we demonstrate recording experience data from one of our PPO systems into a [Flashbax](https://github.com/instadeepai/flashbax) `Vault`. This vault can then easily be integrated into offline MARL systems, such as those found in [OG-MARL](https://github.com/instadeepai/og-marl). See the [Advanced README](./examples/advanced_usage/README.md) for more information.

## Contributing 🤝

Expand Down
2 changes: 0 additions & 2 deletions examples/Quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,6 @@
" )\n",
"\n",
" # Compute the parallel mean (pmean) over the batch.\n",
" # This calculation is inspired by the Anakin architecture demo notebook.\n",
" # available at https://tinyurl.com/26tdzs5x\n",
" # This pmean could be a regular mean as the batch axis is on the same device.\n",
" actor_grads, actor_loss_info = jax.lax.pmean(\n",
" (actor_grads, actor_loss_info), axis_name=\"batch\"\n",
Expand Down
114 changes: 114 additions & 0 deletions examples/advanced_usage/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Advanced Mava usage 👽
## Data recording from a PPO system 🔴
We include here an example of an advanced use case with Mava: recording experience data from a PPO system, which can then be used for offline MARL—e.g. using the [OG-MARL](https://github.com/instadeepai/og-marl) framework. This functionality is demonstrated in [ff_ippo_store_experience.py](./ff_ippo_store_experience.py), and uses [Flashbax](https://github.com/instadeepai/flashbax)'s `Vault` feature. Vault enables efficient storage of experience data recorded in JAX-based systems, and integrates tightly with Mava and the rest of InstaDeep's MARL ecosystem.

Firstly, a vault must be created using the structure of an experience buffer. Here, we create a dummy structure of the data we want to store:
```py
# Transition structure
dummy_flashbax_transition = {
"done": jnp.zeros((config.system.num_agents,), dtype=bool),
"action": jnp.zeros((config.system.num_agents,), dtype=jnp.int32),
"reward": jnp.zeros((config.system.num_agents,), dtype=jnp.float32),
"observation": jnp.zeros(
(
config.system.num_agents,
env.observation_spec().agents_view.shape[1],
),
dtype=jnp.float32,
),
"legal_action_mask": jnp.zeros(
(
config.system.num_agents,
config.system.num_actions,
),
dtype=bool,
),
}

# Flashbax buffer
buffer = fbx.make_flat_buffer(
max_length=int(5e5),
min_length=int(1),
sample_batch_size=1,
add_sequences=True,
add_batch_size=(
n_devices
* config["system"]["num_updates_per_eval"]
* config["system"]["update_batch_size"]
* config["arch"]["num_envs"]
),
)

# Buffer state
buffer_state = buffer.init(
dummy_flashbax_transition,
)
```

We can now create a `Vault` for our data:
```py
v = Vault(
vault_name="our_system_name",
experience_structure=buffer_state.experience,
vault_uid="unique_vault_id",
)
```

We modify our `learn` function to additionally record our agents' trajectories, such that we can access experience data:
```py
learner_output, experience_to_store = learn(learner_state)
```

Because of the Anakin architecture set-up, our trajectories are stored in the incorrect dimensions for our use case. Hence, we transform the data, and then store it in a flashbax buffer:
```py
# Shape legend:
# D: Number of devices
# NU: Number of updates per evaluation
# UB: Update batch size
# T: Time steps per rollout
# NE: Number of environments

@jax.jit
def _reshape_experience(experience: Dict[str, chex.Array]) -> Dict[str, chex.Array]:
"""Reshape experience to match buffer."""

# Swap the T and NE axes (D, NU, UB, T, NE, ...) -> (D, NU, UB, NE, T, ...)
experience: Dict[str, chex.Array] = jax.tree.map(lambda x: x.swapaxes(3, 4), experience)
# Merge 4 leading dimensions into 1. (D, NU, UB, NE, T ...) -> (D * NU * UB * NE, T, ...)
experience: Dict[str, chex.Array] = jax.tree.map(
lambda x: x.reshape(-1, *x.shape[4:]), experience
)
return experience

flashbax_transition = _reshape_experience(
{
# (D, NU, UB, T, NE, ...)
"done": experience_to_store.done,
"action": experience_to_store.action,
"reward": experience_to_store.reward,
"observation": experience_to_store.obs.agents_view,
"legal_action_mask": experience_to_store.obs.action_mask,
}
)
# Add to fbx buffer
buffer_state = buffer.add(buffer_state, flashbax_transition)
```

Then, periodically, we can write this buffer state into the vault, which is stored on disk:
```py
v.write(buffer_state)
```

If we now want to use the recorded data, we can easily restore the vault in another context:
```py
v = Vault(
vault_name="our_system_name",
vault_uid="unique_vault_id",
)
buffer_state = v.read()
```

For a demonstration of offline MARL training, see some examples [here](https://github.com/instadeepai/og-marl/tree/feat/vault).

---
⚠️ Note: this functionality is highly experimental! The current API likely to change.
Loading

0 comments on commit 1f32a15

Please sign in to comment.