diff --git a/docs/environments/search_and_rescue.md b/docs/environments/search_and_rescue.md index a707a8553..6fd325191 100644 --- a/docs/environments/search_and_rescue.md +++ b/docs/environments/search_and_rescue.md @@ -21,10 +21,10 @@ space is a uniform square space, wrapped at the boundaries. Many aspects of the environment can be customised: -- Agent observations can include targets as well as other searcher agents. -- Rewards can be shared by agents, or can be treated completely individually for individual - agents and can be scaled by time-step. -- Target dynamics can be customised to model various search scenarios. +- Agent observations can be customised by implementing the `ObservationFn` interface. +- Rewards can be customised by implementing the `RewardFn` interface. +- Target dynamics can be customised to model various search scenarios by implementing the + `TargetDynamics` interface. ## Observations @@ -42,7 +42,7 @@ Many aspects of the environment can be customised: where `-1.0` indicates there are no agents along that ray, and `0.5` is the normalised distance to the other agent. Channels in the segmented view are used to differentiate between different agents/targets and can be customised. By default, the view has three - channels representing other agents, found targets, and unlocated targets. + channels representing other agents, found targets, and unlocated targets respectively. - `targets_remaining`: float in the range `[0, 1]`. The normalised number of targets remaining to be detected (i.e. 1.0 when no targets have been found). - `step`: int in the range `[0, time_limit]`. The current simulation step. @@ -75,6 +75,7 @@ by the `min_speed` and `max_speed` parameters. Jax array (float) of `(num_searchers,)`. Rewards are generated for each agent individually. Agents are rewarded +1 for locating a target that has not already been detected. It is possible -for multiple agents to detect a target inside a step, as such rewards can either be shared -by the locating agents, or each individual agent can get the full reward. Rewards provided can -also be scaled by simulation step to encourage agents to develop efficient search patterns. +for multiple agents to newly detect the same target inside a step. By default, the reward is +split between the locating agents if this is the case. By default, rewards granted linearly +decrease over time, from +1 to 0 at the final step. The reward function can be customised by +implementing the `RewardFn` interface. diff --git a/jumanji/environments/swarms/common/updates.py b/jumanji/environments/swarms/common/updates.py index 25f99e1e5..e06beb0d5 100644 --- a/jumanji/environments/swarms/common/updates.py +++ b/jumanji/environments/swarms/common/updates.py @@ -92,11 +92,7 @@ def update_state( headings, speeds = update_velocity(params, (actions, state)) positions = jax.vmap(move, in_axes=(0, 0, 0, None))(state.pos, headings, speeds, env_size) - return types.AgentState( - pos=positions, - speed=speeds, - heading=headings, - ) + return state.replace(pos=positions, speed=speeds, heading=headings) # type: ignore def view_reduction_fn(view_a: chex.Array, view_b: chex.Array) -> chex.Array: diff --git a/jumanji/environments/swarms/search_and_rescue/conftest.py b/jumanji/environments/swarms/search_and_rescue/conftest.py index 8158bacac..2829c3974 100644 --- a/jumanji/environments/swarms/search_and_rescue/conftest.py +++ b/jumanji/environments/swarms/search_and_rescue/conftest.py @@ -21,49 +21,14 @@ @pytest.fixture def env() -> SearchAndRescue: - return SearchAndRescue( - target_contact_range=0.05, - searcher_max_rotate=0.2, - searcher_max_accelerate=0.01, - searcher_min_speed=0.01, - searcher_max_speed=0.05, - searcher_view_angle=0.5, - time_limit=10, + observation_fn = observations.AgentAndTargetObservationFn( + num_vision=32, + searcher_vision_range=0.1, + target_vision_range=0.1, + view_angle=0.5, + agent_radius=0.01, + env_size=1.0, ) - - -class FixtureRequest: - """Just used for typing""" - - param: observations.ObservationFn - - -@pytest.fixture( - params=[ - observations.AgentObservationFn( - num_vision=32, - vision_range=0.1, - view_angle=0.5, - agent_radius=0.01, - env_size=1.0, - ), - observations.AgentAndTargetObservationFn( - num_vision=32, - vision_range=0.1, - view_angle=0.5, - agent_radius=0.01, - env_size=1.0, - ), - observations.AgentAndAllTargetObservationFn( - num_vision=32, - vision_range=0.1, - view_angle=0.5, - agent_radius=0.01, - env_size=1.0, - ), - ] -) -def multi_obs_env(request: FixtureRequest) -> SearchAndRescue: return SearchAndRescue( target_contact_range=0.05, searcher_max_rotate=0.2, @@ -72,7 +37,7 @@ def multi_obs_env(request: FixtureRequest) -> SearchAndRescue: searcher_max_speed=0.05, searcher_view_angle=0.5, time_limit=10, - observation=request.param, + observation=observation_fn, ) diff --git a/jumanji/environments/swarms/search_and_rescue/dynamics.py b/jumanji/environments/swarms/search_and_rescue/dynamics.py index 1a6897769..8d00e6bbe 100644 --- a/jumanji/environments/swarms/search_and_rescue/dynamics.py +++ b/jumanji/environments/swarms/search_and_rescue/dynamics.py @@ -16,6 +16,7 @@ import chex import jax +import jax.numpy as jnp from jumanji.environments.swarms.search_and_rescue.types import TargetState @@ -40,17 +41,20 @@ def __call__(self, key: chex.PRNGKey, targets: TargetState, env_size: float) -> class RandomWalk(TargetDynamics): - def __init__(self, step_size: float): + def __init__(self, acc_std: float, vel_max: float): """ Simple random walk target dynamics. - Target positions are updated with random steps, sampled uniformly - from the range `[-step-size, step-size]`. + Target velocities are updated with values sampled from + a normal distribution with width `acc_std`, and the + magnitude of the updated velocity clipped to `vel_max`. Args: - step_size: Maximum random step-size in each axis. + acc_std: Standard deviation of acceleration distribution + vel_max: Max velocity magnitude. """ - self.step_size = step_size + self.acc_std = acc_std + self.vel_max = vel_max def __call__(self, key: chex.PRNGKey, targets: TargetState, env_size: float) -> TargetState: """Update target state. @@ -63,7 +67,11 @@ def __call__(self, key: chex.PRNGKey, targets: TargetState, env_size: float) -> Returns: Updated target states. """ - d_pos = jax.random.uniform(key, targets.pos.shape) - d_pos = self.step_size * 2.0 * (d_pos - 0.5) - pos = (targets.pos + d_pos) % env_size - return TargetState(pos=pos, vel=targets.vel, found=targets.found) + acc = self.acc_std * jax.random.normal(key, targets.pos.shape) + vel = targets.vel + acc + norm = jnp.sqrt(jnp.sum(vel * vel, axis=1)) + vel = jnp.where( + norm[:, jnp.newaxis] > self.vel_max, self.vel_max * vel / norm[:, jnp.newaxis], vel + ) + pos = (targets.pos + vel) % env_size + return targets.replace(pos=pos, vel=vel) # type: ignore diff --git a/jumanji/environments/swarms/search_and_rescue/dynamics_test.py b/jumanji/environments/swarms/search_and_rescue/dynamics_test.py new file mode 100644 index 000000000..5d064c387 --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/dynamics_test.py @@ -0,0 +1,43 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import chex +import jax.numpy as jnp + +from jumanji.environments.swarms.search_and_rescue.dynamics import RandomWalk, TargetDynamics +from jumanji.environments.swarms.search_and_rescue.types import TargetState + + +def test_random_walk_dynamics(key: chex.PRNGKey) -> None: + n_targets = 50 + vel_max = 0.01 + pos_0 = jnp.full((n_targets, 2), 0.5) + + s0 = TargetState( + pos=pos_0, vel=jnp.zeros((n_targets, 2)), found=jnp.zeros((n_targets,), dtype=bool) + ) + + dynamics = RandomWalk(acc_std=0.005, vel_max=vel_max) + assert isinstance(dynamics, TargetDynamics) + s1 = dynamics(key, s0, 1.0) + + assert isinstance(s1, TargetState) + assert s1.pos.shape == (n_targets, 2) + assert s1.vel.shape == (n_targets, 2) + assert jnp.array_equal(s0.found, s1.found) + speed = jnp.sqrt(jnp.sum(s1.vel**2, axis=1)) + assert speed.shape == (n_targets,) + assert jnp.all(speed <= vel_max + 1e-6) + d_pos = jnp.sqrt(jnp.sum((s0.pos - s1.pos) ** 2, axis=1)) + assert jnp.all(d_pos <= vel_max + 1e-6) diff --git a/jumanji/environments/swarms/search_and_rescue/env.py b/jumanji/environments/swarms/search_and_rescue/env.py index ad7c7789a..2cebb76af 100644 --- a/jumanji/environments/swarms/search_and_rescue/env.py +++ b/jumanji/environments/swarms/search_and_rescue/env.py @@ -30,10 +30,10 @@ from jumanji.environments.swarms.search_and_rescue.dynamics import RandomWalk, TargetDynamics from jumanji.environments.swarms.search_and_rescue.generator import Generator, RandomGenerator from jumanji.environments.swarms.search_and_rescue.observations import ( - AgentAndAllTargetObservationFn, + AgentAndTargetObservationFn, ObservationFn, ) -from jumanji.environments.swarms.search_and_rescue.reward import RewardFn, SharedRewardFn +from jumanji.environments.swarms.search_and_rescue.reward import IndividualRewardFn, RewardFn from jumanji.environments.swarms.search_and_rescue.types import Observation, State, TargetState from jumanji.environments.swarms.search_and_rescue.viewer import SearchAndRescueViewer from jumanji.types import TimeStep, restart, termination, transition @@ -57,12 +57,10 @@ class SearchAndRescue(Environment): - observation: `Observation` searcher_views: jax array (float) of shape (num_searchers, channels, num_vision) Individual local views of positions of other agents and targets, where - channels can be used to differentiate between agents and targets. - Each entry in the view indicates the distant to another agent/target + channels can be used to differentiate between agents and targets types. + Each entry in the view indicates the distance to another agent/target along a ray from the agent, and is -1.0 if nothing is in range along the ray. - The view model can be customised using an `ObservationFn` implementation, e.g. - the view can include agents and all targets, agents and found targets,or - just other agents. + The view model can be customised by implementing the `ObservationFn` interface. targets_remaining: (float) Number of targets remaining to be found from the total scaled to the range [0, 1] (i.e. a value of 1.0 indicates all the targets are still to be found). @@ -78,9 +76,13 @@ class SearchAndRescue(Environment): - reward: jax array (float) of shape (num_searchers,) Arrays of individual agent rewards. A reward of +1 is granted when an agent comes into contact range with a target that has not yet been found, and - the target is within the searchers view cone. Rewards can be shared - between agents if a target is simultaneously detected by multiple agents, - or each can be provided the full reward individually. + the target is within the searchers view cone. It is possible for multiple + agents to newly find the same target within a given step, by default + in this case the reward is split between the locating agents. By default, + rewards granted linearly decrease over time, with zero reward granted + at the environment time-limit. These defaults can be modified by flags + in `IndividualRewardFn`, or further customised by implementing the `RewardFn` + interface. - state: `State` - searchers: `AgentState` @@ -146,11 +148,13 @@ def __init__( target_dynamics: Target object dynamics model, implemented as a `TargetDynamics` interface. Defaults to `RandomWalk`. generator: Initial state `Generator` instance. Defaults to `RandomGenerator` - with 20 targets and 10 searchers. - reward_fn: Reward aggregation function. Defaults to `SharedRewardFn` where - agents share rewards if they locate a target simultaneously. + with 50 targets and 2 searchers, with targets uniformly distributed + across the environment. + reward_fn: Reward aggregation function. Defaults to `IndividualRewardFn` where + agents split rewards if they locate a target simultaneously, and + rewards linearly decrease to zero over time. observation: Agent observation view generation function. Defaults to - `AgentAndAllTargetObservationFn` where all targets (found and unfound) + `AgentAndTargetObservationFn` where all targets (found and unfound) and other searching agents are included in the generated view. """ @@ -164,13 +168,14 @@ def __init__( view_angle=searcher_view_angle, ) self.time_limit = time_limit - self._target_dynamics = target_dynamics or RandomWalk(0.001) - self.generator = generator or RandomGenerator(num_targets=50, num_searchers=2) + self._target_dynamics = target_dynamics or RandomWalk(acc_std=0.0001, vel_max=0.002) + self.generator = generator or RandomGenerator(num_targets=20, num_searchers=2) self._viewer = viewer or SearchAndRescueViewer() - self._reward_fn = reward_fn or SharedRewardFn() - self._observation = observation or AgentAndAllTargetObservationFn( + self._reward_fn = reward_fn or IndividualRewardFn() + self._observation_fn = observation or AgentAndTargetObservationFn( num_vision=64, - vision_range=0.25, + searcher_vision_range=0.25, + target_vision_range=0.1, view_angle=searcher_view_angle, agent_radius=0.02, env_size=self.generator.env_size, @@ -187,17 +192,18 @@ def __repr__(self) -> str: f" - max searcher acceleration: {self.searcher_params.max_accelerate}", f" - searcher min speed: {self.searcher_params.min_speed}", f" - searcher max speed: {self.searcher_params.max_speed}", - f" - search vision range: {self._observation.vision_range}", - f" - search view angle: {self._observation.view_angle}", + f" - search vision range: {self._observation_fn.searcher_vision_range}", + f" - target vision range: {self._observation_fn.target_vision_range}", + f" - search view angle: {self._observation_fn.view_angle}", f" - target contact range: {self.target_contact_range}", - f" - num vision: {self._observation.num_vision}", - f" - agent radius: {self._observation.agent_radius}", + f" - num vision: {self._observation_fn.num_vision}", + f" - agent radius: {self._observation_fn.agent_radius}", f" - time limit: {self.time_limit}," f" - env size: {self.generator.env_size}" f" - target dynamics: {self._target_dynamics.__class__.__name__}", f" - generator: {self.generator.__class__.__name__}", f" - reward fn: {self._reward_fn.__class__.__name__}", - f" - observation fn: {self._observation.__class__.__name__}", + f" - observation fn: {self._observation_fn.__class__.__name__}", ] ) @@ -276,7 +282,7 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser return state, timestep def _state_to_observation(self, state: State) -> Observation: - searcher_views = self._observation(state) + searcher_views = self._observation_fn(state) return Observation( searcher_views=searcher_views, targets_remaining=1.0 - jnp.sum(state.targets.found) / self.generator.num_targets, @@ -301,8 +307,8 @@ def observation_spec(self) -> specs.Spec[Observation]: searcher_views = specs.BoundedArray( shape=( self.num_agents, - self._observation.num_channels, - self._observation.num_vision, + self._observation_fn.num_channels, + self._observation_fn.num_vision, ), minimum=-1.0, maximum=1.0, diff --git a/jumanji/environments/swarms/search_and_rescue/env_test.py b/jumanji/environments/swarms/search_and_rescue/env_test.py index 743546194..734423c8b 100644 --- a/jumanji/environments/swarms/search_and_rescue/env_test.py +++ b/jumanji/environments/swarms/search_and_rescue/env_test.py @@ -58,8 +58,8 @@ def test_env_init(env: SearchAndRescue, key: chex.PRNGKey) -> None: assert isinstance(timestep.observation, Observation) assert timestep.observation.searcher_views.shape == ( env.generator.num_searchers, - env._observation.num_channels, - env._observation.num_vision, + env._observation_fn.num_channels, + env._observation_fn.num_vision, ) assert timestep.step_type == StepType.FIRST assert timestep.reward.shape == (env.generator.num_searchers,) @@ -113,26 +113,26 @@ def step( ) -def test_env_does_not_smoke(multi_obs_env: SearchAndRescue) -> None: +def test_env_does_not_smoke(env: SearchAndRescue) -> None: """Test that we can run an episode without any errors.""" - multi_obs_env.time_limit = 10 + env.time_limit = 10 def select_action(action_key: chex.PRNGKey, _state: Observation) -> chex.Array: return jax.random.uniform( - action_key, (multi_obs_env.generator.num_searchers, 2), minval=-1.0, maxval=1.0 + action_key, (env.generator.num_searchers, 2), minval=-1.0, maxval=1.0 ) - check_env_does_not_smoke(multi_obs_env, select_action=select_action) + check_env_does_not_smoke(env, select_action=select_action) -def test_env_specs_do_not_smoke(multi_obs_env: SearchAndRescue) -> None: +def test_env_specs_do_not_smoke(env: SearchAndRescue) -> None: """Test that we can access specs without any errors.""" - check_env_specs_does_not_smoke(multi_obs_env) + check_env_specs_does_not_smoke(env) def test_target_detection(env: SearchAndRescue, key: chex.PRNGKey) -> None: # Keep targets in one location - env._target_dynamics = RandomWalk(step_size=0.0) + env._target_dynamics = RandomWalk(acc_std=0.1, vel_max=0.0) env.generator.num_targets = 1 # Agent facing wrong direction should not see target @@ -183,7 +183,7 @@ def test_target_detection(env: SearchAndRescue, key: chex.PRNGKey) -> None: def test_multi_target_detection(env: SearchAndRescue, key: chex.PRNGKey) -> None: # Keep targets in one location - env._target_dynamics = RandomWalk(step_size=0.0) + env._target_dynamics = RandomWalk(acc_std=0.1, vel_max=0.0) env.searcher_params = AgentParams( max_rotate=0.1, max_accelerate=0.01, diff --git a/jumanji/environments/swarms/search_and_rescue/observations.py b/jumanji/environments/swarms/search_and_rescue/observations.py index 6a7e587c0..98b1d2f0f 100644 --- a/jumanji/environments/swarms/search_and_rescue/observations.py +++ b/jumanji/environments/swarms/search_and_rescue/observations.py @@ -36,7 +36,8 @@ def __init__( self, num_channels: int, num_vision: int, - vision_range: float, + searcher_vision_range: float, + target_vision_range: float, view_angle: float, agent_radius: float, env_size: float, @@ -52,14 +53,17 @@ def __init__( Args: num_channels: Number of channels in agent view. num_vision: Size of vision array. - vision_range: Vision range. + searcher_vision_range: Range at which other searching agents + become visible. + target_vision_range: Range at which targets become visible. view_angle: Agent view angle (as a fraction of π). agent_radius: Agent/target visual radius. env_size: Environment size. """ self.num_channels = num_channels self.num_vision = num_vision - self.vision_range = vision_range + self.searcher_vision_range = searcher_vision_range + self.target_vision_range = target_vision_range self.view_angle = view_angle self.agent_radius = agent_radius self.env_size = env_size @@ -77,191 +81,7 @@ def __call__(self, state: State) -> chex.Array: """ -class AgentObservationFn(ObservationFn): - def __init__( - self, - num_vision: int, - vision_range: float, - view_angle: float, - agent_radius: float, - env_size: float, - ) -> None: - """ - Observation that only visualises other search agents in proximity. - - Args: - num_vision: Size of vision array. - vision_range: Vision range. - view_angle: Agent view angle (as a fraction of π). - agent_radius: Agent/target visual radius. - env_size: Environment size. - """ - super().__init__( - 1, - num_vision, - vision_range, - view_angle, - agent_radius, - env_size, - ) - - def __call__(self, state: State) -> chex.Array: - """ - Generate agent view/observation from state - - Args: - state: Current simulation state - - Returns: - Array of individual agent views of shape - (n-agents, 1, n-vision). - """ - searcher_views = esquilax.transforms.spatial( - view, - reduction=view_reduction((self.num_vision,)), - include_self=False, - i_range=self.vision_range, - dims=self.env_size, - )( - (self.view_angle, self.agent_radius), - state.searchers, - state.searchers, - pos=state.searchers.pos, - n_view=self.num_vision, - i_range=self.vision_range, - env_size=self.env_size, - ) - return searcher_views[:, jnp.newaxis] - - -def found_target_view( - params: Tuple[float, float], - searcher: AgentState, - target: TargetState, - *, - n_view: int, - i_range: float, - env_size: float, -) -> chex.Array: - """ - Return view of a target, dependent on target status. - - This function is intended to be mapped over agents-target - pairs by Esquilax. - - Args: - params: View angle and target visual radius. - searcher: Searcher agent state - target: Target state - n_view: Number of values in view array. - i_range: Vision range - env_size: Environment size - - Returns: - Segmented agent view of target. - """ - view_angle, agent_radius = params - rays = jnp.linspace( - -view_angle * jnp.pi, - view_angle * jnp.pi, - n_view, - endpoint=True, - ) - d, left, right = angular_width( - searcher.pos, - target.pos, - searcher.heading, - i_range, - agent_radius, - env_size, - ) - checks = jnp.logical_and(target.found, jnp.logical_and(left < rays, rays < right)) - obs = jnp.where(checks, d, -1.0) - return obs - - -class AgentAndTargetObservationFn(ObservationFn): - def __init__( - self, - num_vision: int, - vision_range: float, - view_angle: float, - agent_radius: float, - env_size: float, - ) -> None: - """ - Vision model that contains other agents and found targets. - - Searchers and targets are visualised as individual channels. - Targets are only included if they have been located already. - - Args: - num_vision: Size of vision array. - vision_range: Vision range. - view_angle: Agent view angle (as a fraction of π). - agent_radius: Agent/target visual radius. - env_size: Environment size. - """ - self.vision_range = vision_range - self.view_angle = view_angle - self.agent_radius = agent_radius - self.env_size = env_size - super().__init__( - 2, - num_vision, - vision_range, - view_angle, - agent_radius, - env_size, - ) - - def __call__(self, state: State) -> chex.Array: - """ - Generate agent view/observation from state - - Args: - state: Current simulation state - - Returns: - Array of individual agent views of shape - (n-agents, 2, n-vision). Other agents are shown - in channel 0, and located targets in channel 1. - """ - searcher_views = esquilax.transforms.spatial( - view, - reduction=view_reduction((self.num_vision,)), - include_self=False, - i_range=self.vision_range, - dims=self.env_size, - )( - (self.view_angle, self.agent_radius), - state.searchers, - state.searchers, - pos=state.searchers.pos, - n_view=self.num_vision, - i_range=self.vision_range, - env_size=self.env_size, - ) - target_views = esquilax.transforms.spatial( - found_target_view, - reduction=view_reduction((self.num_vision,)), - include_self=False, - i_range=self.vision_range, - dims=self.env_size, - )( - (self.view_angle, self.agent_radius), - state.searchers, - state.targets, - pos=state.searchers.pos, - pos_b=state.targets.pos, - n_view=self.num_vision, - i_range=self.vision_range, - env_size=self.env_size, - ) - return jnp.hstack([searcher_views[:, jnp.newaxis], target_views[:, jnp.newaxis]]) - - -def all_target_view( +def _target_view( params: Tuple[float, float], searcher: AgentState, target: TargetState, @@ -310,11 +130,12 @@ def all_target_view( return obs -class AgentAndAllTargetObservationFn(ObservationFn): +class AgentAndTargetObservationFn(ObservationFn): def __init__( self, num_vision: int, - vision_range: float, + searcher_vision_range: float, + target_vision_range: float, view_angle: float, agent_radius: float, env_size: float, @@ -327,19 +148,18 @@ def __init__( Args: num_vision: Size of vision array. - vision_range: Vision range. + searcher_vision_range: Range at which other searching agents + become visible. + target_vision_range: Range at which targets become visible. view_angle: Agent view angle (as a fraction of π). agent_radius: Agent/target visual radius. env_size: Environment size. """ - self.vision_range = vision_range - self.view_angle = view_angle - self.agent_radius = agent_radius - self.env_size = env_size super().__init__( 3, num_vision, - vision_range, + searcher_vision_range, + target_vision_range, view_angle, agent_radius, env_size, @@ -362,7 +182,7 @@ def __call__(self, state: State) -> chex.Array: view, reduction=view_reduction((self.num_vision,)), include_self=False, - i_range=self.vision_range, + i_range=self.searcher_vision_range, dims=self.env_size, )( (self.view_angle, self.agent_radius), @@ -370,14 +190,14 @@ def __call__(self, state: State) -> chex.Array: state.searchers, pos=state.searchers.pos, n_view=self.num_vision, - i_range=self.vision_range, + i_range=self.searcher_vision_range, env_size=self.env_size, ) target_views = esquilax.transforms.spatial( - all_target_view, + _target_view, reduction=view_reduction((2, self.num_vision)), include_self=False, - i_range=self.vision_range, + i_range=self.target_vision_range, dims=self.env_size, )( (self.view_angle, self.agent_radius), @@ -386,7 +206,7 @@ def __call__(self, state: State) -> chex.Array: pos=state.searchers.pos, pos_b=state.targets.pos, n_view=self.num_vision, - i_range=self.vision_range, + i_range=self.target_vision_range, env_size=self.env_size, ) return jnp.hstack([searcher_views[:, jnp.newaxis], target_views]) diff --git a/jumanji/environments/swarms/search_and_rescue/observations_test.py b/jumanji/environments/swarms/search_and_rescue/observations_test.py index a95dcd271..9e91488da 100644 --- a/jumanji/environments/swarms/search_and_rescue/observations_test.py +++ b/jumanji/environments/swarms/search_and_rescue/observations_test.py @@ -26,87 +26,6 @@ VIEW_ANGLE = 0.5 -@pytest.mark.parametrize( - "searcher_positions, searcher_headings, env_size, view_updates", - [ - # Both out of view range - ([[0.8, 0.5], [0.2, 0.5]], [jnp.pi, 0.0], 1.0, []), - # Both view each other - ([[0.25, 0.5], [0.2, 0.5]], [jnp.pi, 0.0], 1.0, [(0, 5, 0.25), (1, 5, 0.25)]), - # One facing wrong direction - ( - [[0.25, 0.5], [0.2, 0.5]], - [jnp.pi, jnp.pi], - 1.0, - [(0, 5, 0.25)], - ), - # Only see closest neighbour - ( - [[0.35, 0.5], [0.25, 0.5], [0.2, 0.5]], - [jnp.pi, 0.0, 0.0], - 1.0, - [(0, 5, 0.5), (1, 5, 0.5), (2, 5, 0.25)], - ), - # Observed around wrapped edge - ( - [[0.025, 0.5], [0.975, 0.5]], - [jnp.pi, 0.0], - 1.0, - [(0, 5, 0.25), (1, 5, 0.25)], - ), - # Observed around wrapped edge of smaller env - ( - [[0.025, 0.25], [0.475, 0.25]], - [jnp.pi, 0.0], - 0.5, - [(0, 5, 0.25), (1, 5, 0.25)], - ), - ], -) -def test_searcher_view( - key: chex.PRNGKey, - searcher_positions: List[List[float]], - searcher_headings: List[float], - env_size: float, - view_updates: List[Tuple[int, int, float]], -) -> None: - """ - Test agent-only view model generates expected array with different - configurations of agents. - """ - - searcher_positions = jnp.array(searcher_positions) - searcher_headings = jnp.array(searcher_headings) - searcher_speed = jnp.zeros(searcher_headings.shape) - - state = State( - searchers=AgentState( - pos=searcher_positions, heading=searcher_headings, speed=searcher_speed - ), - targets=TargetState( - pos=jnp.zeros((1, 2)), vel=jnp.zeros((1, 2)), found=jnp.zeros((1, 2), dtype=bool) - ), - key=key, - ) - - observe_fn = observations.AgentObservationFn( - num_vision=11, - vision_range=VISION_RANGE, - view_angle=VIEW_ANGLE, - agent_radius=0.01, - env_size=env_size, - ) - - obs = observe_fn(state) - - expected = jnp.full((searcher_headings.shape[0], 1, observe_fn.num_vision), -1.0) - - for i, idx, val in view_updates: - expected = expected.at[i, 0, idx].set(val) - - assert jnp.all(jnp.isclose(obs, expected)) - - @pytest.mark.parametrize( "searcher_positions, searcher_headings, env_size, view_updates", [ @@ -173,16 +92,17 @@ def test_search_and_target_view_searchers( observe_fn = observations.AgentAndTargetObservationFn( num_vision=11, - vision_range=VISION_RANGE, + searcher_vision_range=VISION_RANGE, + target_vision_range=VISION_RANGE, view_angle=VIEW_ANGLE, agent_radius=0.01, env_size=env_size, ) obs = observe_fn(state) - assert obs.shape == (n_agents, 2, observe_fn.num_vision) + assert obs.shape == (n_agents, 3, observe_fn.num_vision) - expected = jnp.full((n_agents, 2, observe_fn.num_vision), -1.0) + expected = jnp.full((n_agents, 3, observe_fn.num_vision), -1.0) for i, idx, val in view_updates: expected = expected.at[i, 0, idx].set(val) @@ -190,84 +110,6 @@ def test_search_and_target_view_searchers( assert jnp.all(jnp.isclose(obs, expected)) -@pytest.mark.parametrize( - "searcher_position, searcher_heading, target_position, target_found, env_size, view_updates", - [ - # Target out of view range - ([0.8, 0.5], jnp.pi, [0.2, 0.5], True, 1.0, []), - # Target in view and found - ([0.25, 0.5], jnp.pi, [0.2, 0.5], True, 1.0, [(5, 0.25)]), - # Target in view but not found - ([0.25, 0.5], jnp.pi, [0.2, 0.5], False, 1.0, []), - # Observed around wrapped edge - ( - [0.025, 0.5], - jnp.pi, - [0.975, 0.5], - True, - 1.0, - [(5, 0.25)], - ), - # Observed around wrapped edge of smaller env - ( - [0.025, 0.25], - jnp.pi, - [0.475, 0.25], - True, - 0.5, - [(5, 0.25)], - ), - ], -) -def test_search_and_target_view_targets( - key: chex.PRNGKey, - searcher_position: List[float], - searcher_heading: float, - target_position: List[float], - target_found: bool, - env_size: float, - view_updates: List[Tuple[int, float]], -) -> None: - """ - Test agent+target view model generates expected array with different - configurations of targets only. - """ - - searcher_position = jnp.array([searcher_position]) - searcher_heading = jnp.array([searcher_heading]) - searcher_speed = jnp.zeros((1,)) - target_position = jnp.array([target_position]) - target_found = jnp.array([target_found]) - - state = State( - searchers=AgentState(pos=searcher_position, heading=searcher_heading, speed=searcher_speed), - targets=TargetState( - pos=target_position, - vel=jnp.zeros_like(target_position), - found=target_found, - ), - key=key, - ) - - observe_fn = observations.AgentAndTargetObservationFn( - num_vision=11, - vision_range=VISION_RANGE, - view_angle=VIEW_ANGLE, - agent_radius=0.01, - env_size=env_size, - ) - - obs = observe_fn(state) - assert obs.shape == (1, 2, observe_fn.num_vision) - - expected = jnp.full((1, 2, observe_fn.num_vision), -1.0) - - for idx, val in view_updates: - expected = expected.at[0, 1, idx].set(val) - - assert jnp.all(jnp.isclose(obs, expected)) - - @pytest.mark.parametrize( "searcher_position, searcher_heading, target_position, target_found, env_size, view_updates", [ @@ -306,7 +148,7 @@ def test_search_and_target_view_targets( ), ], ) -def test_search_and_all_target_view_targets( +def test_search_and_target_view_targets( key: chex.PRNGKey, searcher_position: List[float], searcher_heading: float, @@ -336,9 +178,10 @@ def test_search_and_all_target_view_targets( key=key, ) - observe_fn = observations.AgentAndAllTargetObservationFn( + observe_fn = observations.AgentAndTargetObservationFn( num_vision=11, - vision_range=VISION_RANGE, + searcher_vision_range=VISION_RANGE, + target_vision_range=VISION_RANGE, view_angle=VIEW_ANGLE, agent_radius=0.01, env_size=env_size, diff --git a/jumanji/environments/swarms/search_and_rescue/reward.py b/jumanji/environments/swarms/search_and_rescue/reward.py index 756fec884..0d5f61c8d 100644 --- a/jumanji/environments/swarms/search_and_rescue/reward.py +++ b/jumanji/environments/swarms/search_and_rescue/reward.py @@ -45,65 +45,44 @@ def _scale_rewards(rewards: chex.Array, step: int, time_limit: int) -> chex.Arra return scale * rewards -class SharedRewardFn(RewardFn): - """ - Calculate per agent rewards from detected targets - - Targets detected by multiple agents share rewards. Agents - can receive rewards for detecting multiple targets. +class IndividualRewardFn(RewardFn): """ + Calculate individual agent rewards from detected targets inside a step - def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array: - rewards = found_targets.astype(float) - rewards = _normalise_rewards(rewards) - rewards = jnp.sum(rewards, axis=1) - return rewards - + Assigns individual rewards to each agent based on the number of newly + found targets inside a step. -class SharedScaledRewardFn(RewardFn): + Note that multiple agents can locate the same target inside a single + step. By default, rewards are split between the locating agents and + rewards linearly decrease over time (i.e. rewards are +1 per target in + the first step, decreasing to 0 at the final step). """ - Calculate per agent rewards from detected targets and scale by timestep - - Targets detected by multiple agents share rewards. Agents - can receive rewards for detecting multiple targets. - Rewards are linearly scaled by the current time step such that - rewards received are 0 at the final step. - """ - - def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array: - rewards = found_targets.astype(float) - rewards = _normalise_rewards(rewards) - rewards = jnp.sum(rewards, axis=1) - rewards = _scale_rewards(rewards, step, time_limit) - return rewards + def __init__(self, split_rewards: bool = True, scale_rewards: bool = True): + """ + Initialise reward-function -class IndividualRewardFn(RewardFn): - """ - Calculate per agent rewards from detected targets - - Each agent that detects a target receives a +1 reward - even if a target is detected by multiple agents. - """ + Args: + split_rewards: If ``True`` rewards will be split in the case multiple agents + find a target at the same time. If ``False`` then agents will receive the + full reward irrespective of the number of agents who located the + target simultaneously. + scale_rewards: If ``True`` rewards granted will linearly decrease over time from + +1 per target to 0 at the final step. If ``False`` rewards will be fixed at + +1 per target irrespective of step. + """ + self.split_rewards = split_rewards + self.scale_rewards = scale_rewards def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array: rewards = found_targets.astype(float) - rewards = jnp.sum(rewards, axis=1) - return rewards + if self.split_rewards: + rewards = _normalise_rewards(rewards) -class IndividualScaledRewardFn(RewardFn): - """ - Calculate per agent rewards from detected targets and scale by timestep + rewards = jnp.sum(rewards, axis=1) - Each agent that detects a target receives a +1 reward - even if a target is detected by multiple agents. - Rewards are linearly scaled by the current time step such that - rewards received are 0 at the final step. - """ + if self.scale_rewards: + rewards = _scale_rewards(rewards, step, time_limit) - def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array: - rewards = found_targets.astype(float) - rewards = jnp.sum(rewards, axis=1) - rewards = _scale_rewards(rewards, step, time_limit) return rewards diff --git a/jumanji/environments/swarms/search_and_rescue/reward_test.py b/jumanji/environments/swarms/search_and_rescue/reward_test.py index bd49f69e8..e7080837b 100644 --- a/jumanji/environments/swarms/search_and_rescue/reward_test.py +++ b/jumanji/environments/swarms/search_and_rescue/reward_test.py @@ -24,7 +24,7 @@ def target_states() -> chex.Array: def test_shared_rewards(target_states: chex.Array) -> None: - shared_rewards = reward.SharedRewardFn()(target_states, 0, 10) + shared_rewards = reward.IndividualRewardFn(scale_rewards=False)(target_states, 0, 10) assert shared_rewards.shape == (2,) assert shared_rewards.dtype == jnp.float32 @@ -32,7 +32,9 @@ def test_shared_rewards(target_states: chex.Array) -> None: def test_individual_rewards(target_states: chex.Array) -> None: - individual_rewards = reward.IndividualRewardFn()(target_states, 0, 10) + individual_rewards = reward.IndividualRewardFn(scale_rewards=False, split_rewards=False)( + target_states, 0, 10 + ) assert individual_rewards.shape == (2,) assert individual_rewards.dtype == jnp.float32 @@ -40,7 +42,7 @@ def test_individual_rewards(target_states: chex.Array) -> None: def test_shared_scaled_rewards(target_states: chex.Array) -> None: - reward_fn = reward.SharedScaledRewardFn() + reward_fn = reward.IndividualRewardFn() shared_scaled_rewards = reward_fn(target_states, 0, 10) @@ -56,7 +58,7 @@ def test_shared_scaled_rewards(target_states: chex.Array) -> None: def test_individual_scaled_rewards(target_states: chex.Array) -> None: - reward_fn = reward.IndividualScaledRewardFn() + reward_fn = reward.IndividualRewardFn(split_rewards=False) individual_scaled_rewards = reward_fn(target_states, 0, 10) diff --git a/jumanji/environments/swarms/search_and_rescue/utils_test.py b/jumanji/environments/swarms/search_and_rescue/utils_test.py index af4f11a37..de48b9901 100644 --- a/jumanji/environments/swarms/search_and_rescue/utils_test.py +++ b/jumanji/environments/swarms/search_and_rescue/utils_test.py @@ -15,37 +15,17 @@ from functools import partial from typing import List -import chex import jax import jax.numpy as jnp import pytest from jumanji.environments.swarms.common.types import AgentState -from jumanji.environments.swarms.search_and_rescue.dynamics import RandomWalk, TargetDynamics from jumanji.environments.swarms.search_and_rescue.types import TargetState from jumanji.environments.swarms.search_and_rescue.utils import ( searcher_detect_targets, ) -def test_random_walk_dynamics(key: chex.PRNGKey) -> None: - n_targets = 50 - pos_0 = jnp.full((n_targets, 2), 0.5) - - s0 = TargetState( - pos=pos_0, vel=jnp.zeros((n_targets, 2)), found=jnp.zeros((n_targets,), dtype=bool) - ) - - dynamics = RandomWalk(0.1) - assert isinstance(dynamics, TargetDynamics) - s1 = dynamics(key, s0, 1.0) - - assert isinstance(s1, TargetState) - assert s1.pos.shape == (n_targets, 2) - assert jnp.array_equal(s0.found, s1.found) - assert jnp.all(jnp.abs(s0.pos - s1.pos) < 0.1) - - @pytest.mark.parametrize( "pos, heading, view_angle, target_state, expected, env_size", [