diff --git a/adalflow/adalflow/components/model_client/bedrock_client.py b/adalflow/adalflow/components/model_client/bedrock_client.py index b10098bb..493147e6 100644 --- a/adalflow/adalflow/components/model_client/bedrock_client.py +++ b/adalflow/adalflow/components/model_client/bedrock_client.py @@ -1,7 +1,8 @@ """AWS Bedrock ModelClient integration.""" +import json import os -from typing import Dict, Optional, Any, Callable +from typing import Dict, Optional, Any, Callable, Generator as GeneratorType import backoff import logging @@ -26,7 +27,6 @@ def get_first_message_content(completion: Dict) -> str: r"""When we only need the content of the first message. It is the default parser for chat completion.""" return completion["output"]["message"]["content"][0]["text"] - return completion["output"]["message"]["content"][0]["text"] __all__ = [ @@ -117,6 +117,7 @@ def __init__( self._aws_connection_timeout = aws_connection_timeout self._aws_read_timeout = aws_read_timeout + self._client = None self.session = None self.sync_client = self.init_sync_client() self.chat_completion_parser = ( @@ -158,16 +159,51 @@ def init_sync_client(self): def init_async_client(self): raise NotImplementedError("Async call not implemented yet.") - def parse_chat_completion(self, completion): - log.debug(f"completion: {completion}") + def handle_stream_response(self, stream: dict) -> GeneratorType: + r"""Handle the stream response from bedrock. Yield the chunks. + + Args: + stream (dict): The stream response generator from bedrock. + + Returns: + GeneratorType: A generator that yields the chunks from bedrock stream. + """ + try: + stream: GeneratorType = stream["stream"] + for chunk in stream: + log.debug(f"Raw chunk: {chunk}") + yield chunk + except Exception as e: + log.debug(f"Error in handle_stream_response: {e}") # Debug print + raise + + def parse_chat_completion(self, completion: dict) -> "GeneratorOutput": + r"""Parse the completion, and assign it into the raw_response attribute. + + If the completion is a stream, it will be handled by the handle_stream_response + method that returns a Generator. Otherwise, the completion will be parsed using + the get_first_message_content method. + + Args: + completion (dict): The completion response from bedrock API call. + + Returns: + GeneratorOutput: A generator output object with the parsed completion. May + return a generator if the completion is a stream. + """ try: - data = completion["output"]["message"]["content"][0]["text"] - usage = self.track_completion_usage(completion) - return GeneratorOutput(data=None, usage=usage, raw_response=data) + usage = None + data = self.chat_completion_parser(completion) + if not isinstance(data, GeneratorType): + # Streaming completion usage tracking is not implemented. + usage = self.track_completion_usage(completion) + return GeneratorOutput( + data=None, error=None, raw_response=data, usage=usage + ) except Exception as e: - log.error(f"Error parsing completion: {e}") + log.error(f"Error parsing the completion: {e}") return GeneratorOutput( - data=None, error=str(e), raw_response=str(completion) + data=None, error=str(e), raw_response=json.dumps(completion) ) def track_completion_usage(self, completion: Dict) -> CompletionUsage: @@ -191,6 +227,7 @@ def list_models(self): print(f" Description: {model['description']}") print(f" Provider: {model['provider']}") print("") + except Exception as e: print(f"Error listing models: {e}") @@ -222,14 +259,27 @@ def convert_inputs_to_api_kwargs( bedrock_runtime_exceptions.ModelErrorException, bedrock_runtime_exceptions.ValidationException, ), - max_time=5, + max_time=2, ) - def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED): + def call( + self, + api_kwargs: Dict = {}, + model_type: ModelType = ModelType.UNDEFINED, + ) -> dict: """ kwargs is the combined input and model_kwargs """ if model_type == ModelType.LLM: - return self.sync_client.converse(**api_kwargs) + if "stream" in api_kwargs and api_kwargs.get("stream", False): + log.debug("Streaming call") + api_kwargs.pop( + "stream", None + ) # stream is not a valid parameter for bedrock + self.chat_completion_parser = self.handle_stream_response + return self.sync_client.converse_stream(**api_kwargs) + else: + api_kwargs.pop("stream", None) + return self.sync_client.converse(**api_kwargs) else: raise ValueError(f"model_type {model_type} is not supported") diff --git a/adalflow/poetry.lock b/adalflow/poetry.lock index 0cd8b67f..a9e93bd0 100644 --- a/adalflow/poetry.lock +++ b/adalflow/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "absl-py" @@ -289,7 +289,7 @@ files = [ name = "boto3" version = "1.36.0" description = "The AWS SDK for Python" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "boto3-1.36.0-py3-none-any.whl", hash = "sha256:d0ca7a58ce25701a52232cc8df9d87854824f1f2964b929305722ebc7959d5a9"}, @@ -308,7 +308,7 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] name = "botocore" version = "1.36.0" description = "Low-level, data-driven core of boto 3." -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "botocore-1.36.0-py3-none-any.whl", hash = "sha256:b54b11f0cfc47fc1243ada0f7f461266c279968487616720fa8ebb02183917d7"}, @@ -1912,7 +1912,7 @@ files = [ name = "jmespath" version = "1.0.1" description = "JSON Matching Expressions" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, @@ -3958,7 +3958,7 @@ pyasn1 = ">=0.1.3" name = "s3transfer" version = "0.11.0" description = "An Amazon S3 Transfer Manager" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "s3transfer-0.11.0-py3-none-any.whl", hash = "sha256:f43b03931c198743569bbfb6a328a53f4b2b4ec723cd7c01fab68e3119db3f8b"}, @@ -4879,4 +4879,4 @@ torch = ["torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9, <4.0" -content-hash = "0a2347ff8b273139d7f2e32a1dc4da7596fb1d657bbb8b4b52aab42c57bbaad1" +content-hash = "d824db4561b03ced4ab2c5d584a945b85ef26917d78d423e41cbe82b9fcb6e71" diff --git a/adalflow/pyproject.toml b/adalflow/pyproject.toml index 5c8340a2..bbe78133 100644 --- a/adalflow/pyproject.toml +++ b/adalflow/pyproject.toml @@ -75,6 +75,8 @@ groq = "^0.9.0" google-generativeai = "^0.7.2" anthropic = "^0.31.1" lancedb = "^0.5.2" +boto3 = "^1.35.19" + # TODO: cant make qdrant work here # qdrant_client = [ # { version = ">=1.12.2,<2.0.0", optional = true, markers = "python_version >= '3.10'" }, diff --git a/adalflow/tests/test_aws_bedrock_client.py b/adalflow/tests/test_aws_bedrock_client.py new file mode 100644 index 00000000..9f21682a --- /dev/null +++ b/adalflow/tests/test_aws_bedrock_client.py @@ -0,0 +1,139 @@ +import unittest +from unittest.mock import Mock, patch +from adalflow.core.types import ModelType, GeneratorOutput +from adalflow.components.model_client import BedrockAPIClient + + +class TestBedrockClient(unittest.TestCase): + def setUp(self) -> None: + """Set up mocks and test data. + + Mocks the boto3 session and the init_sync_client method. Mocks will create a + mock bedrock client and mock responses that can be reused across tests. + """ + self.session_patcher = patch( + "adalflow.components.model_client.bedrock_client.boto3.Session" + ) + self.mock_session = self.session_patcher.start() + self.mock_boto3_client = Mock() + self.mock_session.return_value.client.return_value = self.mock_boto3_client + self.init_sync_patcher = patch.object(BedrockAPIClient, "init_sync_client") + self.mock_init_sync_client = self.init_sync_patcher.start() + self.mock_sync_client = Mock() + self.mock_init_sync_client.return_value = self.mock_sync_client + self.mock_sync_client.converse = Mock() + self.mock_sync_client.converse_stream = Mock() + self.client = BedrockAPIClient() + self.client.sync_client = self.mock_sync_client + + self.mock_response = { + "ResponseMetadata": { + "RequestId": "43aec10a-9780-4bd5-abcc-857d12460569", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "date": "Sat, 30 Nov 2024 14:27:44 GMT", + "content-type": "application/json", + "content-length": "273", + "connection": "keep-alive", + "x-amzn-requestid": "43aec10a-9780-4bd5-abcc-857d12460569", + }, + "RetryAttempts": 0, + }, + "output": { + "message": {"role": "assistant", "content": [{"text": "Hello, world!"}]} + }, + "stopReason": "end_turn", + "usage": {"inputTokens": 20, "outputTokens": 10, "totalTokens": 30}, + "metrics": {"latencyMs": 430}, + } + self.mock_stream_response = { + "ResponseMetadata": { + "RequestId": "c76d625e-9fdb-4173-8138-debdd724fc56", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "date": "Sun, 12 Jan 2025 15:10:00 GMT", + "content-type": "application/vnd.amazon.eventstream", + "transfer-encoding": "chunked", + "connection": "keep-alive", + "x-amzn-requestid": "c76d625e-9fdb-4173-8138-debdd724fc56", + }, + "RetryAttempts": 0, + }, + "stream": iter(()), + } + self.api_kwargs = { + "messages": [{"role": "user", "content": "Hello"}], + "model": "gpt-3.5-turbo", + } + + def tearDown(self) -> None: + """Stop the patchers.""" + self.init_sync_patcher.stop() + + def test_call(self) -> None: + """Tests that the call method calls the converse method correctly.""" + self.mock_sync_client.converse = Mock(return_value=self.mock_response) + self.mock_sync_client.converse_stream = Mock(return_value=self.mock_response) + + result = self.client.call(api_kwargs=self.api_kwargs, model_type=ModelType.LLM) + + # Assertions: converse is called once and stream is not called + self.mock_sync_client.converse.assert_called_once_with(**self.api_kwargs) + self.mock_sync_client.converse_stream.assert_not_called() + self.assertEqual(result, self.mock_response) + + def test_parse_chat_completion(self) -> None: + """Tests that the parse_chat_completion method returns expected object.""" + output = self.client.parse_chat_completion(completion=self.mock_response) + self.assertTrue(isinstance(output, GeneratorOutput)) + self.assertEqual(output.raw_response, "Hello, world!") + self.assertEqual(output.usage.prompt_tokens, 20) + self.assertEqual(output.usage.completion_tokens, 10) + self.assertEqual(output.usage.total_tokens, 30) + + def test_parse_chat_completion_call_usage(self) -> None: + """Test that the parse_chat_completion calls usage completion when not + streaming.""" + mock_track_completion_usage = Mock() + self.client.track_completion_usage = mock_track_completion_usage + generator_output = self.client.parse_chat_completion(self.mock_response) + + mock_track_completion_usage.assert_called_once() + assert isinstance(generator_output, GeneratorOutput) + + def test_streaming_call(self) -> None: + """Test that a streaming call calls the converse_stream method.""" + self.mock_sync_client.converse = Mock(return_value=self.mock_response) + self.mock_sync_client.converse_stream = Mock(return_value=self.mock_response) + + # Call the call method. + stream_kwargs = self.api_kwargs | {"stream": True} + self.client.call(api_kwargs=stream_kwargs, model_type=ModelType.LLM) + + # Assertions: Streaming method is called + self.mock_sync_client.converse_stream.assert_called_once_with(**stream_kwargs) + self.mock_sync_client.converse.assert_not_called() + + def test_call_value_error(self) -> None: + """Test that a ValueError is raised when an invalid model_type is passed.""" + with self.assertRaises(ValueError): + self.client.call( + api_kwargs={}, + model_type=ModelType.UNDEFINED, # This should trigger ValueError + ) + + def test_parse_streaming_chat_completion(self) -> None: + """Test that the parse_chat_completion does not call usage completion when + streaming.""" + mock_track_completion_usage = Mock() + self.client.track_completion_usage = mock_track_completion_usage + + self.client.chat_completion_parser = self.client.handle_stream_response + generator_output = self.client.parse_chat_completion(self.mock_stream_response) + + mock_track_completion_usage.assert_not_called() + assert isinstance(generator_output, GeneratorOutput) + + +if __name__ == "__main__": + unittest.main()