From 5f66fd16f348a348e06bb83ca0ad93cdb25be7ee Mon Sep 17 00:00:00 2001 From: AlexanderMueller Date: Tue, 20 Aug 2024 11:41:53 +0200 Subject: [PATCH] changes from review --- include/pybind11/functional.h | 33 ++++++++++++++++++--------------- tests/test_callbacks.py | 11 +++++++++++ 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/include/pybind11/functional.h b/include/pybind11/functional.h index 0f87e29477..3484aab8d7 100644 --- a/include/pybind11/functional.h +++ b/include/pybind11/functional.h @@ -59,8 +59,8 @@ struct type_caster> { rec = c.get_pointer(); } while (rec != nullptr) { - const int correctingSelfArgument = rec->is_method ? 1 : 0; - if (rec->nargs - correctingSelfArgument != sizeof...(Args)) { + const int self_offset = rec->is_method ? 1 : 0; + if (rec->nargs != sizeof...(Args) + self_offset) { rec = rec->next; // if the overload is not feasible in terms of number of arguments, we // continue to the next one. If there is no next one, we return false. @@ -86,20 +86,24 @@ struct type_caster> { // See PR #1413 for full details } else { // Check number of arguments of Python function - auto getArgCount = [&](PyObject *obj) { - // This is faster then doing import inspect and inspect.signature(obj).parameters - auto *t = PyObject_GetAttrString(obj, "__code__"); - auto *argCount = PyObject_GetAttrString(t, "co_argcount"); - return PyLong_AsLong(argCount); + auto argCountFromFuncCode = [&](handle &obj) { + // This is faster then doing import inspect and + // inspect.signature(obj).parameters + + object argCount = obj.attr("co_argcount"); + return argCount.template cast(); }; long argCount = -1; - if (static_cast(PyObject_HasAttrString(src.ptr(), "__code__"))) { - argCount = getArgCount(src.ptr()); + handle codeAttr = PyObject_GetAttrString(src.ptr(), "__code__"); + if (codeAttr) { + argCount = argCountFromFuncCode(codeAttr); } else { - if (static_cast(PyObject_HasAttrString(src.ptr(), "__call__"))) { - auto *t2 = PyObject_GetAttrString(src.ptr(), "__call__"); - argCount = getArgCount(t2) - 1; // we have to remove the self argument + handle callAttr = PyObject_GetAttrString(src.ptr(), "__call__"); + if (callAttr) { + handle codeAttr2 = callAttr.attr("__code__"); + argCount = argCountFromFuncCode(codeAttr2) + - 1; // we have to remove the self argument } else { // No __code__ or __call__ attribute, this is not a proper Python function return false; @@ -107,10 +111,9 @@ struct type_caster> { } // if we are a method, we have to correct the argument count since we are not counting // the self argument - const int correctingSelfArgument - = static_cast(PyMethod_Check(src.ptr())) ? 1 : 0; + const int self_offset = static_cast(PyMethod_Check(src.ptr())) ? 1 : 0; - argCount -= correctingSelfArgument; + argCount -= self_offset; if (argCount != sizeof...(Args)) { return false; } diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 82b03fac1f..bb818c52d7 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -107,13 +107,24 @@ def test_cpp_correct_overload_resolution(): def f(a): return a + class A: + def __call__(self, a): + return a + assert m.dummy_function_overloaded_std_func_arg(f) == 9 + assert m.dummy_function_overloaded_std_func_arg(A()) == 9 assert m.dummy_function_overloaded_std_func_arg(lambda i: i) == 9 + def f2(a, b): return a + b + class B: + def __call__(self, a, b): + return a + b + assert m.dummy_function_overloaded_std_func_arg(f2) == 14 + assert m.dummy_function_overloaded_std_func_arg(B()) == 14 assert m.dummy_function_overloaded_std_func_arg(lambda i, j: i + j) == 14