Skip to content

Commit

Permalink
fix: fix sequential MSE runtime error upon grouped conv2d (#3679)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <[email protected]>
  • Loading branch information
quic-kyunggeu authored Dec 20, 2024
1 parent 01fc213 commit 16a817c
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
6 changes: 4 additions & 2 deletions TrainingExtensions/torch/src/python/aimet_torch/v2/seq_mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,10 @@ def _compute_loss(cls,
x_blocks = torch.split(x, block_size, dim=-1)
xq_blocks = torch.split(xq, block_size, dim=-1)
else:
x_blocks = torch.split(x, block_size, dim=-3)
xq_blocks = torch.split(xq, block_size, dim=-3)
assert isinstance(quant_module, torch.nn.Conv2d)
groups = quant_module.groups
x_blocks = torch.split(x, block_size * groups, dim=-3)
xq_blocks = torch.split(xq, block_size * groups, dim=-3)

block_losses = []
for idx, x_block in enumerate(x_blocks):
Expand Down
32 changes: 32 additions & 0 deletions TrainingExtensions/torch/test/python/v2/test_seq_mse_.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,35 @@ def test_handle_grouped_block_quantizers(self):

assert torch.equal(out, out_3)
assert not torch.equal(out, out_2)

@pytest.mark.parametrize('kwargs', [
dict(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=2),
dict(in_channels=16, out_channels=16, kernel_size=(3, 3), padding=1),
dict(in_channels=16, out_channels=16, kernel_size=(3, 3), dilation=2),
dict(in_channels=16, out_channels=16, kernel_size=(3, 3), groups=16),
dict(in_channels=16, out_channels=16, kernel_size=(3, 3), groups=4),
])
def test_non_default_conv(self, kwargs):
"""
When: Run sequential MSE with conv2d with non-default arguments
(stride, padding, dilation, groups, ...)
Then: Shouldn't raise runtime error
"""
model = torch.nn.Sequential(
torch.nn.Conv2d(**kwargs),
)
dummy_input = torch.randn(1, 16, 100, 100)
data_loader = (dummy_input,) * 2
sim = QuantizationSimModel(model, dummy_input, default_param_bw=4,
quant_scheme=QuantScheme.post_training_tf)
qconv = sim.model[0]
qconv.param_quantizers['weight'].min.copy_(-1)
qconv.param_quantizers['weight'].max.copy_(1)
sim.compute_encodings(lambda m: m(dummy_input))

params = SeqMseParams(num_batches=2, inp_symmetry='asym', loss_fn='mse')
apply_seq_mse(model, sim, data_loader, params)

# sanity check
assert qconv.param_quantizers['weight'].min != -1
assert qconv.param_quantizers['weight'].max != 1

0 comments on commit 16a817c

Please sign in to comment.