diff --git a/lmdeploy/pytorch/kernels/cuda/rms_norm.py b/lmdeploy/pytorch/kernels/cuda/rms_norm.py index bc994012f..045b55e1b 100644 --- a/lmdeploy/pytorch/kernels/cuda/rms_norm.py +++ b/lmdeploy/pytorch/kernels/cuda/rms_norm.py @@ -4,8 +4,6 @@ import triton.language as tl from torch import Tensor -from .triton_utils import get_kernel_meta, wrap_jit_func - @triton.jit def _compute_rms_norm(x, w, eps: tl.constexpr, N_COLS: tl.constexpr): @@ -18,15 +16,6 @@ def _compute_rms_norm(x, w, eps: tl.constexpr, N_COLS: tl.constexpr): return out -@wrap_jit_func(type_hint=dict( - input=Tensor, - weight=Tensor, - output=Tensor, - input_row_stride=int, - eps=float, - N_COLS=torch.int32, - BLOCK_N=torch.int32, -)) @triton.jit def rms_norm_kernel(input, weight, output, input_row_stride: tl.constexpr, eps: tl.constexpr, N_COLS: tl.constexpr, @@ -45,18 +34,6 @@ def rms_norm_kernel(input, weight, output, input_row_stride: tl.constexpr, tl.store(out_ptr + offsets, out, mask=offsets < N_COLS) -@wrap_jit_func(type_hint=dict( - input=Tensor, - weight=Tensor, - residual=Tensor, - output=Tensor, - out_residual=Tensor, - input_row_stride=int, - residual_row_stride=int, - eps=float, - N_COLS=torch.int32, - BLOCK_N=torch.int32, -)) @triton.jit def add_rms_norm_kernel(input, weight, residual, output, out_residual, input_row_stride: tl.constexpr, @@ -95,6 +72,7 @@ def rms_norm(hidden_states: Tensor, hidden_states = hidden_states.contiguous() feat_size = weight.shape[0] + assert hidden_states.size(-1) == feat_size seq_len = hidden_states.numel() // hidden_states.size(-1) input_stride = hidden_states.stride(-2) @@ -103,39 +81,40 @@ def rms_norm(hidden_states: Tensor, if out is None: out = torch.empty_like(hidden_states) - kernel_meta = get_kernel_meta(hidden_states) grid = (seq_len, ) if residual is None: - rms_norm_kernel[grid](hidden_states, - weight, - out, - input_row_stride=input_stride, - eps=eps, - N_COLS=feat_size, - BLOCK_N=BLOCK_N, - num_warps=4, - num_stages=2, - **kernel_meta) + rms_norm_kernel[grid]( + hidden_states, + weight, + out, + input_row_stride=input_stride, + eps=eps, + N_COLS=feat_size, + BLOCK_N=BLOCK_N, + num_warps=4, + num_stages=2, + ) return out else: if out_residual is None: out_residual = torch.empty_like(hidden_states) res_stride = residual.stride(-2) - add_rms_norm_kernel[grid](hidden_states, - weight, - residual, - out, - out_residual, - input_row_stride=input_stride, - residual_row_stride=res_stride, - eps=eps, - N_COLS=feat_size, - BLOCK_N=BLOCK_N, - num_warps=4, - num_stages=2, - **kernel_meta) + add_rms_norm_kernel[grid]( + hidden_states, + weight, + residual, + out, + out_residual, + input_row_stride=input_stride, + residual_row_stride=res_stride, + eps=eps, + N_COLS=feat_size, + BLOCK_N=BLOCK_N, + num_warps=4, + num_stages=2, + ) return out, out_residual diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index 1059569a0..5fccd627e 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -124,12 +124,16 @@ def __init__(self, eps=config.layer_norm_eps, dtype=dtype, device=device, + tp=True, + align=self.head_dim, ) self.k_norm = RMSNorm( self.embed_dim, eps=config.layer_norm_eps, dtype=dtype, device=device, + tp=True, + align=self.head_dim, ) self.scale = self.head_dim**-0.5 diff --git a/lmdeploy/pytorch/nn/linear.py b/lmdeploy/pytorch/nn/linear.py index a84d0ec3e..bc2f4b359 100644 --- a/lmdeploy/pytorch/nn/linear.py +++ b/lmdeploy/pytorch/nn/linear.py @@ -12,7 +12,7 @@ from ..backends import OpType, get_backend from ..backends.lora import AdapterInfo -from .utils import get_distribute_size +from .utils import chunk_aligned, get_distribute_size logger = get_logger('lmdeploy') @@ -25,20 +25,7 @@ def _check_qkv_split_layout(layout: str): f'but get: {layout}') -def _chunk_align(weight: torch.Tensor, chunks: int, dim: int, align: int): - """chunk aligned.""" - if align == 1: - return weight.chunk(chunks, dim=dim) - size = weight.size(dim) - assert size % align == 0 - aligned_size = size // align - - # try best to evenly split chunks - align_per_chunk = aligned_size // chunks - remain = aligned_size % chunks - sections = [align_per_chunk + int(c < remain) for c in range(chunks)] - sections = [sec * align for sec in sections] - return weight.split(sections, dim=dim) +_chunk_align = chunk_aligned class QKVMixin: diff --git a/lmdeploy/pytorch/nn/norm.py b/lmdeploy/pytorch/nn/norm.py index ba565263c..7e2c82039 100644 --- a/lmdeploy/pytorch/nn/norm.py +++ b/lmdeploy/pytorch/nn/norm.py @@ -4,7 +4,10 @@ import torch from torch import nn +from lmdeploy.pytorch.distributed import get_world_rank + from ..backends import OpType, get_backend +from .utils import chunk_aligned, get_distribute_size def _is_w8a8(quant_config: Any): @@ -28,7 +31,9 @@ def __init__(self, eps: float = 1e-6, dtype: torch.dtype = None, device: torch.device = None, - quant_config: Any = None): + quant_config: Any = None, + tp: bool = False, + align: int = 1): super().__init__() backend = get_backend() @@ -37,6 +42,14 @@ def __init__(self, builder = backend.get_layer_impl_builder(OpType.RMSNormW8A8) else: builder = backend.get_layer_impl_builder(OpType.RMSNorm) + + if tp: + world_size, rank = get_world_rank() + hidden_size = get_distribute_size(hidden_size, + world_size, + rank, + align=align) + self.register_parameter('weight', self.create_weight(hidden_size, dtype, device)) if w8a8_flag: @@ -46,6 +59,17 @@ def __init__(self, else: self.impl = builder.build(hidden_size, eps) + if tp: + self.weight.weight_loader = self.weight_loader + self.align = align + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): + """weight loader.""" + world_size, rank = get_world_rank() + loaded_weight = chunk_aligned(loaded_weight, world_size, 0, + self.align)[rank] + param.copy_(loaded_weight) + @staticmethod def create_weight(hidden_size: int, dtype: torch.dtype = None, diff --git a/lmdeploy/pytorch/nn/utils.py b/lmdeploy/pytorch/nn/utils.py index 3b60ca21d..085b12c3e 100644 --- a/lmdeploy/pytorch/nn/utils.py +++ b/lmdeploy/pytorch/nn/utils.py @@ -1,4 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +import torch + + def div_up(a: int, b: int): """div up.""" return (a + b - 1) // b @@ -18,3 +21,19 @@ def get_distribute_size(feature_size: int, if rank < aligned_size % world_size: updated_aligned_size += 1 return updated_aligned_size * align + + +def chunk_aligned(weight: torch.Tensor, chunks: int, dim: int, align: int): + """chunk aligned.""" + if align == 1: + return weight.chunk(chunks, dim=dim) + size = weight.size(dim) + assert size % align == 0 + aligned_size = size // align + + # try best to evenly split chunks + align_per_chunk = aligned_size // chunks + remain = aligned_size % chunks + sections = [align_per_chunk + int(c < remain) for c in range(chunks)] + sections = [sec * align for sec in sections] + return weight.split(sections, dim=dim)