Skip to content

Commit

Permalink
Support encoding dicts with float keys
Browse files Browse the repository at this point in the history
We can now roundtrip dicts with float-like keys through JSON.
  • Loading branch information
jcrist committed Aug 10, 2023
1 parent 7bc2092 commit ef7808b
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 26 deletions.
6 changes: 3 additions & 3 deletions docs/source/supported-types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -696,9 +696,9 @@ Dict subclasses (`collections.OrderedDict`, for example) are also supported for
encoding only. To decode into a ``dict`` subclass you'll need to implement a
``dec_hook`` (see :doc:`extending`).

JSON and TOML only support key types that encode as strings or integers (for
example `str`, `int`, `enum.Enum`, `datetime.datetime`, `uuid.UUID`, ...).
MessagePack and YAML support any hashable for the key type.
JSON and TOML only support key types that encode as strings or numbers (for
example `str`, `int`, `float`, `enum.Enum`, `datetime.datetime`, `uuid.UUID`,
...). MessagePack and YAML support any hashable for the key type.

An error is raised during decoding if the keys or values don't match their
respective types (if specified).
Expand Down
29 changes: 23 additions & 6 deletions msgspec/_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -10774,7 +10774,7 @@ ms_decode_decimal_from_float(double val, PathNode *path, MsgspecState *mod) {
/* For finite values, render as the nearest IEEE754 double in string
* form, then call decimal.Decimal to parse */
char buf[24];
int n = write_f64(val, buf);
int n = write_f64(val, buf, false);
return ms_decode_decimal(buf, n, true, path, mod);
}
else {
Expand Down Expand Up @@ -11036,7 +11036,7 @@ parse_number_nonfinite(
PathNode *path,
bool strict
) {
ssize_t size = pend - p;
size_t size = pend - p;
double val;
if (size == 3) {
if (
Expand Down Expand Up @@ -12565,10 +12565,24 @@ static MS_NOINLINE int
json_encode_float(EncoderState *self, PyObject *obj) {
char buf[24];
double x = PyFloat_AS_DOUBLE(obj);
int n = write_f64(x, buf);
int n = write_f64(x, buf, false);
return ms_write(self, buf, n);
}

static MS_NOINLINE int
json_encode_float_as_str(EncoderState *self, PyObject *obj) {
char buf[24];
double x = PyFloat_AS_DOUBLE(obj);
int size = write_f64(x, buf, true);
if (ms_ensure_space(self, size + 2) < 0) return -1;
char *p = self->output_buffer_raw + self->output_len;
*p++ = '"';
memcpy(p, buf, size);
*(p + size) = '"';
self->output_len += size + 2;
return 0;
}

static MS_INLINE int
json_encode_cstr(EncoderState *self, const char *str, Py_ssize_t size) {
if (ms_ensure_space(self, size + 2) < 0) return -1;
Expand Down Expand Up @@ -12940,6 +12954,9 @@ json_encode_dict_key(EncoderState *self, PyObject *obj) {
if (type == &PyLong_Type) {
return json_encode_long_as_str(self, obj);
}
else if (type == &PyFloat_Type) {
return json_encode_float_as_str(self, obj);
}
else if (Py_TYPE(type) == self->mod->EnumMetaType) {
return json_encode_enum(self, obj, true);
}
Expand Down Expand Up @@ -12967,7 +12984,7 @@ json_encode_dict_key(EncoderState *self, PyObject *obj) {
else {
PyErr_SetString(
PyExc_TypeError,
"Only dicts with str-like or int-like keys are supported"
"Only dicts with str-like or number-like keys are supported"
);
return -1;
}
Expand Down Expand Up @@ -18560,7 +18577,7 @@ to_builtins_dict(ToBuiltinsState *self, PyObject *obj) {
new_key = to_builtins(self, key, true);
if (new_key == NULL) goto cleanup;
if (self->str_keys) {
if (PyLong_CheckExact(new_key)) {
if (PyLong_CheckExact(new_key) || PyFloat_CheckExact(new_key)) {
PyObject *temp = PyObject_Str(new_key);
if (temp == NULL) goto cleanup;
Py_DECREF(new_key);
Expand All @@ -18569,7 +18586,7 @@ to_builtins_dict(ToBuiltinsState *self, PyObject *obj) {
else if (!PyUnicode_CheckExact(new_key)) {
PyErr_SetString(
PyExc_TypeError,
"Only dicts with `str` or `int` keys are supported"
"Only dicts with str-like or number-like keys are supported"
);
goto cleanup;
}
Expand Down
22 changes: 18 additions & 4 deletions msgspec/ryu.h
Original file line number Diff line number Diff line change
Expand Up @@ -919,16 +919,30 @@ write_exponent(int32_t k, char* buf) {

/* Write a double to buf, requires 24 bytes of space */
static inline int
write_f64(double f, char* buf) {
write_f64(double f, char* buf, bool allow_nonfinite) {
const uint64_t bits = double_to_bits(f);
const int sign = ((bits >> (DOUBLE_MANTISSA_BITS + DOUBLE_EXPONENT_BITS)) & 1) != 0;
const uint64_t ieee_mantissa = bits & ((1ull << DOUBLE_MANTISSA_BITS) - 1);
const uint32_t ieee_exponent = (uint32_t) ((bits >> DOUBLE_MANTISSA_BITS) & ((1u << DOUBLE_EXPONENT_BITS) - 1));

/* Serialize all non-finite numbers as null */
if (ieee_exponent == ((1 << DOUBLE_EXPONENT_BITS) - 1)) {
memcpy(buf, "null", 4);
return 4;
if (MS_UNLIKELY(ieee_exponent == ((1 << DOUBLE_EXPONENT_BITS) - 1))) {
if (MS_LIKELY(!allow_nonfinite)) {
memcpy(buf, "null", 4);
return 4;
}
else {
if (ieee_mantissa == 0) {
if (sign) {
memcpy(buf, "-inf", 4);
return 4;
}
memcpy(buf, "inf", 3);
return 3;
}
memcpy(buf, "nan", 3);
return 3;
}
}

if (sign) {
Expand Down
11 changes: 7 additions & 4 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3608,12 +3608,15 @@ def test_encode_decimal(self, proto):
s = str(d)
assert proto.encode(d) == proto.encode(s)

def test_decode_decimal_str(self, proto):
d = decimal.Decimal("1.5")
msg = proto.encode(d)
@pytest.mark.parametrize(
"val", ["1.5", "InF", "-iNf", "iNfInItY", "-InFiNiTy", "NaN"]
)
def test_decode_decimal_str(self, val, proto):
sol = decimal.Decimal(val)
msg = proto.encode(sol)
res = proto.decode(msg, type=decimal.Decimal)
assert str(res) == str(sol)
assert type(res) is decimal.Decimal
assert res == d

def test_decode_decimal_str_invalid(self, proto):
msg = proto.encode("1..5")
Expand Down
29 changes: 26 additions & 3 deletions tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,11 +1800,12 @@ class Test(NamedTuple):


class TestDict:
def test_encode_dict_raises_non_string_or_int_keys(self):
def test_encode_dict_raises_non_string_or_numeric_keys(self):
with pytest.raises(
TypeError, match="Only dicts with str-like or int-like keys are supported"
TypeError,
match="Only dicts with str-like or number-like keys are supported",
):
msgspec.json.encode({"a": 1, 2.5: "bad"})
msgspec.json.encode({"a": 1, (1, 2): "bad"})

@pytest.mark.parametrize("x", [{}, {"a": 1}, {"a": 1, "b": 2}])
def test_roundtrip_dict(self, x):
Expand Down Expand Up @@ -1896,11 +1897,13 @@ def test_decode_dict_string_cache_ascii_only(self):
"key",
[
1,
1.5,
FruitInt.APPLE,
uuid.uuid4(),
datetime.datetime.now(),
datetime.date.today(),
datetime.datetime.now().time(),
datetime.timedelta(1.5),
b"test",
Decimal("1.5"),
],
Expand Down Expand Up @@ -1973,13 +1976,33 @@ def test_decode_dict_int_literal_key(self):
with pytest.raises(msgspec.ValidationError, match="Invalid enum value 3"):
dec.decode(b'{"-1": 10, "3": 20}')

def test_encode_dict_float_key(self):
msg = {
1.5: 1,
-1.5: 2,
0.0: 3,
float("-inf"): 4,
float("inf"): 5,
float("nan"): 6,
}
sol = msgspec.json.encode({str(k): v for k, v in msg.items()})
res = msgspec.json.encode(msg)
assert res == sol

def test_decode_dict_float_key(self):
msg = {"1.5": 1, "inf": 2, "-inf": 3, "0": 4, "-1.5e12": 5, "123": 6}
buf = msgspec.json.encode(msg)
sol = {float(k): v for k, v in msg.items()}
res = msgspec.json.decode(buf, type=Dict[float, int])
assert res == sol

def test_decode_dict_int_or_float_key(self):
buf = b'{"1.5": "a", "123": "b"}'
sol = {1.5: "a", 123: "b"}
res = msgspec.json.decode(buf, type=Dict[Union[int, float], str])
assert res == sol
assert type(list(res.keys())[-1]) is int

def test_encode_dict_str_subclass_key(self):
class mystr(str):
pass
Expand Down
8 changes: 2 additions & 6 deletions tests/test_to_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,17 +296,13 @@ def test_dict_str_keys(self):
assert to_builtins({FruitInt.BANANA: 1}, str_keys=True) == {"2": 1}
assert to_builtins({2: 1}, str_keys=True) == {"2": 1}

with pytest.raises(
TypeError, match="Only dicts with `str` or `int` keys are supported"
):
to_builtins({(1, 2): 3}, str_keys=True)

def test_dict_sequence_keys(self):
msg = {frozenset([1, 2]): 1}
assert to_builtins(msg) == {(1, 2): 1}

with pytest.raises(
TypeError, match="Only dicts with `str` or `int` keys are supported"
TypeError,
match="Only dicts with str-like or number-like keys are supported",
):
to_builtins(msg, str_keys=True)

Expand Down

0 comments on commit ef7808b

Please sign in to comment.