diff --git a/msgspec/_core.c b/msgspec/_core.c index 1f5d5230..59a08d08 100644 --- a/msgspec/_core.c +++ b/msgspec/_core.c @@ -11092,6 +11092,22 @@ parse_number_nonfinite( return ms_post_decode_float(val, type, path, strict, true); } +static MS_NOINLINE PyObject * +json_float_hook( + const char *buf, Py_ssize_t size, PathNode *path, PyObject *float_hook +) { + PyObject *str = PyUnicode_New(size, 127); + if (str == NULL) return NULL; + memcpy(ascii_get_buffer(str), buf, size); + PyObject *out = CALL_ONE_ARG(float_hook, str); + Py_DECREF(str); + if (out == NULL) { + ms_maybe_wrap_validation_error(path); + return NULL; + } + return out; +} + static MS_INLINE PyObject * parse_number_inline( const unsigned char *p, @@ -11101,6 +11117,7 @@ parse_number_inline( TypeNode *type, PathNode *path, bool strict, + PyObject *float_hook, bool from_str ) { uint64_t mantissa = 0; @@ -11286,6 +11303,9 @@ parse_number_inline( (char *)start, p - start, true, path, NULL ); } + else if (MS_UNLIKELY(float_hook != NULL && type->types & MS_TYPE_ANY)) { + return json_float_hook((char *)start, p - start, path, float_hook); + } else { if (MS_UNLIKELY(exponent > 288 || exponent < -307)) { /* Exponent is out of bounds */ @@ -11363,6 +11383,7 @@ maybe_parse_number( type, path, strict, + NULL, true ); return (*out != NULL || errmsg == NULL); @@ -15403,6 +15424,7 @@ typedef struct JSONDecoderState { /* Configuration */ TypeNode *type; PyObject *dec_hook; + PyObject *float_hook; bool strict; /* Temporary scratch space */ @@ -15425,10 +15447,11 @@ typedef struct JSONDecoder { TypeNode *type; char strict; PyObject *dec_hook; + PyObject *float_hook; } JSONDecoder; PyDoc_STRVAR(JSONDecoder__doc__, -"Decoder(type='Any', *, strict=True, dec_hook=None)\n" +"Decoder(type='Any', *, strict=True, dec_hook=None, float_hook=None)\n" "--\n" "\n" "A JSON decoder.\n" @@ -15449,19 +15472,28 @@ PyDoc_STRVAR(JSONDecoder__doc__, " signature ``dec_hook(type: Type, obj: Any) -> Any``, where ``type`` is the\n" " expected message type, and ``obj`` is the decoded representation composed\n" " of only basic JSON types. This hook should transform ``obj`` into type\n" -" ``type``, or raise a ``NotImplementedError`` if unsupported." +" ``type``, or raise a ``NotImplementedError`` if unsupported.\n" +"float_hook : callable, optional\n" +" An optional callback for handling decoding untyped float literals. Should\n" +" have the signature ``float_hook(val: str) -> Any``, where ``val`` is the\n" +" raw string value of the JSON float. This hook is called to decode any\n" +" \"untyped\" float value (e.g. ``typing.Any`` typed). The default is\n" +" equivalent to ``float_hook=float``, where all untyped JSON floats are\n" +" decoded as python floats. Specifying ``float_hook=decimal.Decimal``\n" +" will decode all untyped JSON floats as decimals instead." ); static int JSONDecoder_init(JSONDecoder *self, PyObject *args, PyObject *kwds) { - char *kwlist[] = {"type", "strict", "dec_hook", NULL}; + char *kwlist[] = {"type", "strict", "dec_hook", "float_hook", NULL}; MsgspecState *st = msgspec_get_global_state(); PyObject *type = st->typing_any; PyObject *dec_hook = NULL; + PyObject *float_hook = NULL; int strict = 1; if (!PyArg_ParseTupleAndKeywords( - args, kwds, "|O$pO", kwlist, &type, &strict, &dec_hook) + args, kwds, "|O$pOO", kwlist, &type, &strict, &dec_hook, &float_hook) ) { return -1; } @@ -15479,6 +15511,19 @@ JSONDecoder_init(JSONDecoder *self, PyObject *args, PyObject *kwds) } self->dec_hook = dec_hook; + /* Handle float_hook */ + if (float_hook == Py_None) { + float_hook = NULL; + } + if (float_hook != NULL) { + if (!PyCallable_Check(float_hook)) { + PyErr_SetString(PyExc_TypeError, "float_hook must be callable"); + return -1; + } + Py_INCREF(float_hook); + } + self->float_hook = float_hook; + /* Handle strict */ self->strict = strict; @@ -15498,6 +15543,7 @@ JSONDecoder_traverse(JSONDecoder *self, visitproc visit, void *arg) if (out != 0) return out; Py_VISIT(self->orig_type); Py_VISIT(self->dec_hook); + Py_VISIT(self->float_hook); return 0; } @@ -15508,6 +15554,7 @@ JSONDecoder_dealloc(JSONDecoder *self) TypeNode_Free(self->type); Py_XDECREF(self->orig_type); Py_XDECREF(self->dec_hook); + Py_XDECREF(self->float_hook); Py_TYPE(self)->tp_free((PyObject *)self); } @@ -17551,7 +17598,7 @@ json_maybe_decode_number(JSONDecoderState *self, TypeNode *type, PathNode *path) PyObject *out = parse_number_inline( self->input_pos, self->input_end, &pout, &errmsg, - type, path, self->strict, false + type, path, self->strict, self->float_hook, false ); self->input_pos = (unsigned char *)pout; @@ -18014,6 +18061,7 @@ msgspec_json_format(PyObject *self, PyObject *args, PyObject *kwargs) /* Init decoder */ dec.dec_hook = NULL; + dec.float_hook = NULL; dec.type = NULL; dec.scratch = NULL; dec.scratch_capacity = 0; @@ -18095,6 +18143,7 @@ JSONDecoder_decode(JSONDecoder *self, PyObject *const *args, Py_ssize_t nargs) .type = self->type, .strict = self->strict, .dec_hook = self->dec_hook, + .float_hook = self->float_hook, .scratch = NULL, .scratch_capacity = 0, .scratch_len = 0 @@ -18161,6 +18210,7 @@ JSONDecoder_decode_lines(JSONDecoder *self, PyObject *const *args, Py_ssize_t na .type = self->type, .strict = self->strict, .dec_hook = self->dec_hook, + .float_hook = self->float_hook, .scratch = NULL, .scratch_capacity = 0, .scratch_len = 0 @@ -18237,6 +18287,7 @@ static PyMemberDef JSONDecoder_members[] = { {"type", T_OBJECT_EX, offsetof(JSONDecoder, orig_type), READONLY, "The Decoder type"}, {"strict", T_BOOL, offsetof(JSONDecoder, strict), READONLY, "The Decoder strict setting"}, {"dec_hook", T_OBJECT, offsetof(JSONDecoder, dec_hook), READONLY, "The Decoder dec_hook"}, + {"float_hook", T_OBJECT, offsetof(JSONDecoder, float_hook), READONLY, "The Decoder float_hook"}, {NULL}, }; @@ -18334,6 +18385,7 @@ msgspec_json_decode(PyObject *self, PyObject *const *args, Py_ssize_t nargs, PyO JSONDecoderState state = { .strict = strict, .dec_hook = dec_hook, + .float_hook = NULL, .scratch = NULL, .scratch_capacity = 0, .scratch_len = 0 diff --git a/msgspec/json.pyi b/msgspec/json.pyi index 53f9253a..30e87a92 100644 --- a/msgspec/json.pyi +++ b/msgspec/json.pyi @@ -18,6 +18,7 @@ T = TypeVar("T") enc_hook_sig = Optional[Callable[[Any], Any]] dec_hook_sig = Optional[Callable[[type, Any], Any]] +float_hook_sig = Optional[Callable[[str], Any]] class Encoder: enc_hook: enc_hook_sig @@ -41,6 +42,7 @@ class Decoder(Generic[T]): type: Type[T] strict: bool dec_hook: dec_hook_sig + float_hook: float_hook_sig @overload def __init__( @@ -48,6 +50,7 @@ class Decoder(Generic[T]): *, strict: bool = True, dec_hook: dec_hook_sig = None, + float_hook: float_hook_sig = None, ) -> None: ... @overload def __init__( @@ -56,6 +59,7 @@ class Decoder(Generic[T]): *, strict: bool = True, dec_hook: dec_hook_sig = None, + float_hook: float_hook_sig = None, ) -> None: ... @overload def __init__( @@ -64,6 +68,7 @@ class Decoder(Generic[T]): *, strict: bool = True, dec_hook: dec_hook_sig = None, + float_hook: float_hook_sig = None, ) -> None: ... def decode(self, data: Union[bytes, str]) -> T: ... def decode_lines(self, data: Union[bytes, str]) -> list[T]: ... diff --git a/tests/basic_typing_examples.py b/tests/basic_typing_examples.py index f560ec8d..21cd3fa4 100644 --- a/tests/basic_typing_examples.py +++ b/tests/basic_typing_examples.py @@ -2,6 +2,7 @@ from __future__ import annotations import datetime +import decimal import pickle from typing import Any, Dict, Final, List, Type, Union @@ -826,6 +827,14 @@ def dec_hook(typ: Type, obj: Any) -> Any: msgspec.json.Decoder(dec_hook=dec_hook) +def check_json_Decoder_float_hook() -> None: + msgspec.json.Decoder(float_hook=None) + msgspec.json.Decoder(float_hook=float) + dec = msgspec.json.Decoder(float_hook=decimal.Decimal) + if dec.float_hook is not None: + dec.float_hook("1.5") + + def check_json_Decoder_strict() -> None: dec = msgspec.json.Decoder(List[int], strict=False) reveal_type(dec.strict) # assert "bool" in typ diff --git a/tests/test_json.py b/tests/test_json.py index 7ffa0b03..898c455b 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -2,6 +2,7 @@ import base64 import datetime +import decimal import enum import gc import itertools @@ -524,6 +525,19 @@ def test_decode_lines_bad_call(self): with pytest.raises(TypeError): dec.decode(1) + def test_decoder_init_float_hook(self): + dec = msgspec.json.Decoder() + assert dec.float_hook is None + + dec = msgspec.json.Decoder(float_hook=None) + assert dec.float_hook is None + + dec = msgspec.json.Decoder(float_hook=decimal.Decimal) + assert dec.float_hook is decimal.Decimal + + with pytest.raises(TypeError): + dec = msgspec.json.Decoder(float_hook=1) + class TestBoolAndNone: def test_encode_none(self): @@ -1567,6 +1581,42 @@ def test_decode_float_err_expected_int(self, s): ): msgspec.json.decode(s, type=int) + def test_float_hook_untyped(self): + dec = msgspec.json.Decoder(float_hook=decimal.Decimal) + res = dec.decode(b"1.33") + assert res == decimal.Decimal("1.33") + assert type(res) is decimal.Decimal + + def test_float_hook_typed(self): + class Ex(msgspec.Struct): + a: float + b: decimal.Decimal + c: Any + d: Any + + class MyFloat(NamedTuple): + x: str + + dec = msgspec.json.Decoder(Ex, float_hook=MyFloat) + res = dec.decode(b'{"a": 1.5, "b": 1.3, "c": 1.3, "d": 123}') + sol = Ex(1.5, decimal.Decimal("1.3"), MyFloat("1.3"), 123) + assert res == sol + + def test_float_hook_error(self): + def float_hook(val): + raise ValueError("Oh no!") + + class Ex(msgspec.Struct): + a: float + b: Any + + dec = msgspec.json.Decoder(Ex, float_hook=float_hook) + assert dec.decode(b'{"a": 1.5, "b": 2}') == Ex(a=1.5, b=2) + with pytest.raises(msgspec.ValidationError) as rec: + dec.decode(b'{"a": 1.5, "b": 2.5}') + assert "Oh no!" in str(rec.value) + assert "at `$.b`" in str(rec.value) + class TestDecimal: """Most decimal tests are in test_common.py, the ones here are for json