Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/added bedrock streaming #314

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 62 additions & 12 deletions adalflow/adalflow/components/model_client/bedrock_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""AWS Bedrock ModelClient integration."""

import json
import os
from typing import Dict, Optional, Any, Callable
from typing import Dict, Optional, Any, Callable, Generator as GeneratorType
import backoff
import logging

Expand All @@ -26,7 +27,6 @@ def get_first_message_content(completion: Dict) -> str:
r"""When we only need the content of the first message.
It is the default parser for chat completion."""
return completion["output"]["message"]["content"][0]["text"]
return completion["output"]["message"]["content"][0]["text"]


__all__ = [
Expand Down Expand Up @@ -117,6 +117,7 @@ def __init__(
self._aws_connection_timeout = aws_connection_timeout
self._aws_read_timeout = aws_read_timeout

self._client = None
self.session = None
self.sync_client = self.init_sync_client()
self.chat_completion_parser = (
Expand Down Expand Up @@ -158,16 +159,51 @@ def init_sync_client(self):
def init_async_client(self):
raise NotImplementedError("Async call not implemented yet.")

def parse_chat_completion(self, completion):
log.debug(f"completion: {completion}")
def handle_stream_response(self, stream: dict) -> GeneratorType:
r"""Handle the stream response from bedrock. Yield the chunks.

Args:
stream (dict): The stream response generator from bedrock.

Returns:
GeneratorType: A generator that yields the chunks from bedrock stream.
"""
try:
stream: GeneratorType = stream["stream"]
for chunk in stream:
log.debug(f"Raw chunk: {chunk}")
yield chunk
except Exception as e:
log.debug(f"Error in handle_stream_response: {e}") # Debug print
raise

def parse_chat_completion(self, completion: dict) -> "GeneratorOutput":
r"""Parse the completion, and assign it into the raw_response attribute.

If the completion is a stream, it will be handled by the handle_stream_response
method that returns a Generator. Otherwise, the completion will be parsed using
the get_first_message_content method.

Args:
completion (dict): The completion response from bedrock API call.

