diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/sequential_mse/seq_mse.py b/TrainingExtensions/onnx/src/python/aimet_onnx/sequential_mse/seq_mse.py index e2289c48e78..41f20e7cea2 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/sequential_mse/seq_mse.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/sequential_mse/seq_mse.py @@ -99,7 +99,7 @@ def __init__(self, :param model: float model :param sim: QuantizationSimModel object - :param data_loader: Data loader + :param data_loader: Torch Dataloader :param params: Sequential MSE parameters """ @@ -151,6 +151,50 @@ def __init__(self, self.static_tensor_name_to_proto) self.quantizers_to_be_disabled = self._get_quantizers_to_be_disabled() # check this + def _update_value_info_for_output(self, node): + """ + Updates the value info for output of a node in sim model. + Value info for QcQuantizeOp is not present in _sim_extractor + + :param node: onnx node + """ + + input_name = node.input[0] + output_name = node.output[0] + if input_name in self._sim_extractor.vimap and output_name not in self._sim_extractor.vimap: + value_info_for_output = copy.deepcopy(self._sim_extractor.vimap[input_name]) + value_info_for_output.name = node.output[0] + self._sim_extractor.vimap[node.output[0]] = value_info_for_output + + def _update_value_info_for_input(self, node): + """ + Updates the value info for input of a node in sim model. + Value info for QcQuantizeOp is not present in _sim_extractor + + :param node: onnx node + """ + + input_name = node.input[0] + output_name = node.output[0] + if output_name in self._sim_extractor.vimap and input_name not in self._sim_extractor.vimap: + value_info_for_input = copy.deepcopy(self._sim_extractor.vimap[output_name]) + value_info_for_input.name = node.input[0] + self._sim_extractor.vimap[node.input[0]] = value_info_for_input + + def _update_value_info_for_graph_output(self): + """ + Updates the value info for input of a node in sim model. + Value info for QcQuantizeOp is not present in _sim_extractor + + :param node: onnx node + """ + + for value_info in self.model.model.graph.output: + self._float_extractor.vimap[value_info.name] = value_info + + for value_info in self.sim.model.model.graph.output: + self._sim_extractor.vimap[value_info.name] = value_info + def _update_value_info(self): """ Updates the value info for sim model. @@ -159,11 +203,10 @@ def _update_value_info(self): for node in self.sim.model.nodes(): if node.op_type == "QcQuantizeOp": - input_name = node.input[0] - if input_name in self._sim_extractor.vimap: - value_info_for_output = copy.deepcopy(self._sim_extractor.vimap[input_name]) - value_info_for_output.name = node.output[0] - self._sim_extractor.vimap[node.output[0]] = value_info_for_output + self._update_value_info_for_output(node) + self._update_value_info_for_input(node) + + self._update_value_info_for_graph_output() def _fill_static_tensor_name_to_proto(self): """