From e3d99bae998c0bb09f7d3eaff7415a358b787d95 Mon Sep 17 00:00:00 2001 From: Erin <14718778+hchings@users.noreply.github.com> Date: Mon, 5 Dec 2022 08:39:09 -0800 Subject: [PATCH] Add RemBERT ONNX config (#20520) * rembert onnx config * formatting Co-authored-by: Ho --- docs/source/en/serialization.mdx | 1 + src/transformers/models/rembert/__init__.py | 6 +++-- .../models/rembert/configuration_rembert.py | 23 +++++++++++++++++++ src/transformers/onnx/features.py | 10 ++++++++ tests/onnx/test_onnx_v2.py | 1 + 5 files changed, 39 insertions(+), 2 deletions(-) diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index 9167dc8d4e2358..6fde558464bd1e 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -93,6 +93,7 @@ Ready-made configurations include the following architectures: - OWL-ViT - Perceiver - PLBart +- RemBERT - ResNet - RoBERTa - RoFormer diff --git a/src/transformers/models/rembert/__init__.py b/src/transformers/models/rembert/__init__.py index 10af6c4d27f3be..72d43887039430 100644 --- a/src/transformers/models/rembert/__init__.py +++ b/src/transformers/models/rembert/__init__.py @@ -28,7 +28,9 @@ ) -_import_structure = {"configuration_rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig"]} +_import_structure = { + "configuration_rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig", "RemBertOnnxConfig"] +} try: if not is_sentencepiece_available(): @@ -88,7 +90,7 @@ if TYPE_CHECKING: - from .configuration_rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig + from .configuration_rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig, RemBertOnnxConfig try: if not is_sentencepiece_available(): diff --git a/src/transformers/models/rembert/configuration_rembert.py b/src/transformers/models/rembert/configuration_rembert.py index 935b56da164db9..6a2d8a13c5e84c 100644 --- a/src/transformers/models/rembert/configuration_rembert.py +++ b/src/transformers/models/rembert/configuration_rembert.py @@ -13,8 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """ RemBERT model configuration""" +from collections import OrderedDict +from typing import Mapping from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging @@ -135,3 +138,23 @@ def __init__( self.layer_norm_eps = layer_norm_eps self.use_cache = use_cache self.tie_word_embeddings = False + + +class RemBertOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ("token_type_ids", dynamic_axis), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 7ae0b509b96479..0fb750ebc551f8 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -447,6 +447,16 @@ class FeaturesManager: "sequence-classification", onnx_config_cls="models.perceiver.PerceiverOnnxConfig", ), + "rembert": supported_features_mapping( + "default", + "masked-lm", + "causal-lm", + "sequence-classification", + "multiple-choice", + "token-classification", + "question-answering", + onnx_config_cls="models.rembert.RemBertOnnxConfig", + ), "resnet": supported_features_mapping( "default", "image-classification", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index fa58c5ce754465..fbc959284db162 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -210,6 +210,7 @@ def test_values_override(self): ("owlvit", "google/owlvit-base-patch32"), ("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("masked-lm", "sequence-classification")), ("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("image-classification",)), + ("rembert", "google/rembert"), ("resnet", "microsoft/resnet-50"), ("roberta", "hf-internal-testing/tiny-random-RobertaModel"), ("roformer", "hf-internal-testing/tiny-random-RoFormerModel"),