-
Notifications
You must be signed in to change notification settings - Fork 461
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
086481e
commit fa0a742
Showing
4 changed files
with
170 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import os | ||
from typing import Optional | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
||
from lmdeploy.pytorch.kernels.dlinfer.w8a8_kernels import ( | ||
linear_w8a8, per_token_quant_int8, rms_norm_w8a8) | ||
from lmdeploy.pytorch.models.q_modules import QTensor | ||
|
||
from ..qmodules import (LinearW8A8Builder, LinearW8A8Impl, RMSNormW8A8Builder, | ||
RMSNormW8A8Impl) | ||
|
||
|
||
class DlinferLinearW8A8Impl(LinearW8A8Impl): | ||
"""dlinfer linear w8a8 implementation.""" | ||
|
||
def __init__(self, | ||
in_features: int, | ||
out_features: int, | ||
out_dtype: torch.dtype = torch.float16, | ||
quant_dtype: torch.dtype = torch.int8): | ||
self.in_features = in_features | ||
self.out_features = out_features | ||
self.out_dtype = out_dtype | ||
self.quant_dtype = quant_dtype | ||
|
||
def update_weights(self, | ||
weight: torch.Tensor, | ||
scale: torch.Tensor, | ||
bias: Optional[torch.Tensor] = None): | ||
"""update weights.""" | ||
if os.getenv('DLINER_LINEAR_USE_NN_LAYOUT', '0') == '1': | ||
weight = weight.data.t().contiguous() | ||
return weight, scale, bias | ||
|
||
def forward(self, | ||
x, | ||
weight: torch.Tensor, | ||
scale: torch.Tensor, | ||
bias: Optional[torch.Tensor] = None, | ||
all_reduce: bool = False): | ||
"""forward.""" | ||
if isinstance(x, torch.Tensor): | ||
input_quant, input_scale = per_token_quant_int8(x) | ||
else: | ||
assert isinstance(x, QTensor) | ||
input_quant, input_scale = x.tensor, x.scale | ||
|
||
out = linear_w8a8(input_quant, weight, input_scale, scale, | ||
self.out_dtype, self.quant_dtype, bias) | ||
if all_reduce: | ||
dist.all_reduce(out) | ||
return out | ||
|
||
|
||
class DlinferLinearW8A8Builder(LinearW8A8Builder): | ||
"""dlinfer linear w8a8 implementation builder.""" | ||
|
||
@staticmethod | ||
def build(in_features: int, | ||
out_features: int, | ||
bias: bool = True, | ||
dtype: torch.dtype = None, | ||
quant_dtype: torch.dtype = torch.int8): | ||
"""build.""" | ||
return DlinferLinearW8A8Impl(in_features, out_features, dtype, | ||
quant_dtype) | ||
|
||
|
||
class DlinferRMSNormW8A8Impl(RMSNormW8A8Impl): | ||
"""dlinfer RMS norm w8a8 implementation api.""" | ||
|
||
def __init__(self, | ||
hidden_size: int, | ||
eps: float = 1e-6, | ||
quant_dtype: torch.dtype = torch.int8): | ||
super().__init__() | ||
self.hidden_size = hidden_size | ||
self.eps = eps | ||
self.quant_dtype = quant_dtype | ||
|
||
def forward(self, | ||
x: torch.Tensor, | ||
weight: torch.Tensor, | ||
residual: torch.Tensor = None): | ||
"""forward.""" | ||
if residual is None: | ||
(x, rms_scale) = rms_norm_w8a8(x, weight, self.eps, | ||
self.quant_dtype) | ||
x = QTensor(x, rms_scale) | ||
return x | ||
else: | ||
(x, rms_scale, residual) = rms_norm_w8a8(x, weight, self.eps, | ||
self.quant_dtype, | ||
residual) | ||
x = QTensor(x, rms_scale) | ||
return x, residual | ||
|
||
|
||
class DlinferRMSNormW8A8Builder(RMSNormW8A8Builder): | ||
"""dlinfer RMS norm w8a8 implementation builder.""" | ||
|
||
@staticmethod | ||
def build(hidden_size: int, | ||
eps: float = 1e-6, | ||
quant_dtype: torch.dtype = torch.int8): | ||
"""build.""" | ||
return DlinferRMSNormW8A8Impl(hidden_size, eps, quant_dtype) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import dlinfer.ops as ext_ops | ||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def per_token_quant_int8(x): | ||
"""Function to perform per-token quantization on an input tensor `x`. | ||
It converts the tensor values into signed 8-bit integers and returns the | ||
quantized tensor along with the scaling factor used for quantization. | ||
""" | ||
input_quant, input_scale = ext_ops.per_token_quant_int8(x) | ||
return input_quant, input_scale | ||
|
||
|
||
def linear_w8a8( | ||
a: Tensor, | ||
b: Tensor, | ||
rms_scale: float, | ||
linear_scale: float, | ||
out_dtype: torch.dtype, | ||
quant_dtype: torch.dtype, | ||
bias=None, | ||
): | ||
"""This function performs matrix multiplication with dynamic quantization. | ||
It takes two input tensors `a` and `b`, scales them with `rms_scale` and | ||
`linear_scale`, and optionally adds a `bias`. The output is returned in the | ||
specified `output_dtype`. | ||
""" | ||
return ext_ops.linear_w8a8(a, b, rms_scale, linear_scale, out_dtype, | ||
quant_dtype, bias) | ||
|
||
|
||
def rms_norm_w8a8( | ||
hidden_states: Tensor, | ||
weight: Tensor, | ||
epsilon: float, | ||
quant_dtype: torch.dtype = torch.int8, | ||
residual: Tensor = None, | ||
): | ||
"""rms norm kernel.""" | ||
if residual is None: | ||
return ext_ops.rms_norm_w8a8(hidden_states, weight, epsilon, | ||
quant_dtype) | ||
else: | ||
return ext_ops.add_rms_norm_w8a8(hidden_states, residual, weight, | ||
epsilon, quant_dtype) |