From 2f76988dfd40ca04fb6e9dcbd0b02361e2b8946b Mon Sep 17 00:00:00 2001 From: Kyunggeun Lee Date: Wed, 18 Dec 2024 11:06:32 -0800 Subject: [PATCH] Deprecate quantsim util APIs Signed-off-by: Kyunggeun Lee --- .../src/python/aimet_torch/_base/quantsim.py | 2 ++ .../aimet_torch/v2/quantsim/quantsim.py | 23 ++++++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/_base/quantsim.py b/TrainingExtensions/torch/src/python/aimet_torch/_base/quantsim.py index ce37439162f..5c6f4738973 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/_base/quantsim.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/_base/quantsim.py @@ -1735,6 +1735,7 @@ def run_modules_for_traced_custom_marker(self, module_list: List[torch.nn.Module +@deprecated("Use pickle.dump instead") def save_checkpoint(quant_sim_model: _QuantizationSimModelInterface, file_path: str): """ This API provides a way for the user to save a checkpoint of the quantized model which can @@ -1749,6 +1750,7 @@ def save_checkpoint(quant_sim_model: _QuantizationSimModelInterface, file_path: pickle.dump(quant_sim_model, file) +@deprecated("Use pickle.load instead") def load_checkpoint(file_path: str) -> _QuantizationSimModelInterface: """ Load the quantized model diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantsim/quantsim.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantsim/quantsim.py index 36d9977e05e..51d509e3d29 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantsim/quantsim.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantsim/quantsim.py @@ -532,7 +532,17 @@ def _remove_quantization_wrappers(cls, starting_module, list_of_modules_to_exclu cls._remove_quantization_wrappers(module, list_of_modules_to_exclude) -@deprecated("Use QuantizationSimModel.load_encodings instead.") +@deprecated(""" +Use QuantizationSimModel.load_encodings with the following keyword arguments instead: +``` +sim.load_encodings(encoding_path + strict=True, + partial=False, + requires_grad=None, + allow_overwrite=None) +``` +""" +) def load_encodings_to_sim(quant_sim_model: _QuantizationSimModelBase, pytorch_encoding_path: str): """ Loads the saved encodings to quant sim model. The encoding filename to load should end in _torch.encodings, @@ -549,6 +559,17 @@ def load_encodings_to_sim(quant_sim_model: _QuantizationSimModelBase, pytorch_en allow_overwrite=None) +@deprecated(r""" +Use aimet_torch.nn.compute_encodings contextmanager on each sim.model instead. For example: +``` +with torch.no_grad(), \ + aimet_torch.v2.nn.compute_encodings(sim_0.model), \ + aimet_torch.v2.nn.compute_encodings(sim_1.model), \ + aimet_torch.v2.nn.compute_encodings(sim_2.model): + # Run forward pass with calibration dataset +``` +""" +) def compute_encodings_for_sims(sim_list: Sequence[QuantizationSimModel], forward_pass_callback: Callable, forward_pass_callback_args: Any): """