From 49ac890b93c03d089b378b5dca9b5f0540c01f3e Mon Sep 17 00:00:00 2001 From: Lloyd Hamilton Date: Sun, 12 Jan 2025 17:08:24 +0000 Subject: [PATCH] test: add tests for parse_chat_completion method --- .../components/model_client/bedrock_client.py | 3 +- adalflow/tests/test_aws_bedrock_client.py | 95 +++++++++++++++---- 2 files changed, 77 insertions(+), 21 deletions(-) diff --git a/adalflow/adalflow/components/model_client/bedrock_client.py b/adalflow/adalflow/components/model_client/bedrock_client.py index 715c2ead..493147e6 100644 --- a/adalflow/adalflow/components/model_client/bedrock_client.py +++ b/adalflow/adalflow/components/model_client/bedrock_client.py @@ -174,7 +174,7 @@ def handle_stream_response(self, stream: dict) -> GeneratorType: log.debug(f"Raw chunk: {chunk}") yield chunk except Exception as e: - print(f"Error in handle_stream_response: {e}") # Debug print + log.debug(f"Error in handle_stream_response: {e}") # Debug print raise def parse_chat_completion(self, completion: dict) -> "GeneratorOutput": @@ -193,7 +193,6 @@ def parse_chat_completion(self, completion: dict) -> "GeneratorOutput": """ try: usage = None - print(completion) data = self.chat_completion_parser(completion) if not isinstance(data, GeneratorType): # Streaming completion usage tracking is not implemented. diff --git a/adalflow/tests/test_aws_bedrock_client.py b/adalflow/tests/test_aws_bedrock_client.py index 77dc9fba..1beaa4ee 100644 --- a/adalflow/tests/test_aws_bedrock_client.py +++ b/adalflow/tests/test_aws_bedrock_client.py @@ -1,5 +1,5 @@ import unittest -from unittest.mock import patch, Mock +from unittest.mock import Mock # use the openai for mocking standard data types @@ -7,19 +7,8 @@ from adalflow.components.model_client import BedrockAPIClient -def getenv_side_effect(key): - # This dictionary can hold more keys and values as needed - env_vars = { - "AWS_ACCESS_KEY_ID": "fake_api_key", - "AWS_SECRET_ACCESS_KEY": "fake_api_key", - "AWS_REGION_NAME": "fake_api_key", - } - return env_vars.get(key, None) # Returns None if key is not found - - -# modified from test_openai_client.py class TestBedrockClient(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.client = BedrockAPIClient() self.mock_response = { "ResponseMetadata": { @@ -41,21 +30,34 @@ def setUp(self): "usage": {"inputTokens": 20, "outputTokens": 10, "totalTokens": 30}, "metrics": {"latencyMs": 430}, } + self.mock_stream_response = { + "ResponseMetadata": { + "RequestId": "c76d625e-9fdb-4173-8138-debdd724fc56", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "date": "Sun, 12 Jan 2025 15:10:00 GMT", + "content-type": "application/vnd.amazon.eventstream", + "transfer-encoding": "chunked", + "connection": "keep-alive", + "x-amzn-requestid": "c76d625e-9fdb-4173-8138-debdd724fc56", + }, + "RetryAttempts": 0, + }, + "stream": iter(()), + } self.api_kwargs = { "messages": [{"role": "user", "content": "Hello"}], "model": "gpt-3.5-turbo", } - @patch.object(BedrockAPIClient, "init_sync_client") - @patch("adalflow.components.model_client.bedrock_client.boto3") - def test_call(self, MockBedrock, mock_init_sync_client): + def test_call(self) -> None: + """Test that the converse method is called correctly.""" mock_sync_client = Mock() - MockBedrock.return_value = mock_sync_client - mock_init_sync_client.return_value = mock_sync_client - # Mock the client's api: converse + # Mock the converse API calls. mock_sync_client.converse = Mock(return_value=self.mock_response) + mock_sync_client.converse_stream = Mock(return_value=self.mock_response) # Set the sync client self.client.sync_client = mock_sync_client @@ -65,6 +67,7 @@ def test_call(self, MockBedrock, mock_init_sync_client): # Assertions mock_sync_client.converse.assert_called_once_with(**self.api_kwargs) + mock_sync_client.converse_stream.assert_not_called() self.assertEqual(result, self.mock_response) # test parse_chat_completion @@ -75,6 +78,60 @@ def test_call(self, MockBedrock, mock_init_sync_client): self.assertEqual(output.usage.completion_tokens, 10) self.assertEqual(output.usage.total_tokens, 30) + def test_streaming_call(self) -> None: + """Test that a streaming call calls the converse_stream method.""" + mock_sync_client = Mock() + + # Mock the converse API calls. + mock_sync_client.converse_stream = Mock(return_value=self.mock_response) + mock_sync_client.converse = Mock(return_value=self.mock_response) + + # Set the sync client. + self.client.sync_client = mock_sync_client + + # Call the call method. + stream_kwargs = self.api_kwargs | {"stream": True} + self.client.call(api_kwargs=stream_kwargs, model_type=ModelType.LLM) + + # Assert the streaming call was made. + mock_sync_client.converse_stream.assert_called_once_with(**stream_kwargs) + mock_sync_client.converse.assert_not_called() + + def test_call_value_error(self) -> None: + """Test that a ValueError is raised when an invalid model_type is passed.""" + mock_sync_client = Mock() + + # Set the sync client + self.client.sync_client = mock_sync_client + + # Test that ValueError is raised + with self.assertRaises(ValueError): + self.client.call( + api_kwargs={}, + model_type=ModelType.UNDEFINED, # This should trigger ValueError + ) + + def test_parse_chat_completion(self) -> None: + """Test that the parse_chat_completion does not call usage completion when + streaming.""" + mock_track_completion_usage = Mock() + self.client.track_completion_usage = mock_track_completion_usage + + self.client.chat_completion_parser = self.client.handle_stream_response + generator_output = self.client.parse_chat_completion(self.mock_stream_response) + + mock_track_completion_usage.assert_not_called() + assert isinstance(generator_output, GeneratorOutput) + + def test_parse_chat_completion_call_usage(self) -> None: + """Test that the parse_chat_completion calls usage completion when streaming.""" + mock_track_completion_usage = Mock() + self.client.track_completion_usage = mock_track_completion_usage + generator_output = self.client.parse_chat_completion(self.mock_response) + + mock_track_completion_usage.assert_called_once() + assert isinstance(generator_output, GeneratorOutput) + if __name__ == "__main__": unittest.main()