diff --git a/README.md b/README.md index fca9add9a..d191d94c1 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ problems. | ๐Ÿ’ฃ Minesweeper | Logic | `Minesweeper-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/minesweeper/) | [doc](https://instadeepai.github.io/jumanji/environments/minesweeper/) | | ๐ŸŽฒ RubiksCube | Logic | `RubiksCube-v0`
`RubiksCube-partly-scrambled-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/rubiks_cube/) | [doc](https://instadeepai.github.io/jumanji/environments/rubiks_cube/) | | ๐Ÿ“ฆ BinPack (3D BinPacking Problem) | Packing | `BinPack-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/bin_pack/) | [doc](https://instadeepai.github.io/jumanji/environments/bin_pack/) | -| ๐Ÿญ JobShop (Job Shop Scheduling Problem) | Packing | `JobShop-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/job_shop/) | [doc](https://instadeepai.github.io/jumanji/environments/job_shop/) | +| ๐Ÿญ JobShop (Job Shop Scheduling Problem) | Packing | `JobShop-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/job_shop/) | [doc](https://instadeepai.github.io/jumanji/environments/job_shop/) | | ๐ŸŽ’ Knapsack | Packing | `Knapsack-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/knapsack/) | [doc](https://instadeepai.github.io/jumanji/environments/knapsack/) | | ๐Ÿงน Cleaner | Routing | `Cleaner-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/cleaner/) | [doc](https://instadeepai.github.io/jumanji/environments/cleaner/) | | :link: Connector | Routing | `Connector-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/connector/) | [doc](https://instadeepai.github.io/jumanji/environments/connector/) | diff --git a/jumanji/__init__.py b/jumanji/__init__.py index 53bd1438d..1447dcf7c 100644 --- a/jumanji/__init__.py +++ b/jumanji/__init__.py @@ -51,9 +51,8 @@ # largest ones are given in the observation. register(id="BinPack-v1", entry_point="jumanji.environments:BinPack") -# Job-shop scheduling problem with 20 jobs, 10 machines, at most -# 8 operations per job, and a max operation duration of 6 timesteps. -register(id="JobShop-v0", entry_point="jumanji.environments:JobShop") +# Job-shop scheduling problem with 20 jobs, 10 machines, and known optimal makespan of 40. +register(id="JobShop-v1", entry_point="jumanji.environments:JobShop") # Knapsack problem with 50 randomly generated items, a total budget # of 12.5, and a dense reward function. diff --git a/jumanji/environments/packing/job_shop/env.py b/jumanji/environments/packing/job_shop/env.py index 1af25af87..6fceb1970 100644 --- a/jumanji/environments/packing/job_shop/env.py +++ b/jumanji/environments/packing/job_shop/env.py @@ -22,7 +22,7 @@ from jumanji import specs from jumanji.env import Environment -from jumanji.environments.packing.job_shop.generator import Generator, RandomGenerator +from jumanji.environments.packing.job_shop.generator import DenseGenerator, Generator from jumanji.environments.packing.job_shop.types import Observation, State from jumanji.environments.packing.job_shop.viewer import JobShopViewer from jumanji.types import TimeStep, restart, termination, transition @@ -98,16 +98,16 @@ def __init__( Args: generator: `Generator` whose `__call__` instantiates an environment instance. - Implemented options are ['ToyGenerator', 'RandomGenerator']. - Defaults to `RandomGenerator` with 20 jobs, 10 machines, up to 8 ops - for any given job, and a max operation duration of 6. + Implemented options are ['ToyGenerator', 'RandomGenerator', 'DenseGenerator`]. + Defaults to `DenseGenerator` with 20 jobs, 10 machines, and a makespan of 40. viewer: `Viewer` used for rendering. Defaults to `JobShopViewer`. """ - self.generator = generator or RandomGenerator( + self.generator = generator or DenseGenerator( num_jobs=20, num_machines=10, - max_num_ops=8, - max_op_duration=6, + max_num_ops=40, + max_op_duration=40, + makespan=40, ) self.num_jobs = self.generator.num_jobs self.num_machines = self.generator.num_machines diff --git a/jumanji/environments/packing/job_shop/generator.py b/jumanji/environments/packing/job_shop/generator.py index cfbec6d6c..7bcbbd75a 100644 --- a/jumanji/environments/packing/job_shop/generator.py +++ b/jumanji/environments/packing/job_shop/generator.py @@ -13,6 +13,7 @@ # limitations under the License. import abc +from typing import Any, Tuple import chex import jax @@ -184,3 +185,281 @@ def __call__(self, key: chex.PRNGKey) -> State: ) return state + + +class DenseGenerator(Generator): + """`Generator` which creates a dense schedule of a specified makespan. This is done by: + - Specifying the `makespan` (schedule length) and the `num_machines`. + - Initialising an "empty" schedule. + - Creating a valid schedule: + 1. Randomly sample `num_machines` jobs w/o replacement. These jobs will be + scheduled on the machines in the first time step. + 2. At the next timestep, stochastically either: + - Reuse the previous jobs on the machines, or + - Randomly sample `num_machines` new jobs w/o replacement. + 3. Repeat step 2 until the desired `makespan` is reached. + - Extracting the info (duration and machine) about operations from the schedule and + padding the operations to the max number of operations. + + This generator assumes that the number of jobs is less than or equal to the number of + machines. + """ + + def __init__( + self, + num_jobs: int, + num_machines: int, + max_num_ops: int, + max_op_duration: int, + makespan: int, + prob_reuse_threshold: float = 0.5, + ): + """Instantiate a `DenseGenerator`. Note that the `makespan` is an upper + bound to both `max_num_ops` and `max_op_duration`, hence they are deleted. + + Args: + num_jobs: the number of jobs that need to be scheduled. + num_machines: the number of machines that the jobs can be scheduled on. + max_num_ops: the maximum number of operations for any given job. + max_op_duration: the maximum processing time of any given operation. + makespan: the length of the schedule. By construction, this will be the + shortest possible length of the schedule. + prob_reuse_threshold: the threshold probability of reusing the previous + job_id on a given machine. When generating the schedule, this quantity + determines how likely a machine is to try to reuse the job_id in the + previous timestep. + """ + del max_op_duration + del max_num_ops + if num_jobs < num_machines: + raise ValueError( + "The number of jobs must be greater than or equal to the number of machines." + ) + + super().__init__( + num_jobs=num_jobs, + num_machines=num_machines, + max_num_ops=makespan, + max_op_duration=makespan, + ) + self.makespan = makespan + self.prob_reuse_threshold = prob_reuse_threshold + + def __call__(self, key: chex.PRNGKey) -> State: + key, schedule_key = jax.random.split(key) + + # Generate a random, dense schedule of the specified length + schedule = self._generate_schedule(schedule_key) + + # Extract ops information from the schedule + ops_machine_ids, ops_durations = self._register_ops(schedule) + + # Initially, all machines are available (the value self.num_jobs corresponds to no-op) + machines_job_ids = jnp.full(self.num_machines, self.num_jobs, jnp.int32) + machines_remaining_times = jnp.full(self.num_machines, 0, jnp.int32) + scheduled_times = jnp.full((self.num_jobs, self.max_num_ops), -1, jnp.int32) + ops_mask = ops_machine_ids != -1 + step_count = jnp.int32(0) + + state = State( + ops_machine_ids=ops_machine_ids, + ops_durations=ops_durations, + ops_mask=ops_mask, + machines_job_ids=machines_job_ids, + machines_remaining_times=machines_remaining_times, + action_mask=None, + step_count=step_count, + scheduled_times=scheduled_times, + key=key, + ) + + return state + + def _generate_schedule(self, key: chex.PRNGKey) -> chex.Array: + """Creates a schedule given the constraints of the job shop scheduling problem. + + For example, for 3 machines, 5 jobs, and a chosen optimal makespan of 12, a schedule + may look like: + [[1, 0, 1, 1, 2, 3, 4, 0, 1, 3, 2, 3], + [4, 1, 2, 0, 3, 4, 2, 3, 2, 2, 3, 2], + [0, 2, 3, 4, 0, 0, 0, 2, 3, 1, 0, 0]] + + This means + - Machine 0 processes job 1, job 0, job 1 (for two steps), etc. + - Machine 1 processes job 4, job 1, job 2, job 0, etc. + - Machine 2 processes job 0, job 2, job 3, job 4, etc. + + Importantly, since a job can only be executed on one machine at a time, this method + is written such that the schedule does not have duplicates in any column. + + Args: + key: used for stochasticity in the generation of the schedule. + + Returns: + Schedule with the specified length. Shape (num_machines, makespan). + """ + init_col_key, scan_key = jax.random.split(key) + all_job_ids = jnp.arange(self.num_jobs) + + def insert_col(carry, _): + _scan_key, _init_col = carry + _scan_key, inner_key = jax.random.split(_scan_key) + + init_job_mask = jnp.ones(shape=self.num_jobs, dtype=jnp.bool_) + init_machine_id = 0 + inner_init_carry = inner_key, init_job_mask, _init_col, init_machine_id + _, col = jax.lax.scan( + lambda inner_carry, _: self.insert_operation(inner_carry, _), + inner_init_carry, + xs=jnp.arange(self.num_machines), + ) + + return (_scan_key, col), col + + init_col = jax.random.choice( + init_col_key, + all_job_ids, + (self.num_machines,), + replace=False, + ) + init_carry = scan_key, init_col + _, schedule_transposed = jax.lax.scan( + lambda carry, _: insert_col(carry, _), + init_carry, + xs=jnp.arange(self.makespan), + ) + schedule = schedule_transposed.T + return schedule + + def _register_ops(self, schedule: chex.Array) -> Tuple[chex.Array, chex.Array]: + """Extract, for every job, the machine id and duration of each operation in the job. + + For example, for the schedule + [[1, 0, 1, 1, 2, 3, 4, 0, 1, 3, 2, 3], + [4, 1, 2, 0, 3, 4, 2, 3, 2, 2, 3, 2], + [0, 2, 3, 4, 0, 0, 0, 2, 3, 1, 0, 0]] + + the ops would have the machine ids: + [[ 2, 0, 1, 2, 0, 2, -1, -1, -1, -1], + [ 0, 1, 0, 0, 2, -1, -1, -1, -1, -1], + [ 2, 1, 0, 1, 2, 1, 0, 1, -1, -1], + [ 2, 1, 0, 1, 2, 0, 1, 0, -1, -1], + [ 1, 2, 1, 0, -1, -1, -1, -1, -1, -1]] + + and the durations: + [[ 1, 1, 1, 3, 1, 2, -1, -1, -1, -1], + [ 1, 1, 2, 1, 1, -1, -1, -1, -1, -1], + [ 1, 1, 1, 1, 1, 2, 1, 1, -1, -1], + [ 1, 1, 1, 1, 1, 1, 1, 1, -1, -1], + [ 1, 1, 1, 1, -1, -1, -1, -1, -1, -1]] + + Args: + schedule: array representing which job each machine is working on at each timestep. + + Returns: + Arrays representing which machine id and duration characterising each operation. + """ + + def get_job_info( + job_id: int, _: Any + ) -> Tuple[int, Tuple[chex.Array, chex.Array]]: + """Extract the machine id and duration of every op in the specified job. + + In the above example, for job 0, this will return + - machine_ids [2, 0, 1, 2, 0, 2, -1, -1, -1, -1] + - durations [1, 1, 1, 3, 1, 2, -1, -1, -1, -1] + """ + + def get_op_info( + mask: chex.Array, + _: Any, + ) -> Tuple[chex.Array, Tuple[chex.Array, chex.Array]]: + """Extract the machine id and duration for a given operation. + + In the above example, for job 0 and operation 0, the machine id is 2 and + the duration is 1. + + Args: + mask: array which keeps track of which operations have been registered. + """ + prev_mask = mask + + # Flatten by column + mask_flat = jnp.ravel(mask, order="F") + + # Find index of the next operation + idx = jnp.argmax(mask_flat) + t_start, machine_id = jnp.divmod(idx, self.num_machines) + + # Update the mask -> the op is registered + mask = mask.at[machine_id, t_start].set(False) + + # While loop in case op has duration > 1 + init_val = (mask, machine_id, t_start + 1) + + def next_is_same_op(val: Tuple) -> chex.Array: + m, mid, t = val + return m[mid, t] + + def update_mask(val: Tuple) -> Tuple: + m, mid, t = val + m = m.at[mid, t].set(False) + return m, mid, t + 1 + + (mask, machine_id, time) = jax.lax.while_loop( + next_is_same_op, update_mask, init_val + ) + + duration = time - t_start + + # If all ops for this job are registered, return -1 for padding + all_ops_registered = ~jnp.any(prev_mask) + machine_id = jax.lax.select(all_ops_registered, -1, machine_id) + duration = jax.lax.select(all_ops_registered, -1, duration) + + return mask, (machine_id, duration) + + # Carry the mask + init_mask = jnp.array(schedule == job_id) + _, (mids, durs) = jax.lax.scan( + get_op_info, init_mask, xs=None, length=self.makespan + ) + + return job_id + 1, (mids, durs) + + # Carry the job id + init_job_id = 0 + job_id, (ops_mids, ops_durs) = jax.lax.scan( + get_job_info, init_job_id, xs=None, length=self.num_jobs + ) + return ops_mids, ops_durs + + def insert_operation(self, carry, _): + _inner_key, _job_mask, prev_col, machine_id = carry + _inner_key, reuse_key, job_key = jax.random.split(_inner_key, num=3) + + all_job_ids = jnp.arange(self.num_jobs) + + # Use the previous job on the machine with some probability and + # if the job hasn't already been scheduled in this timestep + prev_job_id = prev_col[machine_id] + reuse = jax.random.uniform(reuse_key, shape=()) >= self.prob_reuse_threshold + is_available = _job_mask[prev_job_id] + job_id = jax.lax.cond( + reuse & is_available, + lambda _: prev_job_id, + lambda _: jax.random.choice(job_key, all_job_ids, (), p=_job_mask), + None, + ) + + # Update the job mask to reflect that the chosen job can + # no longer be scheduled on any of the remaining machines + _job_mask = _job_mask.at[job_id].set(False) + + return (_inner_key, _job_mask, prev_col, machine_id + 1), job_id + + +if __name__ == "__main__": + gen = DenseGenerator(10, 5, 14, 14, 14) + key = jax.random.PRNGKey(0) + state = gen(key) diff --git a/jumanji/environments/packing/job_shop/generator_test.py b/jumanji/environments/packing/job_shop/generator_test.py index e35ebf3fe..3abaa1ac1 100644 --- a/jumanji/environments/packing/job_shop/generator_test.py +++ b/jumanji/environments/packing/job_shop/generator_test.py @@ -14,10 +14,12 @@ import chex import jax +import jax.numpy as jnp import pytest from jumanji.environments.packing.job_shop.conftest import DummyGenerator from jumanji.environments.packing.job_shop.generator import ( + DenseGenerator, RandomGenerator, ToyGenerator, ) @@ -103,3 +105,108 @@ def test_random_generator__call(self, random_generator: RandomGenerator) -> None state2 = call_fn(key=jax.random.PRNGKey(2)) assert_trees_are_different(state1, state2) + + +class TestDenseGenerator: + NUM_JOBS = 5 + NUM_MACHINES = 3 + MAKESPAN = 12 + MAX_NUM_OPS = MAKESPAN + MAX_OP_DURATION = MAKESPAN + + @pytest.fixture + def dense_generator(self) -> DenseGenerator: + return DenseGenerator( + num_jobs=self.NUM_JOBS, + num_machines=self.NUM_MACHINES, + max_num_ops=self.MAX_NUM_OPS, + max_op_duration=self.MAX_OP_DURATION, + makespan=self.MAKESPAN, + ) + + @pytest.fixture + def hardcoded_schedule(self) -> chex.Array: + return jnp.array( + [ + [1, 0, 1, 1, 2, 3, 4, 0, 1, 3, 2, 3], + [4, 1, 2, 0, 3, 4, 2, 3, 2, 2, 3, 2], + [0, 2, 3, 4, 0, 0, 0, 2, 3, 1, 0, 0], + ] + ) + + def test_dense_generator__attributes(self, dense_generator: DenseGenerator) -> None: + """Validate that the random instance generator has the correct properties.""" + assert dense_generator.num_jobs == self.NUM_JOBS + assert dense_generator.num_machines == self.NUM_MACHINES + assert dense_generator.max_num_ops == self.MAKESPAN + assert dense_generator.max_op_duration == self.MAKESPAN + assert dense_generator.makespan == self.MAKESPAN + + key = jax.random.PRNGKey(0) + state = dense_generator(key) + assert isinstance(state, State) + + def test_dense_generator__generate_schedule( + self, dense_generator: DenseGenerator + ) -> None: + key = jax.random.PRNGKey(0) + schedule = dense_generator._generate_schedule(key) + assert schedule.shape == (self.NUM_MACHINES, self.MAKESPAN) + assert jnp.all((schedule >= 0) & (schedule < self.NUM_JOBS)) + + # For 20 different randomly generated dense schedules, check + # that no column in any schedule contains duplicates + keys = jax.random.split(key, 20) + for k in keys: + schedule = dense_generator._generate_schedule(k) + for t in range(self.MAKESPAN): + num_unique_jobs = len(jnp.unique(schedule[:, t])) + assert ( + num_unique_jobs == self.NUM_MACHINES + ), f"{num_unique_jobs} โ‰  {self.NUM_MACHINES} at t={t}." + + def test_dense_generator__register_ops( + self, + dense_generator: DenseGenerator, + hardcoded_schedule: chex.Array, + ) -> None: + ops_machine_ids, ops_durations = dense_generator._register_ops( + hardcoded_schedule + ) + assert jnp.array_equal( + ops_machine_ids, + jnp.array( + [ + [2, 0, 1, 2, 0, 2, -1, -1, -1, -1, -1, -1], + [0, 1, 0, 0, 2, -1, -1, -1, -1, -1, -1, -1], + [2, 1, 0, 1, 2, 1, 0, 1, -1, -1, -1, -1], + [2, 1, 0, 1, 2, 0, 1, 0, -1, -1, -1, -1], + [1, 2, 1, 0, -1, -1, -1, -1, -1, -1, -1, -1], + ] + ), + ) + assert jnp.array_equal( + ops_durations, + jnp.array( + [ + [1, 1, 1, 3, 1, 2, -1, -1, -1, -1, -1, -1], + [1, 1, 2, 1, 1, -1, -1, -1, -1, -1, -1, -1], + [1, 1, 1, 1, 1, 2, 1, 1, -1, -1, -1, -1], + [1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1], + [1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1], + ] + ), + ) + + def test_dense_generator__jit_only_compiles_once( + self, dense_generator: DenseGenerator + ) -> None: + key = jax.random.PRNGKey(0) + key_first, key_second = jax.random.split(key, 2) + + # First call should compile the function. + call_fn = jax.jit(chex.assert_max_traces(dense_generator.__call__, n=1)) + _ = call_fn(key=key_first) + + # Second call should not compile the function. + _ = call_fn(key=key_second)