diff --git a/lmdeploy/lite/apis/smooth_quant.py b/lmdeploy/lite/apis/smooth_quant.py index 8d67535bc..f2b7a20de 100644 --- a/lmdeploy/lite/apis/smooth_quant.py +++ b/lmdeploy/lite/apis/smooth_quant.py @@ -112,14 +112,15 @@ def smooth_quant(model: str, setattr(parent, child_name, q_norm) norm.to('cpu') + quant_dtype_s = str(quant_dtype).split('.')[1] + model.config.update( + dict(quantization_config=dict(quant_method='smooth_quant', + quant_dtype=f'{quant_dtype_s}'))) + if vl_model: from .auto_awq import save_vl_model save_vl_model(vl_model, model_path, work_dir) else: - quant_dtype_s = str(quant_dtype).split('.')[1] - model.config.update( - dict(quantization_config=dict(quant_method='smooth_quant', - quant_dtype=f'{quant_dtype_s}'))) model.save_pretrained(work_dir, max_shard_size='2GB', safe_serialization=False) diff --git a/lmdeploy/pytorch/backends/dlinfer/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/op_backend.py index a0f04f34b..afbb10d89 100644 --- a/lmdeploy/pytorch/backends/dlinfer/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/op_backend.py @@ -37,6 +37,12 @@ def get_layer_impl_builder(cls, layer_type: OpType): elif layer_type == OpType.RMSNorm: from .norm import DlinferRMSNormBuilder return DlinferRMSNormBuilder + elif layer_type == OpType.LinearW8A8: + from .qmodules import DlinferLinearW8A8Builder + return DlinferLinearW8A8Builder + elif layer_type == OpType.RMSNormW8A8: + from .qmodules import DlinferRMSNormW8A8Builder + return DlinferRMSNormW8A8Builder elif layer_type == OpType.SoftmaxTopK: from .moe import DlinferSoftmaxTopKBuilder return DlinferSoftmaxTopKBuilder diff --git a/lmdeploy/pytorch/backends/dlinfer/qmodules.py b/lmdeploy/pytorch/backends/dlinfer/qmodules.py new file mode 100644 index 000000000..817c3d938 --- /dev/null +++ b/lmdeploy/pytorch/backends/dlinfer/qmodules.py @@ -0,0 +1,110 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from typing import Optional + +import torch +import torch.distributed as dist + +from lmdeploy.pytorch.kernels.dlinfer.w8a8_kernels import ( + linear_w8a8, per_token_quant_int8, rms_norm_w8a8) +from lmdeploy.pytorch.models.q_modules import QTensor + +from ..qmodules import (LinearW8A8Builder, LinearW8A8Impl, RMSNormW8A8Builder, + RMSNormW8A8Impl) + + +class DlinferLinearW8A8Impl(LinearW8A8Impl): + """dlinfer linear w8a8 implementation.""" + + def __init__(self, + in_features: int, + out_features: int, + out_dtype: torch.dtype = torch.float16, + quant_dtype: torch.dtype = torch.int8): + self.in_features = in_features + self.out_features = out_features + self.out_dtype = out_dtype + self.quant_dtype = quant_dtype + + def update_weights(self, + weight: torch.Tensor, + scale: torch.Tensor, + bias: Optional[torch.Tensor] = None): + """update weights.""" + if os.getenv('DLINER_LINEAR_USE_NN_LAYOUT', '0') == '1': + weight = weight.data.t().contiguous() + return weight, scale, bias + + def forward(self, + x, + weight: torch.Tensor, + scale: torch.Tensor, + bias: Optional[torch.Tensor] = None, + all_reduce: bool = False): + """forward.""" + if isinstance(x, torch.Tensor): + input_quant, input_scale = per_token_quant_int8(x) + else: + assert isinstance(x, QTensor) + input_quant, input_scale = x.tensor, x.scale + + out = linear_w8a8(input_quant, weight, input_scale, scale, + self.out_dtype, self.quant_dtype, bias) + if all_reduce: + dist.all_reduce(out) + return out + + +class DlinferLinearW8A8Builder(LinearW8A8Builder): + """dlinfer linear w8a8 implementation builder.""" + + @staticmethod + def build(in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None, + quant_dtype: torch.dtype = torch.int8): + """build.""" + return DlinferLinearW8A8Impl(in_features, out_features, dtype, + quant_dtype) + + +class DlinferRMSNormW8A8Impl(RMSNormW8A8Impl): + """dlinfer RMS norm w8a8 implementation api.""" + + def __init__(self, + hidden_size: int, + eps: float = 1e-6, + quant_dtype: torch.dtype = torch.int8): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.quant_dtype = quant_dtype + + def forward(self, + x: torch.Tensor, + weight: torch.Tensor, + residual: torch.Tensor = None): + """forward.""" + if residual is None: + (x, rms_scale) = rms_norm_w8a8(x, weight, self.eps, + self.quant_dtype) + x = QTensor(x, rms_scale) + return x + else: + (x, rms_scale, residual) = rms_norm_w8a8(x, weight, self.eps, + self.quant_dtype, + residual) + x = QTensor(x, rms_scale) + return x, residual + + +class DlinferRMSNormW8A8Builder(RMSNormW8A8Builder): + """dlinfer RMS norm w8a8 implementation builder.""" + + @staticmethod + def build(hidden_size: int, + eps: float = 1e-6, + quant_dtype: torch.dtype = torch.int8): + """build.""" + return DlinferRMSNormW8A8Impl(hidden_size, eps, quant_dtype) diff --git a/lmdeploy/pytorch/kernels/dlinfer/w8a8_kernels.py b/lmdeploy/pytorch/kernels/dlinfer/w8a8_kernels.py new file mode 100644 index 000000000..08fa51fae --- /dev/null +++ b/lmdeploy/pytorch/kernels/dlinfer/w8a8_kernels.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import dlinfer.ops as ext_ops +import torch +from torch import Tensor + + +def per_token_quant_int8(x): + """Function to perform per-token quantization on an input tensor `x`. + + It converts the tensor values into signed 8-bit integers and returns the + quantized tensor along with the scaling factor used for quantization. + """ + input_quant, input_scale = ext_ops.per_token_quant_int8(x) + return input_quant, input_scale + + +def linear_w8a8( + a: Tensor, + b: Tensor, + rms_scale: float, + linear_scale: float, + out_dtype: torch.dtype, + quant_dtype: torch.dtype, + bias=None, +): + """This function performs matrix multiplication with dynamic quantization. + + It takes two input tensors `a` and `b`, scales them with `rms_scale` and + `linear_scale`, and optionally adds a `bias`. The output is returned in the + specified `output_dtype`. + """ + return ext_ops.linear_w8a8(a, b, rms_scale, linear_scale, out_dtype, + quant_dtype, bias) + + +def rms_norm_w8a8( + hidden_states: Tensor, + weight: Tensor, + epsilon: float, + quant_dtype: torch.dtype = torch.int8, + residual: Tensor = None, +): + """rms norm kernel.""" + if residual is None: + return ext_ops.rms_norm_w8a8(hidden_states, weight, epsilon, + quant_dtype) + else: + return ext_ops.add_rms_norm_w8a8(hidden_states, residual, weight, + epsilon, quant_dtype)