Skip to content

Commit

Permalink
Addding save,load and modifying fit
Browse files Browse the repository at this point in the history
  • Loading branch information
The-Blitz committed Oct 4, 2024
1 parent b17a5de commit f6b0386
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions sklearn_genetic/genetic_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
from .utils.random import weighted_bool_individual
from .utils.tools import cxUniform, mutFlipBit, novelty_scorer

import pickle
import os
from .callbacks.model_checkpoint import ModelCheckpoint


class GASearchCV(BaseSearchCV):
"""
Expand Down Expand Up @@ -524,6 +528,16 @@ def fit(self, X, y, callbacks=None):
# Make sure the callbacks are valid
self.callbacks = check_callback(callbacks)

# Load state if a checkpoint exists
for callback in self.callbacks:
if isinstance(callback, ModelCheckpoint):
if os.path.exists(callback.checkpoint_path):
checkpoint_data = callback.load()
if checkpoint_data:
self.estimator.__dict__.update(checkpoint_data['estimator_state']) # noqa
self.logbook = checkpoint_data['logbook']
break

if callable(self.scoring):
self.scorer_ = self.scoring
self.metrics_list = [self.refit_metric]
Expand Down Expand Up @@ -601,6 +615,28 @@ def fit(self, X, y, callbacks=None):

return self

def save(self, filepath):
"""Save the current state of the GASearchCV instance to a file."""
try:
checkpoint_data = self.__dict__
with open(filepath, 'wb') as f:
pickle.dump(checkpoint_data, f)
print(f"GASearchCV model successfully saved to {filepath}")
except Exception as e:
print(f"Error saving GASearchCV: {e}")

@staticmethod
def load(filepath):
"""Load a GASearchCV instance from a file."""
try:
with open(filepath, 'rb') as f:
checkpoint_data = pickle.load(f)
model = GASearchCV(**checkpoint_data)
print(f"GASearchCV model successfully loaded from {filepath}")
return model
except Exception as e:
print(f"Error loading GASearchCV: {e}")

def _select_algorithm(self, pop, stats, hof):
"""
It selects the algorithm to run from the sklearn_genetic.algorithms module
Expand Down Expand Up @@ -1131,6 +1167,16 @@ def fit(self, X, y, callbacks=None):
# Make sure the callbacks are valid
self.callbacks = check_callback(callbacks)

# Load state if a checkpoint exists
for callback in self.callbacks:
if isinstance(callback, ModelCheckpoint):
if os.path.exists(callback.checkpoint_path):
checkpoint_data = callback.load()
if checkpoint_data:
self.estimator.__dict__.update(checkpoint_data['estimator_state']) # noqa
self.logbook = checkpoint_data['logbook']
break

if callable(self.scoring):
self.scorer_ = self.scoring
self.metrics_list = [self.refit_metric]
Expand Down Expand Up @@ -1192,6 +1238,28 @@ def fit(self, X, y, callbacks=None):

return self

def save(self, filepath):
"""Save the current state of the GAFeatureSelectionCV instance to a file."""
try:
checkpoint_data = self.__dict__
with open(filepath, 'wb') as f:
pickle.dump(checkpoint_data, f)
print(f"GAFeatureSelectionCV model successfully saved to {filepath}")
except Exception as e:
print(f"Error saving GAFeatureSelectionCV: {e}")

@staticmethod
def load(filepath):
"""Load a GAFeatureSelectionCV instance from a file."""
try:
with open(filepath, 'rb') as f:
checkpoint_data = pickle.load(f)
model = GAFeatureSelectionCV(**checkpoint_data)
print(f"GAFeatureSelectionCV model successfully loaded from {filepath}") # noqa
return model
except Exception as e:
print(f"Error loading GAFeatureSelectionCV: {e}")

def _select_algorithm(self, pop, stats, hof):
"""
It selects the algorithm to run from the sklearn_genetic.algorithms module
Expand Down

0 comments on commit f6b0386

Please sign in to comment.