Skip to content

Commit

Permalink
chore: add few more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
WiemKhlifi committed Oct 24, 2024
1 parent 9adee92 commit ab72da5
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 3 deletions.
2 changes: 1 addition & 1 deletion jumanji/environments/routing/lbf/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
force_coop (bool): Whether to force cooperation among agents.
"""
assert 5 <= grid_size, "Grid size must be greater than 5."
assert 2 <= fov <= grid_size, "Field of view must be between 2 and grid_size."
assert 1 <= fov <= grid_size, "Field of view must be between 1 and grid_size."
assert num_agents > 0, "Number of agents must be positive."
assert num_food > 0, "Number of food items must be positive."
assert max_agent_level >= 2, "Maximum agent level must be at least 2."
Expand Down
4 changes: 2 additions & 2 deletions jumanji/environments/routing/lbf/tests_files/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,10 @@ def lbf_env_grid_obs() -> LevelBasedForaging:


@pytest.fixture
def lbf_env_with_penalty() -> LevelBasedForaging:
def lbf_with_penalty() -> LevelBasedForaging:
return LevelBasedForaging(penalty=1.0)


@pytest.fixture
def lbf_env_no_norm_reward() -> LevelBasedForaging:
def lbf_with_no_norm_reward() -> LevelBasedForaging:
return LevelBasedForaging(normalize_reward=False)
72 changes: 72 additions & 0 deletions jumanji/environments/routing/lbf/tests_files/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,31 @@ def test_reset(lbf_environment: LevelBasedForaging, key: chex.PRNGKey) -> None:
assert timestep.reward.shape == (num_agents,)


def test_reset_grid_obs(
lbf_env_grid_obs: LevelBasedForaging, key: chex.PRNGKey
) -> None:
num_agents = lbf_env_grid_obs.num_agents

state, timestep = lbf_env_grid_obs.reset(key)
assert len(state.agents.position) == num_agents
assert len(state.food_items.position) == lbf_env_grid_obs.num_food

expected_obs_shape = (
num_agents,
3,
2 * lbf_env_grid_obs.fov + 1,
2 * lbf_env_grid_obs.fov + 1,
)
assert timestep.observation.agents_view.shape == expected_obs_shape

assert jnp.all(timestep.discount == 1.0)
assert jnp.all(timestep.reward == 0.0)
assert timestep.step_type == StepType.FIRST

assert timestep.discount.shape == (num_agents,)
assert timestep.reward.shape == (num_agents,)


def test_get_reward(
lbf_environment: LevelBasedForaging, agents: Agent, food_items: Food
) -> None:
Expand All @@ -85,6 +110,53 @@ def test_get_reward(
assert jnp.all(reward == expected_reward)


def test_reward_with_penalty(
lbf_with_penalty: LevelBasedForaging, food_items: Food
) -> None:
adj_food0_level = jnp.array([0.0, 1, 2])
adj_food1_level = jnp.array([0.0, 0.0, 4])
adj_food2_level = jnp.array([2, 0.0, 0.0])
adj_agent_levels = jnp.array([adj_food0_level, adj_food1_level, adj_food2_level])
eaten = jnp.array([True, True, True])

reward = lbf_with_penalty.get_reward(food_items, adj_agent_levels, eaten)
penalty = lbf_with_penalty.penalty

penalty_0 = jnp.where(jnp.sum(adj_food0_level) < food_items.level[0], penalty, 0)
expected_reward_food0 = (
adj_food0_level * eaten[0] * food_items.level[0] - penalty_0
) / (jnp.sum(food_items.level) * jnp.sum(adj_food0_level))
penalty_1 = jnp.where(jnp.sum(adj_food1_level) < food_items.level[1], penalty, 0)
expected_reward_food1 = (
adj_food1_level * eaten[1] * food_items.level[1] - penalty_1
) / (jnp.sum(food_items.level) * jnp.sum(adj_food1_level))
penalty_2 = jnp.where(jnp.sum(adj_food2_level) < food_items.level[2], penalty, 0)
expected_reward_food2 = (
adj_food2_level * eaten[1] * food_items.level[2] - penalty_2
) / (jnp.sum(food_items.level) * jnp.sum(adj_food2_level))
expected_reward = (
expected_reward_food0 + expected_reward_food1 + expected_reward_food2
)
assert jnp.all(reward == expected_reward)


def test_reward_with_no_norm(
lbf_with_no_norm_reward: LevelBasedForaging, agents: Agent, food_items: Food
) -> None:
adj_food0_level = jnp.array([0.0, agents.level[1], agents.level[2]])
adj_food1_level = jnp.array([0.0, 0.0, agents.level[2]])
adj_food2_level = jnp.array([0.0, 0.0, 0.0])
adj_agent_levels = jnp.array([adj_food0_level, adj_food1_level, adj_food2_level])
eaten = jnp.array([True, True, False])

reward = lbf_with_no_norm_reward.get_reward(food_items, adj_agent_levels, eaten)

expected_reward_food0 = adj_food0_level * food_items.level[0]
expected_reward_food1 = adj_food1_level * food_items.level[1]
expected_reward = expected_reward_food0 + expected_reward_food1
assert jnp.all(reward == expected_reward)


def test_step(lbf_environment: LevelBasedForaging, state: State) -> None:

num_agents = lbf_environment.num_agents
Expand Down

0 comments on commit ab72da5

Please sign in to comment.