From 75d77c7f91f042cd7218f8bb8b3ae26f9473dfd7 Mon Sep 17 00:00:00 2001 From: Maren Mahsereci Date: Mon, 29 Apr 2024 15:13:16 +0200 Subject: [PATCH 1/3] correct typo --- tests/emukit/quadrature/test_warpings.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/emukit/quadrature/test_warpings.py b/tests/emukit/quadrature/test_warpings.py index db107df7..ef3feea5 100644 --- a/tests/emukit/quadrature/test_warpings.py +++ b/tests/emukit/quadrature/test_warpings.py @@ -17,21 +17,21 @@ def identity_warping(): @pytest.fixture -def squarerroot_warping(): +def square_root_warping(): offset = 1.0 return SquareRootWarping(offset=offset) @pytest.fixture -def inverted_squarerroot_warping(): +def inverted_square_root_warping(): offset = 1.0 return SquareRootWarping(offset=offset, is_inverted=True) warpings = [ "identity_warping", - "squarerroot_warping", - "inverted_squarerroot_warping", + "square_root_warping", + "inverted_square_root_warping", ] @@ -56,16 +56,16 @@ def test_warping_values(warping_name, request): assert_allclose(warping.inverse_transform(warping.transform(Y)), Y, rtol=RTOL, atol=ATOL) -def test_squarerroot_warping_update_parameters(squarerroot_warping, inverted_squarerroot_warping): +def test_square_root_warping_update_parameters(square_root_warping, inverted_square_root_warping): new_offset = 10.0 - squarerroot_warping.update_parameters(offset=new_offset) - assert squarerroot_warping.offset == new_offset + square_root_warping.update_parameters(offset=new_offset) + assert square_root_warping.offset == new_offset - inverted_squarerroot_warping.update_parameters(offset=new_offset) - assert inverted_squarerroot_warping.offset == new_offset + inverted_square_root_warping.update_parameters(offset=new_offset) + assert inverted_square_root_warping.offset == new_offset -def test_squarerroot_warping_inverted_flag(squarerroot_warping, inverted_squarerroot_warping): - assert not squarerroot_warping.is_inverted - assert inverted_squarerroot_warping.is_inverted +def test_square_root_warping_inverted_flag(square_root_warping, inverted_square_root_warping): + assert not square_root_warping.is_inverted + assert inverted_square_root_warping.is_inverted From 794ffb48ab3575409b22e91e87d269332c62d7bb Mon Sep 17 00:00:00 2001 From: Maren Mahsereci Date: Tue, 7 May 2024 15:29:28 +0200 Subject: [PATCH 2/3] using fixtures --- tests/emukit/quadrature/test_measures.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/emukit/quadrature/test_measures.py b/tests/emukit/quadrature/test_measures.py index b4b7dff0..7914d009 100644 --- a/tests/emukit/quadrature/test_measures.py +++ b/tests/emukit/quadrature/test_measures.py @@ -76,26 +76,28 @@ def gauss_measure(): measure_test_list = [ - DataLebesgueMeasure(), - DataLebesgueNormalizedMeasure(), - DataGaussIsoMeasure(), - DataGaussMeasure(), + "lebesgue_measure", + "lebesgue_measure_normalized", + "gauss_iso_measure", + "gauss_measure" ] # === tests shared by all measures start here -@pytest.mark.parametrize("measure", measure_test_list) -def test_measure_gradient_values(measure): +@pytest.mark.parametrize("measure_name", measure_test_list) +def test_measure_gradient_values(measure_name, request): + measure = request.getfixturevalue(measure_name) D, measure, dat_bounds = measure.D, measure.measure, measure.dat_bounds func = lambda x: measure.compute_density(x) dfunc = lambda x: measure.compute_density_gradient(x).T check_grad(func, dfunc, in_shape=(3, D), bounds=dat_bounds) -@pytest.mark.parametrize("measure", measure_test_list) -def test_measure_shapes(measure): +@pytest.mark.parametrize("measure_name", measure_test_list) +def test_measure_shapes(measure_name, request): + measure = request.getfixturevalue(measure_name) D, measure = measure.D, measure.measure # box bounds From b2dea4f353067807e56ae36187995409e8574c4d Mon Sep 17 00:00:00 2001 From: Maren Mahsereci Date: Tue, 7 May 2024 15:36:08 +0200 Subject: [PATCH 3/3] black formatting --- tests/emukit/quadrature/test_measures.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/emukit/quadrature/test_measures.py b/tests/emukit/quadrature/test_measures.py index 7914d009..d656de59 100644 --- a/tests/emukit/quadrature/test_measures.py +++ b/tests/emukit/quadrature/test_measures.py @@ -75,12 +75,7 @@ def gauss_measure(): return DataGaussMeasure() -measure_test_list = [ - "lebesgue_measure", - "lebesgue_measure_normalized", - "gauss_iso_measure", - "gauss_measure" -] +measure_test_list = ["lebesgue_measure", "lebesgue_measure_normalized", "gauss_iso_measure", "gauss_measure"] # === tests shared by all measures start here