diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/experimental/onnx/_export.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/experimental/onnx/_export.py index b260228767b..3dfc2fbc770 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/experimental/onnx/_export.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/experimental/onnx/_export.py @@ -46,6 +46,7 @@ from torch.onnx import is_in_onnx_export, symbolic_helper from aimet_torch.v2.utils import patch_attr +from aimet_torch.quantization.base import QuantizerBase aimet_opset = onnxscript.values.Opset(domain="aimet", version=1) @@ -204,6 +205,9 @@ def wrapper(*args, **kwargs): def export(model: torch.nn.Module, *args, **kwargs): + """ + Export a torch model to ONNX with precomputed scale and offset. + """ if not isinstance(model, torch.nn.Module): raise NotImplementedError @@ -215,8 +219,6 @@ def export(model: torch.nn.Module, *args, **kwargs): @contextmanager def _precompute_encodings(model: torch.nn.Module): - from aimet_torch.quantization.base import QuantizerBase - with ExitStack() as stack: for q in model.modules(): if isinstance(q, QuantizerBase):