From ab72da556226939aaef6616b3818d80180ac7f4a Mon Sep 17 00:00:00 2001 From: WiemKhlifi Date: Thu, 24 Oct 2024 15:39:08 +0100 Subject: [PATCH] chore: add few more tests --- jumanji/environments/routing/lbf/generator.py | 2 +- .../routing/lbf/tests_files/conftest.py | 4 +- .../routing/lbf/tests_files/env_test.py | 72 +++++++++++++++++++ 3 files changed, 75 insertions(+), 3 deletions(-) diff --git a/jumanji/environments/routing/lbf/generator.py b/jumanji/environments/routing/lbf/generator.py index ebf77b5e4..2089ecc50 100644 --- a/jumanji/environments/routing/lbf/generator.py +++ b/jumanji/environments/routing/lbf/generator.py @@ -50,7 +50,7 @@ def __init__( force_coop (bool): Whether to force cooperation among agents. """ assert 5 <= grid_size, "Grid size must be greater than 5." - assert 2 <= fov <= grid_size, "Field of view must be between 2 and grid_size." + assert 1 <= fov <= grid_size, "Field of view must be between 1 and grid_size." assert num_agents > 0, "Number of agents must be positive." assert num_food > 0, "Number of food items must be positive." assert max_agent_level >= 2, "Maximum agent level must be at least 2." diff --git a/jumanji/environments/routing/lbf/tests_files/conftest.py b/jumanji/environments/routing/lbf/tests_files/conftest.py index b13f66dde..a5a492c8f 100644 --- a/jumanji/environments/routing/lbf/tests_files/conftest.py +++ b/jumanji/environments/routing/lbf/tests_files/conftest.py @@ -196,10 +196,10 @@ def lbf_env_grid_obs() -> LevelBasedForaging: @pytest.fixture -def lbf_env_with_penalty() -> LevelBasedForaging: +def lbf_with_penalty() -> LevelBasedForaging: return LevelBasedForaging(penalty=1.0) @pytest.fixture -def lbf_env_no_norm_reward() -> LevelBasedForaging: +def lbf_with_no_norm_reward() -> LevelBasedForaging: return LevelBasedForaging(normalize_reward=False) diff --git a/jumanji/environments/routing/lbf/tests_files/env_test.py b/jumanji/environments/routing/lbf/tests_files/env_test.py index d08088588..41f4e3c45 100644 --- a/jumanji/environments/routing/lbf/tests_files/env_test.py +++ b/jumanji/environments/routing/lbf/tests_files/env_test.py @@ -64,6 +64,31 @@ def test_reset(lbf_environment: LevelBasedForaging, key: chex.PRNGKey) -> None: assert timestep.reward.shape == (num_agents,) +def test_reset_grid_obs( + lbf_env_grid_obs: LevelBasedForaging, key: chex.PRNGKey +) -> None: + num_agents = lbf_env_grid_obs.num_agents + + state, timestep = lbf_env_grid_obs.reset(key) + assert len(state.agents.position) == num_agents + assert len(state.food_items.position) == lbf_env_grid_obs.num_food + + expected_obs_shape = ( + num_agents, + 3, + 2 * lbf_env_grid_obs.fov + 1, + 2 * lbf_env_grid_obs.fov + 1, + ) + assert timestep.observation.agents_view.shape == expected_obs_shape + + assert jnp.all(timestep.discount == 1.0) + assert jnp.all(timestep.reward == 0.0) + assert timestep.step_type == StepType.FIRST + + assert timestep.discount.shape == (num_agents,) + assert timestep.reward.shape == (num_agents,) + + def test_get_reward( lbf_environment: LevelBasedForaging, agents: Agent, food_items: Food ) -> None: @@ -85,6 +110,53 @@ def test_get_reward( assert jnp.all(reward == expected_reward) +def test_reward_with_penalty( + lbf_with_penalty: LevelBasedForaging, food_items: Food +) -> None: + adj_food0_level = jnp.array([0.0, 1, 2]) + adj_food1_level = jnp.array([0.0, 0.0, 4]) + adj_food2_level = jnp.array([2, 0.0, 0.0]) + adj_agent_levels = jnp.array([adj_food0_level, adj_food1_level, adj_food2_level]) + eaten = jnp.array([True, True, True]) + + reward = lbf_with_penalty.get_reward(food_items, adj_agent_levels, eaten) + penalty = lbf_with_penalty.penalty + + penalty_0 = jnp.where(jnp.sum(adj_food0_level) < food_items.level[0], penalty, 0) + expected_reward_food0 = ( + adj_food0_level * eaten[0] * food_items.level[0] - penalty_0 + ) / (jnp.sum(food_items.level) * jnp.sum(adj_food0_level)) + penalty_1 = jnp.where(jnp.sum(adj_food1_level) < food_items.level[1], penalty, 0) + expected_reward_food1 = ( + adj_food1_level * eaten[1] * food_items.level[1] - penalty_1 + ) / (jnp.sum(food_items.level) * jnp.sum(adj_food1_level)) + penalty_2 = jnp.where(jnp.sum(adj_food2_level) < food_items.level[2], penalty, 0) + expected_reward_food2 = ( + adj_food2_level * eaten[1] * food_items.level[2] - penalty_2 + ) / (jnp.sum(food_items.level) * jnp.sum(adj_food2_level)) + expected_reward = ( + expected_reward_food0 + expected_reward_food1 + expected_reward_food2 + ) + assert jnp.all(reward == expected_reward) + + +def test_reward_with_no_norm( + lbf_with_no_norm_reward: LevelBasedForaging, agents: Agent, food_items: Food +) -> None: + adj_food0_level = jnp.array([0.0, agents.level[1], agents.level[2]]) + adj_food1_level = jnp.array([0.0, 0.0, agents.level[2]]) + adj_food2_level = jnp.array([0.0, 0.0, 0.0]) + adj_agent_levels = jnp.array([adj_food0_level, adj_food1_level, adj_food2_level]) + eaten = jnp.array([True, True, False]) + + reward = lbf_with_no_norm_reward.get_reward(food_items, adj_agent_levels, eaten) + + expected_reward_food0 = adj_food0_level * food_items.level[0] + expected_reward_food1 = adj_food1_level * food_items.level[1] + expected_reward = expected_reward_food0 + expected_reward_food1 + assert jnp.all(reward == expected_reward) + + def test_step(lbf_environment: LevelBasedForaging, state: State) -> None: num_agents = lbf_environment.num_agents