Skip to content

Commit

Permalink
added deit onnx config (huggingface#16887)
Browse files Browse the repository at this point in the history
* added deit onnx config
  • Loading branch information
0xrushi authored and elusenji committed Jun 12, 2022
1 parent 5251b69 commit 37c6759
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/source/en/serialization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ Ready-made configurations include the following architectures:
- ConvBERT
- Data2VecText
- Data2VecVision
- DeiT
- DistilBERT
- ELECTRA
- FlauBERT
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/deit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


_import_structure = {
"configuration_deit": ["DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeiTConfig"],
"configuration_deit": ["DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DeiTConfig", "DeiTOnnxConfig"],
}

if is_vision_available():
Expand All @@ -39,7 +39,7 @@


if TYPE_CHECKING:
from .configuration_deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig
from .configuration_deit import DEIT_PRETRAINED_CONFIG_ARCHIVE_MAP, DeiTConfig, DeiTOnnxConfig

if is_vision_available():
from .feature_extraction_deit import DeiTFeatureExtractor
Expand Down
22 changes: 22 additions & 0 deletions src/transformers/models/deit/configuration_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
# limitations under the License.
""" DeiT model configuration"""

from collections import OrderedDict
from typing import Mapping

from packaging import version

from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging


Expand Down Expand Up @@ -120,3 +126,19 @@ def __init__(
self.num_channels = num_channels
self.qkv_bias = qkv_bias
self.encoder_stride = encoder_stride


class DeiTOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")

@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("pixel_values", {0: "batch", 1: "sequence"}),
]
)

@property
def atol_for_validation(self) -> float:
return 1e-4
14 changes: 12 additions & 2 deletions src/transformers/onnx/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..models.camembert import CamembertOnnxConfig
from ..models.convbert import ConvBertOnnxConfig
from ..models.data2vec import Data2VecTextOnnxConfig
from ..models.deit import DeiTOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig
from ..models.electra import ElectraOnnxConfig
from ..models.flaubert import FlaubertOnnxConfig
Expand All @@ -38,6 +39,7 @@
AutoModel,
AutoModelForCausalLM,
AutoModelForImageClassification,
AutoModelForMaskedImageModeling,
AutoModelForMaskedLM,
AutoModelForMultipleChoice,
AutoModelForQuestionAnswering,
Expand Down Expand Up @@ -103,6 +105,7 @@ class FeaturesManager:
"multiple-choice": AutoModelForMultipleChoice,
"question-answering": AutoModelForQuestionAnswering,
"image-classification": AutoModelForImageClassification,
"masked-im": AutoModelForMaskedImageModeling,
}
if is_tf_available():
_TASKS_TO_TF_AUTOMODELS = {
Expand Down Expand Up @@ -294,8 +297,15 @@ class FeaturesManager:
"question-answering",
onnx_config_cls=ElectraOnnxConfig,
),
"vit": supported_features_mapping("default", "image-classification", onnx_config_cls=ViTOnnxConfig),
"beit": supported_features_mapping("default", "image-classification", onnx_config_cls=BeitOnnxConfig),
"vit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls=ViTOnnxConfig
),
"beit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls=BeitOnnxConfig
),
"deit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls=DeiTOnnxConfig
),
"blenderbot": supported_features_mapping(
"default",
"default-with-past",
Expand Down
1 change: 1 addition & 0 deletions tests/onnx/test_onnx_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def test_values_override(self):
("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"),
("vit", "google/vit-base-patch16-224"),
("deit", "facebook/deit-small-patch16-224"),
("beit", "microsoft/beit-base-patch16-224"),
("data2vec-text", "facebook/data2vec-text-base"),
}
Expand Down

0 comments on commit 37c6759

Please sign in to comment.