Skip to content

Commit

Permalink
test: add tests for parse_chat_completion method
Browse files Browse the repository at this point in the history
  • Loading branch information
Lloyd Hamilton committed Jan 12, 2025
1 parent 61ef29d commit 49ac890
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 21 deletions.
3 changes: 1 addition & 2 deletions adalflow/adalflow/components/model_client/bedrock_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def handle_stream_response(self, stream: dict) -> GeneratorType:
log.debug(f"Raw chunk: {chunk}")
yield chunk
except Exception as e:
print(f"Error in handle_stream_response: {e}") # Debug print
log.debug(f"Error in handle_stream_response: {e}") # Debug print
raise

def parse_chat_completion(self, completion: dict) -> "GeneratorOutput":
Expand All @@ -193,7 +193,6 @@ def parse_chat_completion(self, completion: dict) -> "GeneratorOutput":
"""
try:
usage = None
print(completion)
data = self.chat_completion_parser(completion)
if not isinstance(data, GeneratorType):
# Streaming completion usage tracking is not implemented.
Expand Down
95 changes: 76 additions & 19 deletions adalflow/tests/test_aws_bedrock_client.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,14 @@
import unittest
from unittest.mock import patch, Mock
from unittest.mock import Mock

# use the openai for mocking standard data types

from adalflow.core.types import ModelType, GeneratorOutput
from adalflow.components.model_client import BedrockAPIClient


def getenv_side_effect(key):
# This dictionary can hold more keys and values as needed
env_vars = {
"AWS_ACCESS_KEY_ID": "fake_api_key",
"AWS_SECRET_ACCESS_KEY": "fake_api_key",
"AWS_REGION_NAME": "fake_api_key",
}
return env_vars.get(key, None) # Returns None if key is not found


# modified from test_openai_client.py
class TestBedrockClient(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
self.client = BedrockAPIClient()
self.mock_response = {
"ResponseMetadata": {
Expand All @@ -41,21 +30,34 @@ def setUp(self):
"usage": {"inputTokens": 20, "outputTokens": 10, "totalTokens": 30},
"metrics": {"latencyMs": 430},
}
self.mock_stream_response = {
"ResponseMetadata": {
"RequestId": "c76d625e-9fdb-4173-8138-debdd724fc56",
"HTTPStatusCode": 200,
"HTTPHeaders": {
"date": "Sun, 12 Jan 2025 15:10:00 GMT",
"content-type": "application/vnd.amazon.eventstream",
"transfer-encoding": "chunked",
"connection": "keep-alive",
"x-amzn-requestid": "c76d625e-9fdb-4173-8138-debdd724fc56",
},
"RetryAttempts": 0,
},
"stream": iter(()),
}

self.api_kwargs = {
"messages": [{"role": "user", "content": "Hello"}],
"model": "gpt-3.5-turbo",
}

@patch.object(BedrockAPIClient, "init_sync_client")
@patch("adalflow.components.model_client.bedrock_client.boto3")
def test_call(self, MockBedrock, mock_init_sync_client):
def test_call(self) -> None:
"""Test that the converse method is called correctly."""
mock_sync_client = Mock()
MockBedrock.return_value = mock_sync_client
mock_init_sync_client.return_value = mock_sync_client

# Mock the client's api: converse
# Mock the converse API calls.
mock_sync_client.converse = Mock(return_value=self.mock_response)
mock_sync_client.converse_stream = Mock(return_value=self.mock_response)

# Set the sync client
self.client.sync_client = mock_sync_client
Expand All @@ -65,6 +67,7 @@ def test_call(self, MockBedrock, mock_init_sync_client):

# Assertions
mock_sync_client.converse.assert_called_once_with(**self.api_kwargs)
mock_sync_client.converse_stream.assert_not_called()
self.assertEqual(result, self.mock_response)

# test parse_chat_completion
Expand All @@ -75,6 +78,60 @@ def test_call(self, MockBedrock, mock_init_sync_client):
self.assertEqual(output.usage.completion_tokens, 10)
self.assertEqual(output.usage.total_tokens, 30)

def test_streaming_call(self) -> None:
"""Test that a streaming call calls the converse_stream method."""
mock_sync_client = Mock()

# Mock the converse API calls.
mock_sync_client.converse_stream = Mock(return_value=self.mock_response)
mock_sync_client.converse = Mock(return_value=self.mock_response)

# Set the sync client.
self.client.sync_client = mock_sync_client

# Call the call method.
stream_kwargs = self.api_kwargs | {"stream": True}
self.client.call(api_kwargs=stream_kwargs, model_type=ModelType.LLM)

# Assert the streaming call was made.
mock_sync_client.converse_stream.assert_called_once_with(**stream_kwargs)
mock_sync_client.converse.assert_not_called()

def test_call_value_error(self) -> None:
"""Test that a ValueError is raised when an invalid model_type is passed."""
mock_sync_client = Mock()

# Set the sync client
self.client.sync_client = mock_sync_client

# Test that ValueError is raised
with self.assertRaises(ValueError):
self.client.call(
api_kwargs={},
model_type=ModelType.UNDEFINED, # This should trigger ValueError
)

def test_parse_chat_completion(self) -> None:
"""Test that the parse_chat_completion does not call usage completion when
streaming."""
mock_track_completion_usage = Mock()
self.client.track_completion_usage = mock_track_completion_usage

self.client.chat_completion_parser = self.client.handle_stream_response
generator_output = self.client.parse_chat_completion(self.mock_stream_response)

mock_track_completion_usage.assert_not_called()
assert isinstance(generator_output, GeneratorOutput)

def test_parse_chat_completion_call_usage(self) -> None:
"""Test that the parse_chat_completion calls usage completion when streaming."""
mock_track_completion_usage = Mock()
self.client.track_completion_usage = mock_track_completion_usage
generator_output = self.client.parse_chat_completion(self.mock_response)

mock_track_completion_usage.assert_called_once()
assert isinstance(generator_output, GeneratorOutput)


if __name__ == "__main__":
unittest.main()

0 comments on commit 49ac890

Please sign in to comment.