v0.4.0
Controller training now supports multiple models.
Example:
ModelControllerTrainer(
model = {
"model1": fitted_model1,
"model2": fitted_model2,
},
controller=controller,
# Note the `model1` name
trackers = [Tracker("model1", "train_mse")]
)
To get a scalar loss value, the TrainingOptionsController
have gained a new option called loss_fn_reduce_along_models
.
By default it is:
NAME_AND_VALUE = dict[str, float]
def default_loss_fn_reduce_along_models(
log_of_loss_values: dict[str, NAME_AND_VALUE]
) -> NAME_AND_VALUE:
flat_logs = batch_concat(log_of_loss_values, 0)
return {"loss_without_regu": jnp.mean(flat_logs)}
We could provide a custom implementation (extending the above example)
def our_loss_fn_reduce_along_models(log_of_loss_values):
train_mse1 = log_of_loss_values["model1"]["train_mse"]
train_mse2 = log_of_loss_values["model2"]["train_mse"]
return {"my_loss": (train_mse1 + train_mse2) / 2}
Then, we could also track on this metric by using
ModelControllerTrainer(
model = {
"model1": fitted_model1,
"model2": fitted_model2,
},
controller=controller,
controller_train_options=my_options,
trackers = [Tracker("my_loss")]
)