Skip to content

Commit

Permalink
fix: Validate image size for Stable Diffusion 3 on adapter side (#175)
Browse files Browse the repository at this point in the history
  • Loading branch information
roman-romanov-o authored Nov 6, 2024
1 parent 404a3c9 commit cf0b629
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 21 deletions.
2 changes: 2 additions & 0 deletions aidial_adapter_bedrock/llm/model/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ async def get_bedrock_adapter(
model,
api_key,
image_to_image_supported=True,
image_width_constraints=(640, 1536),
image_height_constraints=(640, 1536),
)
case ChatCompletionDeployment.AMAZON_TITAN_TG1_LARGE:
return AmazonAdapter.create(
Expand Down
83 changes: 64 additions & 19 deletions aidial_adapter_bedrock/llm/model/stability/v2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional, assert_never
from io import BytesIO
from typing import List, Optional, Tuple, assert_never

from aidial_sdk.chat_completion import (
Message,
Expand All @@ -7,6 +8,8 @@
Role,
)
from aidial_sdk.chat_completion.request import ImageURL
from aidial_sdk.exceptions import RequestValidationError
from PIL import Image
from pydantic import BaseModel

from aidial_adapter_bedrock.bedrock import Bedrock
Expand All @@ -27,10 +30,51 @@
from aidial_adapter_bedrock.llm.model.stability.storage import save_to_storage
from aidial_adapter_bedrock.llm.truncate_prompt import DiscardedMessages
from aidial_adapter_bedrock.utils.json import remove_nones
from aidial_adapter_bedrock.utils.resource import Resource

SUPPORTED_IMAGE_TYPES = ["image/jpeg", "image/png", "image/webp"]


def _validate_image_size(
image: Resource,
width_constraints: Tuple[int, int] | None,
height_constraints: Tuple[int, int] | None,
) -> None:
if width_constraints is None and height_constraints is None:
return

with Image.open(BytesIO(image.data)) as img:
width, height = img.size

for constraints, value, name in [
(width_constraints, width, "width"),
(height_constraints, height, "height"),
]:
if constraints is None:
continue
min_value, max_value = constraints
if not (min_value <= value <= max_value):
error_msg = (
f"Image {name} is {value}, but should be "
f"between {min_value} and {max_value}"
)
raise RequestValidationError(
message=error_msg,
display_message=error_msg,
code="invalid_argument",
)


def _validate_last_message(messages: List[Message]):
if not messages:
raise ValidationError("No messages provided")

last_message = messages[-1]
if last_message.role != Role.USER:
raise ValidationError("Last message must be from user")
return last_message


class StabilityV2Response(BaseModel):
seeds: List[int]
images: List[str]
Expand Down Expand Up @@ -74,6 +118,8 @@ class StabilityV2Adapter(ChatCompletionAdapter):
client: Bedrock
storage: Optional[FileStorage]
image_to_image_supported: bool
width_constraints: Tuple[int, int] | None
height_constraints: Tuple[int, int] | None

@classmethod
def create(
Expand All @@ -82,28 +128,23 @@ def create(
model: str,
api_key: str,
image_to_image_supported: bool,
image_width_constraints: Tuple[int, int] | None = None,
image_height_constraints: Tuple[int, int] | None = None,
):
storage: Optional[FileStorage] = create_file_storage(api_key)
return cls(
client=client,
model=model,
storage=storage,
image_to_image_supported=image_to_image_supported,
width_constraints=image_width_constraints,
height_constraints=image_height_constraints,
)

def _validate_last_message(self, messages: List[Message]):
if not messages:
raise ValidationError("No messages provided")

last_message = messages[-1]
if last_message.role != Role.USER:
raise ValidationError("Last message must be from user")
return last_message

async def compute_discarded_messages(
self, params: ModelParameters, messages: List[Message]
) -> DiscardedMessages | None:
self._validate_last_message(messages)
_validate_last_message(messages)
return list(range(len(messages) - 1))

async def chat(
Expand All @@ -115,7 +156,7 @@ async def chat(

text_prompt = None
image_resources: List[DialResource] = []
last_message = self._validate_last_message(messages)
last_message = _validate_last_message(messages)
# Handle text content
match last_message.content:
case str(text):
Expand Down Expand Up @@ -166,26 +207,30 @@ async def chat(
if len(image_resources) > 1:
raise ValidationError("Only one input image is supported")

if self.image_to_image_supported and image_resources:
image_resource = await image_resources[0].download(self.storage)
_validate_image_size(
image_resource, self.width_constraints, self.height_constraints
)
else:
image_resource = None

response, _ = await self.client.ainvoke_non_streaming(
self.model,
remove_nones(
{
"prompt": text_prompt,
"image": (
(
await image_resources[0].download(self.storage)
).data_base64
if image_resources
else None
image_resource.data_base64 if image_resource else None
),
"mode": (
"image-to-image" if image_resources else "text-to-image"
"image-to-image" if image_resource else "text-to-image"
),
"output_format": "png",
# This parameter controls how much input image will affect generation from 0 to 1,
# where 0 means that output will be identical to input image and 1 means that model will ignore input image
# Since there is no recommended default value, we use 0.5 as a middle ground
"strength": 0.5 if image_resources else None,
"strength": 0.5 if image_resource else None,
}
),
)
Expand Down
Binary file modified tests/integration_tests/images/dog-sample-image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 6 additions & 2 deletions tests/integration_tests/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,12 @@ async def test_image_to_image_with_too_small_picture(
model=deployment.value,
messages=[user_with_image_content_part("test", BLUE_PNG_PICTURE)],
)
assert exc_info.value.status_code == 400
assert "width must be between 640 and 1536" in exc_info.value.message

assert exc_info.value.status_code == 422
assert (
"Image width is 3, but should be between 640 and 1536"
in exc_info.value.message
)


@pytest.mark.parametrize(
Expand Down

0 comments on commit cf0b629

Please sign in to comment.