Skip to content

Commit

Permalink
Small bugfix in tests. (#462)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmahsereci authored May 7, 2024
1 parent 80275bb commit 2de6ef7
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 22 deletions.
17 changes: 7 additions & 10 deletions tests/emukit/quadrature/test_measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,27 +75,24 @@ def gauss_measure():
return DataGaussMeasure()


measure_test_list = [
DataLebesgueMeasure(),
DataLebesgueNormalizedMeasure(),
DataGaussIsoMeasure(),
DataGaussMeasure(),
]
measure_test_list = ["lebesgue_measure", "lebesgue_measure_normalized", "gauss_iso_measure", "gauss_measure"]


# === tests shared by all measures start here


@pytest.mark.parametrize("measure", measure_test_list)
def test_measure_gradient_values(measure):
@pytest.mark.parametrize("measure_name", measure_test_list)
def test_measure_gradient_values(measure_name, request):
measure = request.getfixturevalue(measure_name)
D, measure, dat_bounds = measure.D, measure.measure, measure.dat_bounds
func = lambda x: measure.compute_density(x)
dfunc = lambda x: measure.compute_density_gradient(x).T
check_grad(func, dfunc, in_shape=(3, D), bounds=dat_bounds)


@pytest.mark.parametrize("measure", measure_test_list)
def test_measure_shapes(measure):
@pytest.mark.parametrize("measure_name", measure_test_list)
def test_measure_shapes(measure_name, request):
measure = request.getfixturevalue(measure_name)
D, measure = measure.D, measure.measure

# box bounds
Expand Down
24 changes: 12 additions & 12 deletions tests/emukit/quadrature/test_warpings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@ def identity_warping():


@pytest.fixture
def squarerroot_warping():
def square_root_warping():
offset = 1.0
return SquareRootWarping(offset=offset)


@pytest.fixture
def inverted_squarerroot_warping():
def inverted_square_root_warping():
offset = 1.0
return SquareRootWarping(offset=offset, is_inverted=True)


warpings = [
"identity_warping",
"squarerroot_warping",
"inverted_squarerroot_warping",
"square_root_warping",
"inverted_square_root_warping",
]


Expand All @@ -56,16 +56,16 @@ def test_warping_values(warping_name, request):
assert_allclose(warping.inverse_transform(warping.transform(Y)), Y, rtol=RTOL, atol=ATOL)


def test_squarerroot_warping_update_parameters(squarerroot_warping, inverted_squarerroot_warping):
def test_square_root_warping_update_parameters(square_root_warping, inverted_square_root_warping):
new_offset = 10.0

squarerroot_warping.update_parameters(offset=new_offset)
assert squarerroot_warping.offset == new_offset
square_root_warping.update_parameters(offset=new_offset)
assert square_root_warping.offset == new_offset

inverted_squarerroot_warping.update_parameters(offset=new_offset)
assert inverted_squarerroot_warping.offset == new_offset
inverted_square_root_warping.update_parameters(offset=new_offset)
assert inverted_square_root_warping.offset == new_offset


def test_squarerroot_warping_inverted_flag(squarerroot_warping, inverted_squarerroot_warping):
assert not squarerroot_warping.is_inverted
assert inverted_squarerroot_warping.is_inverted
def test_square_root_warping_inverted_flag(square_root_warping, inverted_square_root_warping):
assert not square_root_warping.is_inverted
assert inverted_square_root_warping.is_inverted

0 comments on commit 2de6ef7

Please sign in to comment.