Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 9, 2024
1 parent 4da06c8 commit 3374284
Show file tree
Hide file tree
Showing 40 changed files with 577 additions and 340 deletions.
2 changes: 1 addition & 1 deletion neural_compressor/torch/algorithms/fp8_quant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@
ModuleConfig,
ModuleInfo,
ModuleType,
ModuleExtraConfig
ModuleExtraConfig,
)
from neural_compressor.torch.algorithms.fp8_quant.save_load import save, load
20 changes: 14 additions & 6 deletions neural_compressor/torch/algorithms/fp8_quant/_core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,20 @@
import numpy as np
import torch

from .._quant_common.helper_modules import *
from .._quant_common.quant_config import get_hqt_config
from ..utils.logger import logger
from neural_compressor.torch.algorithms.fp8_quant.model_configs import (
ModuleInfo,
ModuleConfig,
ModuleType,
ModuleExtraConfig,
ModuleInfo,
ModuleType,
get_patched_module_table,
get_patched_module_type_table,
)
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

from .._quant_common.helper_modules import *
from .._quant_common.quant_config import get_hqt_config
from ..utils.logger import logger

deepspeed_exists = False
if importlib.util.find_spec("deepspeed"): # check if deepspeed is installed
deepspeed_exists = True
Expand Down Expand Up @@ -227,11 +229,14 @@ def convert_scales_to_tensors_dict(scales_obj, scales_file_format, hp_dtype, dev
}
)


@functools.lru_cache(maxsize=None)
def _import_hpu_modules():
from neural_compressor.torch.algorithms.fp8_quant.patched_module_base import (
PATCHED_MODULE_TABLE, PATCHED_MODULE_TYPES_TABLE
PATCHED_MODULE_TABLE,
PATCHED_MODULE_TYPES_TABLE,
)

cur_accelerator = auto_detect_accelerator()
if not cur_accelerator.current_device_name().startswith("hpu"):
return
Expand All @@ -244,9 +249,11 @@ def _import_hpu_modules():
mod_default_dict = get_patched_module_table()
mod_types = get_patched_module_type_table()


def get_white_list():
return list(mod_default_dict.keys())


class ModInstInfo:
def __init__(self, name, parent):
self.name = name
Expand All @@ -264,6 +271,7 @@ def create_mod_info_recursion(parent):

create_mod_info_recursion(model)


def get_device_type_for_scales(mod):
config = get_hqt_config(mod).cfg
return config["device_for_scales"]
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
import habana_frameworks.torch.utils.experimental as htexp
import torch

from .common import ModuleConfig
from .quant_dequant import cast_to_fp8_fcn, cast_fcn, descale_fcn, scale_fcn
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

from .common import ModuleConfig
from .quant_dequant import cast_fcn, cast_to_fp8_fcn, descale_fcn, scale_fcn

cur_accelerator = auto_detect_accelerator()

GAUDI2 = htexp.synDeviceType.synDeviceGaudi2
Expand Down
37 changes: 19 additions & 18 deletions neural_compressor/torch/algorithms/fp8_quant/_core/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,19 @@

import json
import os
from abc import abstractmethod

import habana_frameworks.torch.core as htcore
import numpy as np
import torch

from abc import abstractmethod
from neural_compressor.torch.algorithms.fp8_quant.model_configs import IMOD_DICT, OBSERVER_PARAMS, OBSERVER_TYPES
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

from .._quant_common.quant_config import MeasureExclude, QuantMode, ScaleMethod, get_hqt_config, set_hqt_config
from ..utils.logger import logger
from .common import *
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
from neural_compressor.torch.algorithms.fp8_quant.model_configs import (
OBSERVER_TYPES,
OBSERVER_PARAMS,
IMOD_DICT,
)

cur_accelerator = auto_detect_accelerator()


Expand Down Expand Up @@ -139,16 +136,20 @@ def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=N
)
patched_types.add(type(mod))