Returns:
GeneratorOutput: A generator output object with the parsed completion. May
return a generator if the completion is a stream.
"""
try:
data = completion["output"]["message"]["content"][0]["text"]
usage = self.track_completion_usage(completion)
return GeneratorOutput(data=None, usage=usage, raw_response=data)
usage = None
data = self.chat_completion_parser(completion)
if not isinstance(data, GeneratorType):
# Streaming completion usage tracking is not implemented.
usage = self.track_completion_usage(completion)
return GeneratorOutput(
data=None, error=None, raw_response=data, usage=usage
)
except Exception as e:
log.error(f"Error parsing completion: {e}")
log.error(f"Error parsing the completion: {e}")
return GeneratorOutput(
data=None, error=str(e), raw_response=str(completion)
data=None, error=str(e), raw_response=json.dumps(completion)
)

def track_completion_usage(self, completion: Dict) -> CompletionUsage:
Expand All @@ -191,6 +227,7 @@ def list_models(self):
print(f" Description: {model['description']}")
print(f" Provider: {model['provider']}")
print("")

except Exception as e:
print(f"Error listing models: {e}")

Expand Down Expand Up @@ -222,14 +259,27 @@ def convert_inputs_to_api_kwargs(
bedrock_runtime_exceptions.ModelErrorException,
bedrock_runtime_exceptions.ValidationException,
),
max_time=5,
max_time=2,
)
def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED):
def call(
self,
api_kwargs: Dict = {},
model_type: ModelType = ModelType.UNDEFINED,
) -> dict:
"""
kwargs is the combined input and model_kwargs
"""
if model_type == ModelType.LLM:
return self.sync_client.converse(**api_kwargs)
if "stream" in api_kwargs and api_kwargs.get("stream", False):
log.debug("Streaming call")
api_kwargs.pop(
"stream", None
) # stream is not a valid parameter for bedrock
self.chat_completion_parser = self.handle_stream_response
return self.sync_client.converse_stream(**api_kwargs)
else:
api_kwargs.pop("stream", None)
return self.sync_client.converse(**api_kwargs)
else:
raise ValueError(f"model_type {model_type} is not supported")

Expand Down
8 changes: 4 additions & 4 deletions adalflow/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions adalflow/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ groq = "^0.9.0"
google-generativeai = "^0.7.2"
anthropic = "^0.31.1"
lancedb = "^0.5.2"
boto3 = "^1.35.19"



Expand Down
137 changes: 137 additions & 0 deletions adalflow/tests/test_aws_bedrock_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import unittest
from unittest.mock import Mock

# use the openai for mocking standard data types

from adalflow.core.types import ModelType, GeneratorOutput
from adalflow.components.model_client import BedrockAPIClient


class TestBedrockClient(unittest.TestCase):
def setUp(self) -> None:
self.client = BedrockAPIClient()
self.mock_response = {
"ResponseMetadata": {
"RequestId": "43aec10a-9780-4bd5-abcc-857d12460569",
"HTTPStatusCode": 200,
"HTTPHeaders": {
"date": "Sat, 30 Nov 2024 14:27:44 GMT",
"content-type": "application/json",
"content-length": "273",
"connection": "keep-alive",
"x-amzn-requestid": "43aec10a-9780-4bd5-abcc-857d12460569",
},
"RetryAttempts": 0,
},
"output": {
"message": {"role": "assistant", "content": [{"text": "Hello, world!"}]}
},
"stopReason": "end_turn",
"usage": {"inputTokens": 20, "outputTokens": 10, "totalTokens": 30},
"metrics": {"latencyMs": 430},
}
self.mock_stream_response = {
"ResponseMetadata": {
"RequestId": "c76d625e-9fdb-4173-8138-debdd724fc56",
"HTTPStatusCode": 200,
"HTTPHeaders": {
"date": "Sun, 12 Jan 2025 15:10:00 GMT",
"content-type": "application/vnd.amazon.eventstream",
"transfer-encoding": "chunked",
"connection": "keep-alive",
"x-amzn-requestid": "c76d625e-9fdb-4173-8138-debdd724fc56",
},
"RetryAttempts": 0,
},
"stream": iter(()),
}

self.api_kwargs = {
"messages": [{"role": "user", "content": "Hello"}],
"model": "gpt-3.5-turbo",
}

def test_call(self) -> None:
"""Test that the converse method is called correctly."""
mock_sync_client = Mock()

# Mock the converse API calls.
mock_sync_client.converse = Mock(return_value=self.mock_response)
mock_sync_client.converse_stream = Mock(return_value=self.mock_response)

# Set the sync client
self.client.sync_client = mock_sync_client

# Call the call method
result = self.client.call(api_kwargs=self.api_kwargs, model_type=ModelType.LLM)

# Assertions
mock_sync_client.converse.assert_called_once_with(**self.api_kwargs)
mock_sync_client.converse_stream.assert_not_called()
self.assertEqual(result, self.mock_response)

# test parse_chat_completion
output = self.client.parse_chat_completion(completion=self.mock_response)
self.assertTrue(isinstance(output, GeneratorOutput))
self.assertEqual(output.raw_response, "Hello, world!")
self.assertEqual(output.usage.prompt_tokens, 20)
self.assertEqual(output.usage.completion_tokens, 10)
self.assertEqual(output.usage.total_tokens, 30)

def test_streaming_call(self) -> None:
"""Test that a streaming call calls the converse_stream method."""
mock_sync_client = Mock()

# Mock the converse API calls.
mock_sync_client.converse_stream = Mock(return_value=self.mock_response)
mock_sync_client.converse = Mock(return_value=self.mock_response)

# Set the sync client.
self.client.sync_client = mock_sync_client

# Call the call method.
stream_kwargs = self.api_kwargs | {"stream": True}
self.client.call(api_kwargs=stream_kwargs, model_type=ModelType.LLM)

# Assert the streaming call was made.
mock_sync_client.converse_stream.assert_called_once_with(**stream_kwargs)
mock_sync_client.converse.assert_not_called()

def test_call_value_error(self) -> None:
"""Test that a ValueError is raised when an invalid model_type is passed."""
mock_sync_client = Mock()

# Set the sync client
self.client.sync_client = mock_sync_client

# Test that ValueError is raised
with self.assertRaises(ValueError):
self.client.call(
api_kwargs={},
model_type=ModelType.UNDEFINED, # This should trigger ValueError
)

def test_parse_chat_completion(self) -> None:
"""Test that the parse_chat_completion does not call usage completion when
streaming."""
mock_track_completion_usage = Mock()
self.client.track_completion_usage = mock_track_completion_usage

self.client.chat_completion_parser = self.client.handle_stream_response
generator_output = self.client.parse_chat_completion(self.mock_stream_response)

mock_track_completion_usage.assert_not_called()
assert isinstance(generator_output, GeneratorOutput)

def test_parse_chat_completion_call_usage(self) -> None:
"""Test that the parse_chat_completion calls usage completion when streaming."""
mock_track_completion_usage = Mock()
self.client.track_completion_usage = mock_track_completion_usage
generator_output = self.client.parse_chat_completion(self.mock_response)

mock_track_completion_usage.assert_called_once()
assert isinstance(generator_output, GeneratorOutput)


if __name__ == "__main__":
unittest.main()