Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds DonutSwin to models exportable with ONNX #19401

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/en/serialization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ Ready-made configurations include the following architectures:
- DeiT
- DETR
- DistilBERT
- DonutSwin
- ELECTRA
- ERNIE
- FlauBERT
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/donut/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


_import_structure = {
"configuration_donut_swin": ["DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "DonutSwinConfig"],
"configuration_donut_swin": ["DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP", "DonutSwinConfig", "DonutSwinOnnxConfig"],
"processing_donut": ["DonutProcessor"],
}

Expand All @@ -47,7 +47,11 @@


if TYPE_CHECKING:
from .configuration_donut_swin import DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP, DonutSwinConfig
from .configuration_donut_swin import (
DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP,
DonutSwinConfig,
DonutSwinOnnxConfig,
)
from .processing_donut import DonutProcessor

try:
Expand Down
66 changes: 66 additions & 0 deletions src/transformers/models/donut/configuration_donut_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Donut Swin Transformer model configuration"""
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Mapping, Optional

from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...onnx.utils import compute_effective_axis_dimension
from ...utils import logging


if TYPE_CHECKING:
from ...feature_extraction_utils import FeatureExtractionMixin
from ...utils import TensorType

logger = logging.get_logger(__name__)

DONUT_SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
Expand Down Expand Up @@ -138,3 +146,61 @@ def __init__(
# we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
# this indicates the channel dimension after the last stage of the model
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))


class DonutSwinOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
]
)

@property
def atol_for_validation(self) -> float:
return 1e-4

@property
def default_onnx_opset(self) -> int:
return 16

def generate_dummy_inputs(
self,
processor: "FeatureExtractionMixin",
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional["TensorType"] = None,
num_channels: int = 3,
image_width: int = 40,
image_height: int = 40,
) -> Mapping[str, Any]:
"""
Generate inputs to provide to the ONNX exporter for the specific framework

Args:
processor ([`ProcessorMixin`]):
The processor associated with this model configuration.
batch_size (`int`, *optional*, defaults to -1):
The batch size to export the model for (-1 means dynamic axis).
seq_length (`int`, *optional*, defaults to -1):
The sequence length to export the model for (-1 means dynamic axis).
is_pair (`bool`, *optional*, defaults to `False`):
Indicate if the input is a pair (sentence 1, sentence 2).
framework (`TensorType`, *optional*, defaults to `None`):
The framework (PyTorch or TensorFlow) that the processor will generate tensors for.
num_channels (`int`, *optional*, defaults to 3):
The number of channels of the generated images.
image_width (`int`, *optional*, defaults to 40):
The width of the generated images.
image_height (`int`, *optional*, defaults to 40):
The height of the generated images.

Returns:
Mapping[str, Any]: holding the kwargs to provide to the model's forward function
"""

batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch)
dummy_input = self._generate_dummy_images(batch_size, num_channels, image_height, image_width)
return dict(processor(images=dummy_input, return_tensors=framework))
10 changes: 10 additions & 0 deletions src/transformers/onnx/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from transformers.models.auto import (
AutoModel,
AutoModelForCausalLM,
AutoModelForDocumentQuestionAnswering,
AutoModelForImageClassification,
AutoModelForImageSegmentation,
AutoModelForMaskedImageModeling,
Expand All @@ -36,6 +37,7 @@
from transformers.models.auto import (
TFAutoModel,
TFAutoModelForCausalLM,
TFAutoModelForDocumentQuestionAnswering,
TFAutoModelForMaskedLM,
TFAutoModelForMultipleChoice,
TFAutoModelForQuestionAnswering,
Expand Down Expand Up @@ -94,6 +96,7 @@ class FeaturesManager:
"token-classification": AutoModelForTokenClassification,
"multiple-choice": AutoModelForMultipleChoice,
"object-detection": AutoModelForObjectDetection,
"document-question-answering": AutoModelForDocumentQuestionAnswering,
"question-answering": AutoModelForQuestionAnswering,
"image-classification": AutoModelForImageClassification,
"image-segmentation": AutoModelForImageSegmentation,
Expand All @@ -110,6 +113,7 @@ class FeaturesManager:
"sequence-classification": TFAutoModelForSequenceClassification,
"token-classification": TFAutoModelForTokenClassification,
"multiple-choice": TFAutoModelForMultipleChoice,
"document-question-answering": TFAutoModelForDocumentQuestionAnswering,
"question-answering": TFAutoModelForQuestionAnswering,
"semantic-segmentation": TFAutoModelForSemanticSegmentation,
}
Expand Down Expand Up @@ -282,6 +286,12 @@ class FeaturesManager:
"question-answering",
onnx_config_cls="models.distilbert.DistilBertOnnxConfig",
),
"donut-swin": supported_features_mapping(
"default",
"document-question-answering",
"image-classification",
onnx_config_cls="models.donut.DonutSwinOnnxConfig",
),
"electra": supported_features_mapping(
"default",
"masked-lm",
Expand Down
1 change: 1 addition & 0 deletions tests/onnx/test_onnx_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def test_values_override(self):
("convnext", "facebook/convnext-tiny-224"),
("detr", "facebook/detr-resnet-50"),
("distilbert", "distilbert-base-cased"),
("donut-swin", "naver-clova-ix/donut-base"),
("electra", "google/electra-base-generator"),
("resnet", "microsoft/resnet-50"),
("roberta", "roberta-base"),
Expand Down