From ead9cd1cfd0f1e51fd060b814a509507323ef8c2 Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Fri, 13 Sep 2024 10:18:21 +0200 Subject: [PATCH] make typing behavior of AgentSet.get explicit (#2293) * make typing behavior of AgentSet.get explicit * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- mesa/agent.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/mesa/agent.py b/mesa/agent.py index 5816ddf060a..75d79f90563 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -20,7 +20,7 @@ from random import Random # mypy -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, overload if TYPE_CHECKING: # We ensure that these are not imported during runtime to prevent cyclic @@ -348,12 +348,28 @@ def agg(self, attribute: str, func: Callable) -> Any: values = self.get(attribute) return func(values) + @overload def get( self, - attr_names: str | list[str], + attr_names: str, handle_missing: Literal["error", "default"] = "error", default_value: Any = None, - ) -> list[Any] | list[list[Any]]: + ) -> list[Any]: ... + + @overload + def get( + self, + attr_names: list[str], + handle_missing: Literal["error", "default"] = "error", + default_value: Any = None, + ) -> list[list[Any]]: ... + + def get( + self, + attr_names, + handle_missing="error", + default_value=None, + ): """ Retrieve the specified attribute(s) from each agent in the AgentSet.