diff --git a/examples/neural_compressor/language-modeling/run_clm.py b/examples/neural_compressor/language-modeling/run_clm.py index 7e81072194..55f79b2185 100644 --- a/examples/neural_compressor/language-modeling/run_clm.py +++ b/examples/neural_compressor/language-modeling/run_clm.py @@ -215,6 +215,10 @@ class OptimizationArguments: default="sym", metadata={"help": "Scheme for weight only quantization. Choose from 'sym' and 'asym'."}, ) + use_layer_wise: bool = field( + default=False, + metadata={"help": "Use layer wise to do quantization to save memory."}, + ) quantization_methodology: str = field( default="rtn", metadata={"help": "Quantization methodology for weight only quantization. Choose from 'rtn' and 'gptq'."}, @@ -659,6 +663,7 @@ def compute_metrics(eval_preds): "bits": optim_args.bits, "sym": optim_args.weight_only_scheme == "sym", "group_size": optim_args.group_size, + "use_layer_wise": optim_args.use_layer_wise, } if optim_args.quantization_methodology == "gptq": @@ -666,6 +671,7 @@ def compute_metrics(eval_preds): damp_percent=optim_args.damp_percent, nsamples=optim_args.num_calibration_samples, blocksize=optim_args.gptq_block_size, + tokenizer=tokenizer, **algorithm_args, ) else: diff --git a/optimum/intel/neural_compressor/quantization.py b/optimum/intel/neural_compressor/quantization.py index 6ca9fd661d..92e7fc57b9 100644 --- a/optimum/intel/neural_compressor/quantization.py +++ b/optimum/intel/neural_compressor/quantization.py @@ -374,22 +374,21 @@ def _weight_only_quantization( } low_cpu_mem_usage = True - if use_xpu: - try: - # TODO: if low_cpu_mem_uasge is True, gptj will have accuracy issue on CPU device. - model = model_class.from_pretrained( - model_id, low_cpu_mem_usage=low_cpu_mem_usage, device_map="cpu", **loading_kwargs - ) - except NotImplementedError: - logger.info( - "Failed to load models with `low_cpu_mem_usage=True`, will fall to traditional load method resulting in higher memory consumption." - ) - low_cpu_mem_usage = False - model = model_class.from_pretrained(model_id, low_cpu_mem_usage=low_cpu_mem_usage, **loading_kwargs) - quantization_config.update(**{"device": "xpu"}) - quantization_config.post_init_xpu() + + if getattr(quantization_config, "use_layer_wise", False): + if is_neural_compressor_version(">=", "3.2"): + from neural_compressor.torch import load_empty_model + + model = load_empty_model(model_id, cls=model_class, **loading_kwargs) + else: + raise ValueError("INC version must be >= 3.2 when use_layer_wise is set to True in quantization_config.") else: model = model_class.from_pretrained(model_id, low_cpu_mem_usage=low_cpu_mem_usage, **loading_kwargs) + + if use_xpu: + quantization_config.update(**{"device": "xpu"}) + quantization_config.post_init_xpu() + else: quantization_config.post_init_cpu() model.config.update({"low_cpu_mem_usage": low_cpu_mem_usage}) diff --git a/tests/neural_compressor/test_optimization.py b/tests/neural_compressor/test_optimization.py index 75f2845c78..6b01baf705 100644 --- a/tests/neural_compressor/test_optimization.py +++ b/tests/neural_compressor/test_optimization.py @@ -45,7 +45,7 @@ set_seed, ) from utils_tests import MODEL_NAMES, SEED, INCTestMixin, _generate_dataset -from optimum.intel.utils.import_utils import is_torch_version +from optimum.intel.utils.import_utils import is_neural_compressor_version from optimum.intel import ( INCConfig, @@ -467,12 +467,16 @@ def _compute_metrics(pred): class WeightOnlyQuantizationTest(INCTestMixin): WEIGHT_ONLY_CONFIG = ( - ("rtn", 4), - ("gptq", 4), + ("rtn", 4, False), + ("rtn", 4, True), + ("gptq", 4, False), + ("gptq", 4, True), ) @parameterized.expand(WEIGHT_ONLY_CONFIG) - def test_weight_only_quantization(self, methodology, bits): + def test_weight_only_quantization(self, methodology, bits, use_layer_wise): + if use_layer_wise and is_neural_compressor_version("<", "3.2"): + self.skipTest("INC version < 3.2 doesn't support layer-wise feature.") from neural_compressor.transformers import GPTQConfig, RtnConfig model_name = "hf-internal-testing/tiny-random-GPTNeoForCausalLM" @@ -489,9 +493,10 @@ def test_weight_only_quantization(self, methodology, bits): batch_size=5, seq_len=32, block_size=16, + use_layer_wise=use_layer_wise, ) else: - quantization_config = RtnConfig(bits=bits, group_size=8) + quantization_config = RtnConfig(bits=bits, group_size=8, use_layer_wise=use_layer_wise) tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) @@ -503,6 +508,7 @@ def test_weight_only_quantization(self, methodology, bits): with torch.no_grad(): quantizer_outputs = quantized_model(**tokens) quantized_model.save_pretrained(tmp_dir) + loaded_model = INCModelForCausalLM.from_pretrained(tmp_dir) with torch.no_grad(): loaded_outputs = loaded_model(**tokens)