diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e5b8c8f31..3955caad9e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -180,7 +180,8 @@ set(PYBIND11_HEADERS include/pybind11/stl_bind.h include/pybind11/stl/filesystem.h include/pybind11/type_caster_pyobject_ptr.h - include/pybind11/typing.h) + include/pybind11/typing.h + include/pybind11/warnings.h) # Compare with grep and warn if mismatched if(PYBIND11_MASTER_PROJECT AND NOT CMAKE_VERSION VERSION_LESS 3.12) diff --git a/include/pybind11/warnings.h b/include/pybind11/warnings.h new file mode 100644 index 0000000000..81e1ce7bbe --- /dev/null +++ b/include/pybind11/warnings.h @@ -0,0 +1,70 @@ +/* + pybind11/warnings.h: Python warnings wrappers. + + Copyright (c) 2024 Jan Iwaszkiewicz + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "pybind11.h" +#include "detail/common.h" + +PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) + +PYBIND11_NAMESPACE_BEGIN(detail) + +inline bool PyWarning_Check(PyObject *obj) { + int result = PyObject_IsSubclass(obj, PyExc_Warning); + if (result == 1) { + return true; + } + if (result == -1) { + raise_from(PyExc_SystemError, + "PyWarning_Check(): internal error of Python C API while " + "checking a subclass of the object!"); + throw error_already_set(); + } + return false; +} + +PYBIND11_NAMESPACE_END(detail) + +PYBIND11_NAMESPACE_BEGIN(warnings) + +inline object +new_warning_type(handle scope, const char *name, handle base = PyExc_RuntimeWarning) { + if (!detail::PyWarning_Check(base.ptr())) { + pybind11_fail("warning(): cannot create custom warning, base must be a subclass of " + "PyExc_Warning!"); + } + if (hasattr(scope, "__dict__") && scope.attr("__dict__").contains(name)) { + pybind11_fail("Error during initialization: multiple incompatible " + "definitions with name \"" + + std::string(name) + "\""); + } + std::string full_name = scope.attr("__name__").cast() + std::string(".") + name; + handle h(PyErr_NewException(const_cast(full_name.c_str()), base.ptr(), nullptr)); + object obj = reinterpret_steal(h); + scope.attr(name) = obj; + return obj; +} + +// Similar to Python `warnings.warn()` +inline void +warn(const char *message, handle category = PyExc_RuntimeWarning, int stack_level = 2) { + if (!detail::PyWarning_Check(category.ptr())) { + pybind11_fail("raise_warning(): cannot raise warning, category must be a subclass of " + "PyExc_Warning!"); + } + + if (PyErr_WarnEx(category.ptr(), message, stack_level) == -1) { + throw error_already_set(); + } +} + +PYBIND11_NAMESPACE_END(warnings) + +PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index aae9be720b..3bcb278e5e 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -162,7 +162,8 @@ set(PYBIND11_TEST_FILES test_unnamed_namespace_a test_unnamed_namespace_b test_vector_unique_ptr_member - test_virtual_functions) + test_virtual_functions + test_warnings) # Invoking cmake with something like: # cmake -DPYBIND11_TEST_OVERRIDE="test_callbacks.cpp;test_pickling.cpp" .. diff --git a/tests/test_warnings.cpp b/tests/test_warnings.cpp new file mode 100644 index 0000000000..383e88835b --- /dev/null +++ b/tests/test_warnings.cpp @@ -0,0 +1,76 @@ +/* + tests/test_warnings.cpp -- usage of warnings::warn() and warnings categories. + + Copyright (c) 2024 Jan Iwaszkiewicz + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#include + +#include "pybind11_tests.h" + +#include + +namespace warning_helpers { +void warn_function(py::module &m, const char *name, py::handle category, const char *message) { + m.def(name, [category, message]() { py::warnings::warn(message, category); }); +} +} // namespace warning_helpers + +class CustomWarning {}; + +TEST_SUBMODULE(warnings_, m) { + + // Test warning mechanism base + m.def("raise_and_return", []() { + std::string message = "Warning was raised!"; + py::warnings::warn(message.c_str(), PyExc_Warning); + return 21; + }); + + m.def("raise_default", []() { py::warnings::warn("RuntimeWarning is raised!"); }); + + m.def("raise_from_cpython", + []() { py::warnings::warn("UnicodeWarning is raised!", PyExc_UnicodeWarning); }); + + m.def("raise_and_fail", + []() { py::warnings::warn("RuntimeError should be raised!", PyExc_Exception); }); + + // Test custom warnings + PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store ex_storage; + ex_storage.call_once_and_store_result([&]() { + return py::warnings::new_warning_type(m, "CustomWarning", PyExc_DeprecationWarning); + }); + + m.def("raise_custom", []() { + py::warnings::warn("CustomWarning was raised!", ex_storage.get_stored()); + return 37; + }); + + // Bind warning categories + warning_helpers::warn_function(m, "raise_base_warning", PyExc_Warning, "This is Warning!"); + warning_helpers::warn_function( + m, "raise_bytes_warning", PyExc_BytesWarning, "This is BytesWarning!"); + warning_helpers::warn_function( + m, "raise_deprecation_warning", PyExc_DeprecationWarning, "This is DeprecationWarning!"); + warning_helpers::warn_function( + m, "raise_future_warning", PyExc_FutureWarning, "This is FutureWarning!"); + warning_helpers::warn_function( + m, "raise_import_warning", PyExc_ImportWarning, "This is ImportWarning!"); + warning_helpers::warn_function(m, + "raise_pending_deprecation_warning", + PyExc_PendingDeprecationWarning, + "This is PendingDeprecationWarning!"); + warning_helpers::warn_function( + m, "raise_resource_warning", PyExc_ResourceWarning, "This is ResourceWarning!"); + warning_helpers::warn_function( + m, "raise_runtime_warning", PyExc_RuntimeWarning, "This is RuntimeWarning!"); + warning_helpers::warn_function( + m, "raise_syntax_warning", PyExc_SyntaxWarning, "This is SyntaxWarning!"); + warning_helpers::warn_function( + m, "raise_unicode_warning", PyExc_UnicodeWarning, "This is UnicodeWarning!"); + warning_helpers::warn_function( + m, "raise_user_warning", PyExc_UserWarning, "This is UserWarning!"); +} diff --git a/tests/test_warnings.py b/tests/test_warnings.py new file mode 100644 index 0000000000..ac3ae399b9 --- /dev/null +++ b/tests/test_warnings.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import warnings + +import pytest + +import pybind11_tests # noqa: F401 +from pybind11_tests import warnings_ as m + + +@pytest.mark.parametrize( + ("expected_category", "expected_message", "expected_value", "module_function"), + [ + (Warning, "Warning was raised!", 21, m.raise_and_return), + (RuntimeWarning, "RuntimeWarning is raised!", None, m.raise_default), + (UnicodeWarning, "UnicodeWarning is raised!", None, m.raise_from_cpython), + ], +) +def test_warning_simple( + expected_category, expected_message, expected_value, module_function +): + with pytest.warns(Warning) as excinfo: + value = module_function() + + assert issubclass(excinfo[0].category, expected_category) + assert str(excinfo[0].message) == expected_message + assert value == expected_value + + +def test_warning_fail(): + with pytest.raises(Exception) as excinfo: + m.raise_and_fail() + + assert issubclass(excinfo.type, RuntimeError) + assert ( + str(excinfo.value) + == "raise_warning(): cannot raise warning, category must be a subclass of PyExc_Warning!" + ) + + +def test_warning_register(): + assert m.CustomWarning is not None + assert issubclass(m.CustomWarning, DeprecationWarning) + + with pytest.warns(m.CustomWarning) as excinfo: + warnings.warn("This is warning from Python!", m.CustomWarning, stacklevel=1) + + assert issubclass(excinfo[0].category, DeprecationWarning) + assert issubclass(excinfo[0].category, m.CustomWarning) + assert str(excinfo[0].message) == "This is warning from Python!" + + +@pytest.mark.parametrize( + ( + "expected_category", + "expected_base", + "expected_message", + "expected_value", + "module_function", + ), + [ + ( + m.CustomWarning, + DeprecationWarning, + "CustomWarning was raised!", + 37, + m.raise_custom, + ), + ], +) +def test_warning_custom( + expected_category, expected_base, expected_message, expected_value, module_function +): + with pytest.warns(expected_category) as excinfo: + value = module_function() + + assert issubclass(excinfo[0].category, expected_base) + assert issubclass(excinfo[0].category, expected_category) + assert str(excinfo[0].message) == expected_message + assert value == expected_value + + +@pytest.mark.parametrize( + ("expected_category", "module_function"), + [ + (Warning, m.raise_base_warning), + (BytesWarning, m.raise_bytes_warning), + (DeprecationWarning, m.raise_deprecation_warning), + (FutureWarning, m.raise_future_warning), + (ImportWarning, m.raise_import_warning), + (PendingDeprecationWarning, m.raise_pending_deprecation_warning), + (ResourceWarning, m.raise_resource_warning), + (RuntimeWarning, m.raise_runtime_warning), + (SyntaxWarning, m.raise_syntax_warning), + (UnicodeWarning, m.raise_unicode_warning), + (UserWarning, m.raise_user_warning), + ], +) +def test_warning_categories(expected_category, module_function): + with pytest.warns(Warning) as excinfo: + module_function() + + assert issubclass(excinfo[0].category, expected_category) + assert str(excinfo[0].message) == f"This is {expected_category.__name__}!"