Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize onnx quantsim init data type inference #3747

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from pathlib import Path
import os
from typing import Dict, List, Union, Tuple, Optional
import itertools
import json
import warnings
import numpy as np
Expand All @@ -66,7 +67,8 @@
from aimet_onnx.meta.connectedgraph import ConnectedGraph
from aimet_onnx.qc_quantize_op import QcQuantizeOp, OpMode, TensorQuantizerParams, GroupedBlockQuantizeDequantize
from aimet_onnx.quantsim_config.quantsim_config import QuantSimConfigurator
from aimet_onnx.utils import make_dummy_input, add_hook_to_get_activation, remove_activation_hooks
from aimet_onnx.utils import make_dummy_input, save_model_with_external_weights, add_hook_to_get_activation, \
remove_activation_hooks

logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)

Expand Down Expand Up @@ -270,7 +272,11 @@ def _get_activations_to_quantize(self, dummy_input: Dict[str, np.ndarray]):

:param dummy_input: Sample input to be run through the model
"""
self.fill_activation_dtypes(dummy_input)
try:
self.activation_dtypes = self._infer_activation_dtypes()
except onnx.shape_inference.InferenceError:
self.activation_dtypes = self._observe_activation_dtypes(dummy_input)

self.input_name_to_nodes = self.model.input_name_to_nodes()
self.output_name_to_node = self.model.output_name_to_node()

Expand Down Expand Up @@ -366,9 +372,32 @@ def _check_matmul_add_patten(self, node: onnx.NodeProto) -> bool:
return True
return False

def fill_activation_dtypes(self, dummy_input: Dict[str, np.ndarray]):
def _infer_activation_dtypes(self):
"""
Get the data type for each activation through shape inference
"""
if self.model.model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF:
with tempfile.TemporaryDirectory(dir=self._path) as tempdir:
save_path = os.path.join(tempdir, "inferred_model.onnx")
save_model_with_external_weights(self.model.model, save_path, location=Path(save_path).name + ".data")
onnx.shape_inference.infer_shapes_path(save_path)
# Do not load the weights for the shape inference model, we only need to access the graph's `value_info`
inferred_model = onnx.load(save_path, load_external_data=False)
else:
inferred_model = onnx.shape_inference.infer_shapes(self.model.model)

activation_dtypes = {}
for val_info in itertools.chain(inferred_model.graph.value_info,
inferred_model.graph.input,
inferred_model.graph.output):
act_name = val_info.name
dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[val_info.type.tensor_type.elem_type]
activation_dtypes[act_name] = dtype
return activation_dtypes

def _observe_activation_dtypes(self, dummy_input: Dict[str, np.ndarray]):
"""
Get the data type for each activation
Get the data type for each activation by returning all activations

