Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for multimodal openai - early version #313

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 177 additions & 10 deletions adalflow/adalflow/components/model_client/openai_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""OpenAI ModelClient integration."""

import os
import base64
from typing import (
Dict,
Sequence,
Expand Down Expand Up @@ -35,6 +36,7 @@
from openai.types import (
Completion,
CreateEmbeddingResponse,
Image,
)
from openai.types.chat import ChatCompletionChunk, ChatCompletion

Expand Down Expand Up @@ -99,7 +101,7 @@ def get_probabilities(completion: ChatCompletion) -> List[List[TokenLogProb]]:
class OpenAIClient(ModelClient):
__doc__ = r"""A component wrapper for the OpenAI API client.

Support both embedding and chat completion API.
Support both embedding and chat completion API, including multimodal capabilities.

Users (1) simplify use ``Embedder`` and ``Generator`` components by passing OpenAIClient() as the model_client.
(2) can use this as an example to create their own API client or extend this class(copying and modifing the code) in their own project.
Expand All @@ -110,6 +112,17 @@ class OpenAIClient(ModelClient):
Instead
- use :ref:`OutputParser<components-output_parsers>` for response parsing and formating.

For multimodal inputs, provide images in model_kwargs["images"] as a path, URL, or list of them.
The model must support vision capabilities (e.g., gpt-4o, gpt-4o-mini, o1, o1-mini).

For image generation, use model_type=ModelType.IMAGE_GENERATION and provide:
- model: "dall-e-3" or "dall-e-2"
- prompt: Text description of the image to generate
- size: "1024x1024", "1024x1792", or "1792x1024" for DALL-E 3; "256x256", "512x512", or "1024x1024" for DALL-E 2
- quality: "standard" or "hd" (DALL-E 3 only)
- n: Number of images to generate (1 for DALL-E 3, 1-10 for DALL-E 2)
- response_format: "url" or "b64_json"

Args:
api_key (Optional[str], optional): OpenAI API key. Defaults to None.
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None.
Expand All @@ -118,6 +131,8 @@ class OpenAIClient(ModelClient):
References:
- Embeddings models: https://platform.openai.com/docs/guides/embeddings
- Chat models: https://platform.openai.com/docs/guides/text-generation
- Vision models: https://platform.openai.com/docs/guides/vision
- Image models: https://platform.openai.com/docs/guides/images
- OpenAI docs: https://platform.openai.com/docs/introduction
"""

Expand Down Expand Up @@ -200,7 +215,7 @@ def track_completion_usage(
def parse_embedding_response(
self, response: CreateEmbeddingResponse
) -> EmbedderOutput:
r"""Parse the embedding response to a structure LightRAG components can understand.
r"""Parse the embedding response to a structure Adalflow components can understand.

Should be called in ``Embedder``.
"""
Expand All @@ -218,7 +233,20 @@ def convert_inputs_to_api_kwargs(
) -> Dict:
r"""
Specify the API input type and output api_kwargs that will be used in _call and _acall methods.
Convert the Component's standard input, and system_input(chat model) and model_kwargs into API-specific format
Convert the Component's standard input, and system_input(chat model) and model_kwargs into API-specific format.
For multimodal inputs, images can be provided in model_kwargs["images"] as a string path, URL, or list of them.
The model specified in model_kwargs["model"] must support multimodal capabilities when using images.

Args:
input: The input text or messages to process
model_kwargs: Additional parameters including:
- images: Optional image source(s) as path, URL, or list of them
- detail: Image detail level ('auto', 'low', or 'high'), defaults to 'auto'
- model: The model to use (must support multimodal inputs if images are provided)
model_type: The type of model (EMBEDDER or LLM)

Returns:
Dict: API-specific kwargs for the model call
"""

