Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Search & Rescue Multi-Agent Environment #259

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
3597d9e
feat: Implement predator prey env (#1)
zombie-einstein Nov 4, 2024
c955320
Merge branch 'instadeepai:main' into main
zombie-einstein Nov 4, 2024
6b34657
Merge branch 'main' into main
sash-a Nov 4, 2024
988339b
fix: PR fixes (#2)
zombie-einstein Nov 5, 2024
a0fe7a5
Merge branch 'instadeepai:main' into main
zombie-einstein Nov 5, 2024
b4cce01
style: Run updated pre-commit
zombie-einstein Nov 6, 2024
cb6d88d
refactor: Consolidate predator prey type
zombie-einstein Nov 7, 2024
06de3a0
feat: Implement search and rescue (#3)
zombie-einstein Nov 11, 2024
34beab6
fix: PR fixes (#4)
zombie-einstein Nov 14, 2024
f5fa659
Merge branch 'instadeepai:main' into main
zombie-einstein Nov 15, 2024
072db18
refactor: PR fixes (#5)
zombie-einstein Nov 19, 2024
162a74d
feat: Allow variable environment dimensions (#6)
zombie-einstein Nov 19, 2024
4996869
Merge branch 'main' into main
zombie-einstein Nov 22, 2024
6322f61
fix: Locate targets in single pass (#8)
zombie-einstein Nov 23, 2024
4ba7688
Merge branch 'instadeepai:main' into main
zombie-einstein Nov 28, 2024
9a654b9
feat: training and customisable observations (#7)
zombie-einstein Dec 7, 2024
5021e20
feat: view all targets (#9)
zombie-einstein Dec 9, 2024
c5c7b85
Merge branch 'instadeepai:main' into main
zombie-einstein Dec 9, 2024
13ffb84
Merge branch 'instadeepai:main' into main
zombie-einstein Dec 9, 2024
9e8ac5c
feat: Scaled rewards and target velocities (#10)
zombie-einstein Dec 11, 2024
5c509c7
Pass shape information to timesteps (#11)
zombie-einstein Dec 11, 2024
8acf242
test: extend tests and docs (#12)
zombie-einstein Dec 11, 2024
1792aa6
fix: unpin jax requirement
zombie-einstein Dec 12, 2024
1e66e78
Include agent positions in observation (#13)
zombie-einstein Dec 12, 2024
407ff79
Upgrade Esquilax and remove unused random keys (#14)
zombie-einstein Dec 27, 2024
04fe710
docs: Review docstrings and docs (#15)
zombie-einstein Jan 12, 2025
ac3f811
fix: Remove enum annotations
zombie-einstein Jan 12, 2025
943a51b
refactor: address pr comments (#16)
zombie-einstein Jan 17, 2025
6a3fdb1
Parameter tweaks
zombie-einstein Jan 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/api/environments/search_and_rescue.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
::: jumanji.environments.swarms.search_and_rescue.env.SearchAndRescue
selection:
members:
- __init__
- reset
- step
- observation_spec
- action_spec
- reward_spec
- render
- animate
81 changes: 81 additions & 0 deletions docs/environments/search_and_rescue.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# 🚁 Search & Rescue

[//]: # (TODO: Add animated plot)

Multi-agent environment, modelling a group of agents searching a 2d environment
for multiple targets. Agents are individually rewarded for finding a target
that has not previously been detected.

Each agent visualises a local region around itself, represented as a simple segmented
view of locations of other agents and targets in the vicinity. The environment
is updated in the following sequence:

- The velocity of searching agents are updated, and consequently their positions.
- The positions of targets are updated.
- Targets within detection range, and within an agents view cone are marked as found.
- Agents are rewarded for locating previously unfound targets.
- Local views of the environment are generated for each searching agent.

The agents are allotted a fixed number of steps to locate the targets. The search
space is a uniform square space, wrapped at the boundaries.

Many aspects of the environment can be customised:

- Agent observations can be customised by implementing the `ObservationFn` interface.
- Rewards can be customised by implementing the `RewardFn` interface.
- Target dynamics can be customised to model various search scenarios by implementing the
`TargetDynamics` interface.

## Observations

- `searcher_views`: jax array (float) of shape `(num_searchers, channels, num_vision)`.
Each agent generates an independent observation, an array of values representing the distance
along a ray from the agent to the nearest neighbour or target, with each cell representing a
ray angle (with `num_vision` rays evenly distributed over the agents field of vision).
For example if an agent sees another agent straight ahead and `num_vision = 5` then
the observation array could be

```
[-1.0, -1.0, 0.5, -1.0, -1.0]
```

where `-1.0` indicates there are no agents along that ray, and `0.5` is the normalised
distance to the other agent. Channels in the segmented view are used to differentiate
between different agents/targets and can be customised. By default, the view has three
channels representing other agents, found targets, and unlocated targets respectively.
- `targets_remaining`: float in the range `[0, 1]`. The normalised number of targets
remaining to be detected (i.e. 1.0 when no targets have been found).
- `step`: int in the range `[0, time_limit]`. The current simulation step.
- `positions`: jax array (float) of shape `(num_searchers, 2)`. Agent coordinates.

## Actions

Jax array (float) of `(num_searchers, 2)` in the range `[-1, 1]`. Each entry in the
array represents an update of each agents velocity in the next step. Searching agents
update their velocity each step by rotating and accelerating/decelerating, where the
values are `[rotation, acceleration]`. Values are clipped to the range `[-1, 1]`
and then scaled by max rotation and acceleration parameters, i.e. the new values each
step are given by

```
heading = heading + max_rotation * action[0]
```

and speed

```
speed = speed + max_acceleration * action[1]
```

Once applied, agent speeds are clipped to velocities within a fixed range of speeds given
by the `min_speed` and `max_speed` parameters.

## Rewards

Jax array (float) of `(num_searchers,)`. Rewards are generated for each agent individually.

Agents are rewarded +1 for locating a target that has not already been detected. It is possible
for multiple agents to newly detect the same target inside a step. By default, the reward is
split between the locating agents if this is the case. By default, rewards granted linearly
decrease over time, from +1 to 0 at the final step. The reward function can be customised by
implementing the `RewardFn` interface.
7 changes: 7 additions & 0 deletions jumanji/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,10 @@
# LevelBasedForaging with a random generator with 8 grid size,
# 2 agents and 2 food items and the maximum agent's level is 2.
register(id="LevelBasedForaging-v0", entry_point="jumanji.environments:LevelBasedForaging")

###
# Swarm Environments
###

# Search-and-Rescue environment
register(id="SearchAndRescue-v0", entry_point="jumanji.environments:SearchAndRescue")
sash-a marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions jumanji/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from jumanji.environments.routing.snake.env import Snake
from jumanji.environments.routing.sokoban.env import Sokoban
from jumanji.environments.routing.tsp.env import TSP
from jumanji.environments.swarms.search_and_rescue.env import SearchAndRescue


def is_colab() -> bool:
Expand Down
8 changes: 4 additions & 4 deletions jumanji/environments/routing/snake/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ def __add__(self, other: Position) -> Position: # type: ignore[override]


class Actions(IntEnum):
UP: int = 0
RIGHT: int = 1
DOWN: int = 2
LEFT: int = 3
UP = 0
RIGHT = 1
DOWN = 2
LEFT = 3


@dataclass
Expand Down
13 changes: 13 additions & 0 deletions jumanji/environments/swarms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions jumanji/environments/swarms/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
188 changes: 188 additions & 0 deletions jumanji/environments/swarms/common/common_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple

import jax.numpy as jnp
import matplotlib
import matplotlib.pyplot as plt
import pytest

from jumanji.environments.swarms.common import types, updates, viewer


@pytest.fixture
def params() -> types.AgentParams:
return types.AgentParams(
max_rotate=0.5,
max_accelerate=0.01,
min_speed=0.01,
max_speed=0.05,
view_angle=0.5,
)


@pytest.mark.parametrize(
"heading, speed, actions, expected",
[
[0.0, 0.01, [1.0, 0.0], (0.5 * jnp.pi, 0.01)],
[0.0, 0.01, [-1.0, 0.0], (1.5 * jnp.pi, 0.01)],
[jnp.pi, 0.01, [1.0, 0.0], (1.5 * jnp.pi, 0.01)],
[jnp.pi, 0.01, [-1.0, 0.0], (0.5 * jnp.pi, 0.01)],
[1.75 * jnp.pi, 0.01, [1.0, 0.0], (0.25 * jnp.pi, 0.01)],
[0.0, 0.01, [0.0, 1.0], (0.0, 0.02)],
[0.0, 0.01, [0.0, -1.0], (0.0, 0.01)],
[0.0, 0.02, [0.0, -1.0], (0.0, 0.01)],
[0.0, 0.05, [0.0, -1.0], (0.0, 0.04)],
[0.0, 0.05, [0.0, 1.0], (0.0, 0.05)],
],
)
def test_velocity_update(
params: types.AgentParams,
heading: float,
speed: float,
actions: List[float],
expected: Tuple[float, float],
) -> None:
state = types.AgentState(
pos=jnp.zeros((1, 2)),
heading=jnp.array([heading]),
speed=jnp.array([speed]),
)
actions = jnp.array([actions])

new_heading, new_speed = updates.update_velocity(params, (actions, state))

assert jnp.isclose(new_heading[0], expected[0])
assert jnp.isclose(new_speed[0], expected[1])


@pytest.mark.parametrize(
"pos, heading, speed, expected, env_size",
[
[[0.0, 0.5], 0.0, 0.1, [0.1, 0.5], 1.0],
[[0.0, 0.5], jnp.pi, 0.1, [0.9, 0.5], 1.0],
[[0.5, 0.0], 0.5 * jnp.pi, 0.1, [0.5, 0.1], 1.0],
[[0.5, 0.0], 1.5 * jnp.pi, 0.1, [0.5, 0.9], 1.0],
[[0.4, 0.2], 0.0, 0.2, [0.1, 0.2], 0.5],
[[0.1, 0.2], jnp.pi, 0.2, [0.4, 0.2], 0.5],
[[0.2, 0.4], 0.5 * jnp.pi, 0.2, [0.2, 0.1], 0.5],
[[0.2, 0.1], 1.5 * jnp.pi, 0.2, [0.2, 0.4], 0.5],
],
)
def test_move(
pos: List[float], heading: float, speed: float, expected: List[float], env_size: float
) -> None:
pos = jnp.array(pos)
new_pos = updates.move(pos, heading, speed, env_size)

assert jnp.allclose(new_pos, jnp.array(expected))


@pytest.mark.parametrize(
"pos, heading, speed, actions, expected_pos, expected_heading, expected_speed, env_size",
[
[[0.0, 0.5], 0.0, 0.01, [0.0, 0.0], [0.01, 0.5], 0.0, 0.01, 1.0],
[[0.5, 0.0], 0.0, 0.01, [1.0, 0.0], [0.5, 0.01], 0.5 * jnp.pi, 0.01, 1.0],
[[0.5, 0.0], 0.0, 0.01, [-1.0, 0.0], [0.5, 0.99], 1.5 * jnp.pi, 0.01, 1.0],
[[0.0, 0.5], 0.0, 0.01, [0.0, 1.0], [0.02, 0.5], 0.0, 0.02, 1.0],
[[0.0, 0.5], 0.0, 0.01, [0.0, -1.0], [0.01, 0.5], 0.0, 0.01, 1.0],
[[0.0, 0.5], 0.0, 0.05, [0.0, 1.0], [0.05, 0.5], 0.0, 0.05, 1.0],
[[0.495, 0.25], 0.0, 0.01, [0.0, 0.0], [0.005, 0.25], 0.0, 0.01, 0.5],
[[0.25, 0.005], 1.5 * jnp.pi, 0.01, [0.0, 0.0], [0.25, 0.495], 1.5 * jnp.pi, 0.01, 0.5],
],
)
def test_state_update(
params: types.AgentParams,
pos: List[float],
heading: float,
speed: float,
actions: List[float],
expected_pos: List[float],
expected_heading: float,
expected_speed: float,
env_size: float,
) -> None:
state = types.AgentState(
pos=jnp.array([pos]),
heading=jnp.array([heading]),
speed=jnp.array([speed]),
)
actions = jnp.array([actions])

new_state = updates.update_state(env_size, params, state, actions)

assert isinstance(new_state, types.AgentState)
assert jnp.allclose(new_state.pos, jnp.array([expected_pos]))
assert jnp.allclose(new_state.heading, jnp.array([expected_heading]))
assert jnp.allclose(new_state.speed, jnp.array([expected_speed]))


def test_view_reduction() -> None:
view_a = jnp.array([-1.0, -1.0, 0.2, 0.2, 0.5])
view_b = jnp.array([-1.0, 0.2, -1.0, 0.5, 0.2])
result = updates.view_reduction_fn(view_a, view_b)
assert jnp.allclose(result, jnp.array([-1.0, 0.2, 0.2, 0.2, 0.2]))


@pytest.mark.parametrize(
"pos, view_angle, env_size, expected",
[
[[0.05, 0.0], 0.5, 1.0, [-1.0, -1.0, 0.5, -1.0, -1.0]],
[[0.0, 0.05], 0.5, 1.0, [0.5, -1.0, -1.0, -1.0, -1.0]],
[[0.0, 0.95], 0.5, 1.0, [-1.0, -1.0, -1.0, -1.0, 0.5]],
[[0.95, 0.0], 0.5, 1.0, [-1.0, -1.0, -1.0, -1.0, -1.0]],
[[0.05, 0.0], 0.25, 1.0, [-1.0, -1.0, 0.5, -1.0, -1.0]],
[[0.0, 0.05], 0.25, 1.0, [-1.0, -1.0, -1.0, -1.0, -1.0]],
[[0.0, 0.95], 0.25, 1.0, [-1.0, -1.0, -1.0, -1.0, -1.0]],
[[0.01, 0.0], 0.5, 1.0, [-1.0, 0.1, 0.1, 0.1, -1.0]],
[[0.0, 0.45], 0.5, 1.0, [4.5, -1.0, -1.0, -1.0, -1.0]],
[[0.0, 0.45], 0.5, 0.5, [-1.0, -1.0, -1.0, -1.0, 0.5]],
],
)
def test_view(pos: List[float], view_angle: float, env_size: float, expected: List[float]) -> None:
state_a = types.AgentState(
pos=jnp.zeros((2,)),
heading=0.0,
speed=0.0,
)

state_b = types.AgentState(
pos=jnp.array(pos),
heading=0.0,
speed=0.0,
)

obs = updates.view(
(view_angle, 0.02), state_a, state_b, n_view=5, i_range=0.1, env_size=env_size
)
assert jnp.allclose(obs, jnp.array(expected))


def test_viewer_utils() -> None:
f, ax = plt.subplots()
f, ax = viewer.format_plot(f, ax, (1.0, 1.0))

assert isinstance(f, matplotlib.figure.Figure)
assert isinstance(ax, matplotlib.axes.Axes)

state = types.AgentState(
pos=jnp.zeros((3, 2)),
heading=jnp.zeros((3,)),
speed=jnp.zeros((3,)),
)

quiver = viewer.draw_agents(ax, state, "red")

assert isinstance(quiver, matplotlib.quiver.Quiver)
53 changes: 53 additions & 0 deletions jumanji/environments/swarms/common/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from dataclasses import dataclass
else:
from chex import dataclass

import chex


@dataclass(frozen=True)
class AgentParams:
"""
max_rotate: Max angle an agent can rotate during a step (a fraction of pi)
max_accelerate: Max change in speed during a step
min_speed: Minimum agent speed
max_speed: Maximum agent speed
view_angle: Agent view angle, as a fraction of pi either side of its heading
"""

max_rotate: float
max_accelerate: float
min_speed: float
max_speed: float
view_angle: float


@dataclass
class AgentState:
"""
State of multiple agents of a single type

pos: 2d position of the (centre of the) agents
heading: Heading of the agents (in radians)
speed: Speed of the agents
"""

pos: chex.Array # (num_agents, 2)
heading: chex.Array # (num_agents,)
speed: chex.Array # (num_agents,)
Loading
Loading