diff --git a/aidial_adapter_bedrock/llm/model/adapter.py b/aidial_adapter_bedrock/llm/model/adapter.py index 2a52e45..dcc39b7 100644 --- a/aidial_adapter_bedrock/llm/model/adapter.py +++ b/aidial_adapter_bedrock/llm/model/adapter.py @@ -98,6 +98,8 @@ async def get_bedrock_adapter( model, api_key, image_to_image_supported=True, + image_width_constraints=(640, 1536), + image_height_constraints=(640, 1536), ) case ChatCompletionDeployment.AMAZON_TITAN_TG1_LARGE: return AmazonAdapter.create( diff --git a/aidial_adapter_bedrock/llm/model/stability/v2.py b/aidial_adapter_bedrock/llm/model/stability/v2.py index dd8787c..ca33130 100644 --- a/aidial_adapter_bedrock/llm/model/stability/v2.py +++ b/aidial_adapter_bedrock/llm/model/stability/v2.py @@ -1,4 +1,5 @@ -from typing import List, Optional, assert_never +from io import BytesIO +from typing import List, Optional, Tuple, assert_never from aidial_sdk.chat_completion import ( Message, @@ -7,6 +8,8 @@ Role, ) from aidial_sdk.chat_completion.request import ImageURL +from aidial_sdk.exceptions import RequestValidationError +from PIL import Image from pydantic import BaseModel from aidial_adapter_bedrock.bedrock import Bedrock @@ -27,10 +30,51 @@ from aidial_adapter_bedrock.llm.model.stability.storage import save_to_storage from aidial_adapter_bedrock.llm.truncate_prompt import DiscardedMessages from aidial_adapter_bedrock.utils.json import remove_nones +from aidial_adapter_bedrock.utils.resource import Resource SUPPORTED_IMAGE_TYPES = ["image/jpeg", "image/png", "image/webp"] +def _validate_image_size( + image: Resource, + width_constraints: Tuple[int, int] | None, + height_constraints: Tuple[int, int] | None, +) -> None: + if width_constraints is None and height_constraints is None: + return + + with Image.open(BytesIO(image.data)) as img: + width, height = img.size + + for constraints, value, name in [ + (width_constraints, width, "width"), + (height_constraints, height, "height"), + ]: + if constraints is None: + continue + min_value, max_value = constraints + if not (min_value <= value <= max_value): + error_msg = ( + f"Image {name} is {value}, but should be " + f"between {min_value} and {max_value}" + ) + raise RequestValidationError( + message=error_msg, + display_message=error_msg, + code="invalid_argument", + ) + + +def _validate_last_message(messages: List[Message]): + if not messages: + raise ValidationError("No messages provided") + + last_message = messages[-1] + if last_message.role != Role.USER: + raise ValidationError("Last message must be from user") + return last_message + + class StabilityV2Response(BaseModel): seeds: List[int] images: List[str] @@ -74,6 +118,8 @@ class StabilityV2Adapter(ChatCompletionAdapter): client: Bedrock storage: Optional[FileStorage] image_to_image_supported: bool + width_constraints: Tuple[int, int] | None + height_constraints: Tuple[int, int] | None @classmethod def create( @@ -82,6 +128,8 @@ def create( model: str, api_key: str, image_to_image_supported: bool, + image_width_constraints: Tuple[int, int] | None = None, + image_height_constraints: Tuple[int, int] | None = None, ): storage: Optional[FileStorage] = create_file_storage(api_key) return cls( @@ -89,21 +137,14 @@ def create( model=model, storage=storage, image_to_image_supported=image_to_image_supported, + width_constraints=image_width_constraints, + height_constraints=image_height_constraints, ) - def _validate_last_message(self, messages: List[Message]): - if not messages: - raise ValidationError("No messages provided") - - last_message = messages[-1] - if last_message.role != Role.USER: - raise ValidationError("Last message must be from user") - return last_message - async def compute_discarded_messages( self, params: ModelParameters, messages: List[Message] ) -> DiscardedMessages | None: - self._validate_last_message(messages) + _validate_last_message(messages) return list(range(len(messages) - 1)) async def chat( @@ -115,7 +156,7 @@ async def chat( text_prompt = None image_resources: List[DialResource] = [] - last_message = self._validate_last_message(messages) + last_message = _validate_last_message(messages) # Handle text content match last_message.content: case str(text): @@ -166,26 +207,30 @@ async def chat( if len(image_resources) > 1: raise ValidationError("Only one input image is supported") + if self.image_to_image_supported and image_resources: + image_resource = await image_resources[0].download(self.storage) + _validate_image_size( + image_resource, self.width_constraints, self.height_constraints + ) + else: + image_resource = None + response, _ = await self.client.ainvoke_non_streaming( self.model, remove_nones( { "prompt": text_prompt, "image": ( - ( - await image_resources[0].download(self.storage) - ).data_base64 - if image_resources - else None + image_resource.data_base64 if image_resource else None ), "mode": ( - "image-to-image" if image_resources else "text-to-image" + "image-to-image" if image_resource else "text-to-image" ), "output_format": "png", # This parameter controls how much input image will affect generation from 0 to 1, # where 0 means that output will be identical to input image and 1 means that model will ignore input image # Since there is no recommended default value, we use 0.5 as a middle ground - "strength": 0.5 if image_resources else None, + "strength": 0.5 if image_resource else None, } ), ) diff --git a/tests/integration_tests/images/dog-sample-image.png b/tests/integration_tests/images/dog-sample-image.png index 8428942..3a6b953 100644 Binary files a/tests/integration_tests/images/dog-sample-image.png and b/tests/integration_tests/images/dog-sample-image.png differ diff --git a/tests/integration_tests/test_stable_diffusion.py b/tests/integration_tests/test_stable_diffusion.py index 955b8d4..dd412d3 100644 --- a/tests/integration_tests/test_stable_diffusion.py +++ b/tests/integration_tests/test_stable_diffusion.py @@ -155,8 +155,12 @@ async def test_image_to_image_with_too_small_picture( model=deployment.value, messages=[user_with_image_content_part("test", BLUE_PNG_PICTURE)], ) - assert exc_info.value.status_code == 400 - assert "width must be between 640 and 1536" in exc_info.value.message + + assert exc_info.value.status_code == 422 + assert ( + "Image width is 3, but should be between 640 and 1536" + in exc_info.value.message + ) @pytest.mark.parametrize(