Skip to content

Commit

Permalink
Add RemBERT ONNX config (huggingface#20520)
Browse files Browse the repository at this point in the history
* rembert onnx config

* formatting

Co-authored-by: Ho <[email protected]>
  • Loading branch information
2 people authored and amyeroberts committed Dec 7, 2022
1 parent c4653b0 commit 0135f86
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/en/serialization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ Ready-made configurations include the following architectures:
- OWL-ViT
- Perceiver
- PLBart
- RemBERT
- ResNet
- RoBERTa
- RoFormer
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/rembert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
)


_import_structure = {"configuration_rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig"]}
_import_structure = {
"configuration_rembert": ["REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RemBertConfig", "RemBertOnnxConfig"]
}

try:
if not is_sentencepiece_available():
Expand Down Expand Up @@ -88,7 +90,7 @@


if TYPE_CHECKING:
from .configuration_rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig
from .configuration_rembert import REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RemBertConfig, RemBertOnnxConfig

try:
if not is_sentencepiece_available():
Expand Down
23 changes: 23 additions & 0 deletions src/transformers/models/rembert/configuration_rembert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" RemBERT model configuration"""
from collections import OrderedDict
from typing import Mapping

from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging


Expand Down Expand Up @@ -135,3 +138,23 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.use_cache = use_cache
self.tie_word_embeddings = False


class RemBertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
("token_type_ids", dynamic_axis),
]
)

@property
def atol_for_validation(self) -> float:
return 1e-4
10 changes: 10 additions & 0 deletions src/transformers/onnx/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,16 @@ class FeaturesManager:
"sequence-classification",
onnx_config_cls="models.perceiver.PerceiverOnnxConfig",
),
"rembert": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls="models.rembert.RemBertOnnxConfig",
),
"resnet": supported_features_mapping(
"default",
"image-classification",
Expand Down
1 change: 1 addition & 0 deletions tests/onnx/test_onnx_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def test_values_override(self):
("owlvit", "google/owlvit-base-patch32"),
("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("masked-lm", "sequence-classification")),
("perceiver", "hf-internal-testing/tiny-random-PerceiverModel", ("image-classification",)),
("rembert", "google/rembert"),
("resnet", "microsoft/resnet-50"),
("roberta", "hf-internal-testing/tiny-random-RobertaModel"),
("roformer", "hf-internal-testing/tiny-random-RoFormerModel"),
Expand Down

0 comments on commit 0135f86

Please sign in to comment.