set_hqt_config(mod, top_level_config) # set config in the module, as it consumed by the patched module
mod_extra_config = init_measure_object(
mod,
name,
observer_class,
mod_types[mod_type],
skip_outputs_measurements,
(d_shapes[name] if ((d_shapes is not None) and (name in d_shapes)) else None),
params,
) if mod_default_dict[mod_type_str].should_measure_and_quant else None
set_hqt_config(mod, top_level_config) # set config in the module, as it consumed by the patched module
mod_extra_config = (
init_measure_object(
mod,
name,
observer_class,
mod_types[mod_type],
skip_outputs_measurements,
(d_shapes[name] if ((d_shapes is not None) and (name in d_shapes)) else None),
params,
)
if mod_default_dict[mod_type_str].should_measure_and_quant
else None
)
pmod = patch_module_measure(mod, mod_extra_config, mod_default_dict)
if pmod._mod_extra_config:
for param_name in pmod._mod_extra_config.params:
Expand Down Expand Up @@ -290,4 +291,4 @@ def save_module(mod):
folder_name = os.path.join(mod.config["dump_stats_base_path"], "tensors")
os.makedirs(folder_name, exist_ok=True)
file_base_name = os.path.join(folder_name, IMOD_DICT[mod] + "_module.pt")
torch.save(mod.state_dict(), file_base_name)
torch.save(mod.state_dict(), file_base_name)
54 changes: 31 additions & 23 deletions neural_compressor/torch/algorithms/fp8_quant/_core/quant_dequant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch.nn as nn
import torch
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
from abc import abstractmethod

import habana_frameworks.torch.core as htcore
import torch
import torch.nn as nn
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib

from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

cur_accelerator = auto_detect_accelerator()

from .._core.scale_handler import create_scale_tensor, get_scale_dtype
Expand All @@ -29,10 +31,14 @@
cast_fcn = lambda x, dtype: x.to(dtype=dtype)
cast_to_fp8_fcn = lambda x, dtype, scale_inv=None: torch.ops.hpu.cast_to_fp8_v2(x, scale_inv, False, False, dtype)[0]
cast_from_fp8_fcn = lambda x, dtype, scale=None: torch.ops.hpu.cast_from_fp8(x, scale, dtype)
quant_to_fp8_fcn = lambda x, scale, zero_point, quant_min, quant_max, dtype=None: \
torch.ops.quantized_decomposed.quantize_per_tensor(x, scale, zero_point, quant_min, quant_max, dtype=dtype)
dequant_from_fp8_fcn = lambda x, scale, zero_point, quant_min, quant_max, dtype, out_dtype=None: \
torch.ops.quantized_decomposed.dequantize_per_tensor(x, scale, zero_point, quant_min, quant_max, dtype=dtype, out_dtype=out_dtype)
quant_to_fp8_fcn = (
lambda x, scale, zero_point, quant_min, quant_max, dtype=None: torch.ops.quantized_decomposed.quantize_per_tensor(
x, scale, zero_point, quant_min, quant_max, dtype=dtype
)
)
dequant_from_fp8_fcn = lambda x, scale, zero_point, quant_min, quant_max, dtype, out_dtype=None: torch.ops.quantized_decomposed.dequantize_per_tensor(
x, scale, zero_point, quant_min, quant_max, dtype=dtype, out_dtype=out_dtype
)


class QuantDequantBase(nn.Module):
Expand Down Expand Up @@ -80,18 +86,20 @@ def __init__(self, scale_inv, lp_dtype, hp_dtype, *args, **kwargs):

def forward(self, x):
# create PCQ inv scale as tmp local variable since its size/mem-usage is equal to the module weight
scale_inv = torch.mul(self.scale_inv[0], self.scale_inv[1]) if isinstance(self.scale_inv, list) else self.scale_inv
scale_inv = (
torch.mul(self.scale_inv[0], self.scale_inv[1]) if isinstance(self.scale_inv, list) else self.scale_inv
)
return cast_to_fp8_fcn(x, self.lp_dtype, scale_inv)

def forward_qdq(self, x):
return quant_to_fp8_fcn(
x,
scale=self.scale,
zero_point=0,
quant_min=self.quant_min,
quant_max=self.quant_max,
dtype=self.lp_dtype,
)
x,
scale=self.scale,
zero_point=0,
quant_min=self.quant_min,
quant_max=self.quant_max,
dtype=self.lp_dtype,
)

def extra_repr(self) -> str:
repr = super(QuantInput, self).extra_repr()
Expand All @@ -110,14 +118,14 @@ def forward(self, x):

def forward_qdq(self, x):
return dequant_from_fp8_fcn(
x,
scale=self.scale,
zero_point=0,
quant_min=self.quant_min,
quant_max=self.quant_max,
dtype=self.lp_dtype,
out_dtype=self.hp_dtype,
)
x,
scale=self.scale,
zero_point=0,
quant_min=self.quant_min,
quant_max=self.quant_max,
dtype=self.lp_dtype,
out_dtype=self.hp_dtype,
)

def extra_repr(self) -> str:
repr = super(DequantOutput, self).extra_repr()
Expand Down
51 changes: 30 additions & 21 deletions neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,21 @@
# limitations under the License.

