diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/sequential_mse/dependency_graph_utils.py b/TrainingExtensions/onnx/src/python/aimet_onnx/sequential_mse/dependency_graph_utils.py index 3f7cea625fd..62266e5b921 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/sequential_mse/dependency_graph_utils.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/sequential_mse/dependency_graph_utils.py @@ -74,10 +74,28 @@ def __init__(self, connected_graph, node_name_to_input_names, static_tensor_name self.static_tensor_name_to_proto = static_tensor_name_to_proto self.starting_ops = list() self.graph_outputs = [output.name for output in self.connected_graph.model.graph.output] + self.input_ops_name = list() self._fill_indegree() + self._fill_input_ops_name() self._init_name_to_dependent_on_supported_module() + def _fill_input_ops_name(self): + """ + Fill the input op names dict with ops having at least one graph input + """ + + graph_inputs = list() + + for input_tensor in self.connected_graph.model.graph.input: + graph_inputs.append(input_tensor.name) + + for node_name, input_names in self.node_name_to_input_names.items(): + for input_name in input_names: + if input_name in graph_inputs: + self.input_ops_name.append(node_name) + break + def _fill_indegree(self): """ Initializes the indegree using the connected graph @@ -158,7 +176,7 @@ def _create_dependency_graph_helper(self, src_op: Op, dependency_graph: Dependen if src_op.model_module is not None: module = src_op.model_module.get_module() - if self.is_dependency_module(module) or src_op in self.connected_graph.starting_ops: + if self.is_dependency_module(module) or src_op.name in self.input_ops_name: is_module_supported = True op_name = src_op.name_op op_type = src_op.type diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/sequential_mse/seq_mse.py b/TrainingExtensions/onnx/src/python/aimet_onnx/sequential_mse/seq_mse.py index 41f20e7cea2..ba01a5e8ed4 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/sequential_mse/seq_mse.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/sequential_mse/seq_mse.py @@ -201,13 +201,13 @@ def _update_value_info(self): Value info for QcQuantizeOp is not present in _sim_extractor """ + self._update_value_info_for_graph_output() + for node in self.sim.model.nodes(): if node.op_type == "QcQuantizeOp": self._update_value_info_for_output(node) self._update_value_info_for_input(node) - self._update_value_info_for_graph_output() - def _fill_static_tensor_name_to_proto(self): """ Fills the mapping from static tensor name to the prop buf diff --git a/TrainingExtensions/onnx/test/python/models/test_models_onnx.py b/TrainingExtensions/onnx/test/python/models/test_models_onnx.py new file mode 100644 index 00000000000..516458f0f93 --- /dev/null +++ b/TrainingExtensions/onnx/test/python/models/test_models_onnx.py @@ -0,0 +1,327 @@ +# -*- mode: python -*- +# ============================================================================= +# @@-COPYRIGHT-START-@@ +# +# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# SPDX-License-Identifier: BSD-3-Clause +# +# @@-COPYRIGHT-END-@@ +# ============================================================================= +"""Dummy onnx models for testing""" + + +import onnx +import numpy as np + + +def create_initializer_tensor( + name: str, + tensor_array: np.ndarray, + data_type: onnx.TensorProto = onnx.TensorProto.FLOAT +) -> onnx.TensorProto: + + initializer_tensor = onnx.helper.make_tensor( + name=name, + data_type=data_type, + dims=tensor_array.shape, + vals=tensor_array.flatten().tolist()) + + return initializer_tensor + +class ModelWithMultipleInputs: + + @staticmethod + def get_model(): + model_input_name = "X1" + X1 = onnx.helper.make_tensor_value_info(model_input_name, + onnx.TensorProto.FLOAT, + [1, 3, 32, 32]) + + model_input_name = "X2" + X2 = onnx.helper.make_tensor_value_info(model_input_name, + onnx.TensorProto.FLOAT, + [1, 3, 32, 32]) + + conv1_output_node_name = "Conv1_Y" + + # Dummy weights for conv. + conv1_in_channels = 3 + conv1_out_channels = 3 + conv1_kernel_shape = (3, 3) + conv1_pads = (1, 1, 1, 1) + conv1_W = np.ones(shape=(conv1_out_channels, conv1_in_channels, + *conv1_kernel_shape)).astype(np.float32) + conv1_B = np.ones(shape=(conv1_out_channels)).astype(np.float32) + # Create the initializer tensor for the weights. + conv1_W_initializer_tensor_name = "Conv1_W" + conv1_W_initializer_tensor = create_initializer_tensor( + name=conv1_W_initializer_tensor_name, + tensor_array=conv1_W, + data_type=onnx.TensorProto.FLOAT) + conv1_B_initializer_tensor_name = "Conv1_B" + conv1_B_initializer_tensor = create_initializer_tensor( + name=conv1_B_initializer_tensor_name, + tensor_array=conv1_B, + data_type=onnx.TensorProto.FLOAT) + + conv1_node = onnx.helper.make_node( + name="Conv1", + op_type="Conv", + inputs=[ + "X1", conv1_W_initializer_tensor_name, + conv1_B_initializer_tensor_name + ], + outputs=[conv1_output_node_name], + kernel_shape=conv1_kernel_shape, + pads=conv1_pads, + ) + + add_0_node_name = "ADD_0" + + add_0_node = onnx.helper.make_node( + name=add_0_node_name, # Name is optional. + op_type="Add", + # Must follow the order of input and output definitions. + # https://github.com/onnx/onnx/blob/rel-1.9.0/docs/Operators.md#inputs-2---3 + inputs=[ + conv1_output_node_name, "X2" + ], + outputs=[add_0_node_name] + ) + + conv2_output_node_name = "Conv2_Y" + + # Dummy weights for conv. + conv2_in_channels = 3 + conv2_out_channels = 3 + conv2_kernel_shape = (3, 3) + conv2_pads = (1, 1, 1, 1) + conv2_W = np.ones(shape=(conv2_out_channels, conv2_in_channels, + *conv2_kernel_shape)).astype(np.float32) + conv2_B = np.ones(shape=(conv2_out_channels)).astype(np.float32) + # Create the initializer tensor for the weights. + conv2_W_initializer_tensor_name = "Conv2_W" + conv2_W_initializer_tensor = create_initializer_tensor( + name=conv2_W_initializer_tensor_name, + tensor_array=conv2_W, + data_type=onnx.TensorProto.FLOAT) + conv2_B_initializer_tensor_name = "Conv2_B" + conv2_B_initializer_tensor = create_initializer_tensor( + name=conv2_B_initializer_tensor_name, + tensor_array=conv2_B, + data_type=onnx.TensorProto.FLOAT) + + conv2_node = onnx.helper.make_node( + name="Conv2", # Name is optional. + op_type="Conv", + # Must follow the order of input and output definitions. + # https://github.com/onnx/onnx/blob/rel-1.9.0/docs/Operators.md#inputs-2---3 + inputs=[ + add_0_node_name, conv2_W_initializer_tensor_name, + conv2_B_initializer_tensor_name + ], + outputs=[conv2_output_node_name], + # The following arguments are attributes. + kernel_shape=conv2_kernel_shape, + # Default values for other attributes: strides=[1, 1], dilations=[1, 1], groups=1 + pads=conv2_pads, + ) + + add_1_node_name = "ADD_1" + + add_1_node = onnx.helper.make_node( + name=add_1_node_name, # Name is optional. + op_type="Add", + # Must follow the order of input and output definitions. + # https://github.com/onnx/onnx/blob/rel-1.9.0/docs/Operators.md#inputs-2---3 + inputs=[ + conv2_output_node_name, "X1" + ], + outputs=["Y"] + ) + + + model_output_name = "Y" + Y = onnx.helper.make_tensor_value_info(model_output_name, + onnx.TensorProto.FLOAT, + [1, 3, 32, 32]) + + + graph_def = onnx.helper.make_graph( + nodes=[conv1_node, add_0_node, conv2_node, add_1_node], + name="ConvReluNet", + inputs=[X1, X2], # Graph input + outputs=[Y], # Graph output + initializer=[conv1_W_initializer_tensor, conv1_B_initializer_tensor, conv2_W_initializer_tensor, conv2_B_initializer_tensor], + ) + # Create the model (ModelProto) + return onnx.helper.make_model(graph_def, producer_name="onnx-example") + + +def model_with_multiple_inputs(): + return ModelWithMultipleInputs.get_model() + + +class ModelWithMultipleOutputs: + + @staticmethod + def get_model(): + model_input_name = "X1" + X1 = onnx.helper.make_tensor_value_info(model_input_name, + onnx.TensorProto.FLOAT, + [1, 3, 32, 32]) + + model_input_name = "X2" + X2 = onnx.helper.make_tensor_value_info(model_input_name, + onnx.TensorProto.FLOAT, + [1, 3, 32, 32]) + + conv1_output_node_name = "Conv1_Y" + + # Dummy weights for conv. + conv1_in_channels = 3 + conv1_out_channels = 3 + conv1_kernel_shape = (3, 3) + conv1_pads = (1, 1, 1, 1) + conv1_W = np.ones(shape=(conv1_out_channels, conv1_in_channels, + *conv1_kernel_shape)).astype(np.float32) + conv1_B = np.ones(shape=(conv1_out_channels)).astype(np.float32) + # Create the initializer tensor for the weights. + conv1_W_initializer_tensor_name = "Conv1_W" + conv1_W_initializer_tensor = create_initializer_tensor( + name=conv1_W_initializer_tensor_name, + tensor_array=conv1_W, + data_type=onnx.TensorProto.FLOAT) + conv1_B_initializer_tensor_name = "Conv1_B" + conv1_B_initializer_tensor = create_initializer_tensor( + name=conv1_B_initializer_tensor_name, + tensor_array=conv1_B, + data_type=onnx.TensorProto.FLOAT) + + conv1_node = onnx.helper.make_node( + name="Conv1", + op_type="Conv", + inputs=[ + "X1", conv1_W_initializer_tensor_name, + conv1_B_initializer_tensor_name + ], + outputs=[conv1_output_node_name], + kernel_shape=conv1_kernel_shape, + pads=conv1_pads, + ) + + add_0_node_name = "ADD_0" + + add_0_node = onnx.helper.make_node( + name=add_0_node_name, # Name is optional. + op_type="Add", + # Must follow the order of input and output definitions. + # https://github.com/onnx/onnx/blob/rel-1.9.0/docs/Operators.md#inputs-2---3 + inputs=[ + conv1_output_node_name, "X2" + ], + outputs=[add_0_node_name] + ) + + conv2_output_node_name = "Conv2_Y" + + # Dummy weights for conv. + conv2_in_channels = 3 + conv2_out_channels = 3 + conv2_kernel_shape = (3, 3) + conv2_pads = (1, 1, 1, 1) + conv2_W = np.ones(shape=(conv2_out_channels, conv2_in_channels, + *conv2_kernel_shape)).astype(np.float32) + conv2_B = np.ones(shape=(conv2_out_channels)).astype(np.float32) + # Create the initializer tensor for the weights. + conv2_W_initializer_tensor_name = "Conv2_W" + conv2_W_initializer_tensor = create_initializer_tensor( + name=conv2_W_initializer_tensor_name, + tensor_array=conv2_W, + data_type=onnx.TensorProto.FLOAT) + conv2_B_initializer_tensor_name = "Conv2_B" + conv2_B_initializer_tensor = create_initializer_tensor( + name=conv2_B_initializer_tensor_name, + tensor_array=conv2_B, + data_type=onnx.TensorProto.FLOAT) + + conv2_node = onnx.helper.make_node( + name="Conv2", # Name is optional. + op_type="Conv", + # Must follow the order of input and output definitions. + # https://github.com/onnx/onnx/blob/rel-1.9.0/docs/Operators.md#inputs-2---3 + inputs=[ + add_0_node_name, conv2_W_initializer_tensor_name, + conv2_B_initializer_tensor_name + ], + outputs=[conv2_output_node_name], + # The following arguments are attributes. + kernel_shape=conv2_kernel_shape, + # Default values for other attributes: strides=[1, 1], dilations=[1, 1], groups=1 + pads=conv2_pads, + ) + + add_1_node_name = "ADD_1" + + add_1_node = onnx.helper.make_node( + name=add_1_node_name, # Name is optional. + op_type="Add", + # Must follow the order of input and output definitions. + # https://github.com/onnx/onnx/blob/rel-1.9.0/docs/Operators.md#inputs-2---3 + inputs=[ + conv2_output_node_name, "X1" + ], + outputs=["Y"] + ) + + + model_output_name = "Y" + Y = onnx.helper.make_tensor_value_info(model_output_name, + onnx.TensorProto.FLOAT, + [1, 3, 32, 32]) + + conv1_output = onnx.helper.make_tensor_value_info(conv1_output_node_name, + onnx.TensorProto.FLOAT, + [1, 3, 32, 32]) + + + graph_def = onnx.helper.make_graph( + nodes=[conv1_node, add_0_node, conv2_node, add_1_node], + name="ConvReluNet", + inputs=[X1, X2], # Graph input + outputs=[Y, conv1_output], # Graph output + initializer=[conv1_W_initializer_tensor, conv1_B_initializer_tensor, conv2_W_initializer_tensor, conv2_B_initializer_tensor], + ) + # Create the model (ModelProto) + return onnx.helper.make_model(graph_def, producer_name="onnx-example") + + +def model_with_multiple_outputs(): + return ModelWithMultipleOutputs.get_model() diff --git a/TrainingExtensions/onnx/test/python/test_seq_mse.py b/TrainingExtensions/onnx/test/python/test_seq_mse.py index 6139ea2a653..24e55b3fafd 100644 --- a/TrainingExtensions/onnx/test/python/test_seq_mse.py +++ b/TrainingExtensions/onnx/test/python/test_seq_mse.py @@ -63,6 +63,8 @@ from models.test_models import single_conv_layer_model from models.test_models import model_with_split from models.test_models import single_residual_model +from models.test_models_onnx import model_with_multiple_inputs +from models.test_models_onnx import model_with_multiple_outputs torch.manual_seed(42) @@ -73,29 +75,32 @@ def __init__(self, data): self.data = data def __getitem__(self, index): - return self.data[index] + return tuple(d[index] for d in self.data) def __len__(self): - return len(self.data) - - dataset = MyDataset([[dummy_input]]) + return len(self.data[0]) + dataset = MyDataset(dummy_input) return DataLoader(dataset) def dummy_input_for_linear_layer(): - return torch.randn((100, 100)) + return [torch.randn((1, 100, 100))] def dummy_input_for_conv_layer(): - return torch.randn((5, 5, 5)) + return [torch.randn((1, 5, 5, 5))] def dummy_input_for_dependency_graph(): - return torch.randn((1, 10, 10)) + return [torch.randn((1, 1, 10, 10))] def dummy_input_for_residual_model(): - return torch.randn((3, 32, 32)) + return [torch.randn((1, 3, 32, 32))] + + +def dummy_input_for_model_with_multiple_input(): + return [torch.randn((1, 3, 32, 32)), torch.randn((1, 3, 32, 32))] def get_single_linear_layer_model(): @@ -110,6 +115,13 @@ def get_model_with_split(): return model_with_split() +def get_model_with_multiple_inputs(): + return model_with_multiple_inputs() + + +def get_model_with_multiple_outputs(): + return model_with_multiple_outputs() + @staticmethod def _get_config_file(is_symmetric: bool, strict_symmetric: bool, unsigned_symmetric:bool, pcq: bool) -> str: """ Temporary fix until the config file can be read from beq_config directory""" @@ -432,4 +444,38 @@ def test_apply_seq_mse_for_residual_model(inp_symmetry, param_bw, loss_fn, enabl assert weight_quantizer_fc.is_encoding_frozen() == False +def test_model_with_multiple_inputs_dependency_graph_utils(): + + model = get_model_with_multiple_inputs() + sim = QuantizationSimModel(model=copy.deepcopy(model), + quant_scheme=QuantScheme.post_training_tf, + default_activation_bw=8, + default_param_bw=4, + use_cuda=False, + config_file=_get_config_file(is_symmetric=True, strict_symmetric=False, + unsigned_symmetric=False, pcq=True)) + seq_params = SeqMseParams() + dataloader = unlabeled_data_loader(dummy_input_for_model_with_multiple_input()) + seq_mse = SequentialMse(model, sim, seq_params, dataloader) + + starting_ops_names = [op.name_op for op in seq_mse.dependency_graph_utils.starting_ops] + + assert starting_ops_names == ["Conv1"] + assert seq_mse.dependency_graph_utils.indegree == {"Conv1": 0, "ADD_0": 1, "ADD_1": 1, "Conv2": 1} + assert seq_mse.dependency_graph_utils.input_ops_name == ["Conv1", "ADD_0", "ADD_1"] + +def test_model_with_multiple_outputs_value_info(): + + model = get_model_with_multiple_outputs() + sim = QuantizationSimModel(model=copy.deepcopy(model), + quant_scheme=QuantScheme.post_training_tf, + default_activation_bw=8, + default_param_bw=4, + use_cuda=False, + config_file=_get_config_file(is_symmetric=True, strict_symmetric=False, + unsigned_symmetric=False, pcq=True)) + seq_params = SeqMseParams() + dataloader = unlabeled_data_loader(dummy_input_for_model_with_multiple_input()) + seq_mse = SequentialMse(model, sim, seq_params, dataloader) + assert 'Conv1_Y' in seq_mse._sim_extractor.vimap \ No newline at end of file