Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG FIX Tool Argument Parsing Error #163

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 96 additions & 8 deletions src/smolagents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import json
from pathlib import Path
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

Expand All @@ -34,6 +36,8 @@
LocalPythonInterpreter,
fix_final_answer_code,
)


from .models import MessageRole
from .monitoring import Monitor
from .prompts import (
Expand Down Expand Up @@ -64,7 +68,10 @@
console,
parse_code_blobs,
parse_json_tool_call,
make_json_serializable,
parse_tool_call_arguments,
truncate_content,
parse_tool_call_arguments,
)


Expand Down Expand Up @@ -197,6 +204,9 @@ def __init__(
max_steps: int = 6,
tool_parser: Optional[Callable] = None,
add_base_tools: bool = False,
remove_final_answer_tool: bool = False,
stream_json_logs: bool = False,
json_logs_path: str = "logs.json",
verbosity_level: int = 1,
grammar: Optional[Dict[str, str]] = None,
managed_agents: Optional[List] = None,
Expand Down Expand Up @@ -233,8 +243,16 @@ def __init__(
or self.__class__.__name__ == "ToolCallingAgent"
):
self.tools[tool_name] = tool_class()
self.tools["final_answer"] = FinalAnswerTool()

if not remove_final_answer_tool:
self.tools["final_answer"] = FinalAnswerTool()

self.stream_json_logs = stream_json_logs
self.json_logs_path = json_logs_path
# Create path if it doesn't exist
if self.stream_json_logs:
Path(self.json_logs_path).parent.mkdir(parents=True, exist_ok=True)


self.system_prompt = self.initialize_system_prompt()
self.input_messages = None
self.logs = []
Expand Down Expand Up @@ -394,7 +412,7 @@ def provide_final_answer(self, task) -> str:
}
]
try:
return self.model(self.input_messages)
return {"answer": self.model(self.input_messages).content}
except Exception as e:
return f"Error in generating final LLM output:\n{e}"

Expand Down Expand Up @@ -560,6 +578,7 @@ def stream_run(self, task: str):

# Run one step!
final_answer = self.step(step_log)

except AgentError as e:
step_log.error = e
finally:
Expand All @@ -569,6 +588,8 @@ def stream_run(self, task: str):
for callback in self.step_callbacks:
callback(step_log)
self.step_number += 1
if self.stream_json_logs:
self.to_json_stream()
yield step_log

if final_answer is None and self.step_number == self.max_steps:
Expand All @@ -582,6 +603,8 @@ def stream_run(self, task: str):
final_step_log.duration = step_log.end_time - step_start_time
for callback in self.step_callbacks:
callback(final_step_log)
if self.stream_json_logs:
self.to_json_stream()
yield final_step_log

yield handle_agent_output_types(final_answer)
Expand Down Expand Up @@ -616,6 +639,9 @@ def direct_run(self, task: str):

# Run one step!
final_answer = self.step(step_log)

if self.stream_json_logs:
self.to_json_stream()

except AgentError as e:
step_log.error = e
Expand All @@ -627,6 +653,8 @@ def direct_run(self, task: str):
for callback in self.step_callbacks:
callback(step_log)
self.step_number += 1
if self.stream_json_logs:
self.to_json_stream()

if final_answer is None and self.step_number == self.max_steps:
error_message = "Reached max steps."
Expand All @@ -638,7 +666,9 @@ def direct_run(self, task: str):
final_step_log.duration = 0
for callback in self.step_callbacks:
callback(final_step_log)

if self.stream_json_logs:
self.to_json_stream()

return handle_agent_output_types(final_answer)

def planning_step(self, task, is_first_step: bool, step: int):
Expand Down Expand Up @@ -686,7 +716,7 @@ def planning_step(self, task, is_first_step: bool, step: int):
answer_plan = self.model(
[message_system_prompt_plan, message_user_prompt_plan],
stop_sequences=["<end_plan>"],
)
).content

