Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make AssistantAgent and Handoff use BaseTool #5193

Merged
merged 5 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
SystemMessage,
UserMessage,
)
from autogen_core.tools import FunctionTool, Tool
from autogen_core.tools import FunctionTool, BaseTool
from pydantic import BaseModel
from typing_extensions import Self

Expand Down Expand Up @@ -57,7 +57,7 @@ class AssistantAgentConfig(BaseModel):

name: str
model_client: ComponentModel
# tools: List[Any] | None = None # TBD
tools: List[ComponentModel] | None
handoffs: List[HandoffBase | str] | None = None
model_context: ComponentModel | None = None
description: str
Expand Down Expand Up @@ -130,7 +130,7 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
Args:
name (str): The name of the agent.
model_client (ChatCompletionClient): The model client to use for inference.
tools (List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional): The tools to register with the agent.
tools (List[BaseTool[Any, Any] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional): The tools to register with the agent.
handoffs (List[HandoffBase | str] | None, optional): The handoff configurations for the agent,
allowing it to transfer to other agents by responding with a :class:`HandoffMessage`.
The transfer is only executed when the team is in :class:`~autogen_agentchat.teams.Swarm`.
Expand Down Expand Up @@ -261,7 +261,7 @@ def __init__(
name: str,
model_client: ChatCompletionClient,
*,
tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
tools: List[BaseTool[Any, Any] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
handoffs: List[HandoffBase | str] | None = None,
model_context: ChatCompletionContext | None = None,
description: str = "An agent that provides assistance with ability to use tools.",
Expand All @@ -288,12 +288,12 @@ def __init__(
self._system_messages = []
else:
self._system_messages = [SystemMessage(content=system_message)]
self._tools: List[Tool] = []
self._tools: List[BaseTool[Any, Any]] = []
if tools is not None:
if model_client.model_info["function_calling"] is False:
raise ValueError("The model does not support function calling.")
for tool in tools:
if isinstance(tool, Tool):
if isinstance(tool, BaseTool):
self._tools.append(tool)
elif callable(tool):
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
Expand All @@ -308,7 +308,7 @@ def __init__(
if len(tool_names) != len(set(tool_names)):
raise ValueError(f"Tool names must be unique: {tool_names}")
# Handoff tools.
self._handoff_tools: List[Tool] = []
self._handoff_tools: List[BaseTool[Any, Any]] = []
self._handoffs: Dict[str, HandoffBase] = {}
if handoffs is not None:
if model_client.model_info["function_calling"] is False:
Expand Down Expand Up @@ -528,15 +528,10 @@ async def load_state(self, state: Mapping[str, Any]) -> None:
def _to_config(self) -> AssistantAgentConfig:
"""Convert the assistant agent to a declarative config."""

# raise an error if tools is not empty until it is implemented
# TBD : Implement serializing tools and remove this check.
if self._tools and len(self._tools) > 0:
raise NotImplementedError("Serializing tools is not implemented yet.")

return AssistantAgentConfig(
name=self.name,
model_client=self._model_client.dump_component(),
# tools=[], # TBD
tools=[tool.dump_component() for tool in self._tools],
handoffs=list(self._handoffs.values()),
model_context=self._model_context.dump_component(),
description=self.description,
Expand All @@ -553,7 +548,7 @@ def _from_config(cls, config: AssistantAgentConfig) -> Self:
return cls(
name=config.name,
model_client=ChatCompletionClient.load_component(config.model_client),
# tools=[], # TBD
tools=[BaseTool.load_component(tool) for tool in config.tools] if config.tools else None,
handoffs=config.handoffs,
model_context=None,
description=config.description,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from typing import Any, Dict

from autogen_core.tools import FunctionTool, Tool
from autogen_core.tools import FunctionTool, BaseTool
from pydantic import BaseModel, Field, model_validator

from .. import EVENT_LOGGER_NAME
Expand Down Expand Up @@ -47,7 +47,7 @@ def set_defaults(cls, values: Dict[str, Any]) -> Dict[str, Any]:
return values

@property
def handoff_tool(self) -> Tool:
def handoff_tool(self) -> BaseTool[BaseModel, BaseModel]:
"""Create a handoff tool from this handoff configuration."""

def _handoff_tool() -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -774,5 +774,5 @@ async def test_assistant_agent_declarative(monkeypatch: pytest.MonkeyPatch) -> N
FunctionTool(_echo_function, description="Echo"),
],
)
with pytest.raises(NotImplementedError):
agent3.dump_component()
agent3_config = agent3.dump_component()
assert agent3_config.provider == "autogen_agentchat.agents.AssistantAgent"
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from autogen_agentchat.agents import AssistantAgent
from autogen_core.models import ChatCompletionClient
from autogen_core.tools import Tool
from autogen_core.tools import BaseTool
from pydantic import BaseModel

from .tools import (
extract_audio,
Expand Down Expand Up @@ -38,7 +39,7 @@ class VideoSurfer(AssistantAgent):
Args:
name (str): The name of the agent.
model_client (ChatCompletionClient): The model client used for generating responses.
tools (List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional):
tools (List[BaseTool[BaseModel, BaseModel] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None, optional):
A list of tools or functions the agent can use. If not provided, defaults to all video tools from the action space.
description (str, optional): A brief description of the agent. Defaults to "An agent that can answer questions about a local video.".
system_message (str | None, optional): The system message guiding the agent's behavior. Defaults to a predefined message.
Expand Down Expand Up @@ -137,7 +138,7 @@ def __init__(
name: str,
model_client: ChatCompletionClient,
*,
tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
tools: List[BaseTool[BaseModel, BaseModel] | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
description: Optional[str] = None,
system_message: Optional[str] = None,
):
Expand Down
Loading