Skip to content

v0.4.0

Compare
Choose a tag to compare
@simon-bachhuber simon-bachhuber released this 27 Dec 20:29
· 139 commits to main since this release

Controller training now supports multiple models.

Example:

ModelControllerTrainer(
    model = {
        "model1": fitted_model1,
        "model2": fitted_model2,
    },
    controller=controller,
    # Note the `model1` name
    trackers = [Tracker("model1", "train_mse")]
)

To get a scalar loss value, the TrainingOptionsController have gained a new option called loss_fn_reduce_along_models.
By default it is:

NAME_AND_VALUE = dict[str, float]

def default_loss_fn_reduce_along_models(
    log_of_loss_values: dict[str, NAME_AND_VALUE]
) -> NAME_AND_VALUE:
    flat_logs = batch_concat(log_of_loss_values, 0)
    return {"loss_without_regu": jnp.mean(flat_logs)}

We could provide a custom implementation (extending the above example)

def our_loss_fn_reduce_along_models(log_of_loss_values):
    train_mse1 = log_of_loss_values["model1"]["train_mse"]
    train_mse2 = log_of_loss_values["model2"]["train_mse"]
    return {"my_loss": (train_mse1 + train_mse2) / 2}

Then, we could also track on this metric by using

ModelControllerTrainer(
    model = {
        "model1": fitted_model1,
        "model2": fitted_model2,
    },
    controller=controller,
    controller_train_options=my_options,
    trackers = [Tracker("my_loss")]
)