From 46c4d95784c620d5b4d8ea3781e2972cdac5dfb2 Mon Sep 17 00:00:00 2001 From: Ho Date: Wed, 30 Nov 2022 21:36:34 -0800 Subject: [PATCH 1/2] rembert onnx config --- docs/source/en/serialization.mdx | 1 + src/transformers/models/rembert/__init__.py | 4 ++-- .../models/rembert/configuration_rembert.py | 22 +++++++++++++++++++ src/transformers/onnx/features.py | 10 +++++++++ tests/onnx/test_onnx_v2.py | 1 + 5 files changed, 36 insertions(+), 2 deletions(-) diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index 9167dc8d4e2358..6fde558464bd1e 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -93,6 +93,7 @@ Ready-made configurations include the following architectures: - OWL-ViT - Perceiver - PLBart +- RemBERT - ResNet - RoBERTa - RoFormer diff --git a/src/transformers/models/rembert/__init__.py b/src/transformers/models/rembert/__init__.py index 10af6c4d27f3be..19c1c143303347 100644 --- a/src/transformers/models/rembert/__init__.py +++ b/src/transformers/models/rembert/__init__.py @@ -28,7 +28,7 @@ ) -_import_structure = {"configuration_rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig"]} +_import_structure = {"configuration_rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig", "RemBertOnnxConfig"]} try: if not is_sentencepiece_available(): @@ -88,7 +88,7 @@ if TYPE_CHECKING: - from .configuration_rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig + from .configuration_rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig, RemBertOnnxConfig try: if not is_sentencepiece_available(): diff --git a/src/transformers/models/rembert/configuration_rembert.py b/src/transformers/models/rembert/configuration_rembert.py index 732d75c5cc2b3d..4f41da1f6f43d8 100644 --- a/src/transformers/models/rembert/configuration_rembert.py +++ b/src/transformers/models/rembert/configuration_rembert.py @@ -13,8 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """ RemBERT model configuration""" +from collections import OrderedDict +from typing import Mapping from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging @@ -136,3 +139,22 @@ def __init__( self.layer_norm_eps = layer_norm_eps self.use_cache = use_cache self.tie_word_embeddings = False + + +class RemBertOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ("token_type_ids", dynamic_axis), + ] + ) + @property + def atol_for_validation(self) -> float: + return 1e-4 diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 7ae0b509b96479..0fb750ebc551f8 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -447,6 +447,16 @@ class FeaturesManager: "sequence-classification", onnx_config_cls="models.perceiver.PerceiverOnnxConfig", ), + "rembert": supported_features_mapping( + "default", + "masked-lm", + "causal-lm", + "sequence-classification", + "multiple-choice", + "token-classification", + "question-answering", + onnx_config_cls="models.rembert.RemBertOnnxConfig", + ), "resnet": supported_features_mapping( "default", "image-classification", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index fa58c5ce754465..fbc959284db162 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -210,6 +210,7 @@ def test_values_override(self): ("owlvit", "google/owlvit-base-patch32"), ("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("masked-lm", "sequence-classification")), ("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("image-classification",)), + ("rembert", "google/rembert"), ("resnet", "microsoft/resnet-50"), ("roberta", "hf-internal-testing/tiny-random-RobertaModel"), ("roformer", "hf-internal-testing/tiny-random-RoFormerModel"), From 4fa263d1e2d4b56fb689f7f7c88e018adf814a46 Mon Sep 17 00:00:00 2001 From: Ho Date: Wed, 30 Nov 2022 23:53:23 -0800 Subject: [PATCH 2/2] formatting --- src/transformers/models/rembert/__init__.py | 4 +++- src/transformers/models/rembert/configuration_rembert.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/rembert/__init__.py b/src/transformers/models/rembert/__init__.py index 19c1c143303347..72d43887039430 100644 --- a/src/transformers/models/rembert/__init__.py +++ b/src/transformers/models/rembert/__init__.py @@ -28,7 +28,9 @@ ) -_import_structure = {"configuration_rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig", "RemBertOnnxConfig"]} +_import_structure = { + "configuration_rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig", "RemBertOnnxConfig"] +} try: if not is_sentencepiece_available(): diff --git a/src/transformers/models/rembert/configuration_rembert.py b/src/transformers/models/rembert/configuration_rembert.py index 4f41da1f6f43d8..22bd7c19d8900f 100644 --- a/src/transformers/models/rembert/configuration_rembert.py +++ b/src/transformers/models/rembert/configuration_rembert.py @@ -155,6 +155,7 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: ("token_type_ids", dynamic_axis), ] ) + @property def atol_for_validation(self) -> float: return 1e-4