diff --git a/TrainingExtensions/common/src/python/aimet_common/cross_layer_equalization.py b/TrainingExtensions/common/src/python/aimet_common/cross_layer_equalization.py index ff2cf1d3a42..b94ee0d6eb0 100644 --- a/TrainingExtensions/common/src/python/aimet_common/cross_layer_equalization.py +++ b/TrainingExtensions/common/src/python/aimet_common/cross_layer_equalization.py @@ -252,7 +252,8 @@ def find_downstream_layer_groups_to_scale(self, op, layer_groups: List, current_ current_group.append(op) # Terminating condition for current group - if not op.get_module() or not op.type in self._cls_supported_layer_types + self._cls_supported_activation_types: + if not op.get_module() or not op.type in self._cls_supported_layer_types + self._cls_supported_activation_types \ + or len(op.output_ops) > 1: if (len(current_group) > 1) and (current_group not in layer_groups): layer_groups.append(current_group) current_group = [] diff --git a/TrainingExtensions/onnx/test/python/models/models_for_tests.py b/TrainingExtensions/onnx/test/python/models/models_for_tests.py index 33a6d5d325d..1579410c5bc 100644 --- a/TrainingExtensions/onnx/test/python/models/models_for_tests.py +++ b/TrainingExtensions/onnx/test/python/models/models_for_tests.py @@ -2698,3 +2698,14 @@ def conv_with_weight_identity_input(): ) onnx.checker.check_model(model, True) return model + + +def squeezenet1_0(tmpdir): + import torchvision + filepath = os.path.join(os.path.join(tmpdir, "squeezenet1_0.onnx")) + model = torchvision.models.squeezenet1_0() + torch.onnx.export(model.eval(), torch.randn(1, 3, 224, 224), filepath, + training=torch.onnx.TrainingMode.EVAL, do_constant_folding=False, + input_names=["input"], output_names=["output"]) + model = onnx.load(filepath) + return ONNXModel(model) diff --git a/TrainingExtensions/onnx/test/python/test_cross_layer_equalization.py b/TrainingExtensions/onnx/test/python/test_cross_layer_equalization.py index 26869fdd009..bf2f17a61cd 100644 --- a/TrainingExtensions/onnx/test/python/test_cross_layer_equalization.py +++ b/TrainingExtensions/onnx/test/python/test_cross_layer_equalization.py @@ -187,6 +187,13 @@ def test_cle_transpose1D_model(self): output_after_cle = session.run(None, {'input': test_data}) assert np.allclose(output_after_cle, output_before_cle, rtol=1e-2, atol=1e-5) + def test_cls_squeezenet(self, tmp_path): + model = models_for_tests.squeezenet1_0(tmp_path) + cls = CrossLayerScaling(model) + cls_set_infos = cls.scale_model() + # Squeezenet1_0 doesn't have any scalable sets + assert not cls_set_infos + class TestHighBiasFold: """ Test methods for HighBiasFold"""