diff --git a/mesa/agent.py b/mesa/agent.py index 5816ddf060a..75d79f90563 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -20,7 +20,7 @@ from random import Random # mypy -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, overload if TYPE_CHECKING: # We ensure that these are not imported during runtime to prevent cyclic @@ -348,12 +348,28 @@ def agg(self, attribute: str, func: Callable) -> Any: values = self.get(attribute) return func(values) + @overload def get( self, - attr_names: str | list[str], + attr_names: str, handle_missing: Literal["error", "default"] = "error", default_value: Any = None, - ) -> list[Any] | list[list[Any]]: + ) -> list[Any]: ... + + @overload + def get( + self, + attr_names: list[str], + handle_missing: Literal["error", "default"] = "error", + default_value: Any = None, + ) -> list[list[Any]]: ... + + def get( + self, + attr_names, + handle_missing="error", + default_value=None, + ): """ Retrieve the specified attribute(s) from each agent in the AgentSet.