diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index 7c89cac4431940..a27e4e53b38390 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -67,6 +67,7 @@ Ready-made configurations include the following architectures: - DeiT - DETR - DistilBERT +- DonutSwin - ELECTRA - ERNIE - FlauBERT diff --git a/src/transformers/models/donut/__init__.py b/src/transformers/models/donut/__init__.py index a01f6b11a9a995..2cad3ac041c4c0 100644 --- a/src/transformers/models/donut/__init__.py +++ b/src/transformers/models/donut/__init__.py @@ -21,7 +21,7 @@ _import_structure = { - "configuration_donut_swin": ["DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "DonutSwinConfig"], + "configuration_donut_swin": ["DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "DonutSwinConfig", "DonutSwinOnnxConfig"], "processing_donut": ["DonutProcessor"], } @@ -47,7 +47,11 @@ if TYPE_CHECKING: - from .configuration_donut_swin import DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, DonutSwinConfig + from .configuration_donut_swin import ( + DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, + DonutSwinConfig, + DonutSwinOnnxConfig, + ) from .processing_donut import DonutProcessor try: diff --git a/src/transformers/models/donut/configuration_donut_swin.py b/src/transformers/models/donut/configuration_donut_swin.py index d3316bdc79f685..a1a05e84f4a21e 100644 --- a/src/transformers/models/donut/configuration_donut_swin.py +++ b/src/transformers/models/donut/configuration_donut_swin.py @@ -13,11 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Donut Swin Transformer model configuration""" +from collections import OrderedDict +from typing import TYPE_CHECKING, Any, Mapping, Optional from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...onnx.utils import compute_effective_axis_dimension from ...utils import logging +if TYPE_CHECKING: + from ...feature_extraction_utils import FeatureExtractionMixin + from ...utils import TensorType + logger = logging.get_logger(__name__) DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP = { @@ -138,3 +146,61 @@ def __init__( # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel # this indicates the channel dimension after the last stage of the model self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) + + +class DonutSwinOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + @property + def default_onnx_opset(self) -> int: + return 16 + + def generate_dummy_inputs( + self, + processor: "FeatureExtractionMixin", + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional["TensorType"] = None, + num_channels: int = 3, + image_width: int = 40, + image_height: int = 40, + ) -> Mapping[str, Any]: + """ + Generate inputs to provide to the ONNX exporter for the specific framework + + Args: + processor ([`ProcessorMixin`]): + The processor associated with this model configuration. + batch_size (`int`, *optional*, defaults to -1): + The batch size to export the model for (-1 means dynamic axis). + seq_length (`int`, *optional*, defaults to -1): + The sequence length to export the model for (-1 means dynamic axis). + is_pair (`bool`, *optional*, defaults to `False`): + Indicate if the input is a pair (sentence 1, sentence 2). + framework (`TensorType`, *optional*, defaults to `None`): + The framework (PyTorch or TensorFlow) that the processor will generate tensors for. + num_channels (`int`, *optional*, defaults to 3): + The number of channels of the generated images. + image_width (`int`, *optional*, defaults to 40): + The width of the generated images. + image_height (`int`, *optional*, defaults to 40): + The height of the generated images. + + Returns: + Mapping[str, Any]: holding the kwargs to provide to the model's forward function + """ + + batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch) + dummy_input = self._generate_dummy_images(batch_size, num_channels, image_height, image_width) + return dict(processor(images=dummy_input, return_tensors=framework)) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 6a0ec0f7c70794..5e0456d19afd67 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -19,6 +19,7 @@ from transformers.models.auto import ( AutoModel, AutoModelForCausalLM, + AutoModelForDocumentQuestionAnswering, AutoModelForImageClassification, AutoModelForImageSegmentation, AutoModelForMaskedImageModeling, @@ -36,6 +37,7 @@ from transformers.models.auto import ( TFAutoModel, TFAutoModelForCausalLM, + TFAutoModelForDocumentQuestionAnswering, TFAutoModelForMaskedLM, TFAutoModelForMultipleChoice, TFAutoModelForQuestionAnswering, @@ -94,6 +96,7 @@ class FeaturesManager: "token-classification": AutoModelForTokenClassification, "multiple-choice": AutoModelForMultipleChoice, "object-detection": AutoModelForObjectDetection, + "document-question-answering": AutoModelForDocumentQuestionAnswering, "question-answering": AutoModelForQuestionAnswering, "image-classification": AutoModelForImageClassification, "image-segmentation": AutoModelForImageSegmentation, @@ -110,6 +113,7 @@ class FeaturesManager: "sequence-classification": TFAutoModelForSequenceClassification, "token-classification": TFAutoModelForTokenClassification, "multiple-choice": TFAutoModelForMultipleChoice, + "document-question-answering": TFAutoModelForDocumentQuestionAnswering, "question-answering": TFAutoModelForQuestionAnswering, "semantic-segmentation": TFAutoModelForSemanticSegmentation, } @@ -282,6 +286,12 @@ class FeaturesManager: "question-answering", onnx_config_cls="models.distilbert.DistilBertOnnxConfig", ), + "donut-swin": supported_features_mapping( + "default", + "document-question-answering", + "image-classification", + onnx_config_cls="models.donut.DonutSwinOnnxConfig", + ), "electra": supported_features_mapping( "default", "masked-lm", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 81cd55d3bb5a81..4242740556ec14 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -192,6 +192,7 @@ def test_values_override(self): ("convnext", "facebook/convnext-tiny-224"), ("detr", "facebook/detr-resnet-50"), ("distilbert", "distilbert-base-cased"), + ("donut-swin", "naver-clova-ix/donut-base"), ("electra", "google/electra-base-generator"), ("resnet", "microsoft/resnet-50"), ("roberta", "roberta-base"),