Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(job_shop): implement dense schedule generator #115

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
1df15f7
feat: initial attempt at optimal generator
dluo96 Mar 28, 2023
51ec62d
feat: initial attempt at an optimal makespan generator
dluo96 Mar 29, 2023
046e8a0
test: added tests for makespan generator
dluo96 Mar 29, 2023
b6b8cbd
feat: working version
dluo96 Apr 3, 2023
32ff604
docs: completed doc string
dluo96 Apr 5, 2023
e43689e
Merge branch 'main' into 90-jobshop-known-makespan-generator
dluo96 Apr 5, 2023
3e7f2b8
fix: failing test
dluo96 Apr 5, 2023
44a591f
Merge branch 'main' of github.com:instadeepai/jumanji into 90-jobshop…
dluo96 Apr 5, 2023
54b5ac6
Merge branch '90-jobshop-known-makespan-generator' of github.com:inst…
dluo96 Apr 5, 2023
d3cede1
feat: register environment with new generator
dluo96 Apr 18, 2023
f411418
Merge branch 'main' into 90-jobshop-known-makespan-generator
dluo96 Apr 18, 2023
e1db3bf
docs: updated registered version in readme
dluo96 Apr 18, 2023
b03cbb8
Merge branch '90-jobshop-known-makespan-generator' of github.com:inst…
dluo96 Apr 18, 2023
d879df9
feat: removed redundant code
dluo96 Apr 18, 2023
1050eaf
docs: modified doc string explaining why attributes are deleted
dluo96 May 13, 2023
66d3053
refactor: renamed to dense generator
dluo96 May 13, 2023
d5bb546
test: added check that no column in the schedule contains duplicates
dluo96 May 13, 2023
7de39d4
test: updated test of duplicates
dluo96 May 13, 2023
53d1b09
test: added test error message
dluo96 May 13, 2023
52deda1
test: check no duplicates for 20 different dense schedules
dluo96 May 13, 2023
8849f14
fix: setting of True in _job_mask can lead to duplicates in columns
dluo96 May 13, 2023
982e677
feat: new random col each time
dluo96 May 14, 2023
257ac07
fix: solved the bug of duplicates
dluo96 May 14, 2023
3e07217
style: fixed indentation
dluo96 May 14, 2023
bde343e
fix: change attributes in test class and define attribute for probabi…
dluo96 May 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ problems.
| 💣 Minesweeper | Logic | `Minesweeper-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/minesweeper/) | [doc](https://instadeepai.github.io/jumanji/environments/minesweeper/) |
| 🎲 RubiksCube | Logic | `RubiksCube-v0`<br/>`RubiksCube-partly-scrambled-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/rubiks_cube/) | [doc](https://instadeepai.github.io/jumanji/environments/rubiks_cube/) |
| 📦 BinPack (3D BinPacking Problem) | Packing | `BinPack-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/bin_pack/) | [doc](https://instadeepai.github.io/jumanji/environments/bin_pack/) |
| 🏭 JobShop (Job Shop Scheduling Problem) | Packing | `JobShop-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/job_shop/) | [doc](https://instadeepai.github.io/jumanji/environments/job_shop/) |
| 🏭 JobShop (Job Shop Scheduling Problem) | Packing | `JobShop-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/job_shop/) | [doc](https://instadeepai.github.io/jumanji/environments/job_shop/) |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless we test the new generator with a full training, we should actually keep the previous generator as default (and keep it to version v0).

| 🎒 Knapsack | Packing | `Knapsack-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/knapsack/) | [doc](https://instadeepai.github.io/jumanji/environments/knapsack/) |
| 🧹 Cleaner | Routing | `Cleaner-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/cleaner/) | [doc](https://instadeepai.github.io/jumanji/environments/cleaner/) |
| :link: Connector | Routing | `Connector-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/connector/) | [doc](https://instadeepai.github.io/jumanji/environments/connector/) |
Expand Down
5 changes: 2 additions & 3 deletions jumanji/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@
# largest ones are given in the observation.
register(id="BinPack-v1", entry_point="jumanji.environments:BinPack")

# Job-shop scheduling problem with 20 jobs, 10 machines, at most
# 8 operations per job, and a max operation duration of 6 timesteps.
register(id="JobShop-v0", entry_point="jumanji.environments:JobShop")
# Job-shop scheduling problem with 20 jobs, 10 machines, and known optimal makespan of 40.
register(id="JobShop-v1", entry_point="jumanji.environments:JobShop")

# Knapsack problem with 50 randomly generated items, a total budget
# of 12.5, and a dense reward function.
Expand Down
14 changes: 7 additions & 7 deletions jumanji/environments/packing/job_shop/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from jumanji import specs
from jumanji.env import Environment
from jumanji.environments.packing.job_shop.generator import Generator, RandomGenerator
from jumanji.environments.packing.job_shop.generator import DenseGenerator, Generator
from jumanji.environments.packing.job_shop.types import Observation, State
from jumanji.environments.packing.job_shop.viewer import JobShopViewer
from jumanji.types import TimeStep, restart, termination, transition
Expand Down Expand Up @@ -98,16 +98,16 @@ def __init__(

Args:
generator: `Generator` whose `__call__` instantiates an environment instance.
Implemented options are ['ToyGenerator', 'RandomGenerator'].
Defaults to `RandomGenerator` with 20 jobs, 10 machines, up to 8 ops
for any given job, and a max operation duration of 6.
Implemented options are ['ToyGenerator', 'RandomGenerator', 'DenseGenerator`].
Defaults to `DenseGenerator` with 20 jobs, 10 machines, and a makespan of 40.
viewer: `Viewer` used for rendering. Defaults to `JobShopViewer`.
"""
self.generator = generator or RandomGenerator(
self.generator = generator or DenseGenerator(
num_jobs=20,
num_machines=10,
max_num_ops=8,
max_op_duration=6,
max_num_ops=40,
max_op_duration=40,
makespan=40,
)
self.num_jobs = self.generator.num_jobs
self.num_machines = self.generator.num_machines
Expand Down
279 changes: 279 additions & 0 deletions jumanji/environments/packing/job_shop/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import abc
from typing import Any, Tuple

import chex
import jax
Expand Down Expand Up @@ -184,3 +185,281 @@ def __call__(self, key: chex.PRNGKey) -> State:
)

return state


class DenseGenerator(Generator):
"""`Generator` which creates a dense schedule of a specified makespan. This is done by:
- Specifying the `makespan` (schedule length) and the `num_machines`.
- Initialising an "empty" schedule.
- Creating a valid schedule:
1. Randomly sample `num_machines` jobs w/o replacement. These jobs will be
scheduled on the machines in the first time step.
2. At the next timestep, stochastically either:
- Reuse the previous jobs on the machines, or
- Randomly sample `num_machines` new jobs w/o replacement.
3. Repeat step 2 until the desired `makespan` is reached.
- Extracting the info (duration and machine) about operations from the schedule and
padding the operations to the max number of operations.

This generator assumes that the number of jobs is less than or equal to the number of
machines.
"""

def __init__(
self,
num_jobs: int,
num_machines: int,
max_num_ops: int,
max_op_duration: int,
Comment on lines +212 to +213
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have to have these in the init?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the way that the Generator interface is currently defined, we would have to, yes. This is because the __init__ takes those arguments.

makespan: int,
prob_reuse_threshold: float = 0.5,
):
"""Instantiate a `DenseGenerator`. Note that the `makespan` is an upper
bound to both `max_num_ops` and `max_op_duration`, hence they are deleted.

Args:
num_jobs: the number of jobs that need to be scheduled.
num_machines: the number of machines that the jobs can be scheduled on.
max_num_ops: the maximum number of operations for any given job.
max_op_duration: the maximum processing time of any given operation.
makespan: the length of the schedule. By construction, this will be the
shortest possible length of the schedule.
prob_reuse_threshold: the threshold probability of reusing the previous
job_id on a given machine. When generating the schedule, this quantity
determines how likely a machine is to try to reuse the job_id in the
previous timestep.
"""
del max_op_duration
del max_num_ops
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a comment to explain why they are deleted ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, have updated the doc string - let me know if you're happy with that

if num_jobs < num_machines:
raise ValueError(
"The number of jobs must be greater than or equal to the number of machines."
)

super().__init__(
num_jobs=num_jobs,
num_machines=num_machines,
max_num_ops=makespan,
max_op_duration=makespan,
)
self.makespan = makespan
self.prob_reuse_threshold = prob_reuse_threshold

def __call__(self, key: chex.PRNGKey) -> State:
key, schedule_key = jax.random.split(key)

# Generate a random, dense schedule of the specified length
schedule = self._generate_schedule(schedule_key)

# Extract ops information from the schedule
ops_machine_ids, ops_durations = self._register_ops(schedule)

# Initially, all machines are available (the value self.num_jobs corresponds to no-op)
machines_job_ids = jnp.full(self.num_machines, self.num_jobs, jnp.int32)
machines_remaining_times = jnp.full(self.num_machines, 0, jnp.int32)
scheduled_times = jnp.full((self.num_jobs, self.max_num_ops), -1, jnp.int32)
ops_mask = ops_machine_ids != -1
step_count = jnp.int32(0)

state = State(
ops_machine_ids=ops_machine_ids,
ops_durations=ops_durations,
ops_mask=ops_mask,
machines_job_ids=machines_job_ids,
machines_remaining_times=machines_remaining_times,
action_mask=None,
step_count=step_count,
scheduled_times=scheduled_times,
key=key,
)

return state

def _generate_schedule(self, key: chex.PRNGKey) -> chex.Array:
"""Creates a schedule given the constraints of the job shop scheduling problem.

For example, for 3 machines, 5 jobs, and a chosen optimal makespan of 12, a schedule
may look like:
[[1, 0, 1, 1, 2, 3, 4, 0, 1, 3, 2, 3],
[4, 1, 2, 0, 3, 4, 2, 3, 2, 2, 3, 2],
[0, 2, 3, 4, 0, 0, 0, 2, 3, 1, 0, 0]]

This means
- Machine 0 processes job 1, job 0, job 1 (for two steps), etc.
- Machine 1 processes job 4, job 1, job 2, job 0, etc.
- Machine 2 processes job 0, job 2, job 3, job 4, etc.

Importantly, since a job can only be executed on one machine at a time, this method
is written such that the schedule does not have duplicates in any column.

Args:
key: used for stochasticity in the generation of the schedule.

Returns:
Schedule with the specified length. Shape (num_machines, makespan).
"""
init_col_key, scan_key = jax.random.split(key)
all_job_ids = jnp.arange(self.num_jobs)

def insert_col(carry, _):
_scan_key, _init_col = carry
_scan_key, inner_key = jax.random.split(_scan_key)

init_job_mask = jnp.ones(shape=self.num_jobs, dtype=jnp.bool_)
init_machine_id = 0
inner_init_carry = inner_key, init_job_mask, _init_col, init_machine_id
_, col = jax.lax.scan(
lambda inner_carry, _: self.insert_operation(inner_carry, _),
inner_init_carry,
xs=jnp.arange(self.num_machines),
)

return (_scan_key, col), col

init_col = jax.random.choice(
init_col_key,
all_job_ids,
(self.num_machines,),
replace=False,
)
init_carry = scan_key, init_col
_, schedule_transposed = jax.lax.scan(
lambda carry, _: insert_col(carry, _),
init_carry,
xs=jnp.arange(self.makespan),
)
schedule = schedule_transposed.T
return schedule

def _register_ops(self, schedule: chex.Array) -> Tuple[chex.Array, chex.Array]:
"""Extract, for every job, the machine id and duration of each operation in the job.

For example, for the schedule
[[1, 0, 1, 1, 2, 3, 4, 0, 1, 3, 2, 3],
[4, 1, 2, 0, 3, 4, 2, 3, 2, 2, 3, 2],
[0, 2, 3, 4, 0, 0, 0, 2, 3, 1, 0, 0]]

the ops would have the machine ids:
[[ 2, 0, 1, 2, 0, 2, -1, -1, -1, -1],
[ 0, 1, 0, 0, 2, -1, -1, -1, -1, -1],
[ 2, 1, 0, 1, 2, 1, 0, 1, -1, -1],
[ 2, 1, 0, 1, 2, 0, 1, 0, -1, -1],
[ 1, 2, 1, 0, -1, -1, -1, -1, -1, -1]]

and the durations:
[[ 1, 1, 1, 3, 1, 2, -1, -1, -1, -1],
[ 1, 1, 2, 1, 1, -1, -1, -1, -1, -1],
[ 1, 1, 1, 1, 1, 2, 1, 1, -1, -1],
[ 1, 1, 1, 1, 1, 1, 1, 1, -1, -1],
[ 1, 1, 1, 1, -1, -1, -1, -1, -1, -1]]

Args:
schedule: array representing which job each machine is working on at each timestep.

Returns:
Arrays representing which machine id and duration characterising each operation.
"""

def get_job_info(
job_id: int, _: Any
) -> Tuple[int, Tuple[chex.Array, chex.Array]]:
"""Extract the machine id and duration of every op in the specified job.

In the above example, for job 0, this will return
- machine_ids [2, 0, 1, 2, 0, 2, -1, -1, -1, -1]
- durations [1, 1, 1, 3, 1, 2, -1, -1, -1, -1]
"""

def get_op_info(
mask: chex.Array,
_: Any,
) -> Tuple[chex.Array, Tuple[chex.Array, chex.Array]]:
"""Extract the machine id and duration for a given operation.

In the above example, for job 0 and operation 0, the machine id is 2 and
the duration is 1.

Args:
mask: array which keeps track of which operations have been registered.
"""
prev_mask = mask

# Flatten by column
mask_flat = jnp.ravel(mask, order="F")

# Find index of the next operation
idx = jnp.argmax(mask_flat)
t_start, machine_id = jnp.divmod(idx, self.num_machines)

# Update the mask -> the op is registered
mask = mask.at[machine_id, t_start].set(False)

# While loop in case op has duration > 1
init_val = (mask, machine_id, t_start + 1)

def next_is_same_op(val: Tuple) -> chex.Array:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are 4 levels of nested functions. Would it be possible to refactor the code with less nesting?

m, mid, t = val
return m[mid, t]

def update_mask(val: Tuple) -> Tuple:
m, mid, t = val
m = m.at[mid, t].set(False)
return m, mid, t + 1

(mask, machine_id, time) = jax.lax.while_loop(
next_is_same_op, update_mask, init_val
)

duration = time - t_start

# If all ops for this job are registered, return -1 for padding
all_ops_registered = ~jnp.any(prev_mask)
machine_id = jax.lax.select(all_ops_registered, -1, machine_id)
duration = jax.lax.select(all_ops_registered, -1, duration)

return mask, (machine_id, duration)

# Carry the mask
init_mask = jnp.array(schedule == job_id)
_, (mids, durs) = jax.lax.scan(
get_op_info, init_mask, xs=None, length=self.makespan
)

return job_id + 1, (mids, durs)

# Carry the job id
init_job_id = 0
job_id, (ops_mids, ops_durs) = jax.lax.scan(
get_job_info, init_job_id, xs=None, length=self.num_jobs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not xs = jnp.arange(self.num_jobs) instead of having job_id + 1 in the carry ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to use vmap instead of scan to avoid sequential evaluation ?

)
return ops_mids, ops_durs

def insert_operation(self, carry, _):
_inner_key, _job_mask, prev_col, machine_id = carry
_inner_key, reuse_key, job_key = jax.random.split(_inner_key, num=3)

all_job_ids = jnp.arange(self.num_jobs)

# Use the previous job on the machine with some probability and
# if the job hasn't already been scheduled in this timestep
prev_job_id = prev_col[machine_id]
reuse = jax.random.uniform(reuse_key, shape=()) >= self.prob_reuse_threshold
is_available = _job_mask[prev_job_id]
job_id = jax.lax.cond(
reuse & is_available,
lambda _: prev_job_id,
lambda _: jax.random.choice(job_key, all_job_ids, (), p=_job_mask),
None,
)

# Update the job mask to reflect that the chosen job can
# no longer be scheduled on any of the remaining machines
_job_mask = _job_mask.at[job_id].set(False)

return (_inner_key, _job_mask, prev_col, machine_id + 1), job_id


if __name__ == "__main__":
gen = DenseGenerator(10, 5, 14, 14, 14)
key = jax.random.PRNGKey(0)
state = gen(key)
Loading