final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task:
```
Expand Down Expand Up @@ -720,7 +750,7 @@ def planning_step(self, task, is_first_step: bool, step: int):
}
facts_update = self.model(
[facts_update_system_prompt] + agent_memory + [facts_update_message]
)
).content

# Redact updated plan
plan_update_message = {
Expand All @@ -744,7 +774,7 @@ def planning_step(self, task, is_first_step: bool, step: int):
plan_update = self.model(
[plan_update_message] + agent_memory + [plan_update_message_user],
stop_sequences=["<end_plan>"],
)
).content

# Log final facts and plan
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(
Expand All @@ -763,6 +793,63 @@ def planning_step(self, task, is_first_step: bool, step: int):
level=LogLevel.INFO,
)

def to_json(self) -> List[Dict]:
"""
Convert the agent's logs into a JSON-serializable format.
Returns a list of dictionaries containing the step information.
"""
json_logs = []
for step_log in self.logs:
if isinstance(step_log, SystemPromptStep):
json_log = {
"type": "system_prompt",
"system_prompt": step_log.system_prompt
}
elif isinstance(step_log, PlanningStep):
json_log = {
"type": "planning",
"plan": step_log.plan,
"facts": step_log.facts
}
elif isinstance(step_log, TaskStep):
json_log = {
"type": "task",
"task": step_log.task
}
elif isinstance(step_log, ActionStep):
json_log = {
"type": "action",
"start_time": step_log.start_time,
"end_time": step_log.end_time,
"step": step_log.step,
"duration": step_log.duration,
"llm_output": step_log.llm_output,
"observations": step_log.observations,
"action_output": make_json_serializable(step_log.action_output),
}

if step_log.tool_call:
json_log["tool_call"] = {
"name": step_log.tool_call.name,
"arguments": make_json_serializable(step_log.tool_call.arguments),
"id": step_log.tool_call.id
}

if step_log.error:
json_log["error"] = {
"type": step_log.error.__class__.__name__,
"message": str(step_log.error)
}

json_logs.append(json_log)

return json_logs

def to_json_stream(self):
# rewrite json logs to a file
with open(self.json_logs_path, "w") as f:
json.dump(self.to_json(), f)


class ToolCallingAgent(MultiStepAgent):
"""
Expand Down Expand Up @@ -805,8 +892,9 @@ def step(self, log_entry: ActionStep) -> Union[None, Any]:
tools_to_call_from=list(self.tools.values()),
stop_sequences=["Observation:"],
)

tool_calls = model_message.tool_calls[0]
tool_arguments = tool_calls.function.arguments
tool_arguments = parse_tool_call_arguments(tool_calls.function.arguments)
tool_name, tool_call_id = tool_calls.function.name, tool_calls.id

except Exception as e:
Expand Down
50 changes: 49 additions & 1 deletion src/smolagents/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import json
import re
import types
from typing import Dict, Tuple, Union
from typing import Dict, Tuple, Union, Any

from rich.console import Console
from transformers.utils.import_utils import _is_package_available
Expand Down Expand Up @@ -160,6 +160,54 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]:
error_msg = "No tool name key found in tool call!" + f" Tool call: {json_blob}"
raise AgentParsingError(error_msg)

def parse_tool_call_arguments(arguments: Union[str, Dict[str, str]]) -> Dict[str, str]:
if isinstance(arguments, str):
try:
parsed = json.loads(arguments)
except json.JSONDecodeError:
parsed = arguments
else:
parsed = arguments
return parsed

def make_json_serializable(obj: Any) -> Any:
if obj is None:
return None
elif isinstance(obj, (str, int, float, bool)):
# Try to parse string as JSON if it looks like a JSON object/array
if isinstance(obj, str):
try:
if (obj.startswith('{') and obj.endswith('}')) or (obj.startswith('[') and obj.endswith(']')):
parsed = json.loads(obj)
return make_json_serializable(parsed)
except json.JSONDecodeError:
pass
return obj
elif isinstance(obj, (list, tuple)):
return [make_json_serializable(item) for item in obj]
elif isinstance(obj, dict):
return {str(k): make_json_serializable(v) for k, v in obj.items()}
elif hasattr(obj, '__dict__'):
# For custom objects, convert their __dict__ to a serializable format
return {
'_type': obj.__class__.__name__,
**{k: make_json_serializable(v) for k, v in obj.__dict__.items()}
}
else:
# For any other type, convert to string
return str(obj)


def parse_tool_call_arguments(arguments: Union[str, Dict[str, str]]) -> Dict[str, str]:
if isinstance(arguments, str):
try:
parsed = json.loads(arguments)
except json.JSONDecodeError:
parsed = arguments
else:
parsed = arguments
return parsed


MAX_LENGTH_TRUNCATE_CONTENT = 20000

Expand Down