From 466bf4800b75ec29bd2ff75bad8e8973bd98d01c Mon Sep 17 00:00:00 2001 From: Manan Dey Date: Sat, 30 Apr 2022 11:43:51 +0530 Subject: [PATCH 1/7] update docs of length_penalty --- src/transformers/generation_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 76086c4b7d6330..f66bbf84b0cb33 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -950,9 +950,9 @@ def generate( eos_token_id (`int`, *optional*): The id of the *end-of-sequence* token. length_penalty (`float`, *optional*, defaults to 1.0): - Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the - model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer - sequences. + Exponential penalty to the length. 1.0 means that the beam score is penalized by the sequence length. 0.0 means no penalty. Set to values < 0.0 in order to encourage the + model to generate longer sequences, to a value > 0.0 in order to encourage the model to produce shorter + sequences. no_repeat_ngram_size (`int`, *optional*, defaults to 0): If set to int > 0, all ngrams of that size can only occur once. encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0): From 7786ce4cb7448c7d93a2d004eb2546bd40a1d766 Mon Sep 17 00:00:00 2001 From: Manan Dey Date: Sun, 1 May 2022 17:30:59 +0530 Subject: [PATCH 2/7] Revert "update docs of length_penalty" This reverts commit 466bf4800b75ec29bd2ff75bad8e8973bd98d01c. --- src/transformers/generation_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index f66bbf84b0cb33..76086c4b7d6330 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -950,9 +950,9 @@ def generate( eos_token_id (`int`, *optional*): The id of the *end-of-sequence* token. length_penalty (`float`, *optional*, defaults to 1.0): - Exponential penalty to the length. 1.0 means that the beam score is penalized by the sequence length. 0.0 means no penalty. Set to values < 0.0 in order to encourage the - model to generate longer sequences, to a value > 0.0 in order to encourage the model to produce shorter - sequences. + Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the + model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer + sequences. no_repeat_ngram_size (`int`, *optional*, defaults to 0): If set to int > 0, all ngrams of that size can only occur once. encoder_no_repeat_ngram_size (`int`, *optional*, defaults to 0): From edc9e165d82592e145e3bf9ed1b535433f8f232c Mon Sep 17 00:00:00 2001 From: Manan Dey Date: Sun, 1 May 2022 17:45:12 +0530 Subject: [PATCH 3/7] add mobilebert onnx config --- docs/source/en/serialization.mdx | 1 + src/transformers/models/mobilebert/__init__.py | 12 ++++++++++-- .../models/mobilebert/configuration_mobilebert.py | 14 ++++++++++++++ src/transformers/onnx/features.py | 11 +++++++++++ tests/onnx/test_onnx_v2.py | 1 + 5 files changed, 37 insertions(+), 2 deletions(-) diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index 4ae5c9a57ecbdd..b1918bf4609d84 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -67,6 +67,7 @@ Ready-made configurations include the following architectures: - M2M100 - Marian - mBART +- MobileBert - OpenAI GPT-2 - PLBart - RoBERTa diff --git a/src/transformers/models/mobilebert/__init__.py b/src/transformers/models/mobilebert/__init__.py index 505dabe1879198..b35fe8a9c11afa 100644 --- a/src/transformers/models/mobilebert/__init__.py +++ b/src/transformers/models/mobilebert/__init__.py @@ -22,7 +22,11 @@ _import_structure = { - "configuration_mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig"], + "configuration_mobilebert": [ + "MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", + "MobileBertConfig", + "MobileBertOnnxConfig", + ], "tokenization_mobilebert": ["MobileBertTokenizer"], } @@ -62,7 +66,11 @@ if TYPE_CHECKING: - from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig + from .configuration_mobilebert import ( + MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, + MobileBertConfig, + MobileBertOnnxConfig, + ) from .tokenization_mobilebert import MobileBertTokenizer if is_tokenizers_available(): diff --git a/src/transformers/models/mobilebert/configuration_mobilebert.py b/src/transformers/models/mobilebert/configuration_mobilebert.py index 27863235b3d7d8..2c49f99a12e790 100644 --- a/src/transformers/models/mobilebert/configuration_mobilebert.py +++ b/src/transformers/models/mobilebert/configuration_mobilebert.py @@ -13,8 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """ MobileBERT model configuration""" +from collections import OrderedDict +from typing import Mapping from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging @@ -165,3 +168,14 @@ def __init__( self.true_hidden_size = hidden_size self.classifier_dropout = classifier_dropout + + +class MobileBertOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ] + ) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index a4d3a49388d601..6a02377c5c3aaa 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -24,6 +24,7 @@ from ..models.m2m_100 import M2M100OnnxConfig from ..models.marian import MarianOnnxConfig from ..models.mbart import MBartOnnxConfig +from ..models.mobilebert import MobileBertOnnxConfig from ..models.roberta import RobertaOnnxConfig from ..models.roformer import RoFormerOnnxConfig from ..models.t5 import T5OnnxConfig @@ -192,6 +193,16 @@ class FeaturesManager: "question-answering", onnx_config_cls=CamembertOnnxConfig, ), + "mobilebert": supported_features_mapping( + "default", + "masked-lm", + "next-sentence-prediction", + "sequence-classification", + "multiple-choice", + "token-classification", + "question-answering", + onnx_config_cls=MobileBertOnnxConfig, + ), "convbert": supported_features_mapping( "default", "masked-lm", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index ea5a54763932c9..d9649b7589104b 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -180,6 +180,7 @@ def test_values_override(self): ("electra", "google/electra-base-generator"), ("roberta", "roberta-base"), ("roformer", "junnyu/roformer_chinese_base"), + ("mobilebert", "google/mobilebert-uncased"), ("xlm-roberta", "xlm-roberta-base"), ("layoutlm", "microsoft/layoutlm-base-uncased"), ("vit", "google/vit-base-patch16-224"), From 641d28a5c9f99968ec0ac185274098e7717bec43 Mon Sep 17 00:00:00 2001 From: Manan Dey Date: Tue, 3 May 2022 18:52:03 +0530 Subject: [PATCH 4/7] address suggestions --- src/transformers/__init__.py | 3 +++ src/transformers/models/auto/__init__.py | 2 ++ .../models/mobilebert/configuration_mobilebert.py | 10 ++++++++-- src/transformers/onnx/features.py | 5 +++++ 4 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 5695ff57c53b07..fdc06e75488ef8 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1770,6 +1770,7 @@ "TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", "TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", "TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", + "TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", "TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "TF_MODEL_FOR_VISION_2_SEQ_MAPPING", @@ -1785,6 +1786,7 @@ "TFAutoModelForSeq2SeqLM", "TFAutoModelForSequenceClassification", "TFAutoModelForSpeechSeq2Seq", + "TFAutoModelForNextSentencePrediction", "TFAutoModelForTableQuestionAnswering", "TFAutoModelForTokenClassification", "TFAutoModelForVision2Seq", @@ -3930,6 +3932,7 @@ TFAutoModelForImageClassification, TFAutoModelForMaskedLM, TFAutoModelForMultipleChoice, + TFAutoModelForNextSentencePrediction, TFAutoModelForPreTraining, TFAutoModelForQuestionAnswering, TFAutoModelForSeq2SeqLM, diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 6dace993cd743b..fa34a11964b041 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -108,6 +108,7 @@ "TFAutoModelForSeq2SeqLM", "TFAutoModelForSequenceClassification", "TFAutoModelForSpeechSeq2Seq", + "TFAutoModelForNextSentencePrediction", "TFAutoModelForTableQuestionAnswering", "TFAutoModelForTokenClassification", "TFAutoModelForVision2Seq", @@ -224,6 +225,7 @@ TFAutoModelForImageClassification, TFAutoModelForMaskedLM, TFAutoModelForMultipleChoice, + TFAutoModelForNextSentencePrediction, TFAutoModelForPreTraining, TFAutoModelForQuestionAnswering, TFAutoModelForSeq2SeqLM, diff --git a/src/transformers/models/mobilebert/configuration_mobilebert.py b/src/transformers/models/mobilebert/configuration_mobilebert.py index 2c49f99a12e790..73b8844ed763df 100644 --- a/src/transformers/models/mobilebert/configuration_mobilebert.py +++ b/src/transformers/models/mobilebert/configuration_mobilebert.py @@ -170,12 +170,18 @@ def __init__( self.classifier_dropout = classifier_dropout +# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Bert->MobileBert class MobileBertOnnxConfig(OnnxConfig): @property def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} return OrderedDict( [ - ("input_ids", {0: "batch", 1: "sequence"}), - ("attention_mask", {0: "batch", 1: "sequence"}), + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ("token_type_ids", dynamic_axis), ] ) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 6a02377c5c3aaa..4a70bf1098ce3b 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -44,6 +44,7 @@ AutoModelForMaskedImageModeling, AutoModelForMaskedLM, AutoModelForMultipleChoice, + AutoModelForNextSentencePrediction, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, @@ -55,6 +56,7 @@ TFAutoModelForCausalLM, TFAutoModelForMaskedLM, TFAutoModelForMultipleChoice, + TFAutoModelForNextSentencePrediction, TFAutoModelForQuestionAnswering, TFAutoModelForSeq2SeqLM, TFAutoModelForSequenceClassification, @@ -108,6 +110,7 @@ class FeaturesManager: "question-answering": AutoModelForQuestionAnswering, "image-classification": AutoModelForImageClassification, "masked-im": AutoModelForMaskedImageModeling, + "next-sentence-prediction": AutoModelForNextSentencePrediction, } if is_tf_available(): _TASKS_TO_TF_AUTOMODELS = { @@ -119,6 +122,7 @@ class FeaturesManager: "token-classification": TFAutoModelForTokenClassification, "multiple-choice": TFAutoModelForMultipleChoice, "question-answering": TFAutoModelForQuestionAnswering, + "next-sentence-prediction": TFAutoModelForNextSentencePrediction, } # Set of model topologies we support associated to the features supported by each topology and the factory @@ -162,6 +166,7 @@ class FeaturesManager: "multiple-choice", "token-classification", "question-answering", + "next-sentence-prediction", onnx_config_cls=BertOnnxConfig, ), "big-bird": supported_features_mapping( From 31c18f51fc3d7441b41fbc5fb8ea366fcde29530 Mon Sep 17 00:00:00 2001 From: Manan Dey Date: Tue, 3 May 2022 19:42:10 +0530 Subject: [PATCH 5/7] Update auto.mdx --- docs/source/en/model_doc/auto.mdx | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/en/model_doc/auto.mdx b/docs/source/en/model_doc/auto.mdx index d941b00318b086..4a4b59e9c16f74 100644 --- a/docs/source/en/model_doc/auto.mdx +++ b/docs/source/en/model_doc/auto.mdx @@ -194,6 +194,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its [[autodoc]] TFAutoModelForMultipleChoice +## TFAutoModelForNextSentencePrediction + +[[autodoc]] TFAutoModelForNextSentencePrediction + ## TFAutoModelForTableQuestionAnswering [[autodoc]] TFAutoModelForTableQuestionAnswering From aa3dfc52155e23539995c67d0bc0b3c347073f80 Mon Sep 17 00:00:00 2001 From: Manan Dey Date: Tue, 3 May 2022 19:52:54 +0530 Subject: [PATCH 6/7] Update __init__.py --- src/transformers/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index fdc06e75488ef8..8f7c0ee6e17f18 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1770,7 +1770,6 @@ "TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", "TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", "TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", - "TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", "TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "TF_MODEL_FOR_VISION_2_SEQ_MAPPING", From 8a3c1f8f0b00885e726d13923382d9f6506318d0 Mon Sep 17 00:00:00 2001 From: Manan Dey Date: Wed, 4 May 2022 19:39:27 +0530 Subject: [PATCH 7/7] Update features.py --- src/transformers/onnx/features.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index c2eb7e7829d450..516288e8d6a74f 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -198,16 +198,6 @@ class FeaturesManager: "question-answering", onnx_config_cls=CamembertOnnxConfig, ), - "mobilebert": supported_features_mapping( - "default", - "masked-lm", - "next-sentence-prediction", - "sequence-classification", - "multiple-choice", - "token-classification", - "question-answering", - onnx_config_cls=MobileBertOnnxConfig, - ), "convbert": supported_features_mapping( "default", "masked-lm", @@ -320,6 +310,16 @@ class FeaturesManager: "question-answering", onnx_config_cls=MBartOnnxConfig, ), + "mobilebert": supported_features_mapping( + "default", + "masked-lm", + "next-sentence-prediction", + "sequence-classification", + "multiple-choice", + "token-classification", + "question-answering", + onnx_config_cls=MobileBertOnnxConfig, + ), "m2m-100": supported_features_mapping( "default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=M2M100OnnxConfig ),