final_model_kwargs = model_kwargs.copy()
Expand All @@ -232,6 +260,8 @@ def convert_inputs_to_api_kwargs(
elif model_type == ModelType.LLM:
# convert input to messages
messages: List[Dict[str, str]] = []
images = final_model_kwargs.pop("images", None)
detail = final_model_kwargs.pop("detail", "auto")

if self._input_type == "messages":
system_start_tag = "<START_OF_SYSTEM_PROMPT>"
Expand All @@ -248,19 +278,78 @@ def convert_inputs_to_api_kwargs(
if match:
system_prompt = match.group(1)
input_str = match.group(2)

else:
print("No match found.")
if system_prompt and input_str:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": input_str})
if images:
content = [{"type": "text", "text": input_str}]
if isinstance(images, (str, dict)):
images = [images]
for img in images:
content.append(self._prepare_image_content(img, detail))
messages.append({"role": "user", "content": content})
else:
messages.append({"role": "user", "content": input_str})
if len(messages) == 0:
messages.append({"role": "system", "content": input})
if images:
content = [{"type": "text", "text": input}]
if isinstance(images, (str, dict)):
images = [images]
for img in images:
content.append(self._prepare_image_content(img, detail))
messages.append({"role": "user", "content": content})
else:
messages.append({"role": "system", "content": input})
final_model_kwargs["messages"] = messages
elif model_type == ModelType.IMAGE_GENERATION:
# For image generation, input is the prompt
final_model_kwargs["prompt"] = input
# Set defaults for DALL-E 3 if not specified
if "model" not in final_model_kwargs:
final_model_kwargs["model"] = "dall-e-3"
if "size" not in final_model_kwargs:
final_model_kwargs["size"] = "1024x1024"
if "quality" not in final_model_kwargs:
final_model_kwargs["quality"] = "standard"
if "n" not in final_model_kwargs:
final_model_kwargs["n"] = 1
if "response_format" not in final_model_kwargs:
final_model_kwargs["response_format"] = "url"

# Handle image edits and variations
if "image" in final_model_kwargs:
if isinstance(final_model_kwargs["image"], str):
# If it's a file path, encode it
if os.path.isfile(final_model_kwargs["image"]):
final_model_kwargs["image"] = self._encode_image(final_model_kwargs["image"])
if "mask" in final_model_kwargs and isinstance(final_model_kwargs["mask"], str):
if os.path.isfile(final_model_kwargs["mask"]):
final_model_kwargs["mask"] = self._encode_image(final_model_kwargs["mask"])
else:
raise ValueError(f"model_type {model_type} is not supported")
return final_model_kwargs

def parse_image_generation_response(self, response: List[Image]) -> GeneratorOutput:
"""Parse the image generation response into a GeneratorOutput."""
try:
# Extract URLs or base64 data from the response
data = [img.url or img.b64_json for img in response]
# For single image responses, unwrap from list
if len(data) == 1:
data = data[0]
return GeneratorOutput(
data=data,
raw_response=str(response),
)
except Exception as e:
log.error(f"Error parsing image generation response: {e}")
return GeneratorOutput(
data=None,
error=str(e),
raw_response=str(response)
)

@backoff.on_exception(
backoff.expo,
(
Expand All @@ -285,6 +374,19 @@ def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINE
self.chat_completion_parser = handle_streaming_response
return self.sync_client.chat.completions.create(**api_kwargs)
return self.sync_client.chat.completions.create(**api_kwargs)
elif model_type == ModelType.IMAGE_GENERATION:
# Determine which image API to call based on the presence of image/mask
if "image" in api_kwargs:
if "mask" in api_kwargs:
# Image edit
response = self.sync_client.images.edit(**api_kwargs)
else:
# Image variation
response = self.sync_client.images.create_variation(**api_kwargs)
else:
# Image generation
response = self.sync_client.images.generate(**api_kwargs)
return response.data
else:
raise ValueError(f"model_type {model_type} is not supported")

Expand All @@ -311,6 +413,19 @@ async def acall(
return await self.async_client.embeddings.create(**api_kwargs)
elif model_type == ModelType.LLM:
return await self.async_client.chat.completions.create(**api_kwargs)
elif model_type == ModelType.IMAGE_GENERATION:
# Determine which image API to call based on the presence of image/mask
if "image" in api_kwargs:
if "mask" in api_kwargs:
# Image edit
response = await self.async_client.images.edit(**api_kwargs)
else:
# Image variation
response = await self.async_client.images.create_variation(**api_kwargs)
else:
# Image generation
response = await self.async_client.images.generate(**api_kwargs)
return response.data
else:
raise ValueError(f"model_type {model_type} is not supported")

Expand All @@ -332,22 +447,74 @@ def to_dict(self) -> Dict[str, Any]:
output = super().to_dict(exclude=exclude)
return output

def _encode_image(self, image_path: str) -> str:
"""Encode image to base64 string.

Args:
image_path: Path to image file.

Returns:
Base64 encoded image string.

Raises:
ValueError: If the file cannot be read or doesn't exist.
"""
try:
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
except FileNotFoundError:
raise ValueError(f"Image file not found: {image_path}")
except PermissionError:
raise ValueError(f"Permission denied when reading image file: {image_path}")
except Exception as e:
raise ValueError(f"Error encoding image {image_path}: {str(e)}")

def _prepare_image_content(
self, image_source: Union[str, Dict[str, Any]], detail: str = "auto"
) -> Dict[str, Any]:
"""Prepare image content for API request.

Args:
image_source: Either a path to local image or a URL.
detail: Image detail level ('auto', 'low', or 'high').

Returns:
Formatted image content for API request.
"""
if isinstance(image_source, str):
if image_source.startswith(("http://", "https://")):
return {
"type": "image_url",
"image_url": {"url": image_source, "detail": detail},
}
else:
base64_image = self._encode_image(image_source)
return {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"detail": detail,
},
}
return image_source


# Example usage:
# if __name__ == "__main__":
# from adalflow.core import Generator
# from adalflow.utils import setup_env, get_logger

#
# log = get_logger(level="DEBUG")

#
# setup_env()
# prompt_kwargs = {"input_str": "What is the meaning of life?"}

#
# gen = Generator(
# model_client=OpenAIClient(),
# model_kwargs={"model": "gpt-3.5-turbo", "stream": True},
# )
# gen_response = gen(prompt_kwargs)
# print(f"gen_response: {gen_response}")

#
# for genout in gen_response.data:
# print(f"genout: {genout}")
1 change: 1 addition & 0 deletions adalflow/adalflow/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class ModelType(Enum):
EMBEDDER = auto()
LLM = auto()
RERANKER = auto() # ranking model
IMAGE_GENERATION = auto() # image generation models like DALL-E
UNDEFINED = auto()


Expand Down
Loading
Loading