From a042ec4ec304512fbe0bf72a62fec0a361f54569 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 25 Jun 2024 14:37:48 +0200 Subject: [PATCH] clip vision model onnx export --- optimum/exporters/onnx/model_configs.py | 16 ++++++++++++++++ optimum/exporters/tasks.py | 4 ++++ tests/exporters/exporters_utils.py | 1 + 3 files changed, 21 insertions(+) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index e23716d4b74..dadc44c7ff9 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -886,6 +886,22 @@ class CLIPNormalizedConfig(NormalizedTextAndVisionConfig): VISION_CONFIG = "vision_config" +class CLIPVisionModelOnnxConfig(VisionOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + return {"pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}} + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + common_outputs = super().outputs + common_outputs["last_hidden_state"] = {0: "batch_size"} + common_outputs["pooler_output"] = {0: "batch_size"} + + return common_outputs + + class CLIPOnnxConfig(TextAndVisionOnnxConfig): NORMALIZED_CONFIG_CLASS = CLIPNormalizedConfig DEFAULT_ONNX_OPSET = 14 diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 608b3df0d7c..12feda333e8 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -446,6 +446,10 @@ class TasksManager: "zero-shot-image-classification", onnx="CLIPOnnxConfig", ), + "clip-vision-model": supported_tasks_mapping( + "feature-extraction", + onnx="CLIPVisionModelOnnxConfig", + ), "codegen": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 0c52754ff60..9c5d2c8991f 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -56,6 +56,7 @@ "bloom": "hf-internal-testing/tiny-random-BloomModel", "camembert": "hf-internal-testing/tiny-random-camembert", "clip": "hf-internal-testing/tiny-random-CLIPModel", + "clip-vision-model": "fxmarty/clip-vision-model-tiny", "convbert": "hf-internal-testing/tiny-random-ConvBertModel", "convnext": "hf-internal-testing/tiny-random-convnext", "convnextv2": "hf-internal-testing/tiny-random-ConvNextV2Model",