Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BREAKING_CHANGE: switch aimet_torch default API from v1 to v2 #3689

Merged
merged 2 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion TrainingExtensions/torch/src/python/aimet_torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,7 +1307,7 @@ def _get_metadata_and_state_dict(safetensor_file_path: str) -> [dict, dict]:


def _get_default_api() -> Union[Literal["v1"], Literal["v2"]]:
default_api = os.getenv("AIMET_DEFAULT_API", "v1").lower()
default_api = os.getenv("AIMET_DEFAULT_API", "v2").lower()

if default_api not in ("v1", "v2"):
raise RuntimeError("Invalid value specified for environment variable AIMET_DEFAULT_API. "
Expand Down
132 changes: 76 additions & 56 deletions TrainingExtensions/torch/test/python/test_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,104 +43,124 @@
import pytest

import aimet_torch
from aimet_torch.utils import _get_default_api


def test_default_import():
if _get_default_api() == "v1":
from aimet_torch.v1 import quantsim as default_quantsim
from aimet_torch.v1.quantsim import QuantizationSimModel as default_QuantizationSimModel
from aimet_torch.v1.adaround import adaround_weight as default_adaround_weight
from aimet_torch.v1.adaround.adaround_weight import Adaround as default_Adaround
from aimet_torch.v1 import seq_mse as default_seq_mse
from aimet_torch.v1.seq_mse import apply_seq_mse as default_apply_seq_mse
from aimet_torch.v1.nn.modules import custom as default_custom
from aimet_torch.v1.nn.modules.custom import Add as default_Add
from aimet_torch.v1 import auto_quant as default_auto_quant
from aimet_torch.v1.auto_quant import AutoQuant as default_AutoQuant
from aimet_torch.v1 import quant_analyzer as default_quant_analyzer
from aimet_torch.v1.quant_analyzer import QuantAnalyzer as default_QuantAnalyzer
from aimet_torch.v1 import batch_norm_fold as default_batch_norm_fold
from aimet_torch.v1.batch_norm_fold import fold_all_batch_norms_to_scale as default_fold_all_batch_norms_to_scale
from aimet_torch.v1 import mixed_precision as default_mixed_precision
from aimet_torch.v1.mixed_precision import choose_mixed_precision as default_choose_mixed_precision
else:
from aimet_torch.v2 import quantsim as default_quantsim
from aimet_torch.v2.quantsim import QuantizationSimModel as default_QuantizationSimModel
from aimet_torch.v2.adaround import adaround_weight as default_adaround_weight
from aimet_torch.v2.adaround.adaround_weight import Adaround as default_Adaround
from aimet_torch.v2 import seq_mse as default_seq_mse
from aimet_torch.v2.seq_mse import apply_seq_mse as default_apply_seq_mse
from aimet_torch.v2.nn.modules import custom as default_custom
from aimet_torch.v2.nn.modules.custom import Add as default_Add
from aimet_torch.v2 import auto_quant as default_auto_quant
from aimet_torch.v2.auto_quant import AutoQuant as default_AutoQuant
from aimet_torch.v2 import quant_analyzer as default_quant_analyzer
from aimet_torch.v2.quant_analyzer import QuantAnalyzer as default_QuantAnalyzer
from aimet_torch.v2 import batch_norm_fold as default_batch_norm_fold
from aimet_torch.v2.batch_norm_fold import fold_all_batch_norms_to_scale as default_fold_all_batch_norms_to_scale
from aimet_torch.v2 import mixed_precision as default_mixed_precision
from aimet_torch.v2.mixed_precision import choose_mixed_precision as default_choose_mixed_precision

"""
When: Import from aimet_torch.quantsim
Then: Import should be redirected to aimet_torch.v1.quantsim
Then: Import should be redirected to aimet_torch.v1 or aimet_torch.v2.quantsim
"""
from aimet_torch import quantsim
from aimet_torch.v1 import quantsim as v1_quantsim
assert quantsim.QuantizationSimModel is v1_quantsim.QuantizationSimModel
from aimet_torch import quantsim
assert quantsim.QuantizationSimModel is default_quantsim.QuantizationSimModel

from aimet_torch.quantsim import QuantizationSimModel
from aimet_torch.v1.quantsim import QuantizationSimModel as v1_QuantizationSimModel
assert QuantizationSimModel is v1_QuantizationSimModel
from aimet_torch.quantsim import QuantizationSimModel
assert QuantizationSimModel is default_QuantizationSimModel

"""
When: Import from aimet_torch.adaround
Then: Import should be redirected to aimet_torch.v1.adaround
Then: Import should be redirected to aimet_torch.v1 or aimet_torch.v2.adaround
"""
from aimet_torch.adaround import adaround_weight
from aimet_torch.v1.adaround import adaround_weight as v1_adaround_weight
assert adaround_weight.Adaround is v1_adaround_weight.Adaround
from aimet_torch.adaround import adaround_weight
assert adaround_weight.Adaround is default_adaround_weight.Adaround

from aimet_torch.adaround.adaround_weight import Adaround
from aimet_torch.v1.adaround.adaround_weight import Adaround as v1_Adaround
assert Adaround is v1_Adaround
from aimet_torch.adaround.adaround_weight import Adaround
assert Adaround is default_Adaround

"""
When: Import from aimet_torch.seq_mse
Then: Import should be redirected to aimet_torch.v1.seq_mse
Then: Import should be redirected to aimet_torch.v1 or aimet_torch.v2.seq_mse
"""
from aimet_torch import seq_mse
from aimet_torch.v1 import seq_mse as v1_seq_mse
assert seq_mse.apply_seq_mse is v1_seq_mse.apply_seq_mse
from aimet_torch import seq_mse
assert seq_mse.apply_seq_mse is default_seq_mse.apply_seq_mse

from aimet_torch.seq_mse import apply_seq_mse
from aimet_torch.v1.seq_mse import apply_seq_mse as v1_apply_seq_mse
assert apply_seq_mse is v1_apply_seq_mse
from aimet_torch.seq_mse import apply_seq_mse
assert apply_seq_mse is default_apply_seq_mse

"""
When: Import from aimet_torch.nn
Then: Import should be redirected to aimet_torch.v1.nn
Then: Import should be redirected to aimet_torch.v1 or aimet_torch.v2.nn
"""
from aimet_torch.nn.modules import custom
from aimet_torch.v1.nn.modules import custom as v1_custom
assert custom.Add is v1_custom.Add
from aimet_torch.nn.modules import custom
assert custom.Add is default_custom.Add

from aimet_torch.nn.modules.custom import Add
from aimet_torch.v1.nn.modules.custom import Add as v1_Add
assert Add is v1_Add
from aimet_torch.nn.modules.custom import Add
assert Add is default_Add

"""
When: Import from aimet_torch.auto_quant
Then: Import should be redirected to aimet_torch.v1.auto_quant
Then: Import should be redirected to aimet_torch.v1 or aimet_torch.v2.auto_quant
"""
from aimet_torch import auto_quant
from aimet_torch.v1 import auto_quant as v1_auto_quant
assert auto_quant.AutoQuant is v1_auto_quant.AutoQuant
from aimet_torch import auto_quant
assert auto_quant.AutoQuant is default_auto_quant.AutoQuant

from aimet_torch.auto_quant import AutoQuant
from aimet_torch.v1.auto_quant import AutoQuant as v1_AutoQuant
assert AutoQuant is v1_AutoQuant
from aimet_torch.auto_quant import AutoQuant
assert AutoQuant is default_AutoQuant

"""
When: Import from aimet_torch.quant_analyzer
Then: Import should be redirected to aimet_torch.v1.quant_analyzer
Then: Import should be redirected to aimet_torch.v1 or aimet_torch.v2.quant_analyzer
"""
from aimet_torch import quant_analyzer
from aimet_torch.v1 import quant_analyzer as v1_auto_quant
assert quant_analyzer.QuantAnalyzer is v1_auto_quant.QuantAnalyzer
from aimet_torch import quant_analyzer
assert quant_analyzer.QuantAnalyzer is default_quant_analyzer.QuantAnalyzer

from aimet_torch.quant_analyzer import QuantAnalyzer
from aimet_torch.v1.quant_analyzer import QuantAnalyzer as v1_QuantAnalyzer
assert QuantAnalyzer is v1_QuantAnalyzer
from aimet_torch.quant_analyzer import QuantAnalyzer
assert QuantAnalyzer is default_QuantAnalyzer

"""
When: Import from aimet_torch.batch_norm_fold
Then: Import should be redirected to aimet_torch.v1.batch_norm_fold
Then: Import should be redirected to aimet_torch.v1 or aimet_torch.v2.batch_norm_fold
"""
from aimet_torch import batch_norm_fold
from aimet_torch.v1 import batch_norm_fold as v1_batch_norm_fold
assert batch_norm_fold.fold_all_batch_norms_to_scale is v1_batch_norm_fold.fold_all_batch_norms_to_scale
from aimet_torch import batch_norm_fold
assert batch_norm_fold.fold_all_batch_norms_to_scale is default_batch_norm_fold.fold_all_batch_norms_to_scale

from aimet_torch.batch_norm_fold import fold_all_batch_norms_to_scale
from aimet_torch.v1.batch_norm_fold import fold_all_batch_norms_to_scale as v1_fold_all_batch_norms_to_scale
assert fold_all_batch_norms_to_scale is v1_fold_all_batch_norms_to_scale
from aimet_torch.batch_norm_fold import fold_all_batch_norms_to_scale
assert fold_all_batch_norms_to_scale is default_fold_all_batch_norms_to_scale

"""
When: Import from aimet_torch.mixed_precision
Then: Import should be redirected to aimet_torch.v1.mixed_precision
Then: Import should be redirected to aimet_torch.v1 or aimet_torch.v2.mixed_precision
"""
from aimet_torch import mixed_precision
from aimet_torch.v1 import mixed_precision as v1_mixed_precision
assert mixed_precision.choose_mixed_precision is v1_mixed_precision.choose_mixed_precision
from aimet_torch import mixed_precision
assert mixed_precision.choose_mixed_precision is default_mixed_precision.choose_mixed_precision

from aimet_torch.mixed_precision import choose_mixed_precision
from aimet_torch.v1.mixed_precision import choose_mixed_precision as v1_choose_mixed_precision
assert choose_mixed_precision is v1_choose_mixed_precision
from aimet_torch.mixed_precision import choose_mixed_precision
assert choose_mixed_precision is default_choose_mixed_precision


def _get_all_modules():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from aimet_torch.amp.mixed_precision_algo import GreedyMixedPrecisionAlgo
from aimet_common.defs import QuantizationDataType
from aimet_torch import utils
from aimet_torch.quantsim import QuantizationSimModel
from aimet_torch.v1.quantsim import QuantizationSimModel
from aimet_torch.save_utils import SaveUtils


Expand Down
2 changes: 1 addition & 1 deletion TrainingExtensions/torch/test/python/v1/test_bn_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import torch
from torchvision import models

from aimet_torch.batch_norm_fold import fold_given_batch_norms
from aimet_torch.v1.batch_norm_fold import fold_given_batch_norms
from ..models.test_models import TransposedConvModel
from aimet_torch.model_preparer import prepare_model
from aimet_common.defs import QuantScheme
Expand Down
Loading