diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index c7620752884..c5f438b170d 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -12,5 +12,4 @@ specific language governing permissions and limitations under the License. # Overview -🤗 Optimum handles the export of PyTorch or TensorFlow models to ONNX in the `exporters.onnx` module. It provides -classes, functions, and a command line interface to perform the export easily. +🤗 Optimum handles the export of PyTorch or TensorFlow models to ONNX in the `exporters.onnx` module. It provides classes, functions, and a command line interface to perform the export easily. diff --git a/docs/source/exporters/onnx/usage_guides/contribute.mdx b/docs/source/exporters/onnx/usage_guides/contribute.mdx index e7fa515d7dc..798d2f39a2e 100644 --- a/docs/source/exporters/onnx/usage_guides/contribute.mdx +++ b/docs/source/exporters/onnx/usage_guides/contribute.mdx @@ -96,7 +96,7 @@ Once you have implemented an ONNX configuration, you can instantiate it by provi ```python >>> from transformers import AutoConfig - +>>> from optimum.exporters.onnx.model_configs import BertOnnxConfig >>> config = AutoConfig.from_pretrained("bert-base-uncased") >>> onnx_config = BertOnnxConfig(config) ``` @@ -182,10 +182,10 @@ This function expects the ONNX configuration, along with the base model, and the >>> base_model = AutoModel.from_pretrained("bert-base-uncased") >>> onnx_path = Path("model.onnx") ->>> onnx_config_constructor = TasksManager.get_exporter_config_constructor(base_model, "onnx") +>>> onnx_config_constructor = TasksManager.get_exporter_config_constructor("onnx", base_model) >>> onnx_config = onnx_config_constructor(base_model.config) ->>> onnx_inputs, onnx_outputs = export(base_model, onnx_config, onnx_config.DEFAULT_ONNX_OPSET, onnx_path) +>>> onnx_inputs, onnx_outputs = export(base_model, onnx_config, onnx_path, onnx_config.DEFAULT_ONNX_OPSET) ``` The `onnx_inputs` and `onnx_outputs` returned by the `export()` function are lists of the keys defined in the [`~optimum.exporters.onnx.OnnxConfig.inputs`] diff --git a/docs/source/exporters/onnx/usage_guides/export_a_model.mdx b/docs/source/exporters/onnx/usage_guides/export_a_model.mdx index 1e75c8d8d61..a32278bf81c 100644 --- a/docs/source/exporters/onnx/usage_guides/export_a_model.mdx +++ b/docs/source/exporters/onnx/usage_guides/export_a_model.mdx @@ -16,13 +16,13 @@ specific language governing permissions and limitations under the License. Exporting a model to ONNX is as simple as -``` +```bash optimum-cli export onnx --model gpt2 gpt2_onnx/ ``` Check out the help for more options: -``` +```bash optimum-cli export onnx --help ``` @@ -50,7 +50,6 @@ graph optimization and quantization. Check the `optimum.onnxruntime` subpackage - 🤗 Optimum provides support for the ONNX export by leveraging configuration objects. These configuration objects come ready made for a number of model architectures, and are designed to be easily extendable to other architectures. @@ -71,19 +70,31 @@ The Optimum ONNX export can be used through Optimum command-line: ```bash optimum-cli export onnx --help -usage: Hugging Face Optimum ONNX exporter [-h] -m MODEL [--task TASK] [--opset OPSET] [--atol ATOL] [--framework {pt,tf}] [--pad_token_id PAD_TOKEN_ID] [--cache_dir CACHE_DIR] output - -positional arguments: - output Path indicating the directory where to store generated ONNX model. +usage: optimum-cli [] export onnx [-h] -m MODEL [--task TASK] [--for-ort] [--device DEVICE] [--opset OPSET] [--atol ATOL] + [--framework {pt,tf}] [--pad_token_id PAD_TOKEN_ID] [--cache_dir CACHE_DIR] [--batch_size BATCH_SIZE] + [--sequence_length SEQUENCE_LENGTH] [--num_choices NUM_CHOICES] [--width WIDTH] [--height HEIGHT] + [--num_channels NUM_CHANNELS] [--feature_size FEATURE_SIZE] [--nb_max_frames NB_MAX_FRAMES] + [--audio_sequence_length AUDIO_SEQUENCE_LENGTH] + output -optional arguments: - -h, --help show this help message and exit +Required arguments: -m MODEL, --model MODEL Model ID on huggingface.co or path on disk to load model from. - --task TASK The type of task to export the model with. - --opset OPSET ONNX opset version to export the model with. - --atol ATOL Absolute difference tolerance when validating the model. - --framework {pt,tf} The framework to use for the ONNX export. If not provided, will attempt to use the local checkpoint's original framework or what is available in the environment. + output Path indicating the directory where to store generated ONNX model. + +Optional arguments: + --task TASK The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on + the model, but are among: ['default', 'masked-lm', 'causal-lm', 'seq2seq-lm', 'sequence-classification', 'token-classification', + 'multiple-choice', 'object-detection', 'question-answering', 'image-classification', 'image-segmentation', 'masked-im', 'semantic- + segmentation', 'speech2seq-lm', 'stable-diffusion']. For decoder models, use `xxx-with-past` to export the model using past key + values in the decoder. + --for-ort This exports models ready to be run with Optimum's ORTModel. Useful for encoder-decoder models forconditional generation. If + enabled the encoder and decoder of the model are exported separately. + --device DEVICE The device to use to do the export. Defaults to "cpu". + --opset OPSET If specified, ONNX opset version to export the model with. Otherwise, the default opset will be used. + --atol ATOL If specified, the absolute difference tolerance when validating the model. Otherwise, the default atol for the model will be used. + --framework {pt,tf} The framework to use for the ONNX export. If not provided, will attempt to use the local checkpoint's original framework or what is + available in the environment. --pad_token_id PAD_TOKEN_ID This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it. --cache_dir CACHE_DIR @@ -161,7 +172,7 @@ It is also possible to export the model to ONNX directly from the `ORTModelForQu >>> model = ORTModelForQuestionAnswering.from_pretrained("distilbert-base-uncased-distilled-squad", from_transformers=True) ``` -For more information, check the `optimum.onnxrutime` documentation [page on this topic](/onnxruntime/overview). +For more information, check the `optimum.onnxruntime` documentation [page on this topic](/onnxruntime/overview). @@ -172,6 +183,38 @@ organization](https://huggingface.co/keras-io) as follows: optimum-cli export onnx --model keras-io/transformers-qa distilbert_base_cased_squad_onnx/ ``` +### Exporting a model to be used with Optimum's ORTModel + +Models exported through `optimum-cli export onnx` can be used directly in [`~onnxruntime.ORTModel`] by passing the parameter `--for-ort`. This is especially useful for encoder-decoder models, where in this case the export will split the encoder and decoder into two `.onnx` files, as the encoder is usually only run once while the decoder may be run several times in autogenerative tasks. + +### Exporting a model using past keys/values in the decoder + +When exporting a decoder model used for generation, it can be useful to encapsulate in the exported ONNX the [reuse of past keys and values](https://discuss.huggingface.co/t/what-is-the-purpose-of-use-cache-in-decoder/958/2). This allows to avoid recomputing the same intermediate activations during the generation. + +In the ONNX export, the past keys/values are reused by default. This behavior corresponds to `--task seq2seq-lm-with-past`, `--task causal-lm-with-past`, or `--task speech2seq-lm-with-past`. If for any purpose you would like to disable the export with past keys/values reuse, passing explicitly to `optimum-cli export onnx` the task `seq2seq-lm`, `causal-lm` or `speech2seq-lm` is required. + +A model exported using past key/values can be reused directly into Optimum's [`~onnxruntime.ORTModel`]: + +```bash +optimum-cli export onnx --model gpt2 --for-ort --task causal-lm-with-past gpt2_onnx/ +``` + +and + +```python +from transformers import AutoTokenizer +from optimum.onnxruntime import ORTModelForCausalLM + +tokenizer = AutoTokenizer.from_pretrained("/path/to/gpt2_onnx/") +model = ORTModelForCausalLM.from_pretrained("/path/to/gpt2_onnx/") + +inputs = tokenizer("My name is Arthur and I live in", return_tensors="pt") + +gen_tokens = model.generate(**inputs) +print(tokenizer.batch_decode(gen_tokens)) +# prints ['My name is Arthur and I live in the United States of America. I am a member of the'] +``` + ## Selecting a task Specifying a `--task` should not be necessary in most cases when exporting from a model on the Hugging Face Hub. diff --git a/docs/source/quicktour.mdx b/docs/source/quicktour.mdx index f4423d08b55..f1e02204fc8 100644 --- a/docs/source/quicktour.mdx +++ b/docs/source/quicktour.mdx @@ -207,13 +207,13 @@ The Optimum library handles out of the box the ONNX export of Transformers and D Exporting a model to ONNX is as simple as -``` +```bash optimum-cli export onnx --model gpt2 gpt2_onnx/ ``` Check out the help for more options: -``` +```bash optimum-cli export onnx --help ``` diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 8f5972b771e..33511bbfccb 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -149,6 +149,7 @@ class OnnxConfig(ExportConfig, ABC): "end_logits": {0: "batch_size", 1: "sequence_length"}, } ), + "visual-question-answering": OrderedDict({"logits": {0: "batch_size"}}), "semantic-segmentation": OrderedDict({"logits": {0: "batch_size", 1: "num_labels", 2: "height", 3: "width"}}), "seq2seq-lm": OrderedDict( { @@ -187,6 +188,7 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGene forces the other generators to use the same batch size, meaning they will all produce inputs of the same batch size. Override this method for custom behavior. """ + first_inputs_gen = self.DUMMY_INPUT_GENERATOR_CLASSES[0](self.task, self._normalized_config, **kwargs) dummy_inputs_generators = [ cls_(self.task, self._normalized_config, **kwargs) for cls_ in self.DUMMY_INPUT_GENERATOR_CLASSES[1:] diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index 8281ca51dc9..b65f4bae21d 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -346,6 +346,7 @@ def export_pytorch( input_shapes = {} # will use the defaults from DEFAULT_DUMMY_SHAPES # Check that inputs match, and order them properly + dummy_inputs = config.generate_dummy_inputs(framework="pt", **input_shapes) device = torch.device(device) if device.type == "cuda" and torch.cuda.is_available(): diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 23b516349d3..ab767c4fd86 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -27,6 +27,7 @@ DummyTextInputGenerator, DummyTimestepInputGenerator, DummyVisionInputGenerator, + DummyVisualBertInputGenerator, NormalizedConfig, NormalizedSeq2SeqConfig, NormalizedTextAndVisionConfig, @@ -71,6 +72,29 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]: } +class VisualBertOnnxConfig(BertOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedTextAndVisionConfig + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyVisualBertInputGenerator) + ATOL_FOR_VALIDATION = 1e-4 + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + dynamic_axis = {0: "batch_size", 1: "sequence_length"} + result = { + "input_ids": dynamic_axis, + "attention_mask": dynamic_axis, + "token_type_ids": dynamic_axis, + "visual_embeds": {0: "batch_size", 1: "visual_seq_length", 2: "visual_embedding_dim"}, + "visual_token_type_ids": {0: "batch_size", 1: "visual_seq_length"}, + "visual_attention_mask": {0: "batch_size", 1: "visual_seq_length"}, + } + if self.task == "region-to-phrase-alignment": + result.update({"region_to_phrase_position": {0: "batch_size", 1: "total_sequence_length"}}) + return result + + + + class AlbertOnnxConfig(BertOnnxConfig): pass diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 29742e67f47..b359a26f81f 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -168,6 +168,15 @@ class TasksManager: "question-answering", onnx="BertOnnxConfig", ), + "visual-bert": supported_tasks_mapping( + "default", + "visual-question-answering", + + "multiple-choice", + "visual-reasoning", + "region-to-phrase-alignment", + onnx="VisualBertOnnxConfig", + ), "big-bird": supported_tasks_mapping( "default", "masked-lm", diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index 5797957732c..8ae22a3beb7 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -38,6 +38,7 @@ DummyTimestepInputGenerator, DummyTrainingLabelsInputGenerator, DummyVisionInputGenerator, + DummyVisualBertInputGenerator, ) from .normalized_config import ( NormalizedConfig, diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 2b08b99258b..57f2200a148 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -53,6 +53,7 @@ def wrapper(*args, **kwargs): "width": 64, "height": 64, "num_channels": 3, + "num_of_detection_patches": 5, # audio "feature_size": 80, "nb_max_frames": 3000, @@ -627,3 +628,44 @@ def generate(self, input_name: str, framework: str = "pt"): shape = [self.batch_size] return self.random_int_tensor(shape, max_value=max_value, framework=framework) + + +class DummyVisualBertInputGenerator(DummyTextInputGenerator): + SUPPORTED_INPUT_NAMES = ("visual_embeds", "visual_token_type_ids", "visual_attention_mask") + # todo: see how to add ,"region_to_phrase_position" since that input name raises error (ValueError: Config dummy inputs are not a subset of the model inputs) + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + num_of_detection_patches: int = DEFAULT_DUMMY_SHAPES["num_of_detection_patches"], + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + num_choices: int = DEFAULT_DUMMY_SHAPES["num_choices"], + ): + + super().__init__(task, normalized_config) + self.num_of_detection_patches = num_of_detection_patches + + def generate(self, input_name: str, framework: str = "pt"): + visual_embedding_dim = None + # TODO maybe the following should be checked with the model checkpoint_path and checking existance of for example "vqa" substring is better + if self.task in ["visual-question-answering", "region-to-phrase-alignment"]: + visual_embedding_dim = 2048 + elif self.task == "multiple-choice": + visual_embedding_dim = 512 + elif self.task == "visual-reasoning": + visual_embedding_dim = 1024 + + shape = [self.batch_size, self.num_of_detection_patches, visual_embedding_dim] + visual_embeddings = self.random_float_tensor(shape, framework=framework) + if self.task == "multiple-choice": + visual_embeddings.expand(1, 2, *shape) + + if input_name == "visual_embeds": + return visual_embeddings + elif input_name == "visual_attention_mask": + return torch.ones(visual_embeddings.shape[:-1], dtype=torch.long) + elif input_name == "visual_token_type_ids": + return torch.ones(visual_embeddings.shape[:-1], dtype=torch.float) + elif input_name == "region_to_phrase_position": + return torch.ones((1, +self.sequence_length + visual_embeddings.shape[-2]))