Skip to content

Commit

Permalink
test: add tests for bedrock client
Browse files Browse the repository at this point in the history
  • Loading branch information
Lloyd Hamilton committed Jan 12, 2025
1 parent f9d46e3 commit 61ef29d
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 16 deletions.
45 changes: 33 additions & 12 deletions adalflow/adalflow/components/model_client/bedrock_client.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
"""AWS Bedrock ModelClient integration."""

import json
import os
from typing import (
Dict,
Optional,
Any,
Callable,
Generator as GeneratorType
)
from typing import Dict, Optional, Any, Callable, Generator as GeneratorType
import backoff
import logging

from adalflow.core.model_client import ModelClient
from adalflow.core.types import ModelType, CompletionUsage, GeneratorOutput
from adalflow.utils import printc

from adalflow.utils.lazy_import import safe_import, OptionalPackages

Expand Down Expand Up @@ -166,21 +160,46 @@ def init_async_client(self):
raise NotImplementedError("Async call not implemented yet.")

def handle_stream_response(self, stream: dict) -> GeneratorType:
r"""Handle the stream response from bedrock. Yield the chunks.
Args:
stream (dict): The stream response generator from bedrock.
Returns:
GeneratorType: A generator that yields the chunks from bedrock stream.
"""
try:
stream: GeneratorType = stream["stream"]
for chunk in stream:
log.debug(f"Raw chunk: {chunk}")
yield chunk
except Exception as e:
print(f"Error in handle_stream_response: {e}") # Debug print
raise from e
raise

def parse_chat_completion(self, completion: dict) -> "GeneratorOutput":
"""Parse the completion, and put it into the raw_response."""
r"""Parse the completion, and assign it into the raw_response attribute.
If the completion is a stream, it will be handled by the handle_stream_response
method that returns a Generator. Otherwise, the completion will be parsed using
the get_first_message_content method.
Args:
completion (dict): The completion response from bedrock API call.
Returns:
GeneratorOutput: A generator output object with the parsed completion. May
return a generator if the completion is a stream.
"""
try:
usage = None
print(completion)
data = self.chat_completion_parser(completion)
if not isinstance(data, GeneratorType):
# Streaming completion usage tracking is not implemented.
usage = self.track_completion_usage(completion)
return GeneratorOutput(
data=None, error=None, raw_response=data
data=None, error=None, raw_response=data, usage=usage
)
except Exception as e:
log.error(f"Error parsing the completion: {e}")
Expand Down Expand Up @@ -254,7 +273,9 @@ def call(
if model_type == ModelType.LLM:
if "stream" in api_kwargs and api_kwargs.get("stream", False):
log.debug("Streaming call")
api_kwargs.pop("stream") # stream is not a valid parameter for bedrock
api_kwargs.pop(
"stream", None
) # stream is not a valid parameter for bedrock
self.chat_completion_parser = self.handle_stream_response
return self.sync_client.converse_stream(**api_kwargs)
else:
Expand Down
8 changes: 4 additions & 4 deletions adalflow/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions adalflow/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ groq = "^0.9.0"
google-generativeai = "^0.7.2"
anthropic = "^0.31.1"
lancedb = "^0.5.2"
boto3 = "^1.35.19"



Expand Down
80 changes: 80 additions & 0 deletions adalflow/tests/test_aws_bedrock_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import unittest
from unittest.mock import patch, Mock

# use the openai for mocking standard data types

from adalflow.core.types import ModelType, GeneratorOutput
from adalflow.components.model_client import BedrockAPIClient


def getenv_side_effect(key):
# This dictionary can hold more keys and values as needed
env_vars = {
"AWS_ACCESS_KEY_ID": "fake_api_key",
"AWS_SECRET_ACCESS_KEY": "fake_api_key",
"AWS_REGION_NAME": "fake_api_key",
}
return env_vars.get(key, None) # Returns None if key is not found


# modified from test_openai_client.py
class TestBedrockClient(unittest.TestCase):
def setUp(self):
self.client = BedrockAPIClient()
self.mock_response = {
"ResponseMetadata": {
"RequestId": "43aec10a-9780-4bd5-abcc-857d12460569",
"HTTPStatusCode": 200,
"HTTPHeaders": {
"date": "Sat, 30 Nov 2024 14:27:44 GMT",
"content-type": "application/json",
"content-length": "273",
"connection": "keep-alive",
"x-amzn-requestid": "43aec10a-9780-4bd5-abcc-857d12460569",
},
"RetryAttempts": 0,
},
"output": {
"message": {"role": "assistant", "content": [{"text": "Hello, world!"}]}
},
"stopReason": "end_turn",
"usage": {"inputTokens": 20, "outputTokens": 10, "totalTokens": 30},
"metrics": {"latencyMs": 430},
}

self.api_kwargs = {
"messages": [{"role": "user", "content": "Hello"}],
"model": "gpt-3.5-turbo",
}

@patch.object(BedrockAPIClient, "init_sync_client")
@patch("adalflow.components.model_client.bedrock_client.boto3")
def test_call(self, MockBedrock, mock_init_sync_client):
mock_sync_client = Mock()
MockBedrock.return_value = mock_sync_client
mock_init_sync_client.return_value = mock_sync_client

# Mock the client's api: converse
mock_sync_client.converse = Mock(return_value=self.mock_response)

# Set the sync client
self.client.sync_client = mock_sync_client

# Call the call method
result = self.client.call(api_kwargs=self.api_kwargs, model_type=ModelType.LLM)

# Assertions
mock_sync_client.converse.assert_called_once_with(**self.api_kwargs)
self.assertEqual(result, self.mock_response)

# test parse_chat_completion
output = self.client.parse_chat_completion(completion=self.mock_response)
self.assertTrue(isinstance(output, GeneratorOutput))
self.assertEqual(output.raw_response, "Hello, world!")
self.assertEqual(output.usage.prompt_tokens, 20)
self.assertEqual(output.usage.completion_tokens, 10)
self.assertEqual(output.usage.total_tokens, 30)


if __name__ == "__main__":
unittest.main()

0 comments on commit 61ef29d

Please sign in to comment.