Skip to content

Commit

Permalink
Added support for nested callables
Browse files Browse the repository at this point in the history
  • Loading branch information
timohl committed Jan 9, 2025
1 parent 8718c50 commit 89de7a0
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 15 deletions.
4 changes: 2 additions & 2 deletions include/pybind11/detail/descr.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,12 @@ constexpr descr<N + 2, Ts...> type_descr(const descr<N, Ts...> &descr) {

template <size_t N, typename... Ts>
constexpr descr<N + 4, Ts...> arg_descr(const descr<N, Ts...> &descr) {
return const_name("@^") + descr + const_name("@^");
return const_name("@^") + descr + const_name("@!");
}

template <size_t N, typename... Ts>
constexpr descr<N + 4, Ts...> return_descr(const descr<N, Ts...> &descr) {
return const_name("@$") + descr + const_name("@$");
return const_name("@$") + descr + const_name("@!");
}

PYBIND11_NAMESPACE_END(detail)
Expand Down
30 changes: 17 additions & 13 deletions include/pybind11/pybind11.h
Original file line number Diff line number Diff line change
Expand Up @@ -440,10 +440,10 @@ class cpp_function : public function {
std::string signature;
size_t type_index = 0, arg_index = 0;
bool is_starred = false;
// `is_return_value` is true if we are currently inside the return type of the signature.
// The same is true for `use_return_value`, except for forced usage of arg/return type
// using @^/@$.
bool is_return_value = false;
// `is_return_value.top()` is true if we are currently inside the return type of the
// signature. Using `@^`/`@$` we can force types to be arg/return types while `@!` pops
// back to the previous state.
std::stack<bool> is_return_value = {false};
bool use_return_value = false;
for (const auto *pc = text; *pc != '\0'; ++pc) {
const auto c = *pc;
Expand Down Expand Up @@ -499,21 +499,26 @@ class cpp_function : public function {
signature += detail::quote_cpp_type_name(detail::clean_type_id(t->name()));
}
} else if (c == '@') {
// `@^ ... @^` and `@$ ... @$` are used to force arg/return value type (see
// `@^ ... @!` and `@$ ... @!` are used to force arg/return value type (see
// typing::Callable/detail::arg_descr/detail::return_descr)
if ((*(pc + 1) == '^' && is_return_value)
|| (*(pc + 1) == '$' && !is_return_value)) {
use_return_value = !use_return_value;
}
if (*(pc + 1) == '^' || *(pc + 1) == '$') {
if (*(pc + 1) == '^') {
is_return_value.emplace(false);
++pc;
continue;
} else if (*(pc + 1) == '$') {
is_return_value.emplace(true);
++pc;
continue;
} else if (*(pc + 1) == '!') {
is_return_value.pop();
++pc;
continue;
}
// Handle types that differ depending on whether they appear
// in an argument or a return value position (see io_name<text1, text2>).
// For named arguments (py::arg()) with noconvert set, return value type is used.
++pc;
if (!use_return_value
if (!is_return_value.top()
&& !(arg_index < rec->args.size() && !rec->args[arg_index].convert)) {
while (*pc != '\0' && *pc != '@') {
signature += *pc++;
Expand All @@ -537,8 +542,7 @@ class cpp_function : public function {
}
} else {
if (c == '-' && *(pc + 1) == '>') {
is_return_value = true;
use_return_value = true;
is_return_value.emplace(true);
}
signature += c;
}
Expand Down
4 changes: 4 additions & 0 deletions tests/test_pytypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,10 @@ TEST_SUBMODULE(pytypes, m) {
// Callable<R(...)> identity
m.def("identity_callable_ellipsis",
[](const py::typing::Callable<RealNumber(py::ellipsis)> &x) { return x; });
// Nested Callable<R(A)> identity
m.def("identity_nested_callable",
[](const py::typing::Callable<py::typing::Callable<RealNumber(const RealNumber &)>(
py::typing::Callable<RealNumber(const RealNumber &)>)> &x) { return x; });
// Callable<R(A)>
m.def("apply_callable",
[](const RealNumber &x, const py::typing::Callable<RealNumber(const RealNumber &)> &f) {
Expand Down
5 changes: 5 additions & 0 deletions tests/test_pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,11 @@ def test_arg_return_type_hints(doc):
doc(m.identity_callable_ellipsis)
== "identity_callable_ellipsis(arg0: Callable[..., float]) -> Callable[..., float]"
)
# Nested Callable<R(A)> identity
assert (
doc(m.identity_nested_callable)
== "identity_nested_callable(arg0: Callable[[Callable[[Union[float, int]], float]], Callable[[Union[float, int]], float]]) -> Callable[[Callable[[Union[float, int]], float]], Callable[[Union[float, int]], float]]"
)
# Callable<R(A)>
assert (
doc(m.apply_callable)
Expand Down

0 comments on commit 89de7a0

Please sign in to comment.