Skip to content

Commit

Permalink
add dlinfer w8a8 support.
Browse files Browse the repository at this point in the history
  • Loading branch information
Reinerzhou committed Jan 13, 2025
1 parent 086481e commit fa0a742
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 4 deletions.
9 changes: 5 additions & 4 deletions lmdeploy/lite/apis/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,15 @@ def smooth_quant(model: str,
setattr(parent, child_name, q_norm)
norm.to('cpu')

quant_dtype_s = str(quant_dtype).split('.')[1]
model.config.update(
dict(quantization_config=dict(quant_method='smooth_quant',
quant_dtype=f'{quant_dtype_s}')))

if vl_model:
from .auto_awq import save_vl_model
save_vl_model(vl_model, model_path, work_dir)
else:
quant_dtype_s = str(quant_dtype).split('.')[1]
model.config.update(
dict(quantization_config=dict(quant_method='smooth_quant',
quant_dtype=f'{quant_dtype_s}')))
model.save_pretrained(work_dir,
max_shard_size='2GB',
safe_serialization=False)
Expand Down
6 changes: 6 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def get_layer_impl_builder(cls, layer_type: OpType):
elif layer_type == OpType.RMSNorm:
from .norm import DlinferRMSNormBuilder
return DlinferRMSNormBuilder
elif layer_type == OpType.LinearW8A8:
from .qmodules import DlinferLinearW8A8Builder
return DlinferLinearW8A8Builder
elif layer_type == OpType.RMSNormW8A8:
from .qmodules import DlinferRMSNormW8A8Builder
return DlinferRMSNormW8A8Builder
elif layer_type == OpType.SoftmaxTopK:
from .moe import DlinferSoftmaxTopKBuilder
return DlinferSoftmaxTopKBuilder
Expand Down
110 changes: 110 additions & 0 deletions lmdeploy/pytorch/backends/dlinfer/qmodules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Optional

import torch
import torch.distributed as dist

from lmdeploy.pytorch.kernels.dlinfer.w8a8_kernels import (
linear_w8a8, per_token_quant_int8, rms_norm_w8a8)
from lmdeploy.pytorch.models.q_modules import QTensor

from ..qmodules import (LinearW8A8Builder, LinearW8A8Impl, RMSNormW8A8Builder,
RMSNormW8A8Impl)


class DlinferLinearW8A8Impl(LinearW8A8Impl):
"""dlinfer linear w8a8 implementation."""

def __init__(self,
in_features: int,
out_features: int,
out_dtype: torch.dtype = torch.float16,
quant_dtype: torch.dtype = torch.int8):
self.in_features = in_features
self.out_features = out_features
self.out_dtype = out_dtype
self.quant_dtype = quant_dtype

def update_weights(self,
weight: torch.Tensor,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None):
"""update weights."""
if os.getenv('DLINER_LINEAR_USE_NN_LAYOUT', '0') == '1':
weight = weight.data.t().contiguous()
return weight, scale, bias

def forward(self,
x,
weight: torch.Tensor,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False):
"""forward."""
if isinstance(x, torch.Tensor):
input_quant, input_scale = per_token_quant_int8(x)
else:
assert isinstance(x, QTensor)
input_quant, input_scale = x.tensor, x.scale

out = linear_w8a8(input_quant, weight, input_scale, scale,
self.out_dtype, self.quant_dtype, bias)
if all_reduce:
dist.all_reduce(out)
return out


class DlinferLinearW8A8Builder(LinearW8A8Builder):
"""dlinfer linear w8a8 implementation builder."""

@staticmethod
def build(in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
quant_dtype: torch.dtype = torch.int8):
"""build."""
return DlinferLinearW8A8Impl(in_features, out_features, dtype,
quant_dtype)


class DlinferRMSNormW8A8Impl(RMSNormW8A8Impl):
"""dlinfer RMS norm w8a8 implementation api."""

def __init__(self,
hidden_size: int,
eps: float = 1e-6,
quant_dtype: torch.dtype = torch.int8):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.quant_dtype = quant_dtype

def forward(self,
x: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor = None):
"""forward."""
if residual is None:
(x, rms_scale) = rms_norm_w8a8(x, weight, self.eps,
self.quant_dtype)
x = QTensor(x, rms_scale)
return x
else:
(x, rms_scale, residual) = rms_norm_w8a8(x, weight, self.eps,
self.quant_dtype,
residual)
x = QTensor(x, rms_scale)
return x, residual


class DlinferRMSNormW8A8Builder(RMSNormW8A8Builder):
"""dlinfer RMS norm w8a8 implementation builder."""

@staticmethod
def build(hidden_size: int,
eps: float = 1e-6,
quant_dtype: torch.dtype = torch.int8):
"""build."""
return DlinferRMSNormW8A8Impl(hidden_size, eps, quant_dtype)
49 changes: 49 additions & 0 deletions lmdeploy/pytorch/kernels/dlinfer/w8a8_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) OpenMMLab. All rights reserved.
import dlinfer.ops as ext_ops
import torch
from torch import Tensor


def per_token_quant_int8(x):
"""Function to perform per-token quantization on an input tensor `x`.
It converts the tensor values into signed 8-bit integers and returns the
quantized tensor along with the scaling factor used for quantization.
"""
input_quant, input_scale = ext_ops.per_token_quant_int8(x)
return input_quant, input_scale


def linear_w8a8(
a: Tensor,
b: Tensor,
rms_scale: float,
linear_scale: float,
out_dtype: torch.dtype,
quant_dtype: torch.dtype,
bias=None,
):
"""This function performs matrix multiplication with dynamic quantization.
It takes two input tensors `a` and `b`, scales them with `rms_scale` and
`linear_scale`, and optionally adds a `bias`. The output is returned in the
specified `output_dtype`.
"""
return ext_ops.linear_w8a8(a, b, rms_scale, linear_scale, out_dtype,
quant_dtype, bias)


def rms_norm_w8a8(
hidden_states: Tensor,
weight: Tensor,
epsilon: float,
quant_dtype: torch.dtype = torch.int8,
residual: Tensor = None,
):
"""rms norm kernel."""
if residual is None:
return ext_ops.rms_norm_w8a8(hidden_states, weight, epsilon,
quant_dtype)
else:
return ext_ops.add_rms_norm_w8a8(hidden_states, residual, weight,
epsilon, quant_dtype)

0 comments on commit fa0a742

Please sign in to comment.