Skip to content

Commit

Permalink
feat: Support Native Tools in Claude V3 Model (#116)
Browse files Browse the repository at this point in the history
* Run Claude V3 in native mode

* dumb commit to restart pipeline

* rollback poetry.lock

* rollback poetry one more time

* Change version of anthropic without updates

* add support for function emulation in Claude 3 [send part]

* add support for function emulation in Claude 3 [receive part]

* minor adjust of error message

* Stylistic refactors

* Refactor due to PR comments

* PR comments fix

* typo fix

* refactor due to pr comments

* add explicit match/case for all types of events

* add text content to tool and function calls

* make linter happy

* make linter happy one more time

* Consistent processing of tools messages, more validation

* remove dead code

* make linter happy
  • Loading branch information
roman-romanov-o authored Jun 21, 2024
1 parent 9e4b4a6 commit 5eb782a
Show file tree
Hide file tree
Showing 12 changed files with 450 additions and 93 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,4 @@ To remove the virtual environment and build artifacts:

```sh
make clean
```
```
12 changes: 9 additions & 3 deletions aidial_adapter_bedrock/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
from aidial_adapter_bedrock.deployments import ChatCompletionDeployment
from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.llm.chat_model import ChatCompletionAdapter
from aidial_adapter_bedrock.llm.chat_model import (
ChatCompletionAdapter,
TextCompletionAdapter,
)
from aidial_adapter_bedrock.llm.consumer import ChoiceConsumer
from aidial_adapter_bedrock.llm.errors import UserError, ValidationError
from aidial_adapter_bedrock.llm.model.adapter import get_bedrock_adapter
Expand Down Expand Up @@ -64,8 +67,11 @@ async def generate_response(usage: TokenUsage) -> None:
nonlocal discarded_messages

with response.create_choice() as choice:
tools_emulator = model.tools_emulator(params.tool_config)
consumer = ChoiceConsumer(tools_emulator, choice)
consumer = ChoiceConsumer(choice=choice)
if isinstance(model, TextCompletionAdapter):
consumer.set_tools_emulator(
model.tools_emulator(params.tool_config)
)

try:
await model.chat(consumer, params, request.messages)
Expand Down
3 changes: 1 addition & 2 deletions aidial_adapter_bedrock/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ def _is_empty_system_message(msg: Message) -> bool:


class ChatCompletionAdapter(ABC, BaseModel):
tools_emulator: Callable[[Optional[ToolsConfig]], ToolsEmulator]

class Config:
arbitrary_types_allowed = True

Expand Down Expand Up @@ -65,6 +63,7 @@ class TextCompletionPrompt(BaseModel):


class TextCompletionAdapter(ChatCompletionAdapter):
tools_emulator: Callable[[Optional[ToolsConfig]], ToolsEmulator]

@abstractmethod
async def predict(
Expand Down
45 changes: 36 additions & 9 deletions aidial_adapter_bedrock/llm/consumer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from abc import ABC, abstractmethod
from typing import List, Optional, assert_never

from aidial_sdk.chat_completion import Choice, FinishReason
from aidial_sdk.chat_completion import (
Choice,
FinishReason,
FunctionCall,
ToolCall,
)
from pydantic import BaseModel

from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
Expand Down Expand Up @@ -42,23 +47,37 @@ def add_usage(self, usage: TokenUsage):
def set_discarded_messages(self, discarded_messages: List[int]):
pass

@abstractmethod
def create_function_tool_call(self, tool_call: ToolCall):
pass

@abstractmethod
def create_function_call(self, function_call: FunctionCall):
pass


class ChoiceConsumer(Consumer):
usage: TokenUsage
choice: Choice
discarded_messages: Optional[List[int]]
tools_emulator: ToolsEmulator
tools_emulator: Optional[ToolsEmulator]

def __init__(self, tools_emulator: ToolsEmulator, choice: Choice):
def __init__(self, choice: Choice):
self.choice = choice
self.usage = TokenUsage()
self.discarded_messages = None
self.tools_emulator = None

def set_tools_emulator(self, tools_emulator: ToolsEmulator):
self.tools_emulator = tools_emulator

def _process_content(
self, content: str | None, finish_reason: FinishReason | None = None
):
res = self.tools_emulator.recognize_call(content)
if self.tools_emulator is not None:
res = self.tools_emulator.recognize_call(content)
else:
res = content

if res is None:
# Choice.close(finish_reason: Optional[FinishReason]) can be called only once
Expand All @@ -72,11 +91,7 @@ def _process_content(

if isinstance(res, AIToolCallMessage):
for call in res.calls:
self.choice.create_function_tool_call(
id=call.id,
name=call.function.name,
arguments=call.function.arguments,
)
self.create_function_tool_call(call)
return

if isinstance(res, AIFunctionCallMessage):
Expand All @@ -102,3 +117,15 @@ def add_usage(self, usage: TokenUsage):

def set_discarded_messages(self, discarded_messages: List[int]):
self.discarded_messages = discarded_messages

def create_function_tool_call(self, tool_call: ToolCall):
self.choice.create_function_tool_call(
id=tool_call.id,
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)

def create_function_call(self, function_call: FunctionCall):
self.choice.create_function_call(
name=function_call.name, arguments=function_call.arguments
)
17 changes: 7 additions & 10 deletions aidial_adapter_bedrock/llm/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ class AIRegularMessage(BaseModel):

class AIToolCallMessage(BaseModel):
calls: List[ToolCall]
content: Optional[str] = None


class AIFunctionCallMessage(BaseModel):
call: FunctionCall
content: Optional[str] = None


BaseMessage = Union[SystemMessage, HumanRegularMessage, AIRegularMessage]
Expand All @@ -63,18 +65,13 @@ def _parse_assistant_message(
if content is not None and function_call is None and tool_calls is None:
return AIRegularMessage(content=content, custom_content=custom_content)

if content is None and function_call is not None and tool_calls is None:
return AIFunctionCallMessage(call=function_call)
if function_call is not None and tool_calls is None:
return AIFunctionCallMessage(call=function_call, content=content)

if content is None and function_call is None and tool_calls is not None:
return AIToolCallMessage(calls=tool_calls)
if function_call is None and tool_calls is not None:
return AIToolCallMessage(calls=tool_calls, content=content)

raise ValidationError(
"Assistant message must have one and only one of the following fields not-none: "
f"content (is none: {content is None}), "
f"function_call (is none: {function_call is None}), "
f"tool_calls (is none: {tool_calls is None})"
)
raise ValidationError("Unknown type of assistant message")


def parse_dial_message(msg: Message) -> BaseMessage | ToolMessage:
Expand Down
116 changes: 87 additions & 29 deletions aidial_adapter_bedrock/llm/model/claude/v3/adapter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
from typing import List, Mapping, Optional, TypedDict, Union
from typing import List, Mapping, Optional, TypedDict, Union, assert_never

from aidial_sdk.chat_completion import Message
from anthropic import NOT_GIVEN, NotGiven
from anthropic import NOT_GIVEN, MessageStopEvent, NotGiven
from anthropic.lib.bedrock import AsyncAnthropicBedrock
from anthropic.lib.streaming import AsyncMessageStream
from anthropic.lib.streaming import (
AsyncMessageStream,
InputJsonEvent,
TextEvent,
)
from anthropic.types import (
ContentBlockDeltaEvent,
ContentBlockStartEvent,
ContentBlockStopEvent,
MessageDeltaEvent,
MessageParam,
MessageStartEvent,
MessageStreamEvent,
TextBlock,
ToolParam,
ToolUseBlock,
)

from aidial_adapter_bedrock.dial_api.request import ModelParameters
Expand All @@ -20,15 +30,19 @@
from aidial_adapter_bedrock.llm.chat_model import ChatCompletionAdapter
from aidial_adapter_bedrock.llm.consumer import Consumer
from aidial_adapter_bedrock.llm.errors import ValidationError
from aidial_adapter_bedrock.llm.message import parse_dial_message
from aidial_adapter_bedrock.llm.model.claude.v3.converters import (
ClaudeFinishReason,
to_claude_messages,
to_claude_tool_config,
to_dial_finish_reason,
)
from aidial_adapter_bedrock.llm.model.conf import DEFAULT_MAX_TOKENS_ANTHROPIC
from aidial_adapter_bedrock.llm.tools.claude_emulator import (
legacy_tools_emulator,
from aidial_adapter_bedrock.llm.model.claude.v3.tools import (
process_tools_block,
process_with_tools,
)
from aidial_adapter_bedrock.llm.model.conf import DEFAULT_MAX_TOKENS_ANTHROPIC
from aidial_adapter_bedrock.llm.tools.tools_config import ToolsMode
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log


Expand All @@ -51,6 +65,7 @@ class ChatParams(TypedDict):
system: Union[str, NotGiven]
temperature: Union[float, NotGiven]
top_p: Union[float, NotGiven]
tools: Union[List[ToolParam], NotGiven]


class Adapter(ChatCompletionAdapter):
Expand All @@ -67,58 +82,99 @@ async def chat(
if len(messages) == 0:
raise ValidationError("List of messages must not be empty")

tools_emulator = self.tools_emulator(params.tool_config)
base_messages = tools_emulator.parse_dial_messages(messages)
tool_stop_sequences = tools_emulator.get_stop_sequences()
tools = NOT_GIVEN
tools_mode = None
if params.tool_config is not None:
tools = [
to_claude_tool_config(tool_function)
for tool_function in params.tool_config.functions
]
tools_mode = params.tool_config.tools_mode

parsed_messages = [
process_with_tools(parse_dial_message(m), tools_mode)
for m in messages
]

prompt, claude_messages = await to_claude_messages(
base_messages, self.storage
parsed_messages, self.storage
)

completion_params = ChatParams(
max_tokens=params.max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC,
stop_sequences=[*params.stop, *tool_stop_sequences],
stop_sequences=params.stop,
system=prompt or NOT_GIVEN,
temperature=(
NOT_GIVEN
if params.temperature is None
else params.temperature / 2
),
top_p=params.top_p or NOT_GIVEN,
tools=tools,
)

if params.stream:
await self.invoke_streaming(
consumer, claude_messages, completion_params
consumer, claude_messages, completion_params, tools_mode
)
else:
await self.invoke_non_streaming(
consumer, claude_messages, completion_params
consumer, claude_messages, completion_params, tools_mode
)

async def invoke_streaming(
self,
consumer: Consumer,
messages: List[MessageParam],
params: ChatParams,
tools_mode: ToolsMode | None,
):
log.debug(
f"Streaming request: messages={messages}, model={self.model}, params={params}"
)
async with self.client.messages.stream(
messages=messages,
model=self.model,
event_handler=UsageEventHandler,
**params,
) as stream:
async for text in stream.text_stream:
consumer.append_content(text)
consumer.close_content(to_dial_finish_reason(stream.stop_reason))
prompt_tokens = 0
completion_tokens = 0
stop_reason = None
async for event in stream:
match event:
case MessageStartEvent():
prompt_tokens += event.message.usage.input_tokens
case TextEvent():
consumer.append_content(event.text)
case MessageDeltaEvent():
completion_tokens += event.usage.output_tokens
case ContentBlockStopEvent():
if isinstance(event.content_block, ToolUseBlock):
process_tools_block(
consumer, event.content_block, tools_mode
)
case MessageStopEvent():
completion_tokens += event.message.usage.output_tokens
stop_reason = event.message.stop_reason
case (
InputJsonEvent()
| ContentBlockStartEvent()
| ContentBlockDeltaEvent()
):
pass
case _:
raise ValueError(
f"Unsupported event type! {type(event)}"
)

consumer.close_content(
to_dial_finish_reason(stop_reason, tools_mode)
)

consumer.add_usage(
TokenUsage(
prompt_tokens=stream.prompt_tokens,
completion_tokens=stream.completion_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
)

Expand All @@ -127,26 +183,29 @@ async def invoke_non_streaming(
consumer: Consumer,
messages: List[MessageParam],
params: ChatParams,
tools_mode: ToolsMode | None,
):
log.debug(
f"Request: messages={messages}, model={self.model}, params={params}"
)
message = await self.client.messages.create(
messages=messages, model=self.model, **params, stream=False
)
prompt_tokens = 0
completion_tokens = 0
for content in message.content:
usage = message.usage
prompt_tokens = usage.input_tokens
completion_tokens += usage.output_tokens
consumer.append_content(content.text)
consumer.close_content(to_dial_finish_reason(message.stop_reason))
if isinstance(content, TextBlock):
consumer.append_content(content.text)
elif isinstance(content, ToolUseBlock):
process_tools_block(consumer, content, tools_mode)
else:
assert_never(content)
consumer.close_content(
to_dial_finish_reason(message.stop_reason, tools_mode)
)

consumer.add_usage(
TokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
prompt_tokens=message.usage.input_tokens,
completion_tokens=message.usage.output_tokens,
)
)

Expand All @@ -155,7 +214,6 @@ def create(cls, model: str, region: str, headers: Mapping[str, str]):
storage: Optional[FileStorage] = create_file_storage(headers)
return cls(
model=model,
tools_emulator=legacy_tools_emulator,
storage=storage,
client=AsyncAnthropicBedrock(aws_region=region),
)
Loading

0 comments on commit 5eb782a

Please sign in to comment.