Skip to content

Commit

Permalink
refactor: address pr comments (#16)
Browse files Browse the repository at this point in the history
* Tweak env docs

* Env tweaks/refactors

* Separate searcher and target view ranges and consolidate observation-fn

* Consolidate reward functions

* Implement randomly accelerating target dynamics

* Tweak parameters
  • Loading branch information
zombie-einstein authored Jan 17, 2025
1 parent ac3f811 commit 943a51b
Show file tree
Hide file tree
Showing 12 changed files with 183 additions and 540 deletions.
17 changes: 9 additions & 8 deletions docs/environments/search_and_rescue.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ space is a uniform square space, wrapped at the boundaries.

Many aspects of the environment can be customised:

- Agent observations can include targets as well as other searcher agents.
- Rewards can be shared by agents, or can be treated completely individually for individual
agents and can be scaled by time-step.
- Target dynamics can be customised to model various search scenarios.
- Agent observations can be customised by implementing the `ObservationFn` interface.
- Rewards can be customised by implementing the `RewardFn` interface.
- Target dynamics can be customised to model various search scenarios by implementing the
`TargetDynamics` interface.

## Observations

Expand All @@ -42,7 +42,7 @@ Many aspects of the environment can be customised:
where `-1.0` indicates there are no agents along that ray, and `0.5` is the normalised
distance to the other agent. Channels in the segmented view are used to differentiate
between different agents/targets and can be customised. By default, the view has three
channels representing other agents, found targets, and unlocated targets.
channels representing other agents, found targets, and unlocated targets respectively.
- `targets_remaining`: float in the range `[0, 1]`. The normalised number of targets
remaining to be detected (i.e. 1.0 when no targets have been found).
- `step`: int in the range `[0, time_limit]`. The current simulation step.
Expand Down Expand Up @@ -75,6 +75,7 @@ by the `min_speed` and `max_speed` parameters.
Jax array (float) of `(num_searchers,)`. Rewards are generated for each agent individually.

Agents are rewarded +1 for locating a target that has not already been detected. It is possible
for multiple agents to detect a target inside a step, as such rewards can either be shared
by the locating agents, or each individual agent can get the full reward. Rewards provided can
also be scaled by simulation step to encourage agents to develop efficient search patterns.
for multiple agents to newly detect the same target inside a step. By default, the reward is
split between the locating agents if this is the case. By default, rewards granted linearly
decrease over time, from +1 to 0 at the final step. The reward function can be customised by
implementing the `RewardFn` interface.
6 changes: 1 addition & 5 deletions jumanji/environments/swarms/common/updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,7 @@ def update_state(
headings, speeds = update_velocity(params, (actions, state))
positions = jax.vmap(move, in_axes=(0, 0, 0, None))(state.pos, headings, speeds, env_size)

return types.AgentState(
pos=positions,
speed=speeds,
heading=headings,
)
return state.replace(pos=positions, speed=speeds, heading=headings) # type: ignore


def view_reduction_fn(view_a: chex.Array, view_b: chex.Array) -> chex.Array:
Expand Down
51 changes: 8 additions & 43 deletions jumanji/environments/swarms/search_and_rescue/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,49 +21,14 @@

@pytest.fixture
def env() -> SearchAndRescue:
return SearchAndRescue(
target_contact_range=0.05,
searcher_max_rotate=0.2,
searcher_max_accelerate=0.01,
searcher_min_speed=0.01,
searcher_max_speed=0.05,
searcher_view_angle=0.5,
time_limit=10,
observation_fn = observations.AgentAndTargetObservationFn(
num_vision=32,
searcher_vision_range=0.1,
target_vision_range=0.1,
view_angle=0.5,
agent_radius=0.01,
env_size=1.0,
)


class FixtureRequest:
"""Just used for typing"""

param: observations.ObservationFn


@pytest.fixture(
params=[
observations.AgentObservationFn(
num_vision=32,
vision_range=0.1,
view_angle=0.5,
agent_radius=0.01,
env_size=1.0,
),
observations.AgentAndTargetObservationFn(
num_vision=32,
vision_range=0.1,
view_angle=0.5,
agent_radius=0.01,
env_size=1.0,
),
observations.AgentAndAllTargetObservationFn(
num_vision=32,
vision_range=0.1,
view_angle=0.5,
agent_radius=0.01,
env_size=1.0,
),
]
)
def multi_obs_env(request: FixtureRequest) -> SearchAndRescue:
return SearchAndRescue(
target_contact_range=0.05,
searcher_max_rotate=0.2,
Expand All @@ -72,7 +37,7 @@ def multi_obs_env(request: FixtureRequest) -> SearchAndRescue:
searcher_max_speed=0.05,
searcher_view_angle=0.5,
time_limit=10,
observation=request.param,
observation=observation_fn,
)


Expand Down
26 changes: 17 additions & 9 deletions jumanji/environments/swarms/search_and_rescue/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import chex
import jax
import jax.numpy as jnp

from jumanji.environments.swarms.search_and_rescue.types import TargetState

Expand All @@ -40,17 +41,20 @@ def __call__(self, key: chex.PRNGKey, targets: TargetState, env_size: float) ->


class RandomWalk(TargetDynamics):
def __init__(self, step_size: float):
def __init__(self, acc_std: float, vel_max: float):
"""
Simple random walk target dynamics.
Target positions are updated with random steps, sampled uniformly
from the range `[-step-size, step-size]`.
Target velocities are updated with values sampled from
a normal distribution with width `acc_std`, and the
magnitude of the updated velocity clipped to `vel_max`.
Args:
step_size: Maximum random step-size in each axis.
acc_std: Standard deviation of acceleration distribution
vel_max: Max velocity magnitude.
"""
self.step_size = step_size
self.acc_std = acc_std
self.vel_max = vel_max

def __call__(self, key: chex.PRNGKey, targets: TargetState, env_size: float) -> TargetState:
"""Update target state.
Expand All @@ -63,7 +67,11 @@ def __call__(self, key: chex.PRNGKey, targets: TargetState, env_size: float) ->
Returns:
Updated target states.
"""
d_pos = jax.random.uniform(key, targets.pos.shape)
d_pos = self.step_size * 2.0 * (d_pos - 0.5)
pos = (targets.pos + d_pos) % env_size
return TargetState(pos=pos, vel=targets.vel, found=targets.found)
acc = self.acc_std * jax.random.normal(key, targets.pos.shape)
vel = targets.vel + acc
norm = jnp.sqrt(jnp.sum(vel * vel, axis=1))
vel = jnp.where(
norm[:, jnp.newaxis] > self.vel_max, self.vel_max * vel / norm[:, jnp.newaxis], vel
)
pos = (targets.pos + vel) % env_size
return targets.replace(pos=pos, vel=vel) # type: ignore
43 changes: 43 additions & 0 deletions jumanji/environments/swarms/search_and_rescue/dynamics_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import chex
import jax.numpy as jnp

from jumanji.environments.swarms.search_and_rescue.dynamics import RandomWalk, TargetDynamics
from jumanji.environments.swarms.search_and_rescue.types import TargetState


def test_random_walk_dynamics(key: chex.PRNGKey) -> None:
n_targets = 50
vel_max = 0.01
pos_0 = jnp.full((n_targets, 2), 0.5)

s0 = TargetState(
pos=pos_0, vel=jnp.zeros((n_targets, 2)), found=jnp.zeros((n_targets,), dtype=bool)
)

dynamics = RandomWalk(acc_std=0.005, vel_max=vel_max)
assert isinstance(dynamics, TargetDynamics)
s1 = dynamics(key, s0, 1.0)

assert isinstance(s1, TargetState)
assert s1.pos.shape == (n_targets, 2)
assert s1.vel.shape == (n_targets, 2)
assert jnp.array_equal(s0.found, s1.found)
speed = jnp.sqrt(jnp.sum(s1.vel**2, axis=1))
assert speed.shape == (n_targets,)
assert jnp.all(speed <= vel_max + 1e-6)
d_pos = jnp.sqrt(jnp.sum((s0.pos - s1.pos) ** 2, axis=1))
assert jnp.all(d_pos <= vel_max + 1e-6)
60 changes: 33 additions & 27 deletions jumanji/environments/swarms/search_and_rescue/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
from jumanji.environments.swarms.search_and_rescue.dynamics import RandomWalk, TargetDynamics
from jumanji.environments.swarms.search_and_rescue.generator import Generator, RandomGenerator
from jumanji.environments.swarms.search_and_rescue.observations import (
AgentAndAllTargetObservationFn,
AgentAndTargetObservationFn,
ObservationFn,
)
from jumanji.environments.swarms.search_and_rescue.reward import RewardFn, SharedRewardFn
from jumanji.environments.swarms.search_and_rescue.reward import IndividualRewardFn, RewardFn
from jumanji.environments.swarms.search_and_rescue.types import Observation, State, TargetState
from jumanji.environments.swarms.search_and_rescue.viewer import SearchAndRescueViewer
from jumanji.types import TimeStep, restart, termination, transition
Expand All @@ -57,12 +57,10 @@ class SearchAndRescue(Environment):
- observation: `Observation`
searcher_views: jax array (float) of shape (num_searchers, channels, num_vision)
Individual local views of positions of other agents and targets, where
channels can be used to differentiate between agents and targets.
Each entry in the view indicates the distant to another agent/target
channels can be used to differentiate between agents and targets types.
Each entry in the view indicates the distance to another agent/target
along a ray from the agent, and is -1.0 if nothing is in range along the ray.
The view model can be customised using an `ObservationFn` implementation, e.g.
the view can include agents and all targets, agents and found targets,or
just other agents.
The view model can be customised by implementing the `ObservationFn` interface.
targets_remaining: (float) Number of targets remaining to be found from
the total scaled to the range [0, 1] (i.e. a value of 1.0 indicates
all the targets are still to be found).
Expand All @@ -78,9 +76,13 @@ class SearchAndRescue(Environment):
- reward: jax array (float) of shape (num_searchers,)
Arrays of individual agent rewards. A reward of +1 is granted when an agent
comes into contact range with a target that has not yet been found, and
the target is within the searchers view cone. Rewards can be shared
between agents if a target is simultaneously detected by multiple agents,
or each can be provided the full reward individually.
the target is within the searchers view cone. It is possible for multiple
agents to newly find the same target within a given step, by default
in this case the reward is split between the locating agents. By default,
rewards granted linearly decrease over time, with zero reward granted
at the environment time-limit. These defaults can be modified by flags
in `IndividualRewardFn`, or further customised by implementing the `RewardFn`
interface.
- state: `State`
- searchers: `AgentState`
Expand Down Expand Up @@ -146,11 +148,13 @@ def __init__(
target_dynamics: Target object dynamics model, implemented as a
`TargetDynamics` interface. Defaults to `RandomWalk`.
generator: Initial state `Generator` instance. Defaults to `RandomGenerator`
with 20 targets and 10 searchers.
reward_fn: Reward aggregation function. Defaults to `SharedRewardFn` where
agents share rewards if they locate a target simultaneously.
with 50 targets and 2 searchers, with targets uniformly distributed
across the environment.
reward_fn: Reward aggregation function. Defaults to `IndividualRewardFn` where
agents split rewards if they locate a target simultaneously, and
rewards linearly decrease to zero over time.
observation: Agent observation view generation function. Defaults to
`AgentAndAllTargetObservationFn` where all targets (found and unfound)
`AgentAndTargetObservationFn` where all targets (found and unfound)
and other searching agents are included in the generated view.
"""

Expand All @@ -164,13 +168,14 @@ def __init__(
view_angle=searcher_view_angle,
)
self.time_limit = time_limit
self._target_dynamics = target_dynamics or RandomWalk(0.001)
self.generator = generator or RandomGenerator(num_targets=50, num_searchers=2)
self._target_dynamics = target_dynamics or RandomWalk(acc_std=0.0001, vel_max=0.002)
self.generator = generator or RandomGenerator(num_targets=20, num_searchers=2)
self._viewer = viewer or SearchAndRescueViewer()
self._reward_fn = reward_fn or SharedRewardFn()
self._observation = observation or AgentAndAllTargetObservationFn(
self._reward_fn = reward_fn or IndividualRewardFn()
self._observation_fn = observation or AgentAndTargetObservationFn(
num_vision=64,
vision_range=0.25,
searcher_vision_range=0.25,
target_vision_range=0.1,
view_angle=searcher_view_angle,
agent_radius=0.02,
env_size=self.generator.env_size,
Expand All @@ -187,17 +192,18 @@ def __repr__(self) -> str:
f" - max searcher acceleration: {self.searcher_params.max_accelerate}",
f" - searcher min speed: {self.searcher_params.min_speed}",
f" - searcher max speed: {self.searcher_params.max_speed}",
f" - search vision range: {self._observation.vision_range}",
f" - search view angle: {self._observation.view_angle}",
f" - search vision range: {self._observation_fn.searcher_vision_range}",
f" - target vision range: {self._observation_fn.target_vision_range}",
f" - search view angle: {self._observation_fn.view_angle}",
f" - target contact range: {self.target_contact_range}",
f" - num vision: {self._observation.num_vision}",
f" - agent radius: {self._observation.agent_radius}",
f" - num vision: {self._observation_fn.num_vision}",
f" - agent radius: {self._observation_fn.agent_radius}",
f" - time limit: {self.time_limit},"
f" - env size: {self.generator.env_size}"
f" - target dynamics: {self._target_dynamics.__class__.__name__}",
f" - generator: {self.generator.__class__.__name__}",
f" - reward fn: {self._reward_fn.__class__.__name__}",
f" - observation fn: {self._observation.__class__.__name__}",
f" - observation fn: {self._observation_fn.__class__.__name__}",
]
)

Expand Down Expand Up @@ -276,7 +282,7 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser
return state, timestep

def _state_to_observation(self, state: State) -> Observation:
searcher_views = self._observation(state)
searcher_views = self._observation_fn(state)
return Observation(
searcher_views=searcher_views,
targets_remaining=1.0 - jnp.sum(state.targets.found) / self.generator.num_targets,
Expand All @@ -301,8 +307,8 @@ def observation_spec(self) -> specs.Spec[Observation]:
searcher_views = specs.BoundedArray(
shape=(
self.num_agents,
self._observation.num_channels,
self._observation.num_vision,
self._observation_fn.num_channels,
self._observation_fn.num_vision,
),
minimum=-1.0,
maximum=1.0,
Expand Down
20 changes: 10 additions & 10 deletions jumanji/environments/swarms/search_and_rescue/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def test_env_init(env: SearchAndRescue, key: chex.PRNGKey) -> None:
assert isinstance(timestep.observation, Observation)
assert timestep.observation.searcher_views.shape == (
env.generator.num_searchers,
env._observation.num_channels,
env._observation.num_vision,
env._observation_fn.num_channels,
env._observation_fn.num_vision,
)
assert timestep.step_type == StepType.FIRST
assert timestep.reward.shape == (env.generator.num_searchers,)
Expand Down Expand Up @@ -113,26 +113,26 @@ def step(
)


def test_env_does_not_smoke(multi_obs_env: SearchAndRescue) -> None:
def test_env_does_not_smoke(env: SearchAndRescue) -> None:
"""Test that we can run an episode without any errors."""
multi_obs_env.time_limit = 10
env.time_limit = 10

def select_action(action_key: chex.PRNGKey, _state: Observation) -> chex.Array:
return jax.random.uniform(
action_key, (multi_obs_env.generator.num_searchers, 2), minval=-1.0, maxval=1.0
action_key, (env.generator.num_searchers, 2), minval=-1.0, maxval=1.0
)

check_env_does_not_smoke(multi_obs_env, select_action=select_action)
check_env_does_not_smoke(env, select_action=select_action)


def test_env_specs_do_not_smoke(multi_obs_env: SearchAndRescue) -> None:
def test_env_specs_do_not_smoke(env: SearchAndRescue) -> None:
"""Test that we can access specs without any errors."""
check_env_specs_does_not_smoke(multi_obs_env)
check_env_specs_does_not_smoke(env)


def test_target_detection(env: SearchAndRescue, key: chex.PRNGKey) -> None:
# Keep targets in one location
env._target_dynamics = RandomWalk(step_size=0.0)
env._target_dynamics = RandomWalk(acc_std=0.1, vel_max=0.0)
env.generator.num_targets = 1

# Agent facing wrong direction should not see target
Expand Down Expand Up @@ -183,7 +183,7 @@ def test_target_detection(env: SearchAndRescue, key: chex.PRNGKey) -> None:

def test_multi_target_detection(env: SearchAndRescue, key: chex.PRNGKey) -> None:
# Keep targets in one location
env._target_dynamics = RandomWalk(step_size=0.0)
env._target_dynamics = RandomWalk(acc_std=0.1, vel_max=0.0)
env.searcher_params = AgentParams(
max_rotate=0.1,
max_accelerate=0.01,
Expand Down
Loading

0 comments on commit 943a51b

Please sign in to comment.