From 7e418f49243bb7d13fa92cf2634af1eeac386465 Mon Sep 17 00:00:00 2001 From: gentlegiantJGC Date: Tue, 24 Sep 2024 18:28:22 +0100 Subject: [PATCH] Allow subclasses of py::args and py::kwargs (#5381) * Allow subclasses of py::args and py::kwargs The current implementation does not allow subclasses of args or kwargs. This change allows subclasses to be used. * Added test case * style: pre-commit fixes * Added missing semi-colons * style: pre-commit fixes * Added handle_type_name * Moved classes outside of function * Added namespaces * style: pre-commit fixes * Refactored tests Added more tests and moved tests to more appropriate locations. * style: pre-commit fixes --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- include/pybind11/cast.h | 4 ++-- tests/test_kwargs_and_defaults.cpp | 26 ++++++++++++++++++++++++++ tests/test_kwargs_and_defaults.py | 11 +++++++++++ 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/include/pybind11/cast.h b/include/pybind11/cast.h index 0c862e4bec..e3c8b9f659 100644 --- a/include/pybind11/cast.h +++ b/include/pybind11/cast.h @@ -1570,9 +1570,9 @@ class argument_loader { using indices = make_index_sequence; template - using argument_is_args = std::is_same, args>; + using argument_is_args = std::is_base_of>; template - using argument_is_kwargs = std::is_same, kwargs>; + using argument_is_kwargs = std::is_base_of>; // Get kwargs argument position, or -1 if not present: static constexpr auto kwargs_pos = constexpr_last(); diff --git a/tests/test_kwargs_and_defaults.cpp b/tests/test_kwargs_and_defaults.cpp index bc76ec7c2d..09036ccd51 100644 --- a/tests/test_kwargs_and_defaults.cpp +++ b/tests/test_kwargs_and_defaults.cpp @@ -14,6 +14,26 @@ #include +// Classes needed for subclass test. +class ArgsSubclass : public py::args { + using py::args::args; +}; +class KWArgsSubclass : public py::kwargs { + using py::kwargs::kwargs; +}; +namespace pybind11 { +namespace detail { +template <> +struct handle_type_name { + static constexpr auto name = const_name("*Args"); +}; +template <> +struct handle_type_name { + static constexpr auto name = const_name("**KWArgs"); +}; +} // namespace detail +} // namespace pybind11 + TEST_SUBMODULE(kwargs_and_defaults, m) { auto kw_func = [](int x, int y) { return "x=" + std::to_string(x) + ", y=" + std::to_string(y); }; @@ -322,4 +342,10 @@ TEST_SUBMODULE(kwargs_and_defaults, m) { py::pos_only{}, py::arg("i"), py::arg("j")); + + // Test support for args and kwargs subclasses + m.def("args_kwargs_subclass_function", + [](const ArgsSubclass &args, const KWArgsSubclass &kwargs) { + return py::make_tuple(args, kwargs); + }); } diff --git a/tests/test_kwargs_and_defaults.py b/tests/test_kwargs_and_defaults.py index b9b1a7ea87..e3f758165c 100644 --- a/tests/test_kwargs_and_defaults.py +++ b/tests/test_kwargs_and_defaults.py @@ -17,6 +17,10 @@ def test_function_signatures(doc): assert ( doc(m.args_kwargs_function) == "args_kwargs_function(*args, **kwargs) -> tuple" ) + assert ( + doc(m.args_kwargs_subclass_function) + == "args_kwargs_subclass_function(*Args, **KWArgs) -> tuple" + ) assert ( doc(m.KWClass.foo0) == "foo0(self: m.kwargs_and_defaults.KWClass, arg0: int, arg1: float) -> None" @@ -98,6 +102,7 @@ def test_arg_and_kwargs(): args = "a1", "a2" kwargs = {"arg3": "a3", "arg4": 4} assert m.args_kwargs_function(*args, **kwargs) == (args, kwargs) + assert m.args_kwargs_subclass_function(*args, **kwargs) == (args, kwargs) def test_mixed_args_and_kwargs(msg): @@ -413,6 +418,12 @@ def test_args_refcount(): ) assert refcount(myval) == expected + assert m.args_kwargs_subclass_function(7, 8, myval, a=1, b=myval) == ( + (7, 8, myval), + {"a": 1, "b": myval}, + ) + assert refcount(myval) == expected + exp3 = refcount(myval, myval, myval) assert m.args_refcount(myval, myval, myval) == (exp3, exp3, exp3) assert refcount(myval) == expected