diff --git a/adalflow/adalflow/components/model_client/bedrock_client.py b/adalflow/adalflow/components/model_client/bedrock_client.py index 5a3e198a..715c2ead 100644 --- a/adalflow/adalflow/components/model_client/bedrock_client.py +++ b/adalflow/adalflow/components/model_client/bedrock_client.py @@ -1,19 +1,13 @@ """AWS Bedrock ModelClient integration.""" + import json import os -from typing import ( - Dict, - Optional, - Any, - Callable, - Generator as GeneratorType -) +from typing import Dict, Optional, Any, Callable, Generator as GeneratorType import backoff import logging from adalflow.core.model_client import ModelClient from adalflow.core.types import ModelType, CompletionUsage, GeneratorOutput -from adalflow.utils import printc from adalflow.utils.lazy_import import safe_import, OptionalPackages @@ -166,6 +160,14 @@ def init_async_client(self): raise NotImplementedError("Async call not implemented yet.") def handle_stream_response(self, stream: dict) -> GeneratorType: + r"""Handle the stream response from bedrock. Yield the chunks. + + Args: + stream (dict): The stream response generator from bedrock. + + Returns: + GeneratorType: A generator that yields the chunks from bedrock stream. + """ try: stream: GeneratorType = stream["stream"] for chunk in stream: @@ -173,14 +175,31 @@ def handle_stream_response(self, stream: dict) -> GeneratorType: yield chunk except Exception as e: print(f"Error in handle_stream_response: {e}") # Debug print - raise from e + raise def parse_chat_completion(self, completion: dict) -> "GeneratorOutput": - """Parse the completion, and put it into the raw_response.""" + r"""Parse the completion, and assign it into the raw_response attribute. + + If the completion is a stream, it will be handled by the handle_stream_response + method that returns a Generator. Otherwise, the completion will be parsed using + the get_first_message_content method. + + Args: + completion (dict): The completion response from bedrock API call. + + Returns: + GeneratorOutput: A generator output object with the parsed completion. May + return a generator if the completion is a stream. + """ try: + usage = None + print(completion) data = self.chat_completion_parser(completion) + if not isinstance(data, GeneratorType): + # Streaming completion usage tracking is not implemented. + usage = self.track_completion_usage(completion) return GeneratorOutput( - data=None, error=None, raw_response=data + data=None, error=None, raw_response=data, usage=usage ) except Exception as e: log.error(f"Error parsing the completion: {e}") @@ -254,7 +273,9 @@ def call( if model_type == ModelType.LLM: if "stream" in api_kwargs and api_kwargs.get("stream", False): log.debug("Streaming call") - api_kwargs.pop("stream") # stream is not a valid parameter for bedrock + api_kwargs.pop( + "stream", None + ) # stream is not a valid parameter for bedrock self.chat_completion_parser = self.handle_stream_response return self.sync_client.converse_stream(**api_kwargs) else: diff --git a/adalflow/poetry.lock b/adalflow/poetry.lock index 92d3c9cb..6fd97657 100644 --- a/adalflow/poetry.lock +++ b/adalflow/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "absl-py" @@ -289,7 +289,7 @@ files = [ name = "boto3" version = "1.35.80" description = "The AWS SDK for Python" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "boto3-1.35.80-py3-none-any.whl", hash = "sha256:21a3b18c3a7fd20e463708fe3fa035983105dc7f3a1c274e1903e1583ab91159"}, @@ -3963,7 +3963,7 @@ pyasn1 = ">=0.1.3" name = "s3transfer" version = "0.10.4" description = "An Amazon S3 Transfer Manager" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "s3transfer-0.10.4-py3-none-any.whl", hash = "sha256:244a76a24355363a68164241438de1b72f8781664920260c48465896b712a41e"}, @@ -4889,4 +4889,4 @@ torch = ["torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9, <4.0" -content-hash = "86d5f192585121c048dae33edff42527cf40f8a0398e1c3e5b60c1c8ab0af363" +content-hash = "31d11e116d89e9120ef07674aa16bff4282c84887584ab121bd6249e06710384" diff --git a/adalflow/pyproject.toml b/adalflow/pyproject.toml index 08947d81..5c36e275 100644 --- a/adalflow/pyproject.toml +++ b/adalflow/pyproject.toml @@ -80,6 +80,7 @@ groq = "^0.9.0" google-generativeai = "^0.7.2" anthropic = "^0.31.1" lancedb = "^0.5.2" +boto3 = "^1.35.19" diff --git a/adalflow/tests/test_aws_bedrock_client.py b/adalflow/tests/test_aws_bedrock_client.py new file mode 100644 index 00000000..77dc9fba --- /dev/null +++ b/adalflow/tests/test_aws_bedrock_client.py @@ -0,0 +1,80 @@ +import unittest +from unittest.mock import patch, Mock + +# use the openai for mocking standard data types + +from adalflow.core.types import ModelType, GeneratorOutput +from adalflow.components.model_client import BedrockAPIClient + + +def getenv_side_effect(key): + # This dictionary can hold more keys and values as needed + env_vars = { + "AWS_ACCESS_KEY_ID": "fake_api_key", + "AWS_SECRET_ACCESS_KEY": "fake_api_key", + "AWS_REGION_NAME": "fake_api_key", + } + return env_vars.get(key, None) # Returns None if key is not found + + +# modified from test_openai_client.py +class TestBedrockClient(unittest.TestCase): + def setUp(self): + self.client = BedrockAPIClient() + self.mock_response = { + "ResponseMetadata": { + "RequestId": "43aec10a-9780-4bd5-abcc-857d12460569", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "date": "Sat, 30 Nov 2024 14:27:44 GMT", + "content-type": "application/json", + "content-length": "273", + "connection": "keep-alive", + "x-amzn-requestid": "43aec10a-9780-4bd5-abcc-857d12460569", + }, + "RetryAttempts": 0, + }, + "output": { + "message": {"role": "assistant", "content": [{"text": "Hello, world!"}]} + }, + "stopReason": "end_turn", + "usage": {"inputTokens": 20, "outputTokens": 10, "totalTokens": 30}, + "metrics": {"latencyMs": 430}, + } + + self.api_kwargs = { + "messages": [{"role": "user", "content": "Hello"}], + "model": "gpt-3.5-turbo", + } + + @patch.object(BedrockAPIClient, "init_sync_client") + @patch("adalflow.components.model_client.bedrock_client.boto3") + def test_call(self, MockBedrock, mock_init_sync_client): + mock_sync_client = Mock() + MockBedrock.return_value = mock_sync_client + mock_init_sync_client.return_value = mock_sync_client + + # Mock the client's api: converse + mock_sync_client.converse = Mock(return_value=self.mock_response) + + # Set the sync client + self.client.sync_client = mock_sync_client + + # Call the call method + result = self.client.call(api_kwargs=self.api_kwargs, model_type=ModelType.LLM) + + # Assertions + mock_sync_client.converse.assert_called_once_with(**self.api_kwargs) + self.assertEqual(result, self.mock_response) + + # test parse_chat_completion + output = self.client.parse_chat_completion(completion=self.mock_response) + self.assertTrue(isinstance(output, GeneratorOutput)) + self.assertEqual(output.raw_response, "Hello, world!") + self.assertEqual(output.usage.prompt_tokens, 20) + self.assertEqual(output.usage.completion_tokens, 10) + self.assertEqual(output.usage.total_tokens, 30) + + +if __name__ == "__main__": + unittest.main()