Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <[email protected]>
  • Loading branch information
quic-kyunggeu committed Jan 15, 2025
1 parent 48cb4df commit f92a749
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions TrainingExtensions/torch/test/python/v1/test_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from torchvision import models

import aimet_common.libpymo as libpymo
from aimet_common.aimet_tensor_quantizer import AimetTensorQuantizer

from aimet_common.defs import QuantScheme, QuantizationDataType, MAP_ROUND_MODE_TO_PYMO
from aimet_common.quantsim_config.utils import get_path_for_per_channel_config
Expand Down Expand Up @@ -5362,3 +5363,23 @@ def forward_pass(model, args):
sim.compute_encodings(forward_pass, None)

assert sim.model.compare.output_quantizers[0].enabled == False


@pytest.mark.parametrize('device', ['cpu', 'cuda'])
def test_histogram(device):
if device == 'cuda' and not torch.cuda.is_available():
pytest.skip()

"""
libpymo histogram should produce the same histogram as torch.histc
"""
x = torch.arange(-512, 513, dtype=torch.float, device=device)

# NOTE: libpymo histogram is hard-coded to use 3x of the range of the first input
ground_truth = torch.histc(x, bins=512, min=x.min() * 3, max=x.max() * 3)

q = AimetTensorQuantizer(libpymo.QuantizationMode.QUANTIZATION_TF_ENHANCED)
q.updateStats(x, x.is_cuda)
prob = torch.tensor([prob for _, prob in q.getStatsHistogram()], device=device)

assert torch.equal(prob * 1025, ground_truth)

0 comments on commit f92a749

Please sign in to comment.