Skip to content

Commit

Permalink
fix flaky tests (#975)
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin authored Nov 7, 2023
1 parent b451d5f commit e935ac1
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions keras_tuner/distribute/oracle_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,30 +79,28 @@ def test_random_search(tmp_path):
# TensorFlow model building and execution is not thread-safe.
num_workers = 1

x = np.random.uniform(-1, 1, size=(2, 5))
y = np.ones((2,))

def _test_random_search():
def build_model(hp):
model = keras.Sequential()
model.add(keras.layers.Dense(3, input_shape=(5,)))
for i in range(hp.Int("num_layers", 1, 3)):
model.add(
keras.layers.Dense(
hp.Int("num_units_%i" % i, 1, 3), activation="relu"
)
)
model.add(keras.layers.Dense(1, activation="sigmoid"))
model.compile("sgd", "binary_crossentropy")
model.add(
keras.layers.Dense(hp.Int("num_units", 1, 3), input_shape=(5,))
)
model.add(keras.layers.Dense(1))
model.compile(loss="mse")
return model

x = np.random.uniform(-1, 1, size=(2, 5))
y = np.ones((2, 1))

tuner = keras_tuner.tuners.RandomSearch(
hypermodel=build_model,
objective="val_loss",
max_trials=10,
directory=tmp_path,
)
tuner.search(x, y, validation_data=(x, y), epochs=1, batch_size=2)
tuner.search(
x, y, validation_data=(x, y), epochs=1, batch_size=2, verbose=0
)

# Suppress warnings about optimizer state not being restored by
# tf.keras.
Expand All @@ -112,7 +110,9 @@ def build_model(hp):
assert trials[0].score <= trials[1].score

models = tuner.get_best_models(2)
assert models[0].evaluate(x, y) <= models[1].evaluate(x, y)
assert models[0].evaluate(x, y, verbose=0) <= models[1].evaluate(
x, y, verbose=0
)

mock_distribute.mock_distribute(
_test_random_search, num_workers, wait_for_chief=True
Expand Down

0 comments on commit e935ac1

Please sign in to comment.