Skip to content

Commit

Permalink
adding a sample test
Browse files Browse the repository at this point in the history
  • Loading branch information
The-Blitz committed Oct 4, 2024
1 parent c5926d9 commit 2e3930b
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions sklearn_genetic/tests/test_genetic_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sklearn.metrics import make_scorer

import numpy as np
import os

from .. import GASearchCV
from ..space import Integer, Categorical, Continuous
Expand All @@ -19,6 +20,7 @@
ConsecutiveStopping,
TimerStopping,
ProgressBar,
ModelCheckpoint
)

from ..schedules import ExponentialAdapter, InverseAdapter
Expand Down Expand Up @@ -659,3 +661,51 @@ def test_expected_ga_schedulers():
assert "params" in cv_result_keys

assert crossover_scheduler.current_value + mutation_scheduler.current_value <= 1


def test_checkpoint_functionality():
clf = SGDClassifier(loss="modified_huber", fit_intercept=True)
gen = 5
evolved_estimator = GASearchCV(
clf,
cv=3,
scoring="accuracy",
population_size=6,
generations=gen,
tournament_size=3,
param_grid={
"l1_ratio": Continuous(0, 1),
"alpha": Continuous(1e-4, 1),
"average": Categorical([True, False]),
},
)
checkpoint_path = 'test_checkpoint.pkl'
checkpoint = ModelCheckpoint(checkpoint_path=checkpoint_path) # noqa
evolved_estimator.fit(X_train, y_train, callbacks=checkpoint)

checkpoint_data = checkpoint.load()

assert 'estimator' in checkpoint_data['estimator_state']
assert 'algorithm' in checkpoint_data['estimator_state']
assert 'logbook' in checkpoint_data

restored_estimator = GASearchCV(**checkpoint_data['estimator_state'])

assert restored_estimator.algorithm == checkpoint_data['estimator_state']['algorithm'] # noqa

assert len(checkpoint_data['logbook']) == gen + 1

test_estimator = GASearchCV(
clf,
param_grid={
"l1_ratio": Continuous(0, 1),
"alpha": Continuous(1e-1, 1),
"average": Categorical([True, False]),
},)

test_estimator.load('checkpoint_path')

assert restored_estimator.algorithm == test_estimator.algorithm # noqa

if os.path.exists(checkpoint_path):
os.remove(checkpoint_path)

0 comments on commit 2e3930b

Please sign in to comment.