Skip to content

Commit

Permalink
Redefine symmetry as offset=-round((qmin + qmax) / 2) (quic#3718)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <[email protected]>
  • Loading branch information
quic-kyunggeu authored Jan 8, 2025
1 parent b92aaeb commit 314cbb4
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,10 @@ def get_offset(self, dtype=None) -> Optional[torch.Tensor]:
dtype = dtype or torch.float32

if self.symmetric:
offset = torch.zeros_like(self.min, requires_grad=False, dtype=dtype)
offset = torch.full_like(self.min,
fill_value=-round((self.qmin + self.qmax) / 2),
requires_grad=False,
dtype=dtype)
else:
offset = ste_round(self.min.to(dtype) / self.get_scale(dtype)) - self.qmin

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1543,3 +1543,47 @@ def test_parse_args_error():
Then: Create quantizer normally
"""
Quantize((1, 10), -128, 127, True)


@torch.no_grad()
@pytest.mark.parametrize('symmetric', [True, False])
def test_signed_doesnt_affect_output(symmetric):
"""
When: Quantize/Dequantize the same tensor with signed and unsigned quantizers
Then:
1) The quantized outputs should be equal with proper shifting
2) The quantize-dequantized outputs should be equal
"""
q_int8 = Quantize(shape=(), bitwidth=8, symmetric=symmetric)
q_int8.signed = True
q_uint8 = Quantize(shape=(), bitwidth=8, symmetric=symmetric)
q_uint8.signed = False

x = torch.arange(-10.0, 6.0)

with q_int8.compute_encodings(), \
q_uint8.compute_encodings():
_ = q_int8(x)
_ = q_uint8(x)

out_int8 = q_int8(x)
out_uint8 = q_uint8(x)
assert torch.equal(out_int8, out_uint8 - 128)
assert torch.equal(out_int8.dequantize(), out_uint8.dequantize())

qdq_int8 = QuantizeDequantize(shape=(), bitwidth=8, symmetric=symmetric)
qdq_int8.signed = True
qdq_uint8 = QuantizeDequantize(shape=(), bitwidth=8, symmetric=symmetric)
qdq_uint8.signed = False

x = torch.arange(-10.0, 6.0)

with qdq_int8.compute_encodings(), \
qdq_uint8.compute_encodings():
_ = qdq_int8(x)
_ = qdq_uint8(x)

out_int8 = qdq_int8(x)
out_uint8 = qdq_uint8(x)
assert torch.equal(out_int8, out_uint8)
assert torch.equal(out_int8.quantize(), out_uint8.quantize() - 128)

0 comments on commit 314cbb4

Please sign in to comment.