Skip to content

Commit

Permalink
Merge pull request #4207 from tybug/float-clamper
Browse files Browse the repository at this point in the history
Use a single clamper for floats
  • Loading branch information
tybug authored Dec 21, 2024
2 parents ac3d793 + 787e9bf commit cd7dd9c
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 83 deletions.
3 changes: 3 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
RELEASE_TYPE: patch

This patch cleans up some internal code around clamping floats.
39 changes: 13 additions & 26 deletions hypothesis-python/src/hypothesis/internal/conjecture/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1576,8 +1576,7 @@ def draw_float(
(
sampler,
forced_sign_bit,
neg_clamper,
pos_clamper,
clamper,
nasty_floats,
) = self._draw_float_init_logic(
min_value=min_value,
Expand Down Expand Up @@ -1608,12 +1607,8 @@ def draw_float(
)
if allow_nan and math.isnan(result):
clamped = result
elif math.copysign(1.0, result) == -1:
assert neg_clamper is not None
clamped = -neg_clamper(-result)
else:
assert pos_clamper is not None
clamped = pos_clamper(result)
clamped = clamper(result)
if clamped != result and not (math.isnan(result) and allow_nan):
self._draw_float(forced=clamped, fake_forced=fake_forced)
result = clamped
Expand Down Expand Up @@ -1895,8 +1890,7 @@ def _draw_float_init_logic(
) -> tuple[
Optional[Sampler],
Optional[Literal[0, 1]],
Optional[Callable[[float], float]],
Optional[Callable[[float], float]],
Callable[[float], float],
list[float],
]:
"""
Expand Down Expand Up @@ -1933,8 +1927,7 @@ def _compute_draw_float_init_logic(
) -> tuple[
Optional[Sampler],
Optional[Literal[0, 1]],
Optional[Callable[[float], float]],
Optional[Callable[[float], float]],
Callable[[float], float],
list[float],
]:
if smallest_nonzero_magnitude == 0.0: # pragma: no cover
Expand Down Expand Up @@ -1966,23 +1959,17 @@ def permitted(f: float) -> bool:
weights = [0.2 * len(nasty_floats)] + [0.8] * len(nasty_floats)
sampler = Sampler(weights, observe=False) if nasty_floats else None

pos_clamper = neg_clamper = None
if sign_aware_lte(0.0, max_value):
pos_min = max(min_value, smallest_nonzero_magnitude)
allow_zero = sign_aware_lte(min_value, 0.0)
pos_clamper = make_float_clamper(pos_min, max_value, allow_zero=allow_zero)
if sign_aware_lte(min_value, -0.0):
neg_max = min(max_value, -smallest_nonzero_magnitude)
allow_zero = sign_aware_lte(-0.0, max_value)
neg_clamper = make_float_clamper(
-neg_max, -min_value, allow_zero=allow_zero
)

forced_sign_bit: Optional[Literal[0, 1]] = None
if (pos_clamper is None) != (neg_clamper is None):
forced_sign_bit = 1 if neg_clamper else 0
if sign_aware_lte(min_value, -0.0) != sign_aware_lte(0.0, max_value):
forced_sign_bit = 1 if sign_aware_lte(min_value, -0.0) else 0

return (sampler, forced_sign_bit, neg_clamper, pos_clamper, nasty_floats)
clamper = make_float_clamper(
min_value,
max_value,
smallest_nonzero_magnitude=smallest_nonzero_magnitude,
allow_nan=allow_nan,
)
return (sampler, forced_sign_bit, clamper, nasty_floats)


# The set of available `PrimitiveProvider`s, by name. Other libraries, such as
Expand Down
76 changes: 52 additions & 24 deletions hypothesis-python/src/hypothesis/internal/floats.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import math
import struct
from sys import float_info
from typing import TYPE_CHECKING, Callable, Literal, Optional, SupportsFloat, Union
from typing import TYPE_CHECKING, Callable, Literal, SupportsFloat, Union

from hypothesis.internal.conjecture.junkdrawer import clamp

if TYPE_CHECKING:
from typing import TypeAlias
Expand Down Expand Up @@ -136,38 +138,64 @@ def next_up_normal(value: float, width: int, *, allow_subnormal: bool) -> float:
}
assert width_smallest_normals[64] == float_info.min

mantissa_mask = (1 << 52) - 1


def float_permitted(
f: float,
*,
min_value: float,
max_value: float,
allow_nan: bool,
smallest_nonzero_magnitude: float,
) -> bool:
if math.isnan(f):
return allow_nan
if 0 < abs(f) < smallest_nonzero_magnitude:
return False
return sign_aware_lte(min_value, f) and sign_aware_lte(f, max_value)


def make_float_clamper(
min_float: float = 0.0,
max_float: float = math.inf,
min_value: float,
max_value: float,
*,
allow_zero: bool = False, # Allows +0.0 (even if minfloat > 0)
) -> Optional[Callable[[float], float]]:
smallest_nonzero_magnitude: float,
allow_nan: bool,
) -> Callable[[float], float]:
"""
Return a function that clamps positive floats into the given bounds.
Returns None when no values are allowed (min > max and zero is not allowed).
"""
if max_float < min_float:
if allow_zero:
min_float = max_float = 0.0
else:
return None

range_size = min(max_float - min_float, float_info.max)
mantissa_mask = (1 << 52) - 1

def float_clamper(float_val: float) -> float:
if min_float <= float_val <= max_float:
return float_val
if float_val == 0.0 and allow_zero:
return float_val
assert sign_aware_lte(min_value, max_value)
range_size = min(max_value - min_value, float_info.max)

def float_clamper(f: float) -> float:
if float_permitted(
f,
min_value=min_value,
max_value=max_value,
allow_nan=allow_nan,
smallest_nonzero_magnitude=smallest_nonzero_magnitude,
):
return f
# Outside bounds; pick a new value, sampled from the allowed range,
# using the mantissa bits.
mant = float_to_int(float_val) & mantissa_mask
float_val = min_float + range_size * (mant / mantissa_mask)
mant = float_to_int(abs(f)) & mantissa_mask
f = min_value + range_size * (mant / mantissa_mask)

# if we resampled into the space disallowed by smallest_nonzero_magnitude,
# default to smallest_nonzero_magnitude.
if 0 < abs(f) < smallest_nonzero_magnitude:
f = smallest_nonzero_magnitude
# we must have either -smallest_nonzero_magnitude <= min_value or
# smallest_nonzero_magnitude >= max_value, or no values would be
# possible. If smallest_nonzero_magnitude is not valid (because it's
# larger than max_value), then -smallest_nonzero_magnitude must be valid.
if smallest_nonzero_magnitude > max_value:
f *= -1

# Re-enforce the bounds (just in case of floating point arithmetic error)
return max(min_float, min(max_float, float_val))
return clamp(min_value, f, max_value)

return float_clamper

Expand Down
32 changes: 32 additions & 0 deletions hypothesis-python/tests/conjecture/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,35 @@ def ir(*values: list[IRType]) -> list[IRNode]:
)

return tuple(nodes)


def make_float_kw(
min_value,
max_value,
*,
allow_nan=True,
smallest_nonzero_magnitude=SMALLEST_SUBNORMAL,
):
return {
"min_value": min_value,
"max_value": max_value,
"allow_nan": allow_nan,
"smallest_nonzero_magnitude": smallest_nonzero_magnitude,
}


def make_integer_kw(min_value, max_value, *, weights=None, shrink_towards=0):
return {
"min_value": min_value,
"max_value": max_value,
"weights": weights,
"shrink_towards": shrink_towards,
}


def make_string_kw(intervals, *, min_size=0, max_size=COLLECTION_DEFAULT_MAX_SIZE):
return {"intervals": intervals, "min_size": min_size, "max_size": max_size}


# we could in theory define make_bytes_kw and make_boolean_kw, but without any
# default kw values they aren't really a time save.
84 changes: 52 additions & 32 deletions hypothesis-python/tests/cover/test_float_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,18 @@
import pytest

from hypothesis import example, given, strategies as st
from hypothesis.internal.conjecture.data import ir_value_equal
from hypothesis.internal.floats import (
count_between_floats,
float_permitted,
make_float_clamper,
next_down,
next_up,
sign_aware_lte,
)

from tests.conjecture.common import float_kwargs, make_float_kw


def test_can_handle_straddling_zero():
assert count_between_floats(-0.0, 0.0) == 2
Expand All @@ -44,40 +49,55 @@ def test_next_float_equal(func, val):
assert func(val) == val


# invalid order -> clamper is None:
@example(2.0, 1.0, 3.0)
# exponent comparisons:
@example(1, float_info.max, 0)
@example(1, float_info.max, 1)
@example(1, float_info.max, 10)
@example(1, float_info.max, float_info.max)
@example(1, float_info.max, math.inf)
@example(make_float_kw(1, float_info.max), 0)
@example(make_float_kw(1, float_info.max), 1)
@example(make_float_kw(1, float_info.max), 10)
@example(make_float_kw(1, float_info.max), float_info.max)
@example(make_float_kw(1, float_info.max), math.inf)
# mantissa comparisons:
@example(100.0001, 100.0003, 100.0001)
@example(100.0001, 100.0003, 100.0002)
@example(100.0001, 100.0003, 100.0003)
@given(st.floats(min_value=0), st.floats(min_value=0), st.floats(min_value=0))
def test_float_clamper(min_value, max_value, input_value):
clamper = make_float_clamper(min_value, max_value, allow_zero=False)
if max_value < min_value:
assert clamper is None
return
@example(make_float_kw(100.0001, 100.0003), 100.0001)
@example(make_float_kw(100.0001, 100.0003), 100.0002)
@example(make_float_kw(100.0001, 100.0003), 100.0003)
@example(make_float_kw(100.0001, 100.0003, allow_nan=False), math.nan)
@example(make_float_kw(0, 10, allow_nan=False), math.nan)
@example(make_float_kw(0, 10, allow_nan=True), math.nan)
# the branch coverage of resampling in the "out of range of smallest magnitude" case
# relies on randomness from the mantissa. try a few different values.
@example(make_float_kw(-4, -1, smallest_nonzero_magnitude=4), 4)
@example(make_float_kw(-4, -1, smallest_nonzero_magnitude=4), 5)
@example(make_float_kw(-4, -1, smallest_nonzero_magnitude=4), 6)
@example(make_float_kw(1, 4, smallest_nonzero_magnitude=4), -4)
@example(make_float_kw(1, 4, smallest_nonzero_magnitude=4), -5)
@example(make_float_kw(1, 4, smallest_nonzero_magnitude=4), -6)
@given(float_kwargs(), st.floats())
def test_float_clamper(kwargs, input_value):
min_value = kwargs["min_value"]
max_value = kwargs["max_value"]
allow_nan = kwargs["allow_nan"]
smallest_nonzero_magnitude = kwargs["smallest_nonzero_magnitude"]
clamper = make_float_clamper(
min_value,
max_value,
smallest_nonzero_magnitude=smallest_nonzero_magnitude,
allow_nan=allow_nan,
)
clamped = clamper(input_value)
if min_value <= input_value <= max_value:
assert input_value == clamped
if math.isnan(clamped):
# we should only clamp to nan if nans are allowed.
assert allow_nan
else:
assert min_value <= clamped <= max_value

# otherwise, we should have clamped to something in the permitted range.
assert sign_aware_lte(min_value, clamped)
assert sign_aware_lte(clamped, max_value)

@example(0.01, math.inf, 0.0)
@given(st.floats(min_value=0), st.floats(min_value=0), st.floats(min_value=0))
def test_float_clamper_with_allowed_zeros(min_value, max_value, input_value):
clamper = make_float_clamper(min_value, max_value, allow_zero=True)
assert clamper is not None
clamped = clamper(input_value)
if input_value == 0.0 or max_value < min_value:
assert clamped == 0.0
elif min_value <= input_value <= max_value:
assert input_value == clamped
else:
assert min_value <= clamped <= max_value
# if input_value was permitted in the first place, then the clamped value should
# be the same as the input value.
if float_permitted(
input_value,
min_value=min_value,
max_value=max_value,
allow_nan=allow_nan,
smallest_nonzero_magnitude=smallest_nonzero_magnitude,
):
assert ir_value_equal("float", input_value, clamped)
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ ignore = [
"PIE790", # See https://github.com/astral-sh/ruff/issues/10538
"PT001",
"PT003",
"PT004",
"PT006",
"PT007",
"PT009",
Expand Down

0 comments on commit cd7dd9c

Please sign in to comment.