Skip to content

Commit

Permalink
add image gen
Browse files Browse the repository at this point in the history
  • Loading branch information
fm1320 committed Jan 9, 2025
1 parent 852c212 commit ff1060a
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 3 deletions.
80 changes: 80 additions & 0 deletions adalflow/adalflow/components/model_client/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from openai.types import (
Completion,
CreateEmbeddingResponse,
Image,
)
from openai.types.chat import ChatCompletionChunk, ChatCompletion

Expand Down Expand Up @@ -114,6 +115,14 @@ class OpenAIClient(ModelClient):
For multimodal inputs, provide images in model_kwargs["images"] as a path, URL, or list of them.
The model must support vision capabilities (e.g., gpt-4o, gpt-4o-mini, o1, o1-mini).
For image generation, use model_type=ModelType.IMAGE_GENERATION and provide:
- model: "dall-e-3" or "dall-e-2"
- prompt: Text description of the image to generate
- size: "1024x1024", "1024x1792", or "1792x1024" for DALL-E 3; "256x256", "512x512", or "1024x1024" for DALL-E 2
- quality: "standard" or "hd" (DALL-E 3 only)
- n: Number of images to generate (1 for DALL-E 3, 1-10 for DALL-E 2)
- response_format: "url" or "b64_json"
Args:
api_key (Optional[str], optional): OpenAI API key. Defaults to None.
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None.
Expand All @@ -123,6 +132,7 @@ class OpenAIClient(ModelClient):
- Embeddings models: https://platform.openai.com/docs/guides/embeddings
- Chat models: https://platform.openai.com/docs/guides/text-generation
- Vision models: https://platform.openai.com/docs/guides/vision
- Image models: https://platform.openai.com/docs/guides/images
- OpenAI docs: https://platform.openai.com/docs/introduction
"""

Expand Down Expand Up @@ -292,10 +302,54 @@ def convert_inputs_to_api_kwargs(
else:
messages.append({"role": "system", "content": input})
final_model_kwargs["messages"] = messages
elif model_type == ModelType.IMAGE_GENERATION:
# For image generation, input is the prompt
final_model_kwargs["prompt"] = input
# Set defaults for DALL-E 3 if not specified
if "model" not in final_model_kwargs:
final_model_kwargs["model"] = "dall-e-3"
if "size" not in final_model_kwargs:
final_model_kwargs["size"] = "1024x1024"
if "quality" not in final_model_kwargs:
final_model_kwargs["quality"] = "standard"
if "n" not in final_model_kwargs:
final_model_kwargs["n"] = 1
if "response_format" not in final_model_kwargs:
final_model_kwargs["response_format"] = "url"

# Handle image edits and variations
if "image" in final_model_kwargs:
if isinstance(final_model_kwargs["image"], str):
# If it's a file path, encode it
if os.path.isfile(final_model_kwargs["image"]):
final_model_kwargs["image"] = self._encode_image(final_model_kwargs["image"])
if "mask" in final_model_kwargs and isinstance(final_model_kwargs["mask"], str):
if os.path.isfile(final_model_kwargs["mask"]):
final_model_kwargs["mask"] = self._encode_image(final_model_kwargs["mask"])
else:
raise ValueError(f"model_type {model_type} is not supported")
return final_model_kwargs

def parse_image_generation_response(self, response: List[Image]) -> GeneratorOutput:
"""Parse the image generation response into a GeneratorOutput."""
try:
# Extract URLs or base64 data from the response
data = [img.url or img.b64_json for img in response]
# For single image responses, unwrap from list
if len(data) == 1:
data = data[0]
return GeneratorOutput(
data=data,
raw_response=str(response),
)
except Exception as e:
log.error(f"Error parsing image generation response: {e}")
return GeneratorOutput(
data=None,
error=str(e),
raw_response=str(response)
)

@backoff.on_exception(
backoff.expo,
(
Expand All @@ -320,6 +374,19 @@ def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINE
self.chat_completion_parser = handle_streaming_response
return self.sync_client.chat.completions.create(**api_kwargs)
return self.sync_client.chat.completions.create(**api_kwargs)
elif model_type == ModelType.IMAGE_GENERATION:
# Determine which image API to call based on the presence of image/mask
if "image" in api_kwargs:
if "mask" in api_kwargs:
# Image edit
response = self.sync_client.images.edit(**api_kwargs)
else:
# Image variation
response = self.sync_client.images.create_variation(**api_kwargs)
else:
# Image generation
response = self.sync_client.images.generate(**api_kwargs)
return response.data
else:
raise ValueError(f"model_type {model_type} is not supported")

Expand All @@ -346,6 +413,19 @@ async def acall(
return await self.async_client.embeddings.create(**api_kwargs)
elif model_type == ModelType.LLM:
return await self.async_client.chat.completions.create(**api_kwargs)
elif model_type == ModelType.IMAGE_GENERATION:
# Determine which image API to call based on the presence of image/mask
if "image" in api_kwargs:
if "mask" in api_kwargs:
# Image edit
response = await self.async_client.images.edit(**api_kwargs)
else:
# Image variation
response = await self.async_client.images.create_variation(**api_kwargs)
else:
# Image generation
response = await self.async_client.images.generate(**api_kwargs)
return response.data
else:
raise ValueError(f"model_type {model_type} is not supported")

Expand Down
1 change: 1 addition & 0 deletions adalflow/adalflow/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class ModelType(Enum):
EMBEDDER = auto()
LLM = auto()
RERANKER = auto() # ranking model
IMAGE_GENERATION = auto() # image generation models like DALL-E
UNDEFINED = auto()


Expand Down
126 changes: 123 additions & 3 deletions adalflow/tests/test_openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import base64

from openai.types import CompletionUsage
from openai.types import CompletionUsage, Image
from openai.types.chat import ChatCompletion

from adalflow.core.types import ModelType, GeneratorOutput
Expand All @@ -23,7 +23,7 @@ def setUp(self):
"id": "cmpl-3Q8Z5J9Z1Z5z5",
"created": 1635820005,
"object": "chat.completion",
"model": "gpt-3.5-turbo",
"model": "gpt-4o",
"choices": [
{
"message": {
Expand Down Expand Up @@ -59,9 +59,17 @@ def setUp(self):
),
}
self.mock_vision_response = ChatCompletion(**self.mock_vision_response)
self.mock_image_response = [
Image(
url="https://example.com/generated_image.jpg",
b64_json=None,
revised_prompt="A white siamese cat sitting elegantly",
model="dall-e-3",
)
]
self.api_kwargs = {
"messages": [{"role": "user", "content": "Hello"}],
"model": "gpt-3.5-turbo",
"model": "gpt-4o",
}
self.vision_api_kwargs = {
"messages": [
Expand All @@ -81,6 +89,13 @@ def setUp(self):
],
"model": "gpt-4o",
}
self.image_generation_kwargs = {
"model": "dall-e-3",
"prompt": "a white siamese cat",
"size": "1024x1024",
"quality": "standard",
"n": 1,
}

def test_encode_image(self):
# Create a temporary test image file
Expand Down Expand Up @@ -297,6 +312,111 @@ def test_call_with_vision(self, MockSyncOpenAI, mock_init_sync_client):
self.assertEqual(output.usage.prompt_tokens, 25)
self.assertEqual(output.usage.total_tokens, 40)

def test_convert_inputs_to_api_kwargs_for_image_generation(self):
# Test basic image generation
result = self.client.convert_inputs_to_api_kwargs(
input="a white siamese cat",
model_kwargs={"model": "dall-e-3"},
model_type=ModelType.IMAGE_GENERATION,
)
self.assertEqual(result["prompt"], "a white siamese cat")
self.assertEqual(result["model"], "dall-e-3")
self.assertEqual(result["size"], "1024x1024") # default
self.assertEqual(result["quality"], "standard") # default
self.assertEqual(result["n"], 1) # default

# Test image edit
test_image = "test_image.jpg"
test_mask = "test_mask.jpg"
try:
# Create test files
with open(test_image, "wb") as f:
f.write(b"fake image content")
with open(test_mask, "wb") as f:
f.write(b"fake mask content")

result = self.client.convert_inputs_to_api_kwargs(
input="a white siamese cat",
model_kwargs={
"model": "dall-e-2",
"image": test_image,
"mask": test_mask,
},
model_type=ModelType.IMAGE_GENERATION,
)
self.assertEqual(result["prompt"], "a white siamese cat")
self.assertEqual(result["model"], "dall-e-2")
self.assertTrue(isinstance(result["image"], str)) # base64 encoded
self.assertTrue(isinstance(result["mask"], str)) # base64 encoded
finally:
# Cleanup
if os.path.exists(test_image):
os.remove(test_image)
if os.path.exists(test_mask):
os.remove(test_mask)

@patch("adalflow.components.model_client.openai_client.AsyncOpenAI")
async def test_acall_image_generation(self, MockAsyncOpenAI):
mock_async_client = AsyncMock()
MockAsyncOpenAI.return_value = mock_async_client

# Mock the image generation response
mock_async_client.images.generate = AsyncMock(
return_value=type('Response', (), {'data': self.mock_image_response})()
)

# Call the acall method with image generation
result = await self.client.acall(
api_kwargs=self.image_generation_kwargs,
model_type=ModelType.IMAGE_GENERATION,
)

# Assertions
MockAsyncOpenAI.assert_called_once()
mock_async_client.images.generate.assert_awaited_once_with(
**self.image_generation_kwargs
)
self.assertEqual(result, self.mock_image_response)

# Test parse_image_generation_response
output = self.client.parse_image_generation_response(result)
self.assertTrue(isinstance(output, GeneratorOutput))
self.assertEqual(output.data, "https://example.com/generated_image.jpg")

@patch(
"adalflow.components.model_client.openai_client.OpenAIClient.init_sync_client"
)
@patch("adalflow.components.model_client.openai_client.OpenAI")
def test_call_image_generation(self, MockSyncOpenAI, mock_init_sync_client):
mock_sync_client = Mock()
MockSyncOpenAI.return_value = mock_sync_client
mock_init_sync_client.return_value = mock_sync_client

# Mock the image generation response
mock_sync_client.images.generate = Mock(
return_value=type('Response', (), {'data': self.mock_image_response})()
)

# Set the sync client
self.client.sync_client = mock_sync_client

# Call the call method with image generation
result = self.client.call(
api_kwargs=self.image_generation_kwargs,
model_type=ModelType.IMAGE_GENERATION,
)

# Assertions
mock_sync_client.images.generate.assert_called_once_with(
**self.image_generation_kwargs
)
self.assertEqual(result, self.mock_image_response)

# Test parse_image_generation_response
output = self.client.parse_image_generation_response(result)
self.assertTrue(isinstance(output, GeneratorOutput))
self.assertEqual(output.data, "https://example.com/generated_image.jpg")


if __name__ == "__main__":
unittest.main()

0 comments on commit ff1060a

Please sign in to comment.