diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/quantizer.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/quantizer.py index ee067ea97c5..27e5a30abca 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/quantizer.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/quantizer.py @@ -437,7 +437,10 @@ def get_offset(self, dtype=None) -> Optional[torch.Tensor]: dtype = dtype or torch.float32 if self.symmetric: - offset = torch.zeros_like(self.min, requires_grad=False, dtype=dtype) + offset = torch.full_like(self.min, + fill_value=-round((self.qmin + self.qmax) / 2), + requires_grad=False, + dtype=dtype) else: offset = ste_round(self.min.to(dtype) / self.get_scale(dtype)) - self.qmin diff --git a/TrainingExtensions/torch/test/python/v2/quantization/affine/test_affine_quantizer.py b/TrainingExtensions/torch/test/python/v2/quantization/affine/test_affine_quantizer.py index 5165b2b3c75..defe4e9ba6b 100644 --- a/TrainingExtensions/torch/test/python/v2/quantization/affine/test_affine_quantizer.py +++ b/TrainingExtensions/torch/test/python/v2/quantization/affine/test_affine_quantizer.py @@ -1543,3 +1543,47 @@ def test_parse_args_error(): Then: Create quantizer normally """ Quantize((1, 10), -128, 127, True) + + +@torch.no_grad() +@pytest.mark.parametrize('symmetric', [True, False]) +def test_signed_doesnt_affect_output(symmetric): + """ + When: Quantize/Dequantize the same tensor with signed and unsigned quantizers + Then: + 1) The quantized outputs should be equal with proper shifting + 2) The quantize-dequantized outputs should be equal + """ + q_int8 = Quantize(shape=(), bitwidth=8, symmetric=symmetric) + q_int8.signed = True + q_uint8 = Quantize(shape=(), bitwidth=8, symmetric=symmetric) + q_uint8.signed = False + + x = torch.arange(-10.0, 6.0) + + with q_int8.compute_encodings(), \ + q_uint8.compute_encodings(): + _ = q_int8(x) + _ = q_uint8(x) + + out_int8 = q_int8(x) + out_uint8 = q_uint8(x) + assert torch.equal(out_int8, out_uint8 - 128) + assert torch.equal(out_int8.dequantize(), out_uint8.dequantize()) + + qdq_int8 = QuantizeDequantize(shape=(), bitwidth=8, symmetric=symmetric) + qdq_int8.signed = True + qdq_uint8 = QuantizeDequantize(shape=(), bitwidth=8, symmetric=symmetric) + qdq_uint8.signed = False + + x = torch.arange(-10.0, 6.0) + + with qdq_int8.compute_encodings(), \ + qdq_uint8.compute_encodings(): + _ = qdq_int8(x) + _ = qdq_uint8(x) + + out_int8 = qdq_int8(x) + out_uint8 = qdq_uint8(x) + assert torch.equal(out_int8, out_uint8) + assert torch.equal(out_int8.quantize(), out_uint8.quantize() - 128)