Skip to content

Commit

Permalink
fix[test]: fix a bad bound in decimal fuzzing (vyperlang#3909)
Browse files Browse the repository at this point in the history
the decimal fuzz test `is_valid` condition was based on an ancient
version of decimals which had bounds at `-2**127` and `2**127`. update
the condition to be compatible with the latest version of `decimal`.

also increase the range of decimals produced by the decimal fuzzing
strategy, so that the fuzzer finds overflow issues faster.

an additional issue was found in the fuzz tests, which is that some
decimal operations panic with `decimal.InvalidOperation` instead of a
proper exception. this is a known bug, see GH vyperlang#2241. this fixes the
issue by catching the exception and raising an `OverflowException`.

misc/refactor:
- refactor several uses of quantize into a utility function
  • Loading branch information
charles-cooper authored Apr 4, 2024
1 parent dc62e7a commit ee11e3d
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 18 deletions.
3 changes: 2 additions & 1 deletion tests/functional/builtins/codegen/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
checksum_encode,
int_bounds,
is_checksum_encoded,
quantize,
round_towards_zero,
unsigned_to_signed,
)
Expand Down Expand Up @@ -414,7 +415,7 @@ def _vyper_literal(val, typ):
return "0x" + val.hex()
if isinstance(typ, DecimalT):
tmp = val
val = val.quantize(DECIMAL_EPSILON)
val = quantize(val)
assert tmp == val
return str(val)

Expand Down
8 changes: 2 additions & 6 deletions tests/functional/codegen/types/numbers/test_decimals.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from decimal import ROUND_DOWN, Decimal, getcontext
from decimal import Decimal, getcontext

import pytest

Expand All @@ -10,7 +10,7 @@
OverflowException,
TypeMismatch,
)
from vyper.utils import DECIMAL_EPSILON, SizeLimits
from vyper.utils import DECIMAL_EPSILON, SizeLimits, quantize


def test_decimal_override():
Expand Down Expand Up @@ -51,10 +51,6 @@ def foo(x: decimal) -> decimal:
compile_code(code)


def quantize(x: Decimal) -> Decimal:
return x.quantize(DECIMAL_EPSILON, rounding=ROUND_DOWN)


def test_decimal_test(get_contract_with_gas_estimation):
decimal_test = """
@external
Expand Down
24 changes: 18 additions & 6 deletions tests/unit/ast/nodes/test_fold_binop_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,17 @@

from tests.utils import parse_and_fold
from vyper.exceptions import OverflowException, TypeMismatch, ZeroDivisionException
from vyper.semantics.analysis.local import ExprVisitor
from vyper.semantics.types import DecimalT

DECIMAL_T = DecimalT()

st_decimals = st.decimals(
min_value=-(2**32), max_value=2**32, allow_nan=False, allow_infinity=False, places=10
min_value=DECIMAL_T.decimal_bounds[0],
max_value=DECIMAL_T.decimal_bounds[1],
allow_nan=False,
allow_infinity=False,
places=DECIMAL_T._decimal_places,
)


Expand All @@ -30,10 +38,11 @@ def foo(a: decimal, b: decimal) -> decimal:

try:
vyper_ast = parse_and_fold(f"{left} {op} {right}")
old_node = vyper_ast.body[0].value
new_node = old_node.get_folded_value()
expr = vyper_ast.body[0].value
ExprVisitor().visit(expr, DecimalT())
new_node = expr.get_folded_value()
is_valid = True
except ZeroDivisionException:
except (OverflowException, ZeroDivisionException):
is_valid = False

if is_valid:
Expand Down Expand Up @@ -71,9 +80,12 @@ def foo({input_value}) -> decimal:
literal_op = literal_op.rsplit(maxsplit=1)[0]
try:
vyper_ast = parse_and_fold(literal_op)
new_node = vyper_ast.body[0].value.get_folded_value()
expr = vyper_ast.body[0].value
ExprVisitor().visit(expr, DecimalT())
new_node = expr.get_folded_value()
expected = new_node.value
is_valid = -(2**127) <= expected < 2**127
lo, hi = DecimalT().decimal_bounds
is_valid = lo <= expected < hi
except (OverflowException, ZeroDivisionException):
# for overflow or division/modulus by 0, expect the contract call to revert
is_valid = False
Expand Down
29 changes: 24 additions & 5 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,14 @@
VyperException,
ZeroDivisionException,
)
from vyper.utils import MAX_DECIMAL_PLACES, SizeLimits, annotate_source_code, evm_div, sha256sum
from vyper.utils import (
MAX_DECIMAL_PLACES,
SizeLimits,
annotate_source_code,
evm_div,
quantize,
sha256sum,
)

NODE_BASE_ATTRIBUTES = (
"_children",
Expand Down Expand Up @@ -824,6 +831,7 @@ def to_dict(self):
return ast_dict

def validate(self):
# note: maybe use self.value == quantize(self.value) for this check
if self.value.as_tuple().exponent < -MAX_DECIMAL_PLACES:
raise InvalidLiteral("Vyper supports a maximum of ten decimal points", self)
if self.value < SizeLimits.MIN_AST_DECIMAL:
Expand Down Expand Up @@ -1010,9 +1018,15 @@ def _op(self, left, right):
value = left * right
if isinstance(left, decimal.Decimal):
# ensure that the result is truncated to MAX_DECIMAL_PLACES
return value.quantize(
decimal.Decimal(f"{1:0.{MAX_DECIMAL_PLACES}f}"), decimal.ROUND_DOWN
)
try:
# if the intermediate result requires too many decimal places,
# decimal will puke - catch the error and raise an
# OverflowException
return quantize(value)
except decimal.InvalidOperation:
msg = f"{self._description} requires too many decimal places:"
msg += f"\n {left} * {right} => {value}"
raise OverflowException(msg, self) from None
else:
return value

Expand All @@ -1036,7 +1050,12 @@ def _op(self, left, right):
# the EVM always truncates toward zero
value = -(-left / right)
# ensure that the result is truncated to MAX_DECIMAL_PLACES
return value.quantize(decimal.Decimal(f"{1:0.{MAX_DECIMAL_PLACES}f}"), decimal.ROUND_DOWN)
try:
return quantize(value)
except decimal.InvalidOperation:
msg = f"{self._description} requires too many decimal places:"
msg += f"\n {left} {self._pretty} {right} => {value}"
raise OverflowException(msg, self) from None


class FloorDiv(VyperNode):
Expand Down
5 changes: 5 additions & 0 deletions vyper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,11 @@ class SizeLimits:
MAX_UINT256 = 2**256 - 1


def quantize(d: decimal.Decimal, places=MAX_DECIMAL_PLACES, rounding_mode=decimal.ROUND_DOWN):
quantizer = decimal.Decimal(f"{1:0.{places}f}")
return d.quantize(quantizer, rounding_mode)


# List of valid IR macros.
# TODO move this somewhere else, like ir_node.py
VALID_IR_MACROS = {
Expand Down

0 comments on commit ee11e3d

Please sign in to comment.