From 73089ff582fb608396659096b9c2ed1725ed0263 Mon Sep 17 00:00:00 2001 From: fm1320 Date: Mon, 6 Jan 2025 20:29:37 +0000 Subject: [PATCH 1/8] add multi modal support for openai draft --- adalflow/adalflow/__init__.py | 2 + .../components/model_client/__init__.py | 5 + .../model_client/openai_multimodal_client.py | 112 ++++++++++++++ adalflow/adalflow/utils/lazy_import.py | 10 ++ docs/source/tutorials/multimodal.rst | 143 ++++++++++++++++++ notebooks/tutorials/adalflow_multimodal.ipynb | 0 6 files changed, 272 insertions(+) create mode 100644 adalflow/adalflow/components/model_client/openai_multimodal_client.py create mode 100644 docs/source/tutorials/multimodal.rst create mode 100644 notebooks/tutorials/adalflow_multimodal.ipynb diff --git a/adalflow/adalflow/__init__.py b/adalflow/adalflow/__init__.py index 4c9b45ba..2ee67015 100644 --- a/adalflow/adalflow/__init__.py +++ b/adalflow/adalflow/__init__.py @@ -61,6 +61,7 @@ AnthropicAPIClient, CohereAPIClient, BedrockAPIClient, + OpenAIMultimodalClient, ) # data pipeline @@ -129,4 +130,5 @@ "AnthropicAPIClient", "CohereAPIClient", "BedrockAPIClient", + "OpenAIMultimodalClient", ] diff --git a/adalflow/adalflow/components/model_client/__init__.py b/adalflow/adalflow/components/model_client/__init__.py index ae508ece..258285ec 100644 --- a/adalflow/adalflow/components/model_client/__init__.py +++ b/adalflow/adalflow/components/model_client/__init__.py @@ -64,6 +64,10 @@ "adalflow.components.model_client.openai_client.get_probabilities", OptionalPackages.OPENAI, ) +OpenAIMultimodalClient = LazyImport( + "adalflow.components.model_client.openai_multimodal_client.OpenAIMultimodalClient", + OptionalPackages.OPENAI, +) __all__ = [ "CohereAPIClient", @@ -76,6 +80,7 @@ "GroqAPIClient", "OpenAIClient", "GoogleGenAIClient", + "OpenAIMultimodalClient", ] for name in __all__: diff --git a/adalflow/adalflow/components/model_client/openai_multimodal_client.py b/adalflow/adalflow/components/model_client/openai_multimodal_client.py new file mode 100644 index 00000000..9cb8053c --- /dev/null +++ b/adalflow/adalflow/components/model_client/openai_multimodal_client.py @@ -0,0 +1,112 @@ +"""OpenAI multimodal client for handling image and text inputs.""" + +import base64 +from typing import Any, Dict, List, Optional, Union +from adalflow.utils.lazy_import import safe_import, OptionalPackages + +openai = safe_import(OptionalPackages.OPENAI.value[0], OptionalPackages.OPENAI.value[1]) +from openai import OpenAI + +from adalflow.core.model_client import ModelClient +from adalflow.core.types import GeneratorOutput + + +class OpenAIMultimodalClient(ModelClient): + """OpenAI client for multimodal models.""" + + def __init__(self, api_key: Optional[str] = None): + """Initialize the OpenAI multimodal client. + + Args: + api_key: OpenAI API key. If None, will try to get from environment variable. + """ + super().__init__() + self.client = OpenAI(api_key=api_key) + + def _encode_image(self, image_path: str) -> str: + """Encode image to base64 string. + + Args: + image_path: Path to image file. + + Returns: + Base64 encoded image string. + """ + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + def _prepare_image_content( + self, image_source: Union[str, Dict[str, Any]], detail: str = "auto" + ) -> Dict[str, Any]: + """Prepare image content for API request. + + Args: + image_source: Either a path to local image or a URL. + detail: Image detail level ('auto', 'low', or 'high'). + + Returns: + Formatted image content for API request. + """ + if isinstance(image_source, str): + if image_source.startswith(("http://", "https://")): + return { + "type": "image_url", + "image_url": {"url": image_source, "detail": detail}, + } + else: + base64_image = self._encode_image(image_source) + return { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}", + "detail": detail, + }, + } + return image_source + + def generate( + self, + prompt: str, + images: Optional[ + Union[str, List[str], Dict[str, Any], List[Dict[str, Any]]] + ] = None, + model_kwargs: Optional[Dict[str, Any]] = None, + ) -> GeneratorOutput: + """Generate text response for given prompt and images. + + Args: + prompt: Text prompt. + images: Image source(s) - can be path(s), URL(s), or formatted dict(s). + model_kwargs: Additional model parameters. + + Returns: + GeneratorOutput containing the model's response. + """ + model_kwargs = model_kwargs or {} + model = model_kwargs.get("model", "gpt-4o-mini") + max_tokens = model_kwargs.get("max_tokens", 300) + detail = model_kwargs.get("detail", "auto") + + # Prepare message content + content = [{"type": "text", "text": prompt}] + + if images: + if not isinstance(images, list): + images = [images] + for img in images: + content.append(self._prepare_image_content(img, detail)) + + try: + response = self.client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": content}], + max_tokens=max_tokens, + ) + return GeneratorOutput( + id=response.id, + data=response.choices[0].message.content, + usage=response.usage.model_dump() if response.usage else None, + raw_response=response.model_dump(), + ) + except Exception as e: + return GeneratorOutput(error=str(e)) diff --git a/adalflow/adalflow/utils/lazy_import.py b/adalflow/adalflow/utils/lazy_import.py index 16ad8d1f..ceded5eb 100644 --- a/adalflow/adalflow/utils/lazy_import.py +++ b/adalflow/adalflow/utils/lazy_import.py @@ -215,3 +215,13 @@ def safe_import( raise ImportError(f"{install_message}") return return_modules[0] if len(return_modules) == 1 else return_modules + + +OPTIONAL_PACKAGES = { + "openai": "openai", # For OpenAI API clients + "transformers": "transformers", # For local models + "torch": "torch", # For PyTorch models + "anthropic": "anthropic", # For Claude models + "groq": "groq", # For Groq models + "cohere": "cohere", # For Cohere models +} diff --git a/docs/source/tutorials/multimodal.rst b/docs/source/tutorials/multimodal.rst new file mode 100644 index 00000000..2cbb08ab --- /dev/null +++ b/docs/source/tutorials/multimodal.rst @@ -0,0 +1,143 @@ +.. _tutorials-multimodal: + +Multimodal Generation +=================== + +.. raw:: html + +
+ + Open In Colab + + + GitHub + View Source + +
+ +What you will learn? +------------------ + +1. How to use the OpenAI multimodal client for image understanding +2. Different ways to input images (local files, URLs) +3. Controlling image detail levels +4. Working with multiple images + +The OpenAIMultimodalClient +------------------------ + +The :class:`OpenAIMultimodalClient` extends AdalFlow's model client capabilities to handle images along with text. It supports: + +- Local image files (automatically encoded to base64) +- Image URLs +- Multiple images in a single request +- Control over image detail level + +Basic Usage +---------- + +First, install AdalFlow with OpenAI support: + +.. code-block:: bash + + pip install "adalflow[openai]" + +Then you can use the client with the Generator: + +.. code-block:: python + + from adalflow import Generator, OpenAIMultimodalClient + + generator = Generator( + model_client=OpenAIMultimodalClient(), + model_kwargs={ + "model": "gpt-4o-mini", + "max_tokens": 300 + } + ) + + # Using an image URL + response = generator( + prompt="Describe this image.", + images="https://example.com/image.jpg" + ) + +Image Detail Levels +----------------- + +The client supports three detail levels: + +- ``auto``: Let the model decide based on image size (default) +- ``low``: Low-resolution mode (512px x 512px) +- ``high``: High-resolution mode with detailed crops + +.. code-block:: python + + generator = Generator( + model_client=OpenAIMultimodalClient(), + model_kwargs={ + "model": "gpt-4o-mini", + "detail": "high" # or "low" or "auto" + } + ) + +Multiple Images +------------- + +You can analyze multiple images in one request: + +.. code-block:: python + + images = [ + "path/to/local/image.jpg", + "https://example.com/image.jpg" + ] + + response = generator( + prompt="Compare these images.", + images=images + ) + +Implementation Details +------------------- + +The client handles: + +1. Image Processing: + - Automatic base64 encoding for local files + - URL validation and formatting + - Detail level configuration + +2. API Integration: + - Proper message formatting for OpenAI's vision models + - Error handling and response parsing + - Usage tracking + +3. Output Format: + - Returns standard :class:`GeneratorOutput` format + - Includes model usage information + - Preserves error messages if any occur + +Limitations +--------- + +Be aware of these limitations when using the multimodal client: + +1. Image Size: + - Maximum file size: 20MB per image + - Supported formats: PNG, JPEG, WEBP, non-animated GIF + +2. Model Capabilities: + - Best for general visual understanding + - May struggle with: + - Small text + - Precise spatial relationships + - Complex graphs + - Non-Latin text + +3. Cost Considerations: + - Image inputs are metered in tokens + - High detail mode uses more tokens + - Consider using low detail mode for cost efficiency + +For more details, see the :class:`OpenAIMultimodalClient` API reference. diff --git a/notebooks/tutorials/adalflow_multimodal.ipynb b/notebooks/tutorials/adalflow_multimodal.ipynb new file mode 100644 index 00000000..e69de29b From c6c4663aec7374890eaa04eda7ac9a690195fde0 Mon Sep 17 00:00:00 2001 From: fm1320 Date: Mon, 6 Jan 2025 22:46:11 +0000 Subject: [PATCH 2/8] Change multimodal to one client --- .../components/model_client/__init__.py | 5 - .../components/model_client/openai_client.py | 105 ++++++++++++++++ .../model_client/openai_multimodal_client.py | 112 ------------------ docs/source/tutorials/multimodal.rst | 70 +++++++---- notebooks/tutorials/adalflow_multimodal.ipynb | 0 5 files changed, 154 insertions(+), 138 deletions(-) delete mode 100644 adalflow/adalflow/components/model_client/openai_multimodal_client.py delete mode 100644 notebooks/tutorials/adalflow_multimodal.ipynb diff --git a/adalflow/adalflow/components/model_client/__init__.py b/adalflow/adalflow/components/model_client/__init__.py index 258285ec..ae508ece 100644 --- a/adalflow/adalflow/components/model_client/__init__.py +++ b/adalflow/adalflow/components/model_client/__init__.py @@ -64,10 +64,6 @@ "adalflow.components.model_client.openai_client.get_probabilities", OptionalPackages.OPENAI, ) -OpenAIMultimodalClient = LazyImport( - "adalflow.components.model_client.openai_multimodal_client.OpenAIMultimodalClient", - OptionalPackages.OPENAI, -) __all__ = [ "CohereAPIClient", @@ -80,7 +76,6 @@ "GroqAPIClient", "OpenAIClient", "GoogleGenAIClient", - "OpenAIMultimodalClient", ] for name in __all__: diff --git a/adalflow/adalflow/components/model_client/openai_client.py b/adalflow/adalflow/components/model_client/openai_client.py index 809fd3e0..7b08f887 100644 --- a/adalflow/adalflow/components/model_client/openai_client.py +++ b/adalflow/adalflow/components/model_client/openai_client.py @@ -1,6 +1,7 @@ """OpenAI ModelClient integration.""" import os +import base64 from typing import ( Dict, Sequence, @@ -51,6 +52,14 @@ log = logging.getLogger(__name__) T = TypeVar("T") +# Models that support multimodal inputs +MULTIMODAL_MODELS = { + "gpt-4o", # Versatile, high-intelligence flagship model + "gpt-4o-mini", # Fast, affordable small model for focused tasks + "o1", # Reasoning model that excels at complex, multi-step tasks + "o1-mini", # Smaller reasoning model for complex tasks +} + # completion parsing functions and you can combine them into one singple chat completion parser def get_first_message_content(completion: ChatCompletion) -> str: @@ -332,6 +341,102 @@ def to_dict(self) -> Dict[str, Any]: output = super().to_dict(exclude=exclude) return output + def _encode_image(self, image_path: str) -> str: + """Encode image to base64 string. + + Args: + image_path: Path to image file. + + Returns: + Base64 encoded image string. + """ + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + def _prepare_image_content( + self, image_source: Union[str, Dict[str, Any]], detail: str = "auto" + ) -> Dict[str, Any]: + """Prepare image content for API request. + + Args: + image_source: Either a path to local image or a URL. + detail: Image detail level ('auto', 'low', or 'high'). + + Returns: + Formatted image content for API request. + """ + if isinstance(image_source, str): + if image_source.startswith(("http://", "https://")): + return { + "type": "image_url", + "image_url": {"url": image_source, "detail": detail}, + } + else: + base64_image = self._encode_image(image_source) + return { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}", + "detail": detail, + }, + } + return image_source + + def generate( + self, + prompt: str, + images: Optional[ + Union[str, List[str], Dict[str, Any], List[Dict[str, Any]]] + ] = None, + model_kwargs: Optional[Dict[str, Any]] = None, + ) -> GeneratorOutput: + """Generate text response for given prompt and optionally images. + + Args: + prompt: Text prompt. + images: Optional image source(s) - can be path(s), URL(s), or formatted dict(s). + model_kwargs: Additional model parameters. + + Returns: + GeneratorOutput containing the model's response. + """ + model_kwargs = model_kwargs or {} + model = model_kwargs.get("model", "gpt-4o-mini") + max_tokens = model_kwargs.get("max_tokens", 300) + detail = model_kwargs.get("detail", "auto") + + # Check if model supports multimodal inputs when images are provided + if images and model not in MULTIMODAL_MODELS: + return GeneratorOutput( + error=f"Model {model} does not support multimodal inputs. Supported models: {MULTIMODAL_MODELS}" + ) + + # Prepare message content + if images: + content = [{"type": "text", "text": prompt}] + if not isinstance(images, list): + images = [images] + for img in images: + content.append(self._prepare_image_content(img, detail)) + messages = [{"role": "user", "content": content}] + else: + messages = [{"role": "user", "content": prompt}] + + try: + response = self.client.chat.completions.create( + model=model, + messages=messages, + max_tokens=max_tokens, + ) + return GeneratorOutput( + id=response.id, + data=response.choices[0].message.content, + usage=response.usage.model_dump() if response.usage else None, + raw_response=response.model_dump(), + ) + except Exception as e: + return GeneratorOutput(error=str(e)) + # if __name__ == "__main__": # from adalflow.core import Generator diff --git a/adalflow/adalflow/components/model_client/openai_multimodal_client.py b/adalflow/adalflow/components/model_client/openai_multimodal_client.py deleted file mode 100644 index 9cb8053c..00000000 --- a/adalflow/adalflow/components/model_client/openai_multimodal_client.py +++ /dev/null @@ -1,112 +0,0 @@ -"""OpenAI multimodal client for handling image and text inputs.""" - -import base64 -from typing import Any, Dict, List, Optional, Union -from adalflow.utils.lazy_import import safe_import, OptionalPackages - -openai = safe_import(OptionalPackages.OPENAI.value[0], OptionalPackages.OPENAI.value[1]) -from openai import OpenAI - -from adalflow.core.model_client import ModelClient -from adalflow.core.types import GeneratorOutput - - -class OpenAIMultimodalClient(ModelClient): - """OpenAI client for multimodal models.""" - - def __init__(self, api_key: Optional[str] = None): - """Initialize the OpenAI multimodal client. - - Args: - api_key: OpenAI API key. If None, will try to get from environment variable. - """ - super().__init__() - self.client = OpenAI(api_key=api_key) - - def _encode_image(self, image_path: str) -> str: - """Encode image to base64 string. - - Args: - image_path: Path to image file. - - Returns: - Base64 encoded image string. - """ - with open(image_path, "rb") as image_file: - return base64.b64encode(image_file.read()).decode("utf-8") - - def _prepare_image_content( - self, image_source: Union[str, Dict[str, Any]], detail: str = "auto" - ) -> Dict[str, Any]: - """Prepare image content for API request. - - Args: - image_source: Either a path to local image or a URL. - detail: Image detail level ('auto', 'low', or 'high'). - - Returns: - Formatted image content for API request. - """ - if isinstance(image_source, str): - if image_source.startswith(("http://", "https://")): - return { - "type": "image_url", - "image_url": {"url": image_source, "detail": detail}, - } - else: - base64_image = self._encode_image(image_source) - return { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{base64_image}", - "detail": detail, - }, - } - return image_source - - def generate( - self, - prompt: str, - images: Optional[ - Union[str, List[str], Dict[str, Any], List[Dict[str, Any]]] - ] = None, - model_kwargs: Optional[Dict[str, Any]] = None, - ) -> GeneratorOutput: - """Generate text response for given prompt and images. - - Args: - prompt: Text prompt. - images: Image source(s) - can be path(s), URL(s), or formatted dict(s). - model_kwargs: Additional model parameters. - - Returns: - GeneratorOutput containing the model's response. - """ - model_kwargs = model_kwargs or {} - model = model_kwargs.get("model", "gpt-4o-mini") - max_tokens = model_kwargs.get("max_tokens", 300) - detail = model_kwargs.get("detail", "auto") - - # Prepare message content - content = [{"type": "text", "text": prompt}] - - if images: - if not isinstance(images, list): - images = [images] - for img in images: - content.append(self._prepare_image_content(img, detail)) - - try: - response = self.client.chat.completions.create( - model=model, - messages=[{"role": "user", "content": content}], - max_tokens=max_tokens, - ) - return GeneratorOutput( - id=response.id, - data=response.choices[0].message.content, - usage=response.usage.model_dump() if response.usage else None, - raw_response=response.model_dump(), - ) - except Exception as e: - return GeneratorOutput(error=str(e)) diff --git a/docs/source/tutorials/multimodal.rst b/docs/source/tutorials/multimodal.rst index 2cbb08ab..6c32c60d 100644 --- a/docs/source/tutorials/multimodal.rst +++ b/docs/source/tutorials/multimodal.rst @@ -18,15 +18,22 @@ Multimodal Generation What you will learn? ------------------ -1. How to use the OpenAI multimodal client for image understanding +1. How to use OpenAI's multimodal capabilities in AdalFlow 2. Different ways to input images (local files, URLs) 3. Controlling image detail levels 4. Working with multiple images -The OpenAIMultimodalClient ------------------------- +Multimodal Support in OpenAIClient +-------------------------------- -The :class:`OpenAIMultimodalClient` extends AdalFlow's model client capabilities to handle images along with text. It supports: +The :class:`OpenAIClient` supports both text and image inputs. For multimodal generation, you can use the following models: + +- ``gpt-4o``: Versatile, high-intelligence flagship model +- ``gpt-4o-mini``: Fast, affordable small model for focused tasks (default) +- ``o1``: Reasoning model that excels at complex, multi-step tasks +- ``o1-mini``: Smaller reasoning model for complex tasks + +The client supports: - Local image files (automatically encoded to base64) - Image URLs @@ -42,16 +49,17 @@ First, install AdalFlow with OpenAI support: pip install "adalflow[openai]" -Then you can use the client with the Generator: +Then you can use the client with the Generator. By default, it uses ``gpt-4o-mini``, but you can specify any supported model: .. code-block:: python - from adalflow import Generator, OpenAIMultimodalClient + from adalflow import Generator, OpenAIClient + # Using the default gpt-4o-mini model generator = Generator( - model_client=OpenAIMultimodalClient(), + model_client=OpenAIClient(), model_kwargs={ - "model": "gpt-4o-mini", + "model": "gpt-4o-mini", # or "gpt-4o", "o1", "o1-mini" "max_tokens": 300 } ) @@ -62,6 +70,15 @@ Then you can use the client with the Generator: images="https://example.com/image.jpg" ) + # Using the flagship model for more complex tasks + generator_flagship = Generator( + model_client=OpenAIClient(), + model_kwargs={ + "model": "gpt-4o", + "max_tokens": 300 + } + ) + Image Detail Levels ----------------- @@ -74,7 +91,7 @@ The client supports three detail levels: .. code-block:: python generator = Generator( - model_client=OpenAIMultimodalClient(), + model_client=OpenAIClient(), model_kwargs={ "model": "gpt-4o-mini", "detail": "high" # or "low" or "auto" @@ -111,6 +128,7 @@ The client handles: 2. API Integration: - Proper message formatting for OpenAI's vision models - Error handling and response parsing + - Model compatibility checking - Usage tracking 3. Output Format: @@ -121,23 +139,33 @@ The client handles: Limitations --------- -Be aware of these limitations when using the multimodal client: +Be aware of these limitations when using multimodal features: + +1. Model Support and Capabilities: + - Four models available with different strengths: + - ``gpt-4o``: Best for complex visual analysis and detailed understanding + - ``gpt-4o-mini``: Good balance of speed and accuracy for common tasks + - ``o1``: Excels at multi-step reasoning with visual inputs + - ``o1-mini``: Efficient for focused visual reasoning tasks + - The client will return an error if using an unsupported model with images -1. Image Size: +2. Image Size and Format: - Maximum file size: 20MB per image - Supported formats: PNG, JPEG, WEBP, non-animated GIF -2. Model Capabilities: - - Best for general visual understanding +3. Common Limitations: - May struggle with: - - Small text - - Precise spatial relationships - - Complex graphs - - Non-Latin text + - Very small or blurry text + - Complex spatial relationships + - Detailed technical diagrams + - Non-Latin text or symbols -3. Cost Considerations: - - Image inputs are metered in tokens +4. Cost and Performance Considerations: + - Image inputs increase token usage - High detail mode uses more tokens - - Consider using low detail mode for cost efficiency + - Consider using: + - ``gpt-4o-mini`` for routine tasks + - ``o1-mini`` for basic reasoning tasks + - ``gpt-4o`` or ``o1`` for complex analysis -For more details, see the :class:`OpenAIMultimodalClient` API reference. +For more details, see the :class:`OpenAIClient` API reference. diff --git a/notebooks/tutorials/adalflow_multimodal.ipynb b/notebooks/tutorials/adalflow_multimodal.ipynb deleted file mode 100644 index e69de29b..00000000 From b0a473bb2859da495e379c17401faa606a21a6aa Mon Sep 17 00:00:00 2001 From: fm1320 Date: Mon, 6 Jan 2025 23:09:20 +0000 Subject: [PATCH 3/8] remove separate file refs --- adalflow/adalflow/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/adalflow/adalflow/__init__.py b/adalflow/adalflow/__init__.py index 2ee67015..4c9b45ba 100644 --- a/adalflow/adalflow/__init__.py +++ b/adalflow/adalflow/__init__.py @@ -61,7 +61,6 @@ AnthropicAPIClient, CohereAPIClient, BedrockAPIClient, - OpenAIMultimodalClient, ) # data pipeline @@ -130,5 +129,4 @@ "AnthropicAPIClient", "CohereAPIClient", "BedrockAPIClient", - "OpenAIMultimodalClient", ] From 00ea1d5e8582ded0043687f3cedffad998b0efe8 Mon Sep 17 00:00:00 2001 From: fm1320 Date: Wed, 8 Jan 2025 04:05:11 +0000 Subject: [PATCH 4/8] Single function openaiclient and test --- .../components/model_client/openai_client.py | 132 ++++++-------- adalflow/adalflow/utils/lazy_import.py | 10 - adalflow/tests/test_openai_client.py | 101 +++++++++++ docs/source/tutorials/model_client.rst | 40 ++++ docs/source/tutorials/multimodal.rst | 171 ------------------ 5 files changed, 198 insertions(+), 256 deletions(-) delete mode 100644 docs/source/tutorials/multimodal.rst diff --git a/adalflow/adalflow/components/model_client/openai_client.py b/adalflow/adalflow/components/model_client/openai_client.py index 7b08f887..aca8a566 100644 --- a/adalflow/adalflow/components/model_client/openai_client.py +++ b/adalflow/adalflow/components/model_client/openai_client.py @@ -52,14 +52,6 @@ log = logging.getLogger(__name__) T = TypeVar("T") -# Models that support multimodal inputs -MULTIMODAL_MODELS = { - "gpt-4o", # Versatile, high-intelligence flagship model - "gpt-4o-mini", # Fast, affordable small model for focused tasks - "o1", # Reasoning model that excels at complex, multi-step tasks - "o1-mini", # Smaller reasoning model for complex tasks -} - # completion parsing functions and you can combine them into one singple chat completion parser def get_first_message_content(completion: ChatCompletion) -> str: @@ -108,7 +100,7 @@ def get_probabilities(completion: ChatCompletion) -> List[List[TokenLogProb]]: class OpenAIClient(ModelClient): __doc__ = r"""A component wrapper for the OpenAI API client. - Support both embedding and chat completion API. + Support both embedding and chat completion API, including multimodal capabilities. Users (1) simplify use ``Embedder`` and ``Generator`` components by passing OpenAIClient() as the model_client. (2) can use this as an example to create their own API client or extend this class(copying and modifing the code) in their own project. @@ -119,6 +111,9 @@ class OpenAIClient(ModelClient): Instead - use :ref:`OutputParser` for response parsing and formating. + For multimodal inputs, provide images in model_kwargs["images"] as a path, URL, or list of them. + The model must support vision capabilities (e.g., gpt-4o, gpt-4o-mini, o1, o1-mini). + Args: api_key (Optional[str], optional): OpenAI API key. Defaults to None. chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None. @@ -127,6 +122,7 @@ class OpenAIClient(ModelClient): References: - Embeddings models: https://platform.openai.com/docs/guides/embeddings - Chat models: https://platform.openai.com/docs/guides/text-generation + - Vision models: https://platform.openai.com/docs/guides/vision - OpenAI docs: https://platform.openai.com/docs/introduction """ @@ -209,7 +205,7 @@ def track_completion_usage( def parse_embedding_response( self, response: CreateEmbeddingResponse ) -> EmbedderOutput: - r"""Parse the embedding response to a structure LightRAG components can understand. + r"""Parse the embedding response to a structure Adalflow components can understand. Should be called in ``Embedder``. """ @@ -227,7 +223,20 @@ def convert_inputs_to_api_kwargs( ) -> Dict: r""" Specify the API input type and output api_kwargs that will be used in _call and _acall methods. - Convert the Component's standard input, and system_input(chat model) and model_kwargs into API-specific format + Convert the Component's standard input, and system_input(chat model) and model_kwargs into API-specific format. + For multimodal inputs, images can be provided in model_kwargs["images"] as a string path, URL, or list of them. + The model specified in model_kwargs["model"] must support multimodal capabilities when using images. + + Args: + input: The input text or messages to process + model_kwargs: Additional parameters including: + - images: Optional image source(s) as path, URL, or list of them + - detail: Image detail level ('auto', 'low', or 'high'), defaults to 'auto' + - model: The model to use (must support multimodal inputs if images are provided) + model_type: The type of model (EMBEDDER or LLM) + + Returns: + Dict: API-specific kwargs for the model call """ final_model_kwargs = model_kwargs.copy() @@ -241,6 +250,8 @@ def convert_inputs_to_api_kwargs( elif model_type == ModelType.LLM: # convert input to messages messages: List[Dict[str, str]] = [] + images = final_model_kwargs.pop("images", None) + detail = final_model_kwargs.pop("detail", "auto") if self._input_type == "messages": system_start_tag = "" @@ -257,14 +268,29 @@ def convert_inputs_to_api_kwargs( if match: system_prompt = match.group(1) input_str = match.group(2) - else: print("No match found.") if system_prompt and input_str: messages.append({"role": "system", "content": system_prompt}) - messages.append({"role": "user", "content": input_str}) + if images: + content = [{"type": "text", "text": input_str}] + if isinstance(images, (str, dict)): + images = [images] + for img in images: + content.append(self._prepare_image_content(img, detail)) + messages.append({"role": "user", "content": content}) + else: + messages.append({"role": "user", "content": input_str}) if len(messages) == 0: - messages.append({"role": "system", "content": input}) + if images: + content = [{"type": "text", "text": input}] + if isinstance(images, (str, dict)): + images = [images] + for img in images: + content.append(self._prepare_image_content(img, detail)) + messages.append({"role": "user", "content": content}) + else: + messages.append({"role": "system", "content": input}) final_model_kwargs["messages"] = messages else: raise ValueError(f"model_type {model_type} is not supported") @@ -349,9 +375,19 @@ def _encode_image(self, image_path: str) -> str: Returns: Base64 encoded image string. + + Raises: + ValueError: If the file cannot be read or doesn't exist. """ - with open(image_path, "rb") as image_file: - return base64.b64encode(image_file.read()).decode("utf-8") + try: + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + except FileNotFoundError: + raise ValueError(f"Image file not found: {image_path}") + except PermissionError: + raise ValueError(f"Permission denied when reading image file: {image_path}") + except Exception as e: + raise ValueError(f"Error encoding image {image_path}: {str(e)}") def _prepare_image_content( self, image_source: Union[str, Dict[str, Any]], detail: str = "auto" @@ -382,77 +418,23 @@ def _prepare_image_content( } return image_source - def generate( - self, - prompt: str, - images: Optional[ - Union[str, List[str], Dict[str, Any], List[Dict[str, Any]]] - ] = None, - model_kwargs: Optional[Dict[str, Any]] = None, - ) -> GeneratorOutput: - """Generate text response for given prompt and optionally images. - - Args: - prompt: Text prompt. - images: Optional image source(s) - can be path(s), URL(s), or formatted dict(s). - model_kwargs: Additional model parameters. - - Returns: - GeneratorOutput containing the model's response. - """ - model_kwargs = model_kwargs or {} - model = model_kwargs.get("model", "gpt-4o-mini") - max_tokens = model_kwargs.get("max_tokens", 300) - detail = model_kwargs.get("detail", "auto") - - # Check if model supports multimodal inputs when images are provided - if images and model not in MULTIMODAL_MODELS: - return GeneratorOutput( - error=f"Model {model} does not support multimodal inputs. Supported models: {MULTIMODAL_MODELS}" - ) - - # Prepare message content - if images: - content = [{"type": "text", "text": prompt}] - if not isinstance(images, list): - images = [images] - for img in images: - content.append(self._prepare_image_content(img, detail)) - messages = [{"role": "user", "content": content}] - else: - messages = [{"role": "user", "content": prompt}] - - try: - response = self.client.chat.completions.create( - model=model, - messages=messages, - max_tokens=max_tokens, - ) - return GeneratorOutput( - id=response.id, - data=response.choices[0].message.content, - usage=response.usage.model_dump() if response.usage else None, - raw_response=response.model_dump(), - ) - except Exception as e: - return GeneratorOutput(error=str(e)) - +# Example usage: # if __name__ == "__main__": # from adalflow.core import Generator # from adalflow.utils import setup_env, get_logger - +# # log = get_logger(level="DEBUG") - +# # setup_env() # prompt_kwargs = {"input_str": "What is the meaning of life?"} - +# # gen = Generator( # model_client=OpenAIClient(), # model_kwargs={"model": "gpt-3.5-turbo", "stream": True}, # ) # gen_response = gen(prompt_kwargs) # print(f"gen_response: {gen_response}") - +# # for genout in gen_response.data: # print(f"genout: {genout}") diff --git a/adalflow/adalflow/utils/lazy_import.py b/adalflow/adalflow/utils/lazy_import.py index ceded5eb..16ad8d1f 100644 --- a/adalflow/adalflow/utils/lazy_import.py +++ b/adalflow/adalflow/utils/lazy_import.py @@ -215,13 +215,3 @@ def safe_import( raise ImportError(f"{install_message}") return return_modules[0] if len(return_modules) == 1 else return_modules - - -OPTIONAL_PACKAGES = { - "openai": "openai", # For OpenAI API clients - "transformers": "transformers", # For local models - "torch": "torch", # For PyTorch models - "anthropic": "anthropic", # For Claude models - "groq": "groq", # For Groq models - "cohere": "cohere", # For Cohere models -} diff --git a/adalflow/tests/test_openai_client.py b/adalflow/tests/test_openai_client.py index 2bfe2fd9..82046c38 100644 --- a/adalflow/tests/test_openai_client.py +++ b/adalflow/tests/test_openai_client.py @@ -1,5 +1,7 @@ import unittest from unittest.mock import patch, AsyncMock, Mock +import os +import base64 from openai.types import CompletionUsage from openai.types.chat import ChatCompletion @@ -42,6 +44,105 @@ def setUp(self): "model": "gpt-3.5-turbo", } + def test_encode_image(self): + # Create a temporary test image file + test_image_path = "test_image.jpg" + test_content = b"fake image content" + try: + with open(test_image_path, "wb") as f: + f.write(test_content) + + # Test successful encoding + encoded = self.client._encode_image(test_image_path) + self.assertEqual(encoded, base64.b64encode(test_content).decode("utf-8")) + + # Test file not found + with self.assertRaises(ValueError) as context: + self.client._encode_image("nonexistent.jpg") + self.assertIn("Image file not found", str(context.exception)) + + finally: + # Cleanup + if os.path.exists(test_image_path): + os.remove(test_image_path) + + def test_prepare_image_content(self): + # Test URL image + url = "https://example.com/image.jpg" + result = self.client._prepare_image_content(url) + self.assertEqual( + result, + {"type": "image_url", "image_url": {"url": url, "detail": "auto"}}, + ) + + # Test with custom detail level + result = self.client._prepare_image_content(url, detail="high") + self.assertEqual( + result, + {"type": "image_url", "image_url": {"url": url, "detail": "high"}}, + ) + + # Test with pre-formatted content + pre_formatted = { + "type": "image_url", + "image_url": {"url": url, "detail": "low"}, + } + result = self.client._prepare_image_content(pre_formatted) + self.assertEqual(result, pre_formatted) + + def test_convert_inputs_to_api_kwargs_with_images(self): + # Test with single image URL + model_kwargs = { + "model": "gpt-4-vision-preview", + "images": "https://example.com/image.jpg", + } + result = self.client.convert_inputs_to_api_kwargs( + input="Describe this image", + model_kwargs=model_kwargs, + model_type=ModelType.LLM, + ) + expected_content = [ + {"type": "text", "text": "Describe this image"}, + { + "type": "image_url", + "image_url": {"url": "https://example.com/image.jpg", "detail": "auto"}, + }, + ] + self.assertEqual(result["messages"][0]["content"], expected_content) + + # Test with multiple images + model_kwargs = { + "model": "gpt-4-vision-preview", + "images": [ + "https://example.com/image1.jpg", + "https://example.com/image2.jpg", + ], + "detail": "high", + } + result = self.client.convert_inputs_to_api_kwargs( + input="Compare these images", + model_kwargs=model_kwargs, + model_type=ModelType.LLM, + ) + expected_content = [ + {"type": "text", "text": "Compare these images"}, + { + "type": "image_url", + "image_url": { + "url": "https://example.com/image1.jpg", + "detail": "high", + }, + }, + { + "type": "image_url", + "image_url": { + "url": "https://example.com/image2.jpg", + "detail": "high", + }, + }, + ] + self.assertEqual(result["messages"][0]["content"], expected_content) + @patch("adalflow.components.model_client.openai_client.AsyncOpenAI") async def test_acall_llm(self, MockAsyncOpenAI): mock_async_client = AsyncMock() diff --git a/docs/source/tutorials/model_client.rst b/docs/source/tutorials/model_client.rst index 438d34d3..47e83298 100644 --- a/docs/source/tutorials/model_client.rst +++ b/docs/source/tutorials/model_client.rst @@ -1513,6 +1513,46 @@ This is the function call that triggers the execution of the custom model client build_custom_model_client() + +OPENAI LLM Chat - Multimodal Example +------------------------------------------------- + +The OpenAI client also supports multimodal inputs. Here's a quick example: + +.. code-block:: python + + from adalflow import Generator, OpenAIClient + + generator = Generator( + model_client=OpenAIClient(), + model_kwargs={ + "model": "gpt-4o", + "max_tokens": 300 + } + ) + + # Single image + response = generator( + prompt_kwargs={ + "input_str": "What's in this image?", + "images": "path/to/image.jpg" # Local file or URL + } + ) + + # Multiple images + response = generator( + prompt_kwargs={ + "input_str": "Compare these images.", + "images": [ + "path/to/first.jpg", + "https://example.com/second.jpg" + ] + } + ) + +The client handles both local files and URLs, with support for PNG, JPEG, WEBP, and non-animated GIF formats. + + .. admonition:: API reference :class: highlight diff --git a/docs/source/tutorials/multimodal.rst b/docs/source/tutorials/multimodal.rst deleted file mode 100644 index 6c32c60d..00000000 --- a/docs/source/tutorials/multimodal.rst +++ /dev/null @@ -1,171 +0,0 @@ -.. _tutorials-multimodal: - -Multimodal Generation -=================== - -.. raw:: html - - - -What you will learn? ------------------- - -1. How to use OpenAI's multimodal capabilities in AdalFlow -2. Different ways to input images (local files, URLs) -3. Controlling image detail levels -4. Working with multiple images - -Multimodal Support in OpenAIClient --------------------------------- - -The :class:`OpenAIClient` supports both text and image inputs. For multimodal generation, you can use the following models: - -- ``gpt-4o``: Versatile, high-intelligence flagship model -- ``gpt-4o-mini``: Fast, affordable small model for focused tasks (default) -- ``o1``: Reasoning model that excels at complex, multi-step tasks -- ``o1-mini``: Smaller reasoning model for complex tasks - -The client supports: - -- Local image files (automatically encoded to base64) -- Image URLs -- Multiple images in a single request -- Control over image detail level - -Basic Usage ----------- - -First, install AdalFlow with OpenAI support: - -.. code-block:: bash - - pip install "adalflow[openai]" - -Then you can use the client with the Generator. By default, it uses ``gpt-4o-mini``, but you can specify any supported model: - -.. code-block:: python - - from adalflow import Generator, OpenAIClient - - # Using the default gpt-4o-mini model - generator = Generator( - model_client=OpenAIClient(), - model_kwargs={ - "model": "gpt-4o-mini", # or "gpt-4o", "o1", "o1-mini" - "max_tokens": 300 - } - ) - - # Using an image URL - response = generator( - prompt="Describe this image.", - images="https://example.com/image.jpg" - ) - - # Using the flagship model for more complex tasks - generator_flagship = Generator( - model_client=OpenAIClient(), - model_kwargs={ - "model": "gpt-4o", - "max_tokens": 300 - } - ) - -Image Detail Levels ------------------ - -The client supports three detail levels: - -- ``auto``: Let the model decide based on image size (default) -- ``low``: Low-resolution mode (512px x 512px) -- ``high``: High-resolution mode with detailed crops - -.. code-block:: python - - generator = Generator( - model_client=OpenAIClient(), - model_kwargs={ - "model": "gpt-4o-mini", - "detail": "high" # or "low" or "auto" - } - ) - -Multiple Images -------------- - -You can analyze multiple images in one request: - -.. code-block:: python - - images = [ - "path/to/local/image.jpg", - "https://example.com/image.jpg" - ] - - response = generator( - prompt="Compare these images.", - images=images - ) - -Implementation Details -------------------- - -The client handles: - -1. Image Processing: - - Automatic base64 encoding for local files - - URL validation and formatting - - Detail level configuration - -2. API Integration: - - Proper message formatting for OpenAI's vision models - - Error handling and response parsing - - Model compatibility checking - - Usage tracking - -3. Output Format: - - Returns standard :class:`GeneratorOutput` format - - Includes model usage information - - Preserves error messages if any occur - -Limitations ---------- - -Be aware of these limitations when using multimodal features: - -1. Model Support and Capabilities: - - Four models available with different strengths: - - ``gpt-4o``: Best for complex visual analysis and detailed understanding - - ``gpt-4o-mini``: Good balance of speed and accuracy for common tasks - - ``o1``: Excels at multi-step reasoning with visual inputs - - ``o1-mini``: Efficient for focused visual reasoning tasks - - The client will return an error if using an unsupported model with images - -2. Image Size and Format: - - Maximum file size: 20MB per image - - Supported formats: PNG, JPEG, WEBP, non-animated GIF - -3. Common Limitations: - - May struggle with: - - Very small or blurry text - - Complex spatial relationships - - Detailed technical diagrams - - Non-Latin text or symbols - -4. Cost and Performance Considerations: - - Image inputs increase token usage - - High detail mode uses more tokens - - Consider using: - - ``gpt-4o-mini`` for routine tasks - - ``o1-mini`` for basic reasoning tasks - - ``gpt-4o`` or ``o1`` for complex analysis - -For more details, see the :class:`OpenAIClient` API reference. From 578a1654d7f52a28e4b1084d5769ebd3a45a3fc7 Mon Sep 17 00:00:00 2001 From: fm1320 Date: Wed, 8 Jan 2025 16:31:06 +0000 Subject: [PATCH 5/8] add more tests with mock --- adalflow/tests/test_openai_client.py | 93 +++++++++++++++++++++++++++- 1 file changed, 91 insertions(+), 2 deletions(-) diff --git a/adalflow/tests/test_openai_client.py b/adalflow/tests/test_openai_client.py index 82046c38..c0a41f34 100644 --- a/adalflow/tests/test_openai_client.py +++ b/adalflow/tests/test_openai_client.py @@ -39,10 +39,45 @@ def setUp(self): ), } self.mock_response = ChatCompletion(**self.mock_response) + self.mock_vision_response = { + "id": "cmpl-4Q8Z5J9Z1Z5z5", + "created": 1635820005, + "object": "chat.completion", + "model": "gpt-4o", + "choices": [ + { + "message": { + "content": "The image shows a beautiful sunset over mountains.", + "role": "assistant", + }, + "index": 0, + "finish_reason": "stop", + } + ], + "usage": CompletionUsage( + completion_tokens=15, prompt_tokens=25, total_tokens=40 + ), + } + self.mock_vision_response = ChatCompletion(**self.mock_vision_response) self.api_kwargs = { "messages": [{"role": "user", "content": "Hello"}], "model": "gpt-3.5-turbo", } + self.vision_api_kwargs = { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image"}, + { + "type": "image_url", + "image_url": {"url": "https://example.com/image.jpg", "detail": "auto"}, + }, + ], + } + ], + "model": "gpt-4o", + } def test_encode_image(self): # Create a temporary test image file @@ -93,7 +128,7 @@ def test_prepare_image_content(self): def test_convert_inputs_to_api_kwargs_with_images(self): # Test with single image URL model_kwargs = { - "model": "gpt-4-vision-preview", + "model": "gpt-4o", "images": "https://example.com/image.jpg", } result = self.client.convert_inputs_to_api_kwargs( @@ -112,7 +147,7 @@ def test_convert_inputs_to_api_kwargs_with_images(self): # Test with multiple images model_kwargs = { - "model": "gpt-4-vision-preview", + "model": "gpt-4o", "images": [ "https://example.com/image1.jpg", "https://example.com/image2.jpg", @@ -199,6 +234,60 @@ def test_call(self, MockSyncOpenAI, mock_init_sync_client): self.assertEqual(output.usage.prompt_tokens, 20) self.assertEqual(output.usage.total_tokens, 30) + @patch("adalflow.components.model_client.openai_client.AsyncOpenAI") + async def test_acall_llm_with_vision(self, MockAsyncOpenAI): + mock_async_client = AsyncMock() + MockAsyncOpenAI.return_value = mock_async_client + + # Mock the vision model response + mock_async_client.chat.completions.create = AsyncMock( + return_value=self.mock_vision_response + ) + + # Call the _acall method with vision model + result = await self.client.acall( + api_kwargs=self.vision_api_kwargs, model_type=ModelType.LLM + ) + + # Assertions + MockAsyncOpenAI.assert_called_once() + mock_async_client.chat.completions.create.assert_awaited_once_with( + **self.vision_api_kwargs + ) + self.assertEqual(result, self.mock_vision_response) + + @patch( + "adalflow.components.model_client.openai_client.OpenAIClient.init_sync_client" + ) + @patch("adalflow.components.model_client.openai_client.OpenAI") + def test_call_with_vision(self, MockSyncOpenAI, mock_init_sync_client): + mock_sync_client = Mock() + MockSyncOpenAI.return_value = mock_sync_client + mock_init_sync_client.return_value = mock_sync_client + + # Mock the vision model response + mock_sync_client.chat.completions.create = Mock(return_value=self.mock_vision_response) + + # Set the sync client + self.client.sync_client = mock_sync_client + + # Call the call method with vision model + result = self.client.call(api_kwargs=self.vision_api_kwargs, model_type=ModelType.LLM) + + # Assertions + mock_sync_client.chat.completions.create.assert_called_once_with( + **self.vision_api_kwargs + ) + self.assertEqual(result, self.mock_vision_response) + + # Test parse_chat_completion for vision model + output = self.client.parse_chat_completion(completion=self.mock_vision_response) + self.assertTrue(isinstance(output, GeneratorOutput)) + self.assertEqual(output.raw_response, "The image shows a beautiful sunset over mountains.") + self.assertEqual(output.usage.completion_tokens, 15) + self.assertEqual(output.usage.prompt_tokens, 25) + self.assertEqual(output.usage.total_tokens, 40) + if __name__ == "__main__": unittest.main() From 852c212e51d612198349b88f1b9bc0110d99f26c Mon Sep 17 00:00:00 2001 From: fm1320 Date: Wed, 8 Jan 2025 16:31:28 +0000 Subject: [PATCH 6/8] add more tests with mock --- adalflow/tests/test_openai_client.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/adalflow/tests/test_openai_client.py b/adalflow/tests/test_openai_client.py index c0a41f34..258115ab 100644 --- a/adalflow/tests/test_openai_client.py +++ b/adalflow/tests/test_openai_client.py @@ -71,7 +71,10 @@ def setUp(self): {"type": "text", "text": "Describe this image"}, { "type": "image_url", - "image_url": {"url": "https://example.com/image.jpg", "detail": "auto"}, + "image_url": { + "url": "https://example.com/image.jpg", + "detail": "auto", + }, }, ], } @@ -266,13 +269,17 @@ def test_call_with_vision(self, MockSyncOpenAI, mock_init_sync_client): mock_init_sync_client.return_value = mock_sync_client # Mock the vision model response - mock_sync_client.chat.completions.create = Mock(return_value=self.mock_vision_response) + mock_sync_client.chat.completions.create = Mock( + return_value=self.mock_vision_response + ) # Set the sync client self.client.sync_client = mock_sync_client # Call the call method with vision model - result = self.client.call(api_kwargs=self.vision_api_kwargs, model_type=ModelType.LLM) + result = self.client.call( + api_kwargs=self.vision_api_kwargs, model_type=ModelType.LLM + ) # Assertions mock_sync_client.chat.completions.create.assert_called_once_with( @@ -283,7 +290,9 @@ def test_call_with_vision(self, MockSyncOpenAI, mock_init_sync_client): # Test parse_chat_completion for vision model output = self.client.parse_chat_completion(completion=self.mock_vision_response) self.assertTrue(isinstance(output, GeneratorOutput)) - self.assertEqual(output.raw_response, "The image shows a beautiful sunset over mountains.") + self.assertEqual( + output.raw_response, "The image shows a beautiful sunset over mountains." + ) self.assertEqual(output.usage.completion_tokens, 15) self.assertEqual(output.usage.prompt_tokens, 25) self.assertEqual(output.usage.total_tokens, 40) From ff1060af9dcef181ce7d939247fc1e892a58f39c Mon Sep 17 00:00:00 2001 From: fm1320 Date: Thu, 9 Jan 2025 11:28:16 +0000 Subject: [PATCH 7/8] add image gen --- .../components/model_client/openai_client.py | 80 +++++++++++ adalflow/adalflow/core/types.py | 1 + adalflow/tests/test_openai_client.py | 126 +++++++++++++++++- 3 files changed, 204 insertions(+), 3 deletions(-) diff --git a/adalflow/adalflow/components/model_client/openai_client.py b/adalflow/adalflow/components/model_client/openai_client.py index aca8a566..a81f8287 100644 --- a/adalflow/adalflow/components/model_client/openai_client.py +++ b/adalflow/adalflow/components/model_client/openai_client.py @@ -36,6 +36,7 @@ from openai.types import ( Completion, CreateEmbeddingResponse, + Image, ) from openai.types.chat import ChatCompletionChunk, ChatCompletion @@ -114,6 +115,14 @@ class OpenAIClient(ModelClient): For multimodal inputs, provide images in model_kwargs["images"] as a path, URL, or list of them. The model must support vision capabilities (e.g., gpt-4o, gpt-4o-mini, o1, o1-mini). + For image generation, use model_type=ModelType.IMAGE_GENERATION and provide: + - model: "dall-e-3" or "dall-e-2" + - prompt: Text description of the image to generate + - size: "1024x1024", "1024x1792", or "1792x1024" for DALL-E 3; "256x256", "512x512", or "1024x1024" for DALL-E 2 + - quality: "standard" or "hd" (DALL-E 3 only) + - n: Number of images to generate (1 for DALL-E 3, 1-10 for DALL-E 2) + - response_format: "url" or "b64_json" + Args: api_key (Optional[str], optional): OpenAI API key. Defaults to None. chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None. @@ -123,6 +132,7 @@ class OpenAIClient(ModelClient): - Embeddings models: https://platform.openai.com/docs/guides/embeddings - Chat models: https://platform.openai.com/docs/guides/text-generation - Vision models: https://platform.openai.com/docs/guides/vision + - Image models: https://platform.openai.com/docs/guides/images - OpenAI docs: https://platform.openai.com/docs/introduction """ @@ -292,10 +302,54 @@ def convert_inputs_to_api_kwargs( else: messages.append({"role": "system", "content": input}) final_model_kwargs["messages"] = messages + elif model_type == ModelType.IMAGE_GENERATION: + # For image generation, input is the prompt + final_model_kwargs["prompt"] = input + # Set defaults for DALL-E 3 if not specified + if "model" not in final_model_kwargs: + final_model_kwargs["model"] = "dall-e-3" + if "size" not in final_model_kwargs: + final_model_kwargs["size"] = "1024x1024" + if "quality" not in final_model_kwargs: + final_model_kwargs["quality"] = "standard" + if "n" not in final_model_kwargs: + final_model_kwargs["n"] = 1 + if "response_format" not in final_model_kwargs: + final_model_kwargs["response_format"] = "url" + + # Handle image edits and variations + if "image" in final_model_kwargs: + if isinstance(final_model_kwargs["image"], str): + # If it's a file path, encode it + if os.path.isfile(final_model_kwargs["image"]): + final_model_kwargs["image"] = self._encode_image(final_model_kwargs["image"]) + if "mask" in final_model_kwargs and isinstance(final_model_kwargs["mask"], str): + if os.path.isfile(final_model_kwargs["mask"]): + final_model_kwargs["mask"] = self._encode_image(final_model_kwargs["mask"]) else: raise ValueError(f"model_type {model_type} is not supported") return final_model_kwargs + def parse_image_generation_response(self, response: List[Image]) -> GeneratorOutput: + """Parse the image generation response into a GeneratorOutput.""" + try: + # Extract URLs or base64 data from the response + data = [img.url or img.b64_json for img in response] + # For single image responses, unwrap from list + if len(data) == 1: + data = data[0] + return GeneratorOutput( + data=data, + raw_response=str(response), + ) + except Exception as e: + log.error(f"Error parsing image generation response: {e}") + return GeneratorOutput( + data=None, + error=str(e), + raw_response=str(response) + ) + @backoff.on_exception( backoff.expo, ( @@ -320,6 +374,19 @@ def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINE self.chat_completion_parser = handle_streaming_response return self.sync_client.chat.completions.create(**api_kwargs) return self.sync_client.chat.completions.create(**api_kwargs) + elif model_type == ModelType.IMAGE_GENERATION: + # Determine which image API to call based on the presence of image/mask + if "image" in api_kwargs: + if "mask" in api_kwargs: + # Image edit + response = self.sync_client.images.edit(**api_kwargs) + else: + # Image variation + response = self.sync_client.images.create_variation(**api_kwargs) + else: + # Image generation + response = self.sync_client.images.generate(**api_kwargs) + return response.data else: raise ValueError(f"model_type {model_type} is not supported") @@ -346,6 +413,19 @@ async def acall( return await self.async_client.embeddings.create(**api_kwargs) elif model_type == ModelType.LLM: return await self.async_client.chat.completions.create(**api_kwargs) + elif model_type == ModelType.IMAGE_GENERATION: + # Determine which image API to call based on the presence of image/mask + if "image" in api_kwargs: + if "mask" in api_kwargs: + # Image edit + response = await self.async_client.images.edit(**api_kwargs) + else: + # Image variation + response = await self.async_client.images.create_variation(**api_kwargs) + else: + # Image generation + response = await self.async_client.images.generate(**api_kwargs) + return response.data else: raise ValueError(f"model_type {model_type} is not supported") diff --git a/adalflow/adalflow/core/types.py b/adalflow/adalflow/core/types.py index 18724510..251635ca 100644 --- a/adalflow/adalflow/core/types.py +++ b/adalflow/adalflow/core/types.py @@ -58,6 +58,7 @@ class ModelType(Enum): EMBEDDER = auto() LLM = auto() RERANKER = auto() # ranking model + IMAGE_GENERATION = auto() # image generation models like DALL-E UNDEFINED = auto() diff --git a/adalflow/tests/test_openai_client.py b/adalflow/tests/test_openai_client.py index 258115ab..9167c820 100644 --- a/adalflow/tests/test_openai_client.py +++ b/adalflow/tests/test_openai_client.py @@ -3,7 +3,7 @@ import os import base64 -from openai.types import CompletionUsage +from openai.types import CompletionUsage, Image from openai.types.chat import ChatCompletion from adalflow.core.types import ModelType, GeneratorOutput @@ -23,7 +23,7 @@ def setUp(self): "id": "cmpl-3Q8Z5J9Z1Z5z5", "created": 1635820005, "object": "chat.completion", - "model": "gpt-3.5-turbo", + "model": "gpt-4o", "choices": [ { "message": { @@ -59,9 +59,17 @@ def setUp(self): ), } self.mock_vision_response = ChatCompletion(**self.mock_vision_response) + self.mock_image_response = [ + Image( + url="https://example.com/generated_image.jpg", + b64_json=None, + revised_prompt="A white siamese cat sitting elegantly", + model="dall-e-3", + ) + ] self.api_kwargs = { "messages": [{"role": "user", "content": "Hello"}], - "model": "gpt-3.5-turbo", + "model": "gpt-4o", } self.vision_api_kwargs = { "messages": [ @@ -81,6 +89,13 @@ def setUp(self): ], "model": "gpt-4o", } + self.image_generation_kwargs = { + "model": "dall-e-3", + "prompt": "a white siamese cat", + "size": "1024x1024", + "quality": "standard", + "n": 1, + } def test_encode_image(self): # Create a temporary test image file @@ -297,6 +312,111 @@ def test_call_with_vision(self, MockSyncOpenAI, mock_init_sync_client): self.assertEqual(output.usage.prompt_tokens, 25) self.assertEqual(output.usage.total_tokens, 40) + def test_convert_inputs_to_api_kwargs_for_image_generation(self): + # Test basic image generation + result = self.client.convert_inputs_to_api_kwargs( + input="a white siamese cat", + model_kwargs={"model": "dall-e-3"}, + model_type=ModelType.IMAGE_GENERATION, + ) + self.assertEqual(result["prompt"], "a white siamese cat") + self.assertEqual(result["model"], "dall-e-3") + self.assertEqual(result["size"], "1024x1024") # default + self.assertEqual(result["quality"], "standard") # default + self.assertEqual(result["n"], 1) # default + + # Test image edit + test_image = "test_image.jpg" + test_mask = "test_mask.jpg" + try: + # Create test files + with open(test_image, "wb") as f: + f.write(b"fake image content") + with open(test_mask, "wb") as f: + f.write(b"fake mask content") + + result = self.client.convert_inputs_to_api_kwargs( + input="a white siamese cat", + model_kwargs={ + "model": "dall-e-2", + "image": test_image, + "mask": test_mask, + }, + model_type=ModelType.IMAGE_GENERATION, + ) + self.assertEqual(result["prompt"], "a white siamese cat") + self.assertEqual(result["model"], "dall-e-2") + self.assertTrue(isinstance(result["image"], str)) # base64 encoded + self.assertTrue(isinstance(result["mask"], str)) # base64 encoded + finally: + # Cleanup + if os.path.exists(test_image): + os.remove(test_image) + if os.path.exists(test_mask): + os.remove(test_mask) + + @patch("adalflow.components.model_client.openai_client.AsyncOpenAI") + async def test_acall_image_generation(self, MockAsyncOpenAI): + mock_async_client = AsyncMock() + MockAsyncOpenAI.return_value = mock_async_client + + # Mock the image generation response + mock_async_client.images.generate = AsyncMock( + return_value=type('Response', (), {'data': self.mock_image_response})() + ) + + # Call the acall method with image generation + result = await self.client.acall( + api_kwargs=self.image_generation_kwargs, + model_type=ModelType.IMAGE_GENERATION, + ) + + # Assertions + MockAsyncOpenAI.assert_called_once() + mock_async_client.images.generate.assert_awaited_once_with( + **self.image_generation_kwargs + ) + self.assertEqual(result, self.mock_image_response) + + # Test parse_image_generation_response + output = self.client.parse_image_generation_response(result) + self.assertTrue(isinstance(output, GeneratorOutput)) + self.assertEqual(output.data, "https://example.com/generated_image.jpg") + + @patch( + "adalflow.components.model_client.openai_client.OpenAIClient.init_sync_client" + ) + @patch("adalflow.components.model_client.openai_client.OpenAI") + def test_call_image_generation(self, MockSyncOpenAI, mock_init_sync_client): + mock_sync_client = Mock() + MockSyncOpenAI.return_value = mock_sync_client + mock_init_sync_client.return_value = mock_sync_client + + # Mock the image generation response + mock_sync_client.images.generate = Mock( + return_value=type('Response', (), {'data': self.mock_image_response})() + ) + + # Set the sync client + self.client.sync_client = mock_sync_client + + # Call the call method with image generation + result = self.client.call( + api_kwargs=self.image_generation_kwargs, + model_type=ModelType.IMAGE_GENERATION, + ) + + # Assertions + mock_sync_client.images.generate.assert_called_once_with( + **self.image_generation_kwargs + ) + self.assertEqual(result, self.mock_image_response) + + # Test parse_image_generation_response + output = self.client.parse_image_generation_response(result) + self.assertTrue(isinstance(output, GeneratorOutput)) + self.assertEqual(output.data, "https://example.com/generated_image.jpg") + if __name__ == "__main__": unittest.main() From 5144fc4f265907a166f44b7710ea42a155db453e Mon Sep 17 00:00:00 2001 From: fm1320 Date: Fri, 10 Jan 2025 01:37:16 +0000 Subject: [PATCH 8/8] Update .rst file and colab --- docs/source/tutorials/model_client.rst | 66 +++++ .../tutorials/adalflow_modelclient.ipynb | 266 ++++++++++++++++++ 2 files changed, 332 insertions(+) diff --git a/docs/source/tutorials/model_client.rst b/docs/source/tutorials/model_client.rst index 47e83298..e8226398 100644 --- a/docs/source/tutorials/model_client.rst +++ b/docs/source/tutorials/model_client.rst @@ -1552,6 +1552,71 @@ The OpenAI client also supports multimodal inputs. Here's a quick example: The client handles both local files and URLs, with support for PNG, JPEG, WEBP, and non-animated GIF formats. +OPENAI Image Generation +------------------------------------------------- + +The OpenAI client supports image generation, editing, and variation creation through DALL-E models. First, you need to define a Generator class with the correct model type: + +.. code-block:: python + + from adalflow import Generator + from adalflow.core.types import ModelType + + class ImageGenerator(Generator): + """Generator subclass for image generation.""" + model_type = ModelType.IMAGE_GENERATION + +Then you can use it like this: + +.. code-block:: python + + from adalflow import OpenAIClient + + generator = ImageGenerator( + model_client=OpenAIClient(), + model_kwargs={ + "model": "dall-e-3", # or "dall-e-2" + "size": "1024x1024", # "1024x1024", "1024x1792", or "1792x1024" for DALL-E 3 + "quality": "standard", # "standard" or "hd" (DALL-E 3 only) + "n": 1 # Number of images (1 for DALL-E 3, 1-10 for DALL-E 2) + } + ) + + # Generate an image from text + response = generator( + prompt_kwargs={"input_str": "A white siamese cat in a space suit"} + ) + # response.data will contain the image URL + + # Edit an existing image + response = generator( + prompt_kwargs={"input_str": "Add a red hat"}, + model_kwargs={ + "model": "dall-e-2", + "image": "path/to/cat.png", # Original image + "mask": "path/to/mask.png" # Optional mask showing where to edit + } + ) + + # Create variations of an image + response = generator( + prompt_kwargs={"input_str": None}, # Not needed for variations + model_kwargs={ + "model": "dall-e-2", + "image": "path/to/cat.png" # Image to create variations of + } + ) + +The client supports: + +- Image generation from text descriptions using DALL-E 3 or DALL-E 2 +- Image editing with optional masking (DALL-E 2) +- Creating variations of existing images (DALL-E 2) +- Both local file paths and base64-encoded images +- Various image sizes and quality settings +- Multiple output formats (URL or base64) + +The response will always be wrapped in a ``GeneratorOutput`` object, maintaining consistency with other AdalFlow operations. The generated image(s) will be available in the ``data`` field as either a URL or base64 string. .. admonition:: API reference :class: highlight @@ -1563,3 +1628,4 @@ The client handles both local files and URLs, with support for PNG, JPEG, WEBP, - :class:`components.model_client.anthropic_client.AnthropicAPIClient` - :class:`components.model_client.google_client.GoogleGenAIClient` - :class:`components.model_client.cohere_client.CohereAPIClient` + diff --git a/notebooks/tutorials/adalflow_modelclient.ipynb b/notebooks/tutorials/adalflow_modelclient.ipynb index 1a2b3aba..f1c89bee 100644 --- a/notebooks/tutorials/adalflow_modelclient.ipynb +++ b/notebooks/tutorials/adalflow_modelclient.ipynb @@ -2043,6 +2043,272 @@ "build_custom_model_client()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Adalflow multimodal model client" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def analyze_single_image():\n", + " \"\"\"Example of analyzing a single image with GPT-4 Vision\"\"\"\n", + " client = OpenAIClient()\n", + " \n", + " gen = Generator(\n", + " model_client=client,\n", + " model_kwargs={\n", + " \"model\": \"gpt-4o-mini\",\n", + " \"images\": \"https://raw.githubusercontent.com/openai/openai-cookbook/main/examples/images/happy_cat.jpg\",\n", + " \"max_tokens\": 300\n", + " }\n", + " )\n", + " \n", + " response = gen({\"input_str\": \"What do you see in this image? Be detailed but concise.\"})\n", + " print(\"\\n=== Single Image Analysis ===\")\n", + " print(f\"Description: {response.raw_response}\")\n", + "\n", + "def analyze_multiple_images():\n", + " \"\"\"Example of analyzing multiple images in one prompt\"\"\"\n", + " client = OpenAIClient()\n", + " \n", + " # List of images to analyze together\n", + " images = [\n", + " \"https://raw.githubusercontent.com/openai/openai-cookbook/main/examples/images/happy_cat.jpg\",\n", + " \"https://raw.githubusercontent.com/openai/openai-cookbook/main/examples/images/sad_cat.jpg\"\n", + " ]\n", + " \n", + " gen = Generator(\n", + " model_client=client,\n", + " model_kwargs={\n", + " \"model\": \"gpt-4o-mini\",\n", + " \"images\": images,\n", + " \"max_tokens\": 300\n", + " }\n", + " )\n", + " \n", + " response = gen({\"input_str\": \"Compare and contrast these two images. What are the main differences?\"})\n", + " print(\"\\n=== Multiple Images Analysis ===\")\n", + " print(f\"Comparison: {response.raw_response}\")\n", + "\n", + "def generate_art_with_dalle():\n", + " \"\"\"Example of generating art using DALL-E 3\"\"\"\n", + " client = OpenAIClient()\n", + " \n", + " gen = Generator(\n", + " model_client=client,\n", + " model_kwargs={\n", + " \"model\": \"dall-e-3\",\n", + " \"size\": \"1024x1024\",\n", + " \"quality\": \"standard\",\n", + " \"n\": 1\n", + " }\n", + " )\n", + " \n", + " response = gen({\n", + " \"input_str\": \"A serene Japanese garden with a small bridge over a koi pond, cherry blossoms falling gently in the breeze\"\n", + " })\n", + " print(\"\\n=== Art Generation with DALL-E 3 ===\")\n", + " print(f\"Generated Image URL: {response.data}\")\n", + "\n", + "def create_image_variations(image_path=\"path/to/your/image.jpg\"):\n", + " \"\"\"Example of creating variations of an existing image\"\"\"\n", + " client = OpenAIClient()\n", + " \n", + " gen = Generator(\n", + " model_client=client,\n", + " model_kwargs={\n", + " \"model\": \"dall-e-2\",\n", + " \"image\": image_path,\n", + " \"n\": 2, # Generate 2 variations\n", + " \"size\": \"1024x1024\"\n", + " }\n", + " )\n", + " \n", + " response = gen({\"input_str\": \"\"})\n", + " print(\"\\n=== Image Variations ===\")\n", + " print(f\"Variation URLs: {response.data}\")\n", + "\n", + "def edit_image_with_mask(image_path=\"path/to/image.jpg\", mask_path=\"path/to/mask.jpg\"):\n", + " \"\"\"Example of editing specific parts of an image using a mask\"\"\"\n", + " client = OpenAIClient()\n", + " \n", + " gen = Generator(\n", + " model_client=client,\n", + " model_kwargs={\n", + " \"model\": \"dall-e-2\",\n", + " \"image\": image_path,\n", + " \"mask\": mask_path,\n", + " \"n\": 1,\n", + " \"size\": \"1024x1024\"\n", + " }\n", + " )\n", + " \n", + " response = gen({\n", + " \"input_str\": \"Replace the masked area with a beautiful sunset\"\n", + " })\n", + " print(\"\\n=== Image Editing ===\")\n", + " print(f\"Edited Image URL: {response.data}\")\n", + "\n", + "def mixed_image_text_conversation():\n", + " \"\"\"Example of having a conversation that includes both images and text\"\"\"\n", + " client = OpenAIClient()\n", + " \n", + " gen = Generator(\n", + " model_client=client,\n", + " model_kwargs={\n", + " \"model\": \"gpt-4o-mini\",\n", + " \"images\": [\n", + " \"https://raw.githubusercontent.com/openai/openai-cookbook/main/examples/images/happy_cat.jpg\",\n", + " \"https://path/to/local/image.jpg\" # Replace with your local image path\n", + " ],\n", + " \"max_tokens\": 300\n", + " }\n", + " )\n", + " \n", + " conversation = \"\"\"You are a helpful assistant skilled in analyzing images and providing detailed descriptions.\n", + " I'm showing you two images. Please analyze them and tell me what emotions they convey.\"\"\"\n", + " \n", + " response = gen({\"input_str\": conversation})\n", + " print(\"\\n=== Mixed Image-Text Conversation ===\")\n", + " print(f\"Assistant's Analysis: {response.raw_response}\")\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if __name__ == \"__main__\":\n", + " print(\"OpenAI Image Processing Examples\\n\")\n", + " \n", + " # Basic image analysis\n", + " analyze_single_image()\n", + " \n", + " # Multiple image analysis\n", + " analyze_multiple_images()\n", + " \n", + " # Image generation\n", + " generate_art_with_dalle()\n", + " \n", + " # create_image_variations()\n", + " # edit_image_with_mask(, )\n", + " # mixed_image_text_conversation()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Image generation with Dall E and image understanding" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from adalflow.core import Generator\n", + "from adalflow.components.model_client.openai_client import OpenAIClient\n", + "from adalflow.core.types import ModelType" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class ImageGenerator(Generator):\n", + " \"\"\"Generator subclass for image generation.\"\"\"\n", + " model_type = ModelType.IMAGE_GENERATION\n", + "\n", + "def test_vision_and_generation():\n", + " \"\"\"Test both vision analysis and image generation\"\"\"\n", + " client = OpenAIClient()\n", + " \n", + " # 1. Test Vision Analysis\n", + " vision_gen = Generator(\n", + " model_client=client,\n", + " model_kwargs={\n", + " \"model\": \"gpt-4o-mini\",\n", + " \"images\": \"https://upload.wikimedia.org/wikipedia/en/7/7d/Lenna_%28test_image%29.png\",\n", + " \"max_tokens\": 300\n", + " }\n", + " )\n", + " \n", + " vision_response = vision_gen({\"input_str\": \"What do you see in this image? Be detailed but concise.\"})\n", + " print(\"\\n=== Vision Analysis ===\")\n", + " print(f\"Description: {vision_response.raw_response}\")\n", + "\n", + " # 2. Test DALL-E Image Generation\n", + " dalle_gen = ImageGenerator(\n", + " model_client=client,\n", + " model_kwargs={\n", + " \"model\": \"dall-e-3\",\n", + " \"size\": \"1024x1024\",\n", + " \"quality\": \"standard\",\n", + " \"n\": 1\n", + " }\n", + " )\n", + " \n", + " # For image generation, input_str becomes the prompt\n", + " response = dalle_gen({\"input_str\": \"A happy siamese cat playing with a red ball of yarn\"})\n", + " print(\"\\n=== DALL-E Generation ===\")\n", + " print(f\"Generated Image URL: {response.data}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Invalid image url - Generator output still works!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def test_invalid_image_url():\n", + " \"\"\"Test Generator output with invalid image URL\"\"\"\n", + " client = OpenAIClient()\n", + " gen = Generator(\n", + " model_client=client,\n", + " model_kwargs={\n", + " \"model\": \"gpt-4o-mini\",\n", + " \"images\": \"https://invalid.url/nonexistent.jpg\",\n", + " \"max_tokens\": 300\n", + " }\n", + " )\n", + " \n", + " print(\"\\n=== Testing Invalid Image URL ===\")\n", + " response = gen({\"input_str\": \"What do you see in this image?\"})\n", + " print(f\"Response with invalid image URL: {response}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if __name__ == \"__main__\":\n", + " print(\"Starting OpenAI Vision and DALL-E test...\\n\")\n", + " test_invalid_image_url()\n", + " test_vision_and_generation() " + ] + }, { "cell_type": "markdown", "metadata": {