import gc

import habana_frameworks.torch.core as htcore
import numpy as np
import torch
import torch.nn as nn
import numpy as np

from .._quant_common.quant_config import QuantMode
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator

from .._quant_common.helper_modules import PatchedUnmeasuredModule
from .._quant_common.quant_config import get_hqt_config, set_hqt_config
from .._quant_common.quant_config import QuantMode, get_hqt_config, set_hqt_config
from ..utils.logger import logger
from .common import generate_model_info, mod_default_dict, parent_child_mod_dict, \
save_scales, load_scales
from .common import generate_model_info, load_scales, mod_default_dict, parent_child_mod_dict, save_scales
from .measure import load_measurements
from .scale import scale_method_mapping, scaling_methods, convert_scales_to_tensors_dict, load_layer_scales
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
from .scale import convert_scales_to_tensors_dict, load_layer_scales, scale_method_mapping, scaling_methods

cur_accelerator = auto_detect_accelerator()


Expand Down Expand Up @@ -82,7 +83,7 @@ def quantize_params(mod, mod_extra_config):

def convert_fp16_to_bf16(model):
"""Convert all float16 parameters and buffers in the model to bfloat16 after FP8 quantization.
Args:
model (torch.nn.Module): The PyTorch model that needs to be converted.
"""
Expand All @@ -91,7 +92,7 @@ def convert_fp16_to_bf16(model):
if param.dtype == torch.float16:
param.data = param.data.to(torch.bfloat16)
logger.debug("Convert FP16 to BF16, parameter name: %s", name)

# convert buffers
for name, buffer in model.named_buffers():
if buffer.dtype == torch.float16:
Expand All @@ -117,9 +118,7 @@ def prepare_model(model, mod_list, measurement, scale_file, scaling_method, scal
recalc_scales = config.cfg["recalc_scales"]
scales_file_format = np.ndarray
scales_obj = (
load_scales(scale_file + ".npz", scales_file_format)
if (scale_file is not None) and not recalc_scales
else {}
load_scales(scale_file + ".npz", scales_file_format) if (scale_file is not None) and not recalc_scales else {}
)
scales = convert_scales_to_tensors_dict(scales_obj, scales_file_format, scale_config["hp_dtype"])
save_file = False
Expand All @@ -139,19 +138,27 @@ def prepare_model(model, mod_list, measurement, scale_file, scaling_method, scal
apply_hf_hook(mod)
if name in mod_list:
set_hqt_config(mod, config) # set config in the module, as it consumed by the patched module
mod_extra_config, save_file = load_layer_scales(mod, name, config,
mod_type_str, measurement,
scales, scale_file,
scales_file_format,
scales_obj, scaling_method,
scale_config, save_file)
mod_extra_config, save_file = load_layer_scales(
mod,
name,
config,
mod_type_str,
measurement,
scales,
scale_file,
scales_file_format,
scales_obj,
scaling_method,
scale_config,
save_file,
)
if not config.cfg["fake_quant"] and mod_default_dict[mod_type_str].should_measure_and_quant:
quantize_params(mod, mod_extra_config)
patch_module(mod, mod_extra_config, mod_default_dict)
patched_modules.append(name)
patched_module_types.add(type(mod))
logger.debug("Patched module name: %s", name)
if save_file: # cache calculated scales
if save_file: # cache calculated scales
save_scales(model, scales_obj, scales_file_format, scale_file + ".npz")
save_scales(model, scales_obj, scales_file_format, scale_file + ".json")
logger.debug("Patched module types: %s", patched_module_types)
Expand Down Expand Up @@ -203,13 +210,15 @@ def prepare_model_with_dummy_measurement(model, mod_list, scaling_method, scale_
mod_config.params,
dummy_mod_scales,
scale_config,
)
)
# replace bf16 meta weights with FP8 meta weights for loading
if not config.cfg["fake_quant"] and mod_default_dict[mod_type_str].should_measure_and_quant:
for param_name in mod_info.param_names:
if param_name == "weight": # only weight is quantized now
raw_param = getattr(mod, param_name)
param = torch.ones(raw_param.shape, dtype=scale_config["lp_dtype"], device="meta") # meta tensor
param = torch.ones(
raw_param.shape, dtype=scale_config["lp_dtype"], device="meta"
) # meta tensor
delattr(mod, param_name)
setattr(mod, param_name, nn.Parameter(param))
patch_module(mod, dummy_mod_extra_config, mod_default_dict)
Expand Down
Loading

0 comments on commit 3374284

Please sign in to comment.