Skip to content

Commit

Permalink
fix: fix bug in onnx CLS pattern matching (#3682)
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Tuttle <[email protected]>
Co-authored-by: Michael Tuttle <[email protected]>
  • Loading branch information
quic-kyunggeu and quic-mtuttle authored Dec 20, 2024
1 parent 16a817c commit bcc2735
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,8 @@ def find_downstream_layer_groups_to_scale(self, op, layer_groups: List, current_
current_group.append(op)

# Terminating condition for current group
if not op.get_module() or not op.type in self._cls_supported_layer_types + self._cls_supported_activation_types:
if not op.get_module() or not op.type in self._cls_supported_layer_types + self._cls_supported_activation_types \
or len(op.output_ops) > 1:
if (len(current_group) > 1) and (current_group not in layer_groups):
layer_groups.append(current_group)
current_group = []
Expand Down
11 changes: 11 additions & 0 deletions TrainingExtensions/onnx/test/python/models/models_for_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2698,3 +2698,14 @@ def conv_with_weight_identity_input():
)
onnx.checker.check_model(model, True)
return model


def squeezenet1_0(tmpdir):
import torchvision
filepath = os.path.join(os.path.join(tmpdir, "squeezenet1_0.onnx"))
model = torchvision.models.squeezenet1_0()
torch.onnx.export(model.eval(), torch.randn(1, 3, 224, 224), filepath,
training=torch.onnx.TrainingMode.EVAL, do_constant_folding=False,
input_names=["input"], output_names=["output"])
model = onnx.load(filepath)
return ONNXModel(model)
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,13 @@ def test_cle_transpose1D_model(self):
output_after_cle = session.run(None, {'input': test_data})
assert np.allclose(output_after_cle, output_before_cle, rtol=1e-2, atol=1e-5)

def test_cls_squeezenet(self, tmp_path):
model = models_for_tests.squeezenet1_0(tmp_path)
cls = CrossLayerScaling(model)
cls_set_infos = cls.scale_model()
# Squeezenet1_0 doesn't have any scalable sets
assert not cls_set_infos


class TestHighBiasFold:
""" Test methods for HighBiasFold"""
Expand Down

0 comments on commit bcc2735

Please sign in to comment.