diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/seq_mse.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/seq_mse.py index f94a7744079..6ce8ac3aaf0 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/seq_mse.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/seq_mse.py @@ -322,8 +322,10 @@ def _compute_loss(cls, x_blocks = torch.split(x, block_size, dim=-1) xq_blocks = torch.split(xq, block_size, dim=-1) else: - x_blocks = torch.split(x, block_size, dim=-3) - xq_blocks = torch.split(xq, block_size, dim=-3) + assert isinstance(quant_module, torch.nn.Conv2d) + groups = quant_module.groups + x_blocks = torch.split(x, block_size * groups, dim=-3) + xq_blocks = torch.split(xq, block_size * groups, dim=-3) block_losses = [] for idx, x_block in enumerate(x_blocks): diff --git a/TrainingExtensions/torch/test/python/v2/test_seq_mse_.py b/TrainingExtensions/torch/test/python/v2/test_seq_mse_.py index 0ca66c664ba..769d2ce8b55 100644 --- a/TrainingExtensions/torch/test/python/v2/test_seq_mse_.py +++ b/TrainingExtensions/torch/test/python/v2/test_seq_mse_.py @@ -334,3 +334,35 @@ def test_handle_grouped_block_quantizers(self): assert torch.equal(out, out_3) assert not torch.equal(out, out_2) + + @pytest.mark.parametrize('kwargs', [ + dict(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=2), + dict(in_channels=16, out_channels=16, kernel_size=(3, 3), padding=1), + dict(in_channels=16, out_channels=16, kernel_size=(3, 3), dilation=2), + dict(in_channels=16, out_channels=16, kernel_size=(3, 3), groups=16), + dict(in_channels=16, out_channels=16, kernel_size=(3, 3), groups=4), + ]) + def test_non_default_conv(self, kwargs): + """ + When: Run sequential MSE with conv2d with non-default arguments + (stride, padding, dilation, groups, ...) + Then: Shouldn't raise runtime error + """ + model = torch.nn.Sequential( + torch.nn.Conv2d(**kwargs), + ) + dummy_input = torch.randn(1, 16, 100, 100) + data_loader = (dummy_input,) * 2 + sim = QuantizationSimModel(model, dummy_input, default_param_bw=4, + quant_scheme=QuantScheme.post_training_tf) + qconv = sim.model[0] + qconv.param_quantizers['weight'].min.copy_(-1) + qconv.param_quantizers['weight'].max.copy_(1) + sim.compute_encodings(lambda m: m(dummy_input)) + + params = SeqMseParams(num_batches=2, inp_symmetry='asym', loss_fn='mse') + apply_seq_mse(model, sim, data_loader, params) + + # sanity check + assert qconv.param_quantizers['weight'].min != -1 + assert qconv.param_quantizers['weight'].max != 1