From f6b038694b22fb8d933d03623676a9c86273fa32 Mon Sep 17 00:00:00 2001 From: Rodrigo Lopez Date: Fri, 4 Oct 2024 13:52:25 -0500 Subject: [PATCH] Addding save,load and modifying fit --- sklearn_genetic/genetic_search.py | 68 +++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/sklearn_genetic/genetic_search.py b/sklearn_genetic/genetic_search.py index 8decc0e..6c41152 100644 --- a/sklearn_genetic/genetic_search.py +++ b/sklearn_genetic/genetic_search.py @@ -30,6 +30,10 @@ from .utils.random import weighted_bool_individual from .utils.tools import cxUniform, mutFlipBit, novelty_scorer +import pickle +import os +from .callbacks.model_checkpoint import ModelCheckpoint + class GASearchCV(BaseSearchCV): """ @@ -524,6 +528,16 @@ def fit(self, X, y, callbacks=None): # Make sure the callbacks are valid self.callbacks = check_callback(callbacks) + # Load state if a checkpoint exists + for callback in self.callbacks: + if isinstance(callback, ModelCheckpoint): + if os.path.exists(callback.checkpoint_path): + checkpoint_data = callback.load() + if checkpoint_data: + self.estimator.__dict__.update(checkpoint_data['estimator_state']) # noqa + self.logbook = checkpoint_data['logbook'] + break + if callable(self.scoring): self.scorer_ = self.scoring self.metrics_list = [self.refit_metric] @@ -601,6 +615,28 @@ def fit(self, X, y, callbacks=None): return self + def save(self, filepath): + """Save the current state of the GASearchCV instance to a file.""" + try: + checkpoint_data = self.__dict__ + with open(filepath, 'wb') as f: + pickle.dump(checkpoint_data, f) + print(f"GASearchCV model successfully saved to {filepath}") + except Exception as e: + print(f"Error saving GASearchCV: {e}") + + @staticmethod + def load(filepath): + """Load a GASearchCV instance from a file.""" + try: + with open(filepath, 'rb') as f: + checkpoint_data = pickle.load(f) + model = GASearchCV(**checkpoint_data) + print(f"GASearchCV model successfully loaded from {filepath}") + return model + except Exception as e: + print(f"Error loading GASearchCV: {e}") + def _select_algorithm(self, pop, stats, hof): """ It selects the algorithm to run from the sklearn_genetic.algorithms module @@ -1131,6 +1167,16 @@ def fit(self, X, y, callbacks=None): # Make sure the callbacks are valid self.callbacks = check_callback(callbacks) + # Load state if a checkpoint exists + for callback in self.callbacks: + if isinstance(callback, ModelCheckpoint): + if os.path.exists(callback.checkpoint_path): + checkpoint_data = callback.load() + if checkpoint_data: + self.estimator.__dict__.update(checkpoint_data['estimator_state']) # noqa + self.logbook = checkpoint_data['logbook'] + break + if callable(self.scoring): self.scorer_ = self.scoring self.metrics_list = [self.refit_metric] @@ -1192,6 +1238,28 @@ def fit(self, X, y, callbacks=None): return self + def save(self, filepath): + """Save the current state of the GAFeatureSelectionCV instance to a file.""" + try: + checkpoint_data = self.__dict__ + with open(filepath, 'wb') as f: + pickle.dump(checkpoint_data, f) + print(f"GAFeatureSelectionCV model successfully saved to {filepath}") + except Exception as e: + print(f"Error saving GAFeatureSelectionCV: {e}") + + @staticmethod + def load(filepath): + """Load a GAFeatureSelectionCV instance from a file.""" + try: + with open(filepath, 'rb') as f: + checkpoint_data = pickle.load(f) + model = GAFeatureSelectionCV(**checkpoint_data) + print(f"GAFeatureSelectionCV model successfully loaded from {filepath}") # noqa + return model + except Exception as e: + print(f"Error loading GAFeatureSelectionCV: {e}") + def _select_algorithm(self, pop, stats, hof): """ It selects the algorithm to run from the sklearn_genetic.algorithms module