From b17a5dea1b28c23e9f6916354a9042438b65386f Mon Sep 17 00:00:00 2001 From: Rodrigo Lopez Date: Fri, 4 Oct 2024 13:51:09 -0500 Subject: [PATCH] creating new callback --- sklearn_genetic/callbacks/model_checkpoint.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 sklearn_genetic/callbacks/model_checkpoint.py diff --git a/sklearn_genetic/callbacks/model_checkpoint.py b/sklearn_genetic/callbacks/model_checkpoint.py new file mode 100644 index 0000000..4053fe6 --- /dev/null +++ b/sklearn_genetic/callbacks/model_checkpoint.py @@ -0,0 +1,49 @@ +import pickle +from .base import BaseCallback +from .loggers import LogbookSaver +from copy import deepcopy + + +class ModelCheckpoint(BaseCallback): + def __init__(self, checkpoint_path, **dump_options): + self.checkpoint_path = checkpoint_path + self.dump_options = dump_options + + def on_step(self, record=None, logbook=None, estimator=None): + try: + if logbook is not None and len(logbook) > 0: + logbook_saver = LogbookSaver(self.checkpoint_path, **self.dump_options) # noqa + logbook_saver.on_step(record, logbook, estimator) + + estimator_state = { + 'estimator': estimator.estimator, + 'cv': estimator.cv, + 'scoring': estimator.scoring, + 'population_size': estimator.population_size, + 'generations': estimator.generations, + 'crossover_probability': estimator.crossover_probability, + 'mutation_probability': estimator.mutation_probability, + 'param_grid': estimator.param_grid, + 'algorithm': estimator.algorithm, + 'param_grid': estimator.param_grid, + } + checkpoint_data = { + 'estimator_state': estimator_state, + 'logbook': deepcopy(logbook) + } + with open(self.checkpoint_path, 'wb') as f: + pickle.dump(checkpoint_data, f) + print(f"Checkpoint save in {self.checkpoint_path}") + + except Exception as e: + print(f"Error saving checkpoint: {e}") + + def load(self): + """Load the model state from the checkpoint file.""" + try: + with open(self.checkpoint_path, 'rb') as f: + checkpoint_data = pickle.load(f) + return checkpoint_data + except Exception as e: + print(f"Error loading checkpoint: {e}") + return None