From 242e2180f03e8dfa72a08cd0d52126951b471909 Mon Sep 17 00:00:00 2001 From: Rebecca Chen Date: Thu, 19 Jan 2023 15:37:06 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 503283329 --- learned_optimization/population/mutators/fixed_schedule.py | 2 +- .../population/mutators/single_worker_explore.py | 2 +- .../population/mutators/winner_take_all_genetic.py | 7 ++++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/learned_optimization/population/mutators/fixed_schedule.py b/learned_optimization/population/mutators/fixed_schedule.py index ba1f67c..675eb15 100644 --- a/learned_optimization/population/mutators/fixed_schedule.py +++ b/learned_optimization/population/mutators/fixed_schedule.py @@ -46,7 +46,7 @@ def update(self, state: Any, steps = cache[worker.generation_id] # grab the last checkpoint here. - last_checkpoint = steps.values()[-1] + last_checkpoint = steps.values()[-1] # pytype: disable=unsupported-operands logging.info("Active worker: %s", str(worker)) logging.info("last checkpoint : %s", str(last_checkpoint)) diff --git a/learned_optimization/population/mutators/single_worker_explore.py b/learned_optimization/population/mutators/single_worker_explore.py index be03a0a..d46bfcc 100644 --- a/learned_optimization/population/mutators/single_worker_explore.py +++ b/learned_optimization/population/mutators/single_worker_explore.py @@ -105,7 +105,7 @@ def add_worker_to_cache(from_checkpoint: population.Checkpoint, state["branch_checkpoint"] = steps[0] state["center"] = steps[0].generation_id - last_checkpoint = steps.values()[-1] + last_checkpoint = steps.values()[-1] # pytype: disable=unsupported-operands if state["phase"] == "exploit": # switch to center. diff --git a/learned_optimization/population/mutators/winner_take_all_genetic.py b/learned_optimization/population/mutators/winner_take_all_genetic.py index 9517bff..41f79a1 100644 --- a/learned_optimization/population/mutators/winner_take_all_genetic.py +++ b/learned_optimization/population/mutators/winner_take_all_genetic.py @@ -62,11 +62,12 @@ def update( return None, current_workers if self._steps_per_exploit: - if (cache[genid].keys()[-1] - - cache[genid].keys()[0]) < self._steps_per_exploit: + if ( + cache[genid].keys()[-1] - cache[genid].keys()[0] # pytype: disable=unsupported-operands + ) < self._steps_per_exploit: return None, current_workers - to_test = cache[genid].keys()[0] + self._steps_per_exploit + to_test = cache[genid].keys()[0] + self._steps_per_exploit # pytype: disable=unsupported-operands valid_values = [ x.value for (s, x) in cache[genid].items()