From 31857ba88ea62cfd484dae5f654e83874fcc462d Mon Sep 17 00:00:00 2001 From: Kyunggeun Lee Date: Wed, 18 Dec 2024 12:52:39 -0800 Subject: [PATCH 1/2] Switch default API to aimet_torch.v2 Signed-off-by: Kyunggeun Lee --- .../torch/src/python/aimet_torch/utils.py | 2 +- .../torch/test/python/test_import.py | 132 ++++++++++-------- 2 files changed, 77 insertions(+), 57 deletions(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/utils.py b/TrainingExtensions/torch/src/python/aimet_torch/utils.py index e93fd4e20b9..e2b930ab793 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/utils.py @@ -1307,7 +1307,7 @@ def _get_metadata_and_state_dict(safetensor_file_path: str) -> [dict, dict]: def _get_default_api() -> Union[Literal["v1"], Literal["v2"]]: - default_api = os.getenv("AIMET_DEFAULT_API", "v1").lower() + default_api = os.getenv("AIMET_DEFAULT_API", "v2").lower() if default_api not in ("v1", "v2"): raise RuntimeError("Invalid value specified for environment variable AIMET_DEFAULT_API. " diff --git a/TrainingExtensions/torch/test/python/test_import.py b/TrainingExtensions/torch/test/python/test_import.py index 60f34737bae..1da4529b514 100644 --- a/TrainingExtensions/torch/test/python/test_import.py +++ b/TrainingExtensions/torch/test/python/test_import.py @@ -43,104 +43,124 @@ import pytest import aimet_torch +from aimet_torch.utils import _get_default_api def test_default_import(): + if _get_default_api() == "v1": + from aimet_torch.v1 import quantsim as default_quantsim + from aimet_torch.v1.quantsim import QuantizationSimModel as default_QuantizationSimModel + from aimet_torch.v1.adaround import adaround_weight as default_adaround_weight + from aimet_torch.v1.adaround.adaround_weight import Adaround as default_Adaround + from aimet_torch.v1 import seq_mse as default_seq_mse + from aimet_torch.v1.seq_mse import apply_seq_mse as default_apply_seq_mse + from aimet_torch.v1.nn.modules import custom as default_custom + from aimet_torch.v1.nn.modules.custom import Add as default_Add + from aimet_torch.v1 import auto_quant as default_auto_quant + from aimet_torch.v1.auto_quant import AutoQuant as default_AutoQuant + from aimet_torch.v1 import quant_analyzer as default_quant_analyzer + from aimet_torch.v1.quant_analyzer import QuantAnalyzer as default_QuantAnalyzer + from aimet_torch.v1 import batch_norm_fold as default_batch_norm_fold + from aimet_torch.v1.batch_norm_fold import fold_all_batch_norms_to_scale as default_fold_all_batch_norms_to_scale + from aimet_torch.v1 import mixed_precision as default_mixed_precision + from aimet_torch.v1.mixed_precision import choose_mixed_precision as default_choose_mixed_precision + else: + from aimet_torch.v2 import quantsim as default_quantsim + from aimet_torch.v2.quantsim import QuantizationSimModel as default_QuantizationSimModel + from aimet_torch.v2.adaround import adaround_weight as default_adaround_weight + from aimet_torch.v2.adaround.adaround_weight import Adaround as default_Adaround + from aimet_torch.v2 import seq_mse as default_seq_mse + from aimet_torch.v2.seq_mse import apply_seq_mse as default_apply_seq_mse + from aimet_torch.v2.nn.modules import custom as default_custom + from aimet_torch.v2.nn.modules.custom import Add as default_Add + from aimet_torch.v2 import auto_quant as default_auto_quant + from aimet_torch.v2.auto_quant import AutoQuant as default_AutoQuant + from aimet_torch.v2 import quant_analyzer as default_quant_analyzer + from aimet_torch.v2.quant_analyzer import QuantAnalyzer as default_QuantAnalyzer + from aimet_torch.v2 import batch_norm_fold as default_batch_norm_fold + from aimet_torch.v2.batch_norm_fold import fold_all_batch_norms_to_scale as default_fold_all_batch_norms_to_scale + from aimet_torch.v2 import mixed_precision as default_mixed_precision + from aimet_torch.v2.mixed_precision import choose_mixed_precision as default_choose_mixed_precision + """ When: Import from aimet_torch.quantsim - Then: Import should be redirected to aimet_torch.v1.quantsim + Then: Import should be redirected to aimet_torch.v1 or aimet_torch.v2.quantsim """ - from aimet_torch import quantsim - from aimet_torch.v1 import quantsim as v1_quantsim - assert quantsim.QuantizationSimModel is v1_quantsim.QuantizationSimModel + from aimet_torch import quantsim + assert quantsim.QuantizationSimModel is default_quantsim.QuantizationSimModel - from aimet_torch.quantsim import QuantizationSimModel - from aimet_torch.v1.quantsim import QuantizationSimModel as v1_QuantizationSimModel - assert QuantizationSimModel is v1_QuantizationSimModel + from aimet_torch.quantsim import QuantizationSimModel + assert QuantizationSimModel is default_QuantizationSimModel """ When: Import from aimet_torch.adaround - Then: Import should be redirected to aimet_torch.v1.adaround + Then: Import should be redirected to aimet_torch.v1 or aimet_torch.v2.adaround """ - from aimet_torch.adaround import adaround_weight - from aimet_torch.v1.adaround import adaround_weight as v1_adaround_weight - assert adaround_weight.Adaround is v1_adaround_weight.Adaround + from aimet_torch.adaround import adaround_weight + assert adaround_weight.Adaround is default_adaround_weight.Adaround - from aimet_torch.adaround.adaround_weight import Adaround - from aimet_torch.v1.adaround.adaround_weight import Adaround as v1_Adaround - assert Adaround is v1_Adaround + from aimet_torch.adaround.adaround_weight import Adaround + assert Adaround is default_Adaround """ When: Import from aimet_torch.seq_mse - Then: Import should be redirected to aimet_torch.v1.seq_mse + Then: Import should be redirected to aimet_torch.v1 or aimet_torch.v2.seq_mse """ - from aimet_torch import seq_mse - from aimet_torch.v1 import seq_mse as v1_seq_mse - assert seq_mse.apply_seq_mse is v1_seq_mse.apply_seq_mse + from aimet_torch import seq_mse + assert seq_mse.apply_seq_mse is default_seq_mse.apply_seq_mse - from aimet_torch.seq_mse import apply_seq_mse - from aimet_torch.v1.seq_mse import apply_seq_mse as v1_apply_seq_mse - assert apply_seq_mse is v1_apply_seq_mse + from aimet_torch.seq_mse import apply_seq_mse + assert apply_seq_mse is default_apply_seq_mse """ When: Import from aimet_torch.nn - Then: Import should be redirected to aimet_torch.v1.nn + Then: Import should be redirected to aimet_torch.v1 or aimet_torch.v2.nn """ - from aimet_torch.nn.modules import custom - from aimet_torch.v1.nn.modules import custom as v1_custom - assert custom.Add is v1_custom.Add + from aimet_torch.nn.modules import custom + assert custom.Add is default_custom.Add - from aimet_torch.nn.modules.custom import Add - from aimet_torch.v1.nn.modules.custom import Add as v1_Add - assert Add is v1_Add + from aimet_torch.nn.modules.custom import Add + assert Add is default_Add """ When: Import from aimet_torch.auto_quant - Then: Import should be redirected to aimet_torch.v1.auto_quant + Then: Import should be redirected to aimet_torch.v1 or aimet_torch.v2.auto_quant """ - from aimet_torch import auto_quant - from aimet_torch.v1 import auto_quant as v1_auto_quant - assert auto_quant.AutoQuant is v1_auto_quant.AutoQuant + from aimet_torch import auto_quant + assert auto_quant.AutoQuant is default_auto_quant.AutoQuant - from aimet_torch.auto_quant import AutoQuant - from aimet_torch.v1.auto_quant import AutoQuant as v1_AutoQuant - assert AutoQuant is v1_AutoQuant + from aimet_torch.auto_quant import AutoQuant + assert AutoQuant is default_AutoQuant """ When: Import from aimet_torch.quant_analyzer - Then: Import should be redirected to aimet_torch.v1.quant_analyzer + Then: Import should be redirected to aimet_torch.v1 or aimet_torch.v2.quant_analyzer """ - from aimet_torch import quant_analyzer - from aimet_torch.v1 import quant_analyzer as v1_auto_quant - assert quant_analyzer.QuantAnalyzer is v1_auto_quant.QuantAnalyzer + from aimet_torch import quant_analyzer + assert quant_analyzer.QuantAnalyzer is default_quant_analyzer.QuantAnalyzer - from aimet_torch.quant_analyzer import QuantAnalyzer - from aimet_torch.v1.quant_analyzer import QuantAnalyzer as v1_QuantAnalyzer - assert QuantAnalyzer is v1_QuantAnalyzer + from aimet_torch.quant_analyzer import QuantAnalyzer + assert QuantAnalyzer is default_QuantAnalyzer """ When: Import from aimet_torch.batch_norm_fold - Then: Import should be redirected to aimet_torch.v1.batch_norm_fold + Then: Import should be redirected to aimet_torch.v1 or aimet_torch.v2.batch_norm_fold """ - from aimet_torch import batch_norm_fold - from aimet_torch.v1 import batch_norm_fold as v1_batch_norm_fold - assert batch_norm_fold.fold_all_batch_norms_to_scale is v1_batch_norm_fold.fold_all_batch_norms_to_scale + from aimet_torch import batch_norm_fold + assert batch_norm_fold.fold_all_batch_norms_to_scale is default_batch_norm_fold.fold_all_batch_norms_to_scale - from aimet_torch.batch_norm_fold import fold_all_batch_norms_to_scale - from aimet_torch.v1.batch_norm_fold import fold_all_batch_norms_to_scale as v1_fold_all_batch_norms_to_scale - assert fold_all_batch_norms_to_scale is v1_fold_all_batch_norms_to_scale + from aimet_torch.batch_norm_fold import fold_all_batch_norms_to_scale + assert fold_all_batch_norms_to_scale is default_fold_all_batch_norms_to_scale """ When: Import from aimet_torch.mixed_precision - Then: Import should be redirected to aimet_torch.v1.mixed_precision + Then: Import should be redirected to aimet_torch.v1 or aimet_torch.v2.mixed_precision """ - from aimet_torch import mixed_precision - from aimet_torch.v1 import mixed_precision as v1_mixed_precision - assert mixed_precision.choose_mixed_precision is v1_mixed_precision.choose_mixed_precision + from aimet_torch import mixed_precision + assert mixed_precision.choose_mixed_precision is default_mixed_precision.choose_mixed_precision - from aimet_torch.mixed_precision import choose_mixed_precision - from aimet_torch.v1.mixed_precision import choose_mixed_precision as v1_choose_mixed_precision - assert choose_mixed_precision is v1_choose_mixed_precision + from aimet_torch.mixed_precision import choose_mixed_precision + assert choose_mixed_precision is default_choose_mixed_precision def _get_all_modules(): From ce3c0c1afeacfff200298c5a6aba32251dc47675 Mon Sep 17 00:00:00 2001 From: Kyunggeun Lee Date: Thu, 19 Dec 2024 14:48:10 -0800 Subject: [PATCH 2/2] Update import statement Signed-off-by: Kyunggeun Lee --- .../torch/test/python/v1/test_auto_quant_with_amp.py | 2 +- TrainingExtensions/torch/test/python/v1/test_bn_fold.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/TrainingExtensions/torch/test/python/v1/test_auto_quant_with_amp.py b/TrainingExtensions/torch/test/python/v1/test_auto_quant_with_amp.py index 73d4368a13e..986403d0ca6 100644 --- a/TrainingExtensions/torch/test/python/v1/test_auto_quant_with_amp.py +++ b/TrainingExtensions/torch/test/python/v1/test_auto_quant_with_amp.py @@ -54,7 +54,7 @@ from aimet_torch.amp.mixed_precision_algo import GreedyMixedPrecisionAlgo from aimet_common.defs import QuantizationDataType from aimet_torch import utils -from aimet_torch.quantsim import QuantizationSimModel +from aimet_torch.v1.quantsim import QuantizationSimModel from aimet_torch.save_utils import SaveUtils diff --git a/TrainingExtensions/torch/test/python/v1/test_bn_fold.py b/TrainingExtensions/torch/test/python/v1/test_bn_fold.py index 7ea93e6ce70..758747bc727 100644 --- a/TrainingExtensions/torch/test/python/v1/test_bn_fold.py +++ b/TrainingExtensions/torch/test/python/v1/test_bn_fold.py @@ -43,7 +43,7 @@ import torch from torchvision import models -from aimet_torch.batch_norm_fold import fold_given_batch_norms +from aimet_torch.v1.batch_norm_fold import fold_given_batch_norms from ..models.test_models import TransposedConvModel from aimet_torch.model_preparer import prepare_model from aimet_common.defs import QuantScheme