:param dummy_input: Sample input to run through the model
"""
Expand All @@ -379,11 +408,14 @@ def fill_activation_dtypes(self, dummy_input: Dict[str, np.ndarray]):
sess = QuantizationSimModel.build_session(self.model.model, ['CPUExecutionProvider'],
user_onnx_libs=self._user_onnx_libs, path=self._path)
outputs = sess.run(None, dummy_input)

activation_dtypes = {}
for idx in range(len(self.model.graph().output)):
act_name = self.model.graph().output[idx].name
dtype = outputs[idx].dtype
self.activation_dtypes[act_name] = dtype
activation_dtypes[act_name] = dtype
remove_activation_hooks(self.model.model, hooks)
return activation_dtypes

def _add_quantization_nodes(self):
"""
Expand Down Expand Up @@ -524,8 +556,7 @@ def build_session(model: onnx.ModelProto, providers: List, user_onnx_libs: List[
output_path = os.path.join(path, 'model.onnx')
if save_as_external_data:
# Note: Saving as external data mutates the saved model, removing all initializer data
onnx.save_model(model, output_path, save_as_external_data=True, location=Path(output_path).name + ".data")
onnx.load_external_data_for_model(model, base_dir=path)
save_model_with_external_weights(model, output_path, location=Path(output_path).name + ".data")

path_or_bytes = output_path if save_as_external_data else model.SerializeToString()
session = InferenceSession(
Expand Down Expand Up @@ -778,8 +809,7 @@ def export(self, path: str, filename_prefix: str):
self.remove_quantization_nodes()
if self.model.model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF:
# Note: Saving as external data mutates the saved model, removing all initializer data
self.model.save_model_to_file(os.path.join(path, filename_prefix) + '.onnx', use_external_data_format=True)
onnx.load_external_data_for_model(self.model.model, base_dir=path)
save_model_with_external_weights(self.model.model, os.path.join(path, filename_prefix) + '.onnx')
else:
self.model.save_model_to_file(os.path.join(path, filename_prefix) + '.onnx')

Expand Down
12 changes: 12 additions & 0 deletions TrainingExtensions/onnx/src/python/aimet_onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,18 @@ def retrieve_constant_input(node: NodeProto, model: ModelProto, index: int
transposed = True
return weight, transposed

def save_model_with_external_weights(model: onnx.ModelProto, f: str, **kwargs):
"""
Saves an onnx model with external weights without mutating the original model

:param model: ONNX ModelProto object to save
:param f: filename to save the model to
:param kwargs: Additional keyword arguments to pass to :func:`onnx.save_model`
"""
onnx.save_model(model, f, save_as_external_data=True, **kwargs)
# Load back weights which are removed when saving as external data
onnx.load_external_data_for_model(model, os.path.dirname(f))


class CachedDataset:
"""
Expand Down
43 changes: 43 additions & 0 deletions TrainingExtensions/onnx/test/python/models/models_for_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from torch.nn.modules.batchnorm import _BatchNorm
from onnx import helper, numpy_helper, OperatorSetIdProto, TensorProto, load_model
from onnxruntime.quantization.onnx_quantizer import ONNXModel
from onnxruntime_extensions import PyOp, onnx_op

from aimet_common import libquant_info
from .mobilenet import MockMobileNetV1, MockMobileNetV11
Expand Down Expand Up @@ -2709,3 +2710,45 @@ def squeezenet1_0(tmpdir):
input_names=["input"], output_names=["output"])
model = onnx.load(filepath)
return ONNXModel(model)

@onnx_op(op_type="CustomAdd",
inputs=[PyOp.dt_float, PyOp.dt_float],
outputs=[PyOp.dt_float])
def add_op(x, y):
return x + y

def custom_op_model():
model = helper.make_model(
graph=helper.make_graph(
name="CustomAddModel",
inputs=[helper.make_tensor_value_info('model_input', TensorProto.FLOAT, shape=[10, 10])],
outputs=[helper.make_tensor_value_info('model_output', TensorProto.FLOAT, shape=[10, 10])],
initializer=[],
nodes=[
helper.make_node(
"Relu",
inputs=["model_input"],
outputs=["y"],
),
helper.make_node(
"CustomAdd",
inputs=["model_input", "y"],
outputs=["z"],
domain="ai.onnx.contrib"
),
helper.make_node(
"CustomAdd",
inputs=["z", "y"],
outputs=["output"],
domain="ai.onnx.contrib"
),
helper.make_node(
"Exp",
inputs=["output"],
outputs=["model_output"]
)
],
),
opset_imports=[helper.make_operatorsetid('ai.onnx.contrib', 1)]
)
return model
46 changes: 46 additions & 0 deletions TrainingExtensions/onnx/test/python/test_quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import json
import os
import tempfile
import tracemalloc

import onnx.numpy_helper
import torch
import numpy as np
Expand Down Expand Up @@ -653,6 +655,44 @@ def test_multiple_output_quantsim(self):
path=tempdir)
sim.session.run(None, {'input': sample_input})

def test_quantsim_init_memory_usage(self):
"""
When: Instantiate a quantsim model with high activation memory usage
Then: Memory usage should not spike
"""
num_layers = 2 ** 9
activation_dim = 2 ** 13
batch_size = 2 ** 8
total_act_memory = num_layers * activation_dim * batch_size

# Create a model with very high total activation memory usage
layers = [
onnx.helper.make_node("Constant", inputs=[], outputs=["shape"], name="shape",
value=onnx.numpy_helper.from_array(np.array([batch_size, activation_dim], dtype=np.dtype("int64")))),
onnx.helper.make_node("Expand", inputs=["input", "shape"], outputs=["act0"], name="reshape"),
]
for idx in range(num_layers):
layers.append(
onnx.helper.make_node("Sigmoid", inputs=[f"act{idx}"], outputs=[f"act{idx + 1}"],
name=f"layer_{idx}")
)

input_tensor = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 1])
output_tensor = onnx.helper.make_tensor_value_info(f"act{num_layers}", onnx.TensorProto.FLOAT,
[batch_size, activation_dim])
graph = onnx.helper.make_graph(layers, "graph", initializer=[], inputs=[input_tensor],
outputs=[output_tensor])
model = onnx.helper.make_model(graph)

with tempfile.TemporaryDirectory() as tempdir:
tracemalloc.start()
sim = QuantizationSimModel(model, path=tempdir)
current_mem, peak_mem = tracemalloc.get_traced_memory()
tracemalloc.stop()

assert peak_mem < current_mem + 0.25 * total_act_memory
assert peak_mem < current_mem * 5

@pytest.mark.skip(reason="test requires exact version of torch that the code has built against.")
def test_model_with_custom_ops(self):
custom_ops_path = os.path.dirname(libquant_info.__file__)
Expand Down Expand Up @@ -1690,3 +1730,9 @@ def test_identity_conv_perchannel(self):
config_file=get_path_for_per_channel_config())
assert sim.qc_quantize_op_dict["identity.input"].quant_info.usePerChannelMode
assert sim.qc_quantize_op_dict["identity.input"].quant_info.channelAxis == 0

def test_customop_model(self):
from onnxruntime_extensions import get_library_path
model = models_for_tests.custom_op_model()
sim = QuantizationSimModel(model, user_onnx_libs=[get_library_path()])
assert {"model_input", "output", "model_output", "y", "z"} == sim.qc_quantize_op_dict.keys()
Loading