Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VisualBertOnnx #663

Closed
wants to merge 8 commits into from
3 changes: 1 addition & 2 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,4 @@ specific language governing permissions and limitations under the License.

# Overview

🤗 Optimum handles the export of PyTorch or TensorFlow models to ONNX in the `exporters.onnx` module. It provides
classes, functions, and a command line interface to perform the export easily.
🤗 Optimum handles the export of PyTorch or TensorFlow models to ONNX in the `exporters.onnx` module. It provides classes, functions, and a command line interface to perform the export easily.
6 changes: 3 additions & 3 deletions docs/source/exporters/onnx/usage_guides/contribute.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ Once you have implemented an ONNX configuration, you can instantiate it by provi

```python
>>> from transformers import AutoConfig

>>> from optimum.exporters.onnx.model_configs import BertOnnxConfig
>>> config = AutoConfig.from_pretrained("bert-base-uncased")
>>> onnx_config = BertOnnxConfig(config)
```
Expand Down Expand Up @@ -182,10 +182,10 @@ This function expects the ONNX configuration, along with the base model, and the
>>> base_model = AutoModel.from_pretrained("bert-base-uncased")

>>> onnx_path = Path("model.onnx")
>>> onnx_config_constructor = TasksManager.get_exporter_config_constructor(base_model, "onnx")
>>> onnx_config_constructor = TasksManager.get_exporter_config_constructor("onnx", base_model)
>>> onnx_config = onnx_config_constructor(base_model.config)

>>> onnx_inputs, onnx_outputs = export(base_model, onnx_config, onnx_config.DEFAULT_ONNX_OPSET, onnx_path)
>>> onnx_inputs, onnx_outputs = export(base_model, onnx_config, onnx_path, onnx_config.DEFAULT_ONNX_OPSET)
```

The `onnx_inputs` and `onnx_outputs` returned by the `export()` function are lists of the keys defined in the [`~optimum.exporters.onnx.OnnxConfig.inputs`]
Expand Down
71 changes: 57 additions & 14 deletions docs/source/exporters/onnx/usage_guides/export_a_model.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ specific language governing permissions and limitations under the License.

Exporting a model to ONNX is as simple as

```
```bash
optimum-cli export onnx --model gpt2 gpt2_onnx/
```

Check out the help for more options:

```
```bash
optimum-cli export onnx --help
```

Expand Down Expand Up @@ -50,7 +50,6 @@ graph optimization and quantization. Check the `optimum.onnxruntime` subpackage

</Tip>


🤗 Optimum provides support for the ONNX export by leveraging configuration objects.
These configuration objects come ready made for a number of model architectures, and are
designed to be easily extendable to other architectures.
Expand All @@ -71,19 +70,31 @@ The Optimum ONNX export can be used through Optimum command-line:
```bash
optimum-cli export onnx --help

usage: Hugging Face Optimum ONNX exporter [-h] -m MODEL [--task TASK] [--opset OPSET] [--atol ATOL] [--framework {pt,tf}] [--pad_token_id PAD_TOKEN_ID] [--cache_dir CACHE_DIR] output

positional arguments:
output Path indicating the directory where to store generated ONNX model.
usage: optimum-cli <command> [<args>] export onnx [-h] -m MODEL [--task TASK] [--for-ort] [--device DEVICE] [--opset OPSET] [--atol ATOL]
[--framework {pt,tf}] [--pad_token_id PAD_TOKEN_ID] [--cache_dir CACHE_DIR] [--batch_size BATCH_SIZE]
[--sequence_length SEQUENCE_LENGTH] [--num_choices NUM_CHOICES] [--width WIDTH] [--height HEIGHT]
[--num_channels NUM_CHANNELS] [--feature_size FEATURE_SIZE] [--nb_max_frames NB_MAX_FRAMES]
[--audio_sequence_length AUDIO_SEQUENCE_LENGTH]
output

optional arguments:
-h, --help show this help message and exit
Required arguments:
-m MODEL, --model MODEL
Model ID on huggingface.co or path on disk to load model from.
--task TASK The type of task to export the model with.
--opset OPSET ONNX opset version to export the model with.
--atol ATOL Absolute difference tolerance when validating the model.
--framework {pt,tf} The framework to use for the ONNX export. If not provided, will attempt to use the local checkpoint's original framework or what is available in the environment.
output Path indicating the directory where to store generated ONNX model.

Optional arguments:
--task TASK The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on
the model, but are among: ['default', 'masked-lm', 'causal-lm', 'seq2seq-lm', 'sequence-classification', 'token-classification',
'multiple-choice', 'object-detection', 'question-answering', 'image-classification', 'image-segmentation', 'masked-im', 'semantic-
segmentation', 'speech2seq-lm', 'stable-diffusion']. For decoder models, use `xxx-with-past` to export the model using past key
values in the decoder.
--for-ort This exports models ready to be run with Optimum's ORTModel. Useful for encoder-decoder models forconditional generation. If
enabled the encoder and decoder of the model are exported separately.
--device DEVICE The device to use to do the export. Defaults to "cpu".
--opset OPSET If specified, ONNX opset version to export the model with. Otherwise, the default opset will be used.
--atol ATOL If specified, the absolute difference tolerance when validating the model. Otherwise, the default atol for the model will be used.
--framework {pt,tf} The framework to use for the ONNX export. If not provided, will attempt to use the local checkpoint's original framework or what is
available in the environment.
--pad_token_id PAD_TOKEN_ID
This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it.
--cache_dir CACHE_DIR
Expand Down Expand Up @@ -161,7 +172,7 @@ It is also possible to export the model to ONNX directly from the `ORTModelForQu
>>> model = ORTModelForQuestionAnswering.from_pretrained("distilbert-base-uncased-distilled-squad", from_transformers=True)
```

For more information, check the `optimum.onnxrutime` documentation [page on this topic](/onnxruntime/overview).
For more information, check the `optimum.onnxruntime` documentation [page on this topic](/onnxruntime/overview).

</Tip>

Expand All @@ -172,6 +183,38 @@ organization](https://huggingface.co/keras-io) as follows:
optimum-cli export onnx --model keras-io/transformers-qa distilbert_base_cased_squad_onnx/
```

### Exporting a model to be used with Optimum's ORTModel

Models exported through `optimum-cli export onnx` can be used directly in [`~onnxruntime.ORTModel`] by passing the parameter `--for-ort`. This is especially useful for encoder-decoder models, where in this case the export will split the encoder and decoder into two `.onnx` files, as the encoder is usually only run once while the decoder may be run several times in autogenerative tasks.

### Exporting a model using past keys/values in the decoder

When exporting a decoder model used for generation, it can be useful to encapsulate in the exported ONNX the [reuse of past keys and values](https://discuss.huggingface.co/t/what-is-the-purpose-of-use-cache-in-decoder/958/2). This allows to avoid recomputing the same intermediate activations during the generation.

In the ONNX export, the past keys/values are reused by default. This behavior corresponds to `--task seq2seq-lm-with-past`, `--task causal-lm-with-past`, or `--task speech2seq-lm-with-past`. If for any purpose you would like to disable the export with past keys/values reuse, passing explicitly to `optimum-cli export onnx` the task `seq2seq-lm`, `causal-lm` or `speech2seq-lm` is required.

A model exported using past key/values can be reused directly into Optimum's [`~onnxruntime.ORTModel`]:

```bash
optimum-cli export onnx --model gpt2 --for-ort --task causal-lm-with-past gpt2_onnx/
```

and

```python
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("/path/to/gpt2_onnx/")
model = ORTModelForCausalLM.from_pretrained("/path/to/gpt2_onnx/")

inputs = tokenizer("My name is Arthur and I live in", return_tensors="pt")

gen_tokens = model.generate(**inputs)
print(tokenizer.batch_decode(gen_tokens))
# prints ['My name is Arthur and I live in the United States of America. I am a member of the']
```

## Selecting a task

Specifying a `--task` should not be necessary in most cases when exporting from a model on the Hugging Face Hub.
Expand Down
4 changes: 2 additions & 2 deletions docs/source/quicktour.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,13 @@ The Optimum library handles out of the box the ONNX export of Transformers and D

Exporting a model to ONNX is as simple as

```
```bash
optimum-cli export onnx --model gpt2 gpt2_onnx/
```

Check out the help for more options:

```
```bash
optimum-cli export onnx --help
```

Expand Down
2 changes: 2 additions & 0 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class OnnxConfig(ExportConfig, ABC):
"end_logits": {0: "batch_size", 1: "sequence_length"},
}
),
"visual-question-answering": OrderedDict({"logits": {0: "batch_size"}}),
"semantic-segmentation": OrderedDict({"logits": {0: "batch_size", 1: "num_labels", 2: "height", 3: "width"}}),
"seq2seq-lm": OrderedDict(
{
Expand Down Expand Up @@ -187,6 +188,7 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGene
forces the other generators to use the same batch size, meaning they will all produce inputs of the same batch
size. Override this method for custom behavior.
"""

first_inputs_gen = self.DUMMY_INPUT_GENERATOR_CLASSES[0](self.task, self._normalized_config, **kwargs)
dummy_inputs_generators = [
cls_(self.task, self._normalized_config, **kwargs) for cls_ in self.DUMMY_INPUT_GENERATOR_CLASSES[1:]
Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ def export_pytorch(
input_shapes = {} # will use the defaults from DEFAULT_DUMMY_SHAPES

# Check that inputs match, and order them properly

dummy_inputs = config.generate_dummy_inputs(framework="pt", **input_shapes)
device = torch.device(device)
if device.type == "cuda" and torch.cuda.is_available():
Expand Down
24 changes: 24 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
DummyTextInputGenerator,
DummyTimestepInputGenerator,
DummyVisionInputGenerator,
DummyVisualBertInputGenerator,
NormalizedConfig,
NormalizedSeq2SeqConfig,
NormalizedTextAndVisionConfig,
Expand Down Expand Up @@ -71,6 +72,29 @@ def inputs(self) -> Mapping[str, Mapping[int, str]]:
}


class VisualBertOnnxConfig(BertOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextAndVisionConfig
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyVisualBertInputGenerator)
ATOL_FOR_VALIDATION = 1e-4

@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
dynamic_axis = {0: "batch_size", 1: "sequence_length"}
result = {
"input_ids": dynamic_axis,
"attention_mask": dynamic_axis,
"token_type_ids": dynamic_axis,
"visual_embeds": {0: "batch_size", 1: "visual_seq_length", 2: "visual_embedding_dim"},
"visual_token_type_ids": {0: "batch_size", 1: "visual_seq_length"},
"visual_attention_mask": {0: "batch_size", 1: "visual_seq_length"},
}
if self.task == "region-to-phrase-alignment":
result.update({"region_to_phrase_position": {0: "batch_size", 1: "total_sequence_length"}})
return result




class AlbertOnnxConfig(BertOnnxConfig):
pass

Expand Down
9 changes: 9 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,15 @@ class TasksManager:
"question-answering",
onnx="BertOnnxConfig",
),
"visual-bert": supported_tasks_mapping(
"default",
"visual-question-answering",

"multiple-choice",
"visual-reasoning",
"region-to-phrase-alignment",
onnx="VisualBertOnnxConfig",
),
"big-bird": supported_tasks_mapping(
"default",
"masked-lm",
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
DummyTimestepInputGenerator,
DummyTrainingLabelsInputGenerator,
DummyVisionInputGenerator,
DummyVisualBertInputGenerator,
)
from .normalized_config import (
NormalizedConfig,
Expand Down
42 changes: 42 additions & 0 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def wrapper(*args, **kwargs):
"width": 64,
"height": 64,
"num_channels": 3,
"num_of_detection_patches": 5,
# audio
"feature_size": 80,
"nb_max_frames": 3000,
Expand Down Expand Up @@ -627,3 +628,44 @@ def generate(self, input_name: str, framework: str = "pt"):
shape = [self.batch_size]

return self.random_int_tensor(shape, max_value=max_value, framework=framework)


class DummyVisualBertInputGenerator(DummyTextInputGenerator):
SUPPORTED_INPUT_NAMES = ("visual_embeds", "visual_token_type_ids", "visual_attention_mask")
# todo: see how to add ,"region_to_phrase_position" since that input name raises error (ValueError: Config dummy inputs are not a subset of the model inputs)
def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
num_of_detection_patches: int = DEFAULT_DUMMY_SHAPES["num_of_detection_patches"],
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
num_choices: int = DEFAULT_DUMMY_SHAPES["num_choices"],
):

super().__init__(task, normalized_config)
self.num_of_detection_patches = num_of_detection_patches

def generate(self, input_name: str, framework: str = "pt"):
visual_embedding_dim = None
# TODO maybe the following should be checked with the model checkpoint_path and checking existance of for example "vqa" substring is better
if self.task in ["visual-question-answering", "region-to-phrase-alignment"]:
visual_embedding_dim = 2048
elif self.task == "multiple-choice":
visual_embedding_dim = 512
elif self.task == "visual-reasoning":
visual_embedding_dim = 1024

shape = [self.batch_size, self.num_of_detection_patches, visual_embedding_dim]
visual_embeddings = self.random_float_tensor(shape, framework=framework)
if self.task == "multiple-choice":
visual_embeddings.expand(1, 2, *shape)

if input_name == "visual_embeds":
return visual_embeddings
elif input_name == "visual_attention_mask":
return torch.ones(visual_embeddings.shape[:-1], dtype=torch.long)
elif input_name == "visual_token_type_ids":
return torch.ones(visual_embeddings.shape[:-1], dtype=torch.float)
elif input_name == "region_to_phrase_position":
return torch.ones((1, +self.sequence_length + visual_embeddings.shape[-2]))