Skip to content

Commit

Permalink
Allow passing seed to VizierGPBandit
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 527694405
  • Loading branch information
sagipe authored and copybara-github committed Apr 27, 2023
1 parent 40042af commit e120ae2
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
14 changes: 9 additions & 5 deletions vizier/_src/algorithms/designers/gp_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import datetime
import json
import random
from typing import Sequence
from typing import Optional, Sequence

from absl import logging
import attr
Expand Down Expand Up @@ -88,16 +88,15 @@ class VizierGPBandit(vza.Designer, vza.Predictor):
factory=acquisitions.GPBanditAcquisitionBuilder, kw_only=True
)
_use_trust_region: bool = attr.field(default=True, kw_only=True)
_rng: jax.random.KeyArray = attr.field(
factory=lambda: jax.random.PRNGKey(random.getrandbits(32)), kw_only=True
)
_seed: Optional[int] = attr.field(default=None, kw_only=True)
_metadata_ns: str = attr.field(
default='oss_gp_bandit', kw_only=True, init=False
)

# ------------------------------------------------------------------
# Internal attributes which should not be set by callers.
# ------------------------------------------------------------------
_rng: jax.random.KeyArray = attr.field(init=False, kw_only=True)
_trials: list[vz.Trial] = attr.field(factory=list, init=False)
# The number of trials that have been incorporated
# into the designer state (Cholesky decomposition, ARD).
Expand Down Expand Up @@ -130,7 +129,8 @@ class VizierGPBandit(vza.Designer, vza.Predictor):
random_restarts=4, best_n=1
)
default_acquisition_optimizer = vb.VectorizedOptimizer(
strategy_factory=es.VectorizedEagleStrategyFactory())
strategy_factory=es.VectorizedEagleStrategyFactory()
)

def __attrs_post_init__(self):
# Extra validations
Expand All @@ -139,6 +139,9 @@ def __attrs_post_init__(self):
elif len(self._problem.metric_information) != 1:
raise ValueError(f'{type(self)} works with exactly one metric.')
# Extra initializations.
if self._seed is None:
self._seed = random.getrandbits(32)
self._rng = jax.random.PRNGKey(self._seed)
# Discrete parameters are continuified to account for their actual values.
self._converter = converters.TrialToArrayConverter.from_study_config(
self._problem,
Expand Down Expand Up @@ -311,6 +314,7 @@ def precompute_cholesky(params):
method=self._model.precompute_predictive,
mutable='predictive',
)

if self._use_vmap:
precompute_cholesky = jax.vmap(precompute_cholesky)
# `pp_state` contains intermediates that are expensive to compute, depend
Expand Down
26 changes: 25 additions & 1 deletion vizier/_src/algorithms/designers/gp_bandit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ def test_on_flat_mixed_space(
batch_size=batch_size,
verbose=1,
validate_parameters=True,
).run_designer(designer), iters * batch_size)
).run_designer(designer),
iters * batch_size,
)

quasi_random_sampler = quasi_random.QuasiRandomDesigner(
problem.search_space
Expand Down Expand Up @@ -195,6 +197,28 @@ def test_prediction_accuracy(self):
pred = gp_designer.predict([pred_trial], rng=jax.random.PRNGKey(0))
self.assertLess(np.abs(pred.mean[0] - f(0.0)), 1e-2)

def test_seed_specified(self):
problem = vz.ProblemStatement(
test_studies.flat_continuous_space_with_scaling()
)
designer = gp_bandit.VizierGPBandit(problem=problem, seed=346778)
completed1 = test_runners.RandomMetricsRunner(
problem,
iters=3,
batch_size=1,
verbose=1,
validate_parameters=True,
).run_designer(designer)
designer2 = gp_bandit.VizierGPBandit(problem=problem, seed=346778)
completed2 = test_runners.RandomMetricsRunner(
problem,
iters=3,
batch_size=1,
verbose=1,
validate_parameters=True,
).run_designer(designer2)
self.assertEqual(completed1, completed2)


if __name__ == '__main__':
absltest.main()

0 comments on commit e120ae2

Please sign in to comment.