From 37c6759ecebe4ba28a6ec3259d270e487afb22d4 Mon Sep 17 00:00:00 2001 From: Rushi Chaudhari Date: Mon, 25 Apr 2022 14:50:45 -0400 Subject: [PATCH] added deit onnx config (#16887) * added deit onnx config --- docs/source/en/serialization.mdx | 1 + src/transformers/models/deit/__init__.py | 4 ++-- .../models/deit/configuration_deit.py | 22 +++++++++++++++++++ src/transformers/onnx/features.py | 14 ++++++++++-- tests/onnx/test_onnx_v2.py | 1 + 5 files changed, 38 insertions(+), 4 deletions(-) diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index 4255b8f6e1a652..87d327322812b2 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -56,6 +56,7 @@ Ready-made configurations include the following architectures: - ConvBERT - Data2VecText - Data2VecVision +- DeiT - DistilBERT - ELECTRA - FlauBERT diff --git a/src/transformers/models/deit/__init__.py b/src/transformers/models/deit/__init__.py index bcded61602f94b..913e53f9ae8632 100644 --- a/src/transformers/models/deit/__init__.py +++ b/src/transformers/models/deit/__init__.py @@ -21,7 +21,7 @@ _import_structure = { - "configuration_deit": ["DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeiTConfig"], + "configuration_deit": ["DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeiTConfig", "DeiTOnnxConfig"], } if is_vision_available(): @@ -39,7 +39,7 @@ if TYPE_CHECKING: - from .configuration_deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig + from .configuration_deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig, DeiTOnnxConfig if is_vision_available(): from .feature_extraction_deit import DeiTFeatureExtractor diff --git a/src/transformers/models/deit/configuration_deit.py b/src/transformers/models/deit/configuration_deit.py index 616e1288e0bea5..022df1727f5830 100644 --- a/src/transformers/models/deit/configuration_deit.py +++ b/src/transformers/models/deit/configuration_deit.py @@ -14,7 +14,13 @@ # limitations under the License. """ DeiT model configuration""" +from collections import OrderedDict +from typing import Mapping + +from packaging import version + from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging @@ -120,3 +126,19 @@ def __init__( self.num_channels = num_channels self.qkv_bias = qkv_bias self.encoder_stride = encoder_stride + + +class DeiTOnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 31bf0b45d77cc7..4133d6918c9b42 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -12,6 +12,7 @@ from ..models.camembert import CamembertOnnxConfig from ..models.convbert import ConvBertOnnxConfig from ..models.data2vec import Data2VecTextOnnxConfig +from ..models.deit import DeiTOnnxConfig from ..models.distilbert import DistilBertOnnxConfig from ..models.electra import ElectraOnnxConfig from ..models.flaubert import FlaubertOnnxConfig @@ -38,6 +39,7 @@ AutoModel, AutoModelForCausalLM, AutoModelForImageClassification, + AutoModelForMaskedImageModeling, AutoModelForMaskedLM, AutoModelForMultipleChoice, AutoModelForQuestionAnswering, @@ -103,6 +105,7 @@ class FeaturesManager: "multiple-choice": AutoModelForMultipleChoice, "question-answering": AutoModelForQuestionAnswering, "image-classification": AutoModelForImageClassification, + "masked-im": AutoModelForMaskedImageModeling, } if is_tf_available(): _TASKS_TO_TF_AUTOMODELS = { @@ -294,8 +297,15 @@ class FeaturesManager: "question-answering", onnx_config_cls=ElectraOnnxConfig, ), - "vit": supported_features_mapping("default", "image-classification", onnx_config_cls=ViTOnnxConfig), - "beit": supported_features_mapping("default", "image-classification", onnx_config_cls=BeitOnnxConfig), + "vit": supported_features_mapping( + "default", "image-classification", "masked-im", onnx_config_cls=ViTOnnxConfig + ), + "beit": supported_features_mapping( + "default", "image-classification", "masked-im", onnx_config_cls=BeitOnnxConfig + ), + "deit": supported_features_mapping( + "default", "image-classification", "masked-im", onnx_config_cls=DeiTOnnxConfig + ), "blenderbot": supported_features_mapping( "default", "default-with-past", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index e8b403e54f19cb..40a964550976cd 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -182,6 +182,7 @@ def test_values_override(self): ("xlm-roberta", "xlm-roberta-base"), ("layoutlm", "microsoft/layoutlm-base-uncased"), ("vit", "google/vit-base-patch16-224"), + ("deit", "facebook/deit-small-patch16-224"), ("beit", "microsoft/beit-base-patch16-224"), ("data2vec-text", "facebook/data2vec-text-base"), }