Skip to content

Commit

Permalink
fix internvl2 qk norm (#2987)
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire authored Jan 7, 2025
1 parent c6c25ae commit de2050d
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 63 deletions.
73 changes: 26 additions & 47 deletions lmdeploy/pytorch/kernels/cuda/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import triton.language as tl
from torch import Tensor

from .triton_utils import get_kernel_meta, wrap_jit_func


@triton.jit
def _compute_rms_norm(x, w, eps: tl.constexpr, N_COLS: tl.constexpr):
Expand All @@ -18,15 +16,6 @@ def _compute_rms_norm(x, w, eps: tl.constexpr, N_COLS: tl.constexpr):
return out


@wrap_jit_func(type_hint=dict(
input=Tensor,
weight=Tensor,
output=Tensor,
input_row_stride=int,
eps=float,
N_COLS=torch.int32,
BLOCK_N=torch.int32,
))
@triton.jit
def rms_norm_kernel(input, weight, output, input_row_stride: tl.constexpr,
eps: tl.constexpr, N_COLS: tl.constexpr,
Expand All @@ -45,18 +34,6 @@ def rms_norm_kernel(input, weight, output, input_row_stride: tl.constexpr,
tl.store(out_ptr + offsets, out, mask=offsets < N_COLS)


@wrap_jit_func(type_hint=dict(
input=Tensor,
weight=Tensor,
residual=Tensor,
output=Tensor,
out_residual=Tensor,
input_row_stride=int,
residual_row_stride=int,
eps=float,
N_COLS=torch.int32,
BLOCK_N=torch.int32,
))
@triton.jit
def add_rms_norm_kernel(input, weight, residual, output, out_residual,
input_row_stride: tl.constexpr,
Expand Down Expand Up @@ -95,6 +72,7 @@ def rms_norm(hidden_states: Tensor,
hidden_states = hidden_states.contiguous()

feat_size = weight.shape[0]
assert hidden_states.size(-1) == feat_size
seq_len = hidden_states.numel() // hidden_states.size(-1)
input_stride = hidden_states.stride(-2)

Expand All @@ -103,39 +81,40 @@ def rms_norm(hidden_states: Tensor,
if out is None:
out = torch.empty_like(hidden_states)

kernel_meta = get_kernel_meta(hidden_states)
grid = (seq_len, )

if residual is None:
rms_norm_kernel[grid](hidden_states,
weight,
out,
input_row_stride=input_stride,
eps=eps,
N_COLS=feat_size,
BLOCK_N=BLOCK_N,
num_warps=4,
num_stages=2,
**kernel_meta)
rms_norm_kernel[grid](
hidden_states,
weight,
out,
input_row_stride=input_stride,
eps=eps,
N_COLS=feat_size,
BLOCK_N=BLOCK_N,
num_warps=4,
num_stages=2,
)
return out
else:
if out_residual is None:
out_residual = torch.empty_like(hidden_states)

res_stride = residual.stride(-2)
add_rms_norm_kernel[grid](hidden_states,
weight,
residual,
out,
out_residual,
input_row_stride=input_stride,
residual_row_stride=res_stride,
eps=eps,
N_COLS=feat_size,
BLOCK_N=BLOCK_N,
num_warps=4,
num_stages=2,
**kernel_meta)
add_rms_norm_kernel[grid](
hidden_states,
weight,
residual,
out,
out_residual,
input_row_stride=input_stride,
residual_row_stride=res_stride,
eps=eps,
N_COLS=feat_size,
BLOCK_N=BLOCK_N,
num_warps=4,
num_stages=2,
)
return out, out_residual


Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/pytorch/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,16 @@ def __init__(self,
eps=config.layer_norm_eps,
dtype=dtype,
device=device,
tp=True,
align=self.head_dim,
)
self.k_norm = RMSNorm(
self.embed_dim,
eps=config.layer_norm_eps,
dtype=dtype,
device=device,
tp=True,
align=self.head_dim,
)

self.scale = self.head_dim**-0.5
Expand Down
17 changes: 2 additions & 15 deletions lmdeploy/pytorch/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from ..backends import OpType, get_backend
from ..backends.lora import AdapterInfo
from .utils import get_distribute_size
from .utils import chunk_aligned, get_distribute_size

logger = get_logger('lmdeploy')

Expand All @@ -25,20 +25,7 @@ def _check_qkv_split_layout(layout: str):
f'but get: {layout}')


def _chunk_align(weight: torch.Tensor, chunks: int, dim: int, align: int):
"""chunk aligned."""
if align == 1:
return weight.chunk(chunks, dim=dim)
size = weight.size(dim)
assert size % align == 0
aligned_size = size // align

# try best to evenly split chunks
align_per_chunk = aligned_size // chunks
remain = aligned_size % chunks
sections = [align_per_chunk + int(c < remain) for c in range(chunks)]
sections = [sec * align for sec in sections]
return weight.split(sections, dim=dim)
_chunk_align = chunk_aligned


class QKVMixin:
Expand Down
26 changes: 25 additions & 1 deletion lmdeploy/pytorch/nn/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import torch
from torch import nn

from lmdeploy.pytorch.distributed import get_world_rank

from ..backends import OpType, get_backend
from .utils import chunk_aligned, get_distribute_size


def _is_w8a8(quant_config: Any):
Expand All @@ -28,7 +31,9 @@ def __init__(self,
eps: float = 1e-6,
dtype: torch.dtype = None,
device: torch.device = None,
quant_config: Any = None):
quant_config: Any = None,
tp: bool = False,
align: int = 1):
super().__init__()
backend = get_backend()

Expand All @@ -37,6 +42,14 @@ def __init__(self,
builder = backend.get_layer_impl_builder(OpType.RMSNormW8A8)
else:
builder = backend.get_layer_impl_builder(OpType.RMSNorm)

if tp:
world_size, rank = get_world_rank()
hidden_size = get_distribute_size(hidden_size,
world_size,
rank,
align=align)

self.register_parameter('weight',
self.create_weight(hidden_size, dtype, device))
if w8a8_flag:
Expand All @@ -46,6 +59,17 @@ def __init__(self,
else:
self.impl = builder.build(hidden_size, eps)

if tp:
self.weight.weight_loader = self.weight_loader
self.align = align

def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
"""weight loader."""
world_size, rank = get_world_rank()
loaded_weight = chunk_aligned(loaded_weight, world_size, 0,
self.align)[rank]
param.copy_(loaded_weight)

@staticmethod
def create_weight(hidden_size: int,
dtype: torch.dtype = None,
Expand Down
19 changes: 19 additions & 0 deletions lmdeploy/pytorch/nn/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch


def div_up(a: int, b: int):
"""div up."""
return (a + b - 1) // b
Expand All @@ -18,3 +21,19 @@ def get_distribute_size(feature_size: int,
if rank < aligned_size % world_size:
updated_aligned_size += 1
return updated_aligned_size * align


def chunk_aligned(weight: torch.Tensor, chunks: int, dim: int, align: int):
"""chunk aligned."""
if align == 1:
return weight.chunk(chunks, dim=dim)
size = weight.size(dim)
assert size % align == 0
aligned_size = size // align

# try best to evenly split chunks
align_per_chunk = aligned_size // chunks
remain = aligned_size % chunks
sections = [align_per_chunk + int(c < remain) for c in range(chunks)]
sections = [sec * align for sec in sections]
return weight.split(sections, dim=dim)

0 comments on commit de2050d

Please sign in to comment.