-
Notifications
You must be signed in to change notification settings - Fork 85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(job_shop): implement dense schedule generator #115
Changes from all commits
1df15f7
51ec62d
046e8a0
b6b8cbd
32ff604
e43689e
3e7f2b8
44a591f
54b5ac6
d3cede1
f411418
e1db3bf
b03cbb8
d879df9
1050eaf
66d3053
d5bb546
7de39d4
53d1b09
52deda1
8849f14
982e677
257ac07
3e07217
bde343e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
# limitations under the License. | ||
|
||
import abc | ||
from typing import Any, Tuple | ||
|
||
import chex | ||
import jax | ||
|
@@ -184,3 +185,281 @@ def __call__(self, key: chex.PRNGKey) -> State: | |
) | ||
|
||
return state | ||
|
||
|
||
class DenseGenerator(Generator): | ||
"""`Generator` which creates a dense schedule of a specified makespan. This is done by: | ||
- Specifying the `makespan` (schedule length) and the `num_machines`. | ||
- Initialising an "empty" schedule. | ||
- Creating a valid schedule: | ||
1. Randomly sample `num_machines` jobs w/o replacement. These jobs will be | ||
scheduled on the machines in the first time step. | ||
2. At the next timestep, stochastically either: | ||
- Reuse the previous jobs on the machines, or | ||
- Randomly sample `num_machines` new jobs w/o replacement. | ||
3. Repeat step 2 until the desired `makespan` is reached. | ||
- Extracting the info (duration and machine) about operations from the schedule and | ||
padding the operations to the max number of operations. | ||
|
||
This generator assumes that the number of jobs is less than or equal to the number of | ||
machines. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
num_jobs: int, | ||
num_machines: int, | ||
max_num_ops: int, | ||
max_op_duration: int, | ||
Comment on lines
+212
to
+213
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you have to have these in the init? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the way that the |
||
makespan: int, | ||
prob_reuse_threshold: float = 0.5, | ||
): | ||
"""Instantiate a `DenseGenerator`. Note that the `makespan` is an upper | ||
bound to both `max_num_ops` and `max_op_duration`, hence they are deleted. | ||
|
||
Args: | ||
num_jobs: the number of jobs that need to be scheduled. | ||
num_machines: the number of machines that the jobs can be scheduled on. | ||
max_num_ops: the maximum number of operations for any given job. | ||
max_op_duration: the maximum processing time of any given operation. | ||
makespan: the length of the schedule. By construction, this will be the | ||
shortest possible length of the schedule. | ||
prob_reuse_threshold: the threshold probability of reusing the previous | ||
job_id on a given machine. When generating the schedule, this quantity | ||
determines how likely a machine is to try to reuse the job_id in the | ||
previous timestep. | ||
""" | ||
del max_op_duration | ||
del max_num_ops | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe a comment to explain why they are deleted ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good, have updated the doc string - let me know if you're happy with that |
||
if num_jobs < num_machines: | ||
raise ValueError( | ||
"The number of jobs must be greater than or equal to the number of machines." | ||
) | ||
|
||
super().__init__( | ||
num_jobs=num_jobs, | ||
num_machines=num_machines, | ||
max_num_ops=makespan, | ||
max_op_duration=makespan, | ||
) | ||
self.makespan = makespan | ||
self.prob_reuse_threshold = prob_reuse_threshold | ||
|
||
def __call__(self, key: chex.PRNGKey) -> State: | ||
key, schedule_key = jax.random.split(key) | ||
|
||
# Generate a random, dense schedule of the specified length | ||
schedule = self._generate_schedule(schedule_key) | ||
|
||
# Extract ops information from the schedule | ||
ops_machine_ids, ops_durations = self._register_ops(schedule) | ||
|
||
# Initially, all machines are available (the value self.num_jobs corresponds to no-op) | ||
machines_job_ids = jnp.full(self.num_machines, self.num_jobs, jnp.int32) | ||
machines_remaining_times = jnp.full(self.num_machines, 0, jnp.int32) | ||
scheduled_times = jnp.full((self.num_jobs, self.max_num_ops), -1, jnp.int32) | ||
ops_mask = ops_machine_ids != -1 | ||
step_count = jnp.int32(0) | ||
|
||
state = State( | ||
ops_machine_ids=ops_machine_ids, | ||
ops_durations=ops_durations, | ||
ops_mask=ops_mask, | ||
machines_job_ids=machines_job_ids, | ||
machines_remaining_times=machines_remaining_times, | ||
action_mask=None, | ||
step_count=step_count, | ||
scheduled_times=scheduled_times, | ||
key=key, | ||
) | ||
|
||
return state | ||
|
||
def _generate_schedule(self, key: chex.PRNGKey) -> chex.Array: | ||
"""Creates a schedule given the constraints of the job shop scheduling problem. | ||
|
||
For example, for 3 machines, 5 jobs, and a chosen optimal makespan of 12, a schedule | ||
may look like: | ||
[[1, 0, 1, 1, 2, 3, 4, 0, 1, 3, 2, 3], | ||
[4, 1, 2, 0, 3, 4, 2, 3, 2, 2, 3, 2], | ||
[0, 2, 3, 4, 0, 0, 0, 2, 3, 1, 0, 0]] | ||
|
||
This means | ||
- Machine 0 processes job 1, job 0, job 1 (for two steps), etc. | ||
- Machine 1 processes job 4, job 1, job 2, job 0, etc. | ||
- Machine 2 processes job 0, job 2, job 3, job 4, etc. | ||
|
||
Importantly, since a job can only be executed on one machine at a time, this method | ||
is written such that the schedule does not have duplicates in any column. | ||
|
||
Args: | ||
key: used for stochasticity in the generation of the schedule. | ||
|
||
Returns: | ||
Schedule with the specified length. Shape (num_machines, makespan). | ||
""" | ||
init_col_key, scan_key = jax.random.split(key) | ||
all_job_ids = jnp.arange(self.num_jobs) | ||
|
||
def insert_col(carry, _): | ||
_scan_key, _init_col = carry | ||
_scan_key, inner_key = jax.random.split(_scan_key) | ||
|
||
init_job_mask = jnp.ones(shape=self.num_jobs, dtype=jnp.bool_) | ||
init_machine_id = 0 | ||
inner_init_carry = inner_key, init_job_mask, _init_col, init_machine_id | ||
_, col = jax.lax.scan( | ||
lambda inner_carry, _: self.insert_operation(inner_carry, _), | ||
inner_init_carry, | ||
xs=jnp.arange(self.num_machines), | ||
) | ||
|
||
return (_scan_key, col), col | ||
|
||
init_col = jax.random.choice( | ||
init_col_key, | ||
all_job_ids, | ||
(self.num_machines,), | ||
replace=False, | ||
) | ||
init_carry = scan_key, init_col | ||
_, schedule_transposed = jax.lax.scan( | ||
lambda carry, _: insert_col(carry, _), | ||
init_carry, | ||
xs=jnp.arange(self.makespan), | ||
) | ||
schedule = schedule_transposed.T | ||
return schedule | ||
|
||
def _register_ops(self, schedule: chex.Array) -> Tuple[chex.Array, chex.Array]: | ||
"""Extract, for every job, the machine id and duration of each operation in the job. | ||
|
||
For example, for the schedule | ||
[[1, 0, 1, 1, 2, 3, 4, 0, 1, 3, 2, 3], | ||
[4, 1, 2, 0, 3, 4, 2, 3, 2, 2, 3, 2], | ||
[0, 2, 3, 4, 0, 0, 0, 2, 3, 1, 0, 0]] | ||
|
||
the ops would have the machine ids: | ||
[[ 2, 0, 1, 2, 0, 2, -1, -1, -1, -1], | ||
[ 0, 1, 0, 0, 2, -1, -1, -1, -1, -1], | ||
[ 2, 1, 0, 1, 2, 1, 0, 1, -1, -1], | ||
[ 2, 1, 0, 1, 2, 0, 1, 0, -1, -1], | ||
[ 1, 2, 1, 0, -1, -1, -1, -1, -1, -1]] | ||
|
||
and the durations: | ||
[[ 1, 1, 1, 3, 1, 2, -1, -1, -1, -1], | ||
[ 1, 1, 2, 1, 1, -1, -1, -1, -1, -1], | ||
[ 1, 1, 1, 1, 1, 2, 1, 1, -1, -1], | ||
[ 1, 1, 1, 1, 1, 1, 1, 1, -1, -1], | ||
[ 1, 1, 1, 1, -1, -1, -1, -1, -1, -1]] | ||
|
||
Args: | ||
schedule: array representing which job each machine is working on at each timestep. | ||
|
||
Returns: | ||
Arrays representing which machine id and duration characterising each operation. | ||
""" | ||
|
||
def get_job_info( | ||
job_id: int, _: Any | ||
) -> Tuple[int, Tuple[chex.Array, chex.Array]]: | ||
"""Extract the machine id and duration of every op in the specified job. | ||
|
||
In the above example, for job 0, this will return | ||
- machine_ids [2, 0, 1, 2, 0, 2, -1, -1, -1, -1] | ||
- durations [1, 1, 1, 3, 1, 2, -1, -1, -1, -1] | ||
""" | ||
|
||
def get_op_info( | ||
mask: chex.Array, | ||
_: Any, | ||
) -> Tuple[chex.Array, Tuple[chex.Array, chex.Array]]: | ||
"""Extract the machine id and duration for a given operation. | ||
|
||
In the above example, for job 0 and operation 0, the machine id is 2 and | ||
the duration is 1. | ||
|
||
Args: | ||
mask: array which keeps track of which operations have been registered. | ||
""" | ||
prev_mask = mask | ||
|
||
# Flatten by column | ||
mask_flat = jnp.ravel(mask, order="F") | ||
|
||
# Find index of the next operation | ||
idx = jnp.argmax(mask_flat) | ||
t_start, machine_id = jnp.divmod(idx, self.num_machines) | ||
|
||
# Update the mask -> the op is registered | ||
mask = mask.at[machine_id, t_start].set(False) | ||
|
||
# While loop in case op has duration > 1 | ||
init_val = (mask, machine_id, t_start + 1) | ||
|
||
def next_is_same_op(val: Tuple) -> chex.Array: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are 4 levels of nested functions. Would it be possible to refactor the code with less nesting? |
||
m, mid, t = val | ||
return m[mid, t] | ||
|
||
def update_mask(val: Tuple) -> Tuple: | ||
m, mid, t = val | ||
m = m.at[mid, t].set(False) | ||
return m, mid, t + 1 | ||
|
||
(mask, machine_id, time) = jax.lax.while_loop( | ||
next_is_same_op, update_mask, init_val | ||
) | ||
|
||
duration = time - t_start | ||
|
||
# If all ops for this job are registered, return -1 for padding | ||
all_ops_registered = ~jnp.any(prev_mask) | ||
machine_id = jax.lax.select(all_ops_registered, -1, machine_id) | ||
duration = jax.lax.select(all_ops_registered, -1, duration) | ||
|
||
return mask, (machine_id, duration) | ||
|
||
# Carry the mask | ||
init_mask = jnp.array(schedule == job_id) | ||
_, (mids, durs) = jax.lax.scan( | ||
get_op_info, init_mask, xs=None, length=self.makespan | ||
) | ||
|
||
return job_id + 1, (mids, durs) | ||
|
||
# Carry the job id | ||
init_job_id = 0 | ||
job_id, (ops_mids, ops_durs) = jax.lax.scan( | ||
get_job_info, init_job_id, xs=None, length=self.num_jobs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be possible to use |
||
) | ||
return ops_mids, ops_durs | ||
|
||
def insert_operation(self, carry, _): | ||
_inner_key, _job_mask, prev_col, machine_id = carry | ||
_inner_key, reuse_key, job_key = jax.random.split(_inner_key, num=3) | ||
|
||
all_job_ids = jnp.arange(self.num_jobs) | ||
|
||
# Use the previous job on the machine with some probability and | ||
# if the job hasn't already been scheduled in this timestep | ||
prev_job_id = prev_col[machine_id] | ||
reuse = jax.random.uniform(reuse_key, shape=()) >= self.prob_reuse_threshold | ||
is_available = _job_mask[prev_job_id] | ||
job_id = jax.lax.cond( | ||
reuse & is_available, | ||
lambda _: prev_job_id, | ||
lambda _: jax.random.choice(job_key, all_job_ids, (), p=_job_mask), | ||
None, | ||
) | ||
|
||
# Update the job mask to reflect that the chosen job can | ||
# no longer be scheduled on any of the remaining machines | ||
_job_mask = _job_mask.at[job_id].set(False) | ||
|
||
return (_inner_key, _job_mask, prev_col, machine_id + 1), job_id | ||
|
||
|
||
if __name__ == "__main__": | ||
gen = DenseGenerator(10, 5, 14, 14, 14) | ||
key = jax.random.PRNGKey(0) | ||
state = gen(key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless we test the new generator with a full training, we should actually keep the previous generator as default (and keep it to version v0).