From 3329ccc59749e10132c617ae98bebe819854a7c9 Mon Sep 17 00:00:00 2001 From: Lloyd Hamilton Date: Sat, 11 Jan 2025 21:40:31 +0000 Subject: [PATCH 1/5] Added streaming call for bedrock API --- .../components/model_client/bedrock_client.py | 72 ++++++++++++++----- 1 file changed, 55 insertions(+), 17 deletions(-) diff --git a/adalflow/adalflow/components/model_client/bedrock_client.py b/adalflow/adalflow/components/model_client/bedrock_client.py index b10098bb..479b43ef 100644 --- a/adalflow/adalflow/components/model_client/bedrock_client.py +++ b/adalflow/adalflow/components/model_client/bedrock_client.py @@ -1,7 +1,13 @@ """AWS Bedrock ModelClient integration.""" - +import json import os -from typing import Dict, Optional, Any, Callable +from typing import ( + Dict, + Optional, + Any, + Callable, + Generator as GeneratorType +) import backoff import logging @@ -15,6 +21,7 @@ from botocore.config import Config log = logging.getLogger(__name__) +log.level = logging.DEBUG bedrock_runtime_exceptions = boto3.client( service_name="bedrock-runtime", @@ -26,7 +33,6 @@ def get_first_message_content(completion: Dict) -> str: r"""When we only need the content of the first message. It is the default parser for chat completion.""" return completion["output"]["message"]["content"][0]["text"] - return completion["output"]["message"]["content"][0]["text"] __all__ = [ @@ -117,6 +123,7 @@ def __init__( self._aws_connection_timeout = aws_connection_timeout self._aws_read_timeout = aws_read_timeout + self._client = None self.session = None self.sync_client = self.init_sync_client() self.chat_completion_parser = ( @@ -158,16 +165,34 @@ def init_sync_client(self): def init_async_client(self): raise NotImplementedError("Async call not implemented yet.") - def parse_chat_completion(self, completion): - log.debug(f"completion: {completion}") + @staticmethod + def parse_stream_response(completion: dict) -> str: + if "contentBlockDelta" in completion: + if delta_chunk := completion["contentBlockDelta"]["delta"]: + return delta_chunk["text"] + return '' + + def handle_stream_response(self, stream: dict) -> GeneratorType: + try: + for chunk in stream["stream"]: + log.debug(f"Raw chunk: {chunk}") + parsed_content = self.parse_stream_response(chunk) + yield parsed_content + except Exception as e: + print(f"Error in handle_stream_response: {e}") # Debug print + raise + + def parse_chat_completion(self, completion: dict) -> "GeneratorOutput": + """Parse the completion, and put it into the raw_response.""" try: - data = completion["output"]["message"]["content"][0]["text"] - usage = self.track_completion_usage(completion) - return GeneratorOutput(data=None, usage=usage, raw_response=data) + data = self.handle_stream_response(completion) + return GeneratorOutput( + data=None, error=None, raw_response=data + ) except Exception as e: - log.error(f"Error parsing completion: {e}") + log.error(f"Error parsing the completion: {e}") return GeneratorOutput( - data=None, error=str(e), raw_response=str(completion) + data=None, error=str(e), raw_response=json.dumps(completion) ) def track_completion_usage(self, completion: Dict) -> CompletionUsage: @@ -184,12 +209,13 @@ def list_models(self): try: response = self._client.list_foundation_models() - models = response.get("models", []) + models = response.get("modelSummaries", []) for model in models: print(f"Model ID: {model['modelId']}") - print(f" Name: {model['name']}") - print(f" Description: {model['description']}") - print(f" Provider: {model['provider']}") + print(f" Name: {model['modelName']}") + print(f" Input Modalities: {model['inputModalities']}") + print(f" Output Modalities: {model['outputModalities']}") + print(f" Provider: {model['providerName']}") print("") except Exception as e: print(f"Error listing models: {e}") @@ -222,14 +248,26 @@ def convert_inputs_to_api_kwargs( bedrock_runtime_exceptions.ModelErrorException, bedrock_runtime_exceptions.ValidationException, ), - max_time=5, + max_time=2, ) - def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED): + def call( + self, + api_kwargs: Dict = {}, + model_type: ModelType = ModelType.UNDEFINED, + stream: bool = False + ) -> dict: """ kwargs is the combined input and model_kwargs """ if model_type == ModelType.LLM: - return self.sync_client.converse(**api_kwargs) + if "stream" in api_kwargs and api_kwargs.get("stream", False): + log.debug("Streaming call") + api_kwargs.pop("stream") # stream is not a valid parameter for bedrock + self.chat_completion_parser = self.handle_stream_response + return self.sync_client.converse_stream(**api_kwargs) + else: + api_kwargs.pop("stream") + return self.sync_client.converse(**api_kwargs) else: raise ValueError(f"model_type {model_type} is not supported") From 1abea9e821c693f868277b7f6cdaa5969e178767 Mon Sep 17 00:00:00 2001 From: Lloyd Hamilton Date: Sun, 12 Jan 2025 12:22:08 +0000 Subject: [PATCH 2/5] Removed stream parser, favouring return of raw stream from generator --- .../components/model_client/bedrock_client.py | 20 +++++++------------ 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/adalflow/adalflow/components/model_client/bedrock_client.py b/adalflow/adalflow/components/model_client/bedrock_client.py index 479b43ef..8184540d 100644 --- a/adalflow/adalflow/components/model_client/bedrock_client.py +++ b/adalflow/adalflow/components/model_client/bedrock_client.py @@ -13,6 +13,7 @@ from adalflow.core.model_client import ModelClient from adalflow.core.types import ModelType, CompletionUsage, GeneratorOutput +from adalflow.utils import printc from adalflow.utils.lazy_import import safe_import, OptionalPackages @@ -165,19 +166,12 @@ def init_sync_client(self): def init_async_client(self): raise NotImplementedError("Async call not implemented yet.") - @staticmethod - def parse_stream_response(completion: dict) -> str: - if "contentBlockDelta" in completion: - if delta_chunk := completion["contentBlockDelta"]["delta"]: - return delta_chunk["text"] - return '' - def handle_stream_response(self, stream: dict) -> GeneratorType: try: - for chunk in stream["stream"]: + stream: GeneratorType = stream["stream"] + for chunk in stream: log.debug(f"Raw chunk: {chunk}") - parsed_content = self.parse_stream_response(chunk) - yield parsed_content + yield chunk except Exception as e: print(f"Error in handle_stream_response: {e}") # Debug print raise @@ -185,7 +179,7 @@ def handle_stream_response(self, stream: dict) -> GeneratorType: def parse_chat_completion(self, completion: dict) -> "GeneratorOutput": """Parse the completion, and put it into the raw_response.""" try: - data = self.handle_stream_response(completion) + data = self.chat_completion_parser(completion) return GeneratorOutput( data=None, error=None, raw_response=data ) @@ -254,7 +248,6 @@ def call( self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED, - stream: bool = False ) -> dict: """ kwargs is the combined input and model_kwargs @@ -262,11 +255,12 @@ def call( if model_type == ModelType.LLM: if "stream" in api_kwargs and api_kwargs.get("stream", False): log.debug("Streaming call") + printc("Streaming") api_kwargs.pop("stream") # stream is not a valid parameter for bedrock self.chat_completion_parser = self.handle_stream_response return self.sync_client.converse_stream(**api_kwargs) else: - api_kwargs.pop("stream") + api_kwargs.pop("stream", None) return self.sync_client.converse(**api_kwargs) else: raise ValueError(f"model_type {model_type} is not supported") From f9d46e335f12ef545b49f2ff3fd1870871d37740 Mon Sep 17 00:00:00 2001 From: Lloyd Hamilton Date: Sun, 12 Jan 2025 12:29:52 +0000 Subject: [PATCH 3/5] Revert list model changes --- .../components/model_client/bedrock_client.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/adalflow/adalflow/components/model_client/bedrock_client.py b/adalflow/adalflow/components/model_client/bedrock_client.py index 8184540d..5a3e198a 100644 --- a/adalflow/adalflow/components/model_client/bedrock_client.py +++ b/adalflow/adalflow/components/model_client/bedrock_client.py @@ -22,7 +22,6 @@ from botocore.config import Config log = logging.getLogger(__name__) -log.level = logging.DEBUG bedrock_runtime_exceptions = boto3.client( service_name="bedrock-runtime", @@ -174,7 +173,7 @@ def handle_stream_response(self, stream: dict) -> GeneratorType: yield chunk except Exception as e: print(f"Error in handle_stream_response: {e}") # Debug print - raise + raise from e def parse_chat_completion(self, completion: dict) -> "GeneratorOutput": """Parse the completion, and put it into the raw_response.""" @@ -203,14 +202,14 @@ def list_models(self): try: response = self._client.list_foundation_models() - models = response.get("modelSummaries", []) + models = response.get("models", []) for model in models: print(f"Model ID: {model['modelId']}") - print(f" Name: {model['modelName']}") - print(f" Input Modalities: {model['inputModalities']}") - print(f" Output Modalities: {model['outputModalities']}") - print(f" Provider: {model['providerName']}") + print(f" Name: {model['name']}") + print(f" Description: {model['description']}") + print(f" Provider: {model['provider']}") print("") + except Exception as e: print(f"Error listing models: {e}") @@ -255,7 +254,6 @@ def call( if model_type == ModelType.LLM: if "stream" in api_kwargs and api_kwargs.get("stream", False): log.debug("Streaming call") - printc("Streaming") api_kwargs.pop("stream") # stream is not a valid parameter for bedrock self.chat_completion_parser = self.handle_stream_response return self.sync_client.converse_stream(**api_kwargs) From 61ef29d7a1d138b749a517108d21295d5744d7b3 Mon Sep 17 00:00:00 2001 From: Lloyd Hamilton Date: Sun, 12 Jan 2025 15:12:50 +0000 Subject: [PATCH 4/5] test: add tests for bedrock client --- .../components/model_client/bedrock_client.py | 45 ++++++++--- adalflow/poetry.lock | 8 +- adalflow/pyproject.toml | 1 + adalflow/tests/test_aws_bedrock_client.py | 80 +++++++++++++++++++ 4 files changed, 118 insertions(+), 16 deletions(-) create mode 100644 adalflow/tests/test_aws_bedrock_client.py diff --git a/adalflow/adalflow/components/model_client/bedrock_client.py b/adalflow/adalflow/components/model_client/bedrock_client.py index 5a3e198a..715c2ead 100644 --- a/adalflow/adalflow/components/model_client/bedrock_client.py +++ b/adalflow/adalflow/components/model_client/bedrock_client.py @@ -1,19 +1,13 @@ """AWS Bedrock ModelClient integration.""" + import json import os -from typing import ( - Dict, - Optional, - Any, - Callable, - Generator as GeneratorType -) +from typing import Dict, Optional, Any, Callable, Generator as GeneratorType import backoff import logging from adalflow.core.model_client import ModelClient from adalflow.core.types import ModelType, CompletionUsage, GeneratorOutput -from adalflow.utils import printc from adalflow.utils.lazy_import import safe_import, OptionalPackages @@ -166,6 +160,14 @@ def init_async_client(self): raise NotImplementedError("Async call not implemented yet.") def handle_stream_response(self, stream: dict) -> GeneratorType: + r"""Handle the stream response from bedrock. Yield the chunks. + + Args: + stream (dict): The stream response generator from bedrock. + + Returns: + GeneratorType: A generator that yields the chunks from bedrock stream. + """ try: stream: GeneratorType = stream["stream"] for chunk in stream: @@ -173,14 +175,31 @@ def handle_stream_response(self, stream: dict) -> GeneratorType: yield chunk except Exception as e: print(f"Error in handle_stream_response: {e}") # Debug print - raise from e + raise def parse_chat_completion(self, completion: dict) -> "GeneratorOutput": - """Parse the completion, and put it into the raw_response.""" + r"""Parse the completion, and assign it into the raw_response attribute. + + If the completion is a stream, it will be handled by the handle_stream_response + method that returns a Generator. Otherwise, the completion will be parsed using + the get_first_message_content method. + + Args: + completion (dict): The completion response from bedrock API call. + + Returns: + GeneratorOutput: A generator output object with the parsed completion. May + return a generator if the completion is a stream. + """ try: + usage = None + print(completion) data = self.chat_completion_parser(completion) + if not isinstance(data, GeneratorType): + # Streaming completion usage tracking is not implemented. + usage = self.track_completion_usage(completion) return GeneratorOutput( - data=None, error=None, raw_response=data + data=None, error=None, raw_response=data, usage=usage ) except Exception as e: log.error(f"Error parsing the completion: {e}") @@ -254,7 +273,9 @@ def call( if model_type == ModelType.LLM: if "stream" in api_kwargs and api_kwargs.get("stream", False): log.debug("Streaming call") - api_kwargs.pop("stream") # stream is not a valid parameter for bedrock + api_kwargs.pop( + "stream", None + ) # stream is not a valid parameter for bedrock self.chat_completion_parser = self.handle_stream_response return self.sync_client.converse_stream(**api_kwargs) else: diff --git a/adalflow/poetry.lock b/adalflow/poetry.lock index 92d3c9cb..6fd97657 100644 --- a/adalflow/poetry.lock +++ b/adalflow/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "absl-py" @@ -289,7 +289,7 @@ files = [ name = "boto3" version = "1.35.80" description = "The AWS SDK for Python" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "boto3-1.35.80-py3-none-any.whl", hash = "sha256:21a3b18c3a7fd20e463708fe3fa035983105dc7f3a1c274e1903e1583ab91159"}, @@ -3963,7 +3963,7 @@ pyasn1 = ">=0.1.3" name = "s3transfer" version = "0.10.4" description = "An Amazon S3 Transfer Manager" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "s3transfer-0.10.4-py3-none-any.whl", hash = "sha256:244a76a24355363a68164241438de1b72f8781664920260c48465896b712a41e"}, @@ -4889,4 +4889,4 @@ torch = ["torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9, <4.0" -content-hash = "86d5f192585121c048dae33edff42527cf40f8a0398e1c3e5b60c1c8ab0af363" +content-hash = "31d11e116d89e9120ef07674aa16bff4282c84887584ab121bd6249e06710384" diff --git a/adalflow/pyproject.toml b/adalflow/pyproject.toml index 08947d81..5c36e275 100644 --- a/adalflow/pyproject.toml +++ b/adalflow/pyproject.toml @@ -80,6 +80,7 @@ groq = "^0.9.0" google-generativeai = "^0.7.2" anthropic = "^0.31.1" lancedb = "^0.5.2" +boto3 = "^1.35.19" diff --git a/adalflow/tests/test_aws_bedrock_client.py b/adalflow/tests/test_aws_bedrock_client.py new file mode 100644 index 00000000..77dc9fba --- /dev/null +++ b/adalflow/tests/test_aws_bedrock_client.py @@ -0,0 +1,80 @@ +import unittest +from unittest.mock import patch, Mock + +# use the openai for mocking standard data types + +from adalflow.core.types import ModelType, GeneratorOutput +from adalflow.components.model_client import BedrockAPIClient + + +def getenv_side_effect(key): + # This dictionary can hold more keys and values as needed + env_vars = { + "AWS_ACCESS_KEY_ID": "fake_api_key", + "AWS_SECRET_ACCESS_KEY": "fake_api_key", + "AWS_REGION_NAME": "fake_api_key", + } + return env_vars.get(key, None) # Returns None if key is not found + + +# modified from test_openai_client.py +class TestBedrockClient(unittest.TestCase): + def setUp(self): + self.client = BedrockAPIClient() + self.mock_response = { + "ResponseMetadata": { + "RequestId": "43aec10a-9780-4bd5-abcc-857d12460569", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "date": "Sat, 30 Nov 2024 14:27:44 GMT", + "content-type": "application/json", + "content-length": "273", + "connection": "keep-alive", + "x-amzn-requestid": "43aec10a-9780-4bd5-abcc-857d12460569", + }, + "RetryAttempts": 0, + }, + "output": { + "message": {"role": "assistant", "content": [{"text": "Hello, world!"}]} + }, + "stopReason": "end_turn", + "usage": {"inputTokens": 20, "outputTokens": 10, "totalTokens": 30}, + "metrics": {"latencyMs": 430}, + } + + self.api_kwargs = { + "messages": [{"role": "user", "content": "Hello"}], + "model": "gpt-3.5-turbo", + } + + @patch.object(BedrockAPIClient, "init_sync_client") + @patch("adalflow.components.model_client.bedrock_client.boto3") + def test_call(self, MockBedrock, mock_init_sync_client): + mock_sync_client = Mock() + MockBedrock.return_value = mock_sync_client + mock_init_sync_client.return_value = mock_sync_client + + # Mock the client's api: converse + mock_sync_client.converse = Mock(return_value=self.mock_response) + + # Set the sync client + self.client.sync_client = mock_sync_client + + # Call the call method + result = self.client.call(api_kwargs=self.api_kwargs, model_type=ModelType.LLM) + + # Assertions + mock_sync_client.converse.assert_called_once_with(**self.api_kwargs) + self.assertEqual(result, self.mock_response) + + # test parse_chat_completion + output = self.client.parse_chat_completion(completion=self.mock_response) + self.assertTrue(isinstance(output, GeneratorOutput)) + self.assertEqual(output.raw_response, "Hello, world!") + self.assertEqual(output.usage.prompt_tokens, 20) + self.assertEqual(output.usage.completion_tokens, 10) + self.assertEqual(output.usage.total_tokens, 30) + + +if __name__ == "__main__": + unittest.main() From 49ac890b93c03d089b378b5dca9b5f0540c01f3e Mon Sep 17 00:00:00 2001 From: Lloyd Hamilton Date: Sun, 12 Jan 2025 17:08:24 +0000 Subject: [PATCH 5/5] test: add tests for parse_chat_completion method --- .../components/model_client/bedrock_client.py | 3 +- adalflow/tests/test_aws_bedrock_client.py | 95 +++++++++++++++---- 2 files changed, 77 insertions(+), 21 deletions(-) diff --git a/adalflow/adalflow/components/model_client/bedrock_client.py b/adalflow/adalflow/components/model_client/bedrock_client.py index 715c2ead..493147e6 100644 --- a/adalflow/adalflow/components/model_client/bedrock_client.py +++ b/adalflow/adalflow/components/model_client/bedrock_client.py @@ -174,7 +174,7 @@ def handle_stream_response(self, stream: dict) -> GeneratorType: log.debug(f"Raw chunk: {chunk}") yield chunk except Exception as e: - print(f"Error in handle_stream_response: {e}") # Debug print + log.debug(f"Error in handle_stream_response: {e}") # Debug print raise def parse_chat_completion(self, completion: dict) -> "GeneratorOutput": @@ -193,7 +193,6 @@ def parse_chat_completion(self, completion: dict) -> "GeneratorOutput": """ try: usage = None - print(completion) data = self.chat_completion_parser(completion) if not isinstance(data, GeneratorType): # Streaming completion usage tracking is not implemented. diff --git a/adalflow/tests/test_aws_bedrock_client.py b/adalflow/tests/test_aws_bedrock_client.py index 77dc9fba..1beaa4ee 100644 --- a/adalflow/tests/test_aws_bedrock_client.py +++ b/adalflow/tests/test_aws_bedrock_client.py @@ -1,5 +1,5 @@ import unittest -from unittest.mock import patch, Mock +from unittest.mock import Mock # use the openai for mocking standard data types @@ -7,19 +7,8 @@ from adalflow.components.model_client import BedrockAPIClient -def getenv_side_effect(key): - # This dictionary can hold more keys and values as needed - env_vars = { - "AWS_ACCESS_KEY_ID": "fake_api_key", - "AWS_SECRET_ACCESS_KEY": "fake_api_key", - "AWS_REGION_NAME": "fake_api_key", - } - return env_vars.get(key, None) # Returns None if key is not found - - -# modified from test_openai_client.py class TestBedrockClient(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.client = BedrockAPIClient() self.mock_response = { "ResponseMetadata": { @@ -41,21 +30,34 @@ def setUp(self): "usage": {"inputTokens": 20, "outputTokens": 10, "totalTokens": 30}, "metrics": {"latencyMs": 430}, } + self.mock_stream_response = { + "ResponseMetadata": { + "RequestId": "c76d625e-9fdb-4173-8138-debdd724fc56", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "date": "Sun, 12 Jan 2025 15:10:00 GMT", + "content-type": "application/vnd.amazon.eventstream", + "transfer-encoding": "chunked", + "connection": "keep-alive", + "x-amzn-requestid": "c76d625e-9fdb-4173-8138-debdd724fc56", + }, + "RetryAttempts": 0, + }, + "stream": iter(()), + } self.api_kwargs = { "messages": [{"role": "user", "content": "Hello"}], "model": "gpt-3.5-turbo", } - @patch.object(BedrockAPIClient, "init_sync_client") - @patch("adalflow.components.model_client.bedrock_client.boto3") - def test_call(self, MockBedrock, mock_init_sync_client): + def test_call(self) -> None: + """Test that the converse method is called correctly.""" mock_sync_client = Mock() - MockBedrock.return_value = mock_sync_client - mock_init_sync_client.return_value = mock_sync_client - # Mock the client's api: converse + # Mock the converse API calls. mock_sync_client.converse = Mock(return_value=self.mock_response) + mock_sync_client.converse_stream = Mock(return_value=self.mock_response) # Set the sync client self.client.sync_client = mock_sync_client @@ -65,6 +67,7 @@ def test_call(self, MockBedrock, mock_init_sync_client): # Assertions mock_sync_client.converse.assert_called_once_with(**self.api_kwargs) + mock_sync_client.converse_stream.assert_not_called() self.assertEqual(result, self.mock_response) # test parse_chat_completion @@ -75,6 +78,60 @@ def test_call(self, MockBedrock, mock_init_sync_client): self.assertEqual(output.usage.completion_tokens, 10) self.assertEqual(output.usage.total_tokens, 30) + def test_streaming_call(self) -> None: + """Test that a streaming call calls the converse_stream method.""" + mock_sync_client = Mock() + + # Mock the converse API calls. + mock_sync_client.converse_stream = Mock(return_value=self.mock_response) + mock_sync_client.converse = Mock(return_value=self.mock_response) + + # Set the sync client. + self.client.sync_client = mock_sync_client + + # Call the call method. + stream_kwargs = self.api_kwargs | {"stream": True} + self.client.call(api_kwargs=stream_kwargs, model_type=ModelType.LLM) + + # Assert the streaming call was made. + mock_sync_client.converse_stream.assert_called_once_with(**stream_kwargs) + mock_sync_client.converse.assert_not_called() + + def test_call_value_error(self) -> None: + """Test that a ValueError is raised when an invalid model_type is passed.""" + mock_sync_client = Mock() + + # Set the sync client + self.client.sync_client = mock_sync_client + + # Test that ValueError is raised + with self.assertRaises(ValueError): + self.client.call( + api_kwargs={}, + model_type=ModelType.UNDEFINED, # This should trigger ValueError + ) + + def test_parse_chat_completion(self) -> None: + """Test that the parse_chat_completion does not call usage completion when + streaming.""" + mock_track_completion_usage = Mock() + self.client.track_completion_usage = mock_track_completion_usage + + self.client.chat_completion_parser = self.client.handle_stream_response + generator_output = self.client.parse_chat_completion(self.mock_stream_response) + + mock_track_completion_usage.assert_not_called() + assert isinstance(generator_output, GeneratorOutput) + + def test_parse_chat_completion_call_usage(self) -> None: + """Test that the parse_chat_completion calls usage completion when streaming.""" + mock_track_completion_usage = Mock() + self.client.track_completion_usage = mock_track_completion_usage + generator_output = self.client.parse_chat_completion(self.mock_response) + + mock_track_completion_usage.assert_called_once() + assert isinstance(generator_output, GeneratorOutput) + if __name__ == "__main__": unittest.main()