From 8b6bf04fcc15dd293743eebbb4f9d813ffb2dde4 Mon Sep 17 00:00:00 2001 From: "rodrigo.arenas" <31422766+rodrigo-arenas@users.noreply.github.com> Date: Thu, 12 Sep 2024 10:25:29 -0500 Subject: [PATCH] model cache for faster evaluation --- dev-requirements.txt | 2 +- pytest.ini | 2 ++ sklearn_genetic/genetic_search.py | 44 +++++++++++++++++++++++++++++-- 3 files changed, 45 insertions(+), 3 deletions(-) create mode 100644 pytest.ini diff --git a/dev-requirements.txt b/dev-requirements.txt index 53fced5..a9e964e 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,4 +1,4 @@ -scikit-learn>=1.1.0 +scikit-learn>=1.3.0 deap>=1.3.3 numpy>=1.19.0 pytest==7.4.0 diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..11d9f4c --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts = --verbose --color=yes --assert=plain --cov-fail-under=95 --cov-config=.coveragerc --cov=./ -p no:warnings --tb=short --cov-report=term-missing:skip-covered diff --git a/sklearn_genetic/genetic_search.py b/sklearn_genetic/genetic_search.py index 2ea6e14..e2f8da8 100644 --- a/sklearn_genetic/genetic_search.py +++ b/sklearn_genetic/genetic_search.py @@ -259,6 +259,7 @@ def __init__( self.return_train_score = return_train_score self.creator = creator self.log_config = log_config + self.fitness_cache = {} # Check that the estimator is compatible with scikit-learn if not is_classifier(self.estimator) and not is_regressor(self.estimator): @@ -392,6 +393,17 @@ def evaluate(self, individual): key: individual[n] for n, key in enumerate(self.space.parameters) } + # Convert hyperparameters to a tuple to use as a key in the cache + individual_key = tuple(sorted(current_generation_params.items())) + + # Check if the individual has already been evaluated + if individual_key in self.fitness_cache: + # Retrieve cached result + cached_result = self.fitness_cache[individual_key] + # Ensure the logbook is updated even if the individual is cached + self.logbook.record(parameters=cached_result["current_generation_params"]) + return cached_result["fitness"] + local_estimator = clone(self.estimator) local_estimator.set_params(**current_generation_params) @@ -437,7 +449,15 @@ def evaluate(self, individual): # Log the hyperparameters and the cv-score self.logbook.record(parameters=current_generation_params) - return [score] + fitness_result = [score] + + # Store the fitness result and the current generation parameters in the cache + self.fitness_cache[individual_key] = { + "fitness": fitness_result, + "current_generation_params": current_generation_params + } + + return fitness_result def fit(self, X, y, callbacks=None): """ @@ -880,6 +900,7 @@ def __init__( self.return_train_score = return_train_score self.creator = creator self.log_config = log_config + self.fitness_cache = {} # Check that the estimator is compatible with scikit-learn if not is_classifier(self.estimator) and not is_regressor(self.estimator): @@ -965,6 +986,16 @@ def evaluate(self, individual): local_estimator = clone(self.estimator) n_selected_features = np.sum(individual) + # Convert the individual to a tuple to use as a key in the cache + individual_key = tuple(individual) + + # Check if the individual has already been evaluated + if individual_key in self.fitness_cache: + cached_result = self.fitness_cache[individual_key] + # Ensure the logbook is updated even if the individual is cached + self.logbook.record(parameters=cached_result["current_generation_features"]) + return cached_result["fitness"] + # Compute the cv-metrics using only the selected features cv_results = cross_validate( local_estimator, @@ -1014,7 +1045,16 @@ def evaluate(self, individual): ): score = -self.criteria_sign * 100000 - return [score, n_selected_features] + # Prepare the fitness result + fitness_result = [score, n_selected_features] + + # Store the fitness result and the current generation features in the cache + self.fitness_cache[individual_key] = { + "fitness": fitness_result, + "current_generation_features": current_generation_features + } + + return fitness_result def fit(self, X, y, callbacks=None): """