diff --git a/lmdeploy/pytorch/backends/base.py b/lmdeploy/pytorch/backends/base.py index ca2cff391e..263b419f1a 100644 --- a/lmdeploy/pytorch/backends/base.py +++ b/lmdeploy/pytorch/backends/base.py @@ -29,6 +29,8 @@ class OpType(Enum): SoftmaxTopK = auto() FusedMoE = auto() FusedMoEW8A8 = auto() + LinearBlockedF8 = auto() + FusedMoEBlockedF8 = auto() class OpsBackend(ABC): diff --git a/lmdeploy/pytorch/backends/blockedf8_modules.py b/lmdeploy/pytorch/backends/blockedf8_modules.py new file mode 100644 index 0000000000..d79b41330c --- /dev/null +++ b/lmdeploy/pytorch/backends/blockedf8_modules.py @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABC, abstractmethod +from typing import Optional + +import torch + + +class LinearBlockedF8Impl(ABC): + """linear BlockedF8 implementation api.""" + + def update_weights(self, + weight: torch.Tensor, + scale: torch.Tensor, + bias: Optional[torch.Tensor] = None): + """update weights.""" + return weight, scale, bias + + @abstractmethod + def forward(self, + x, + weight: torch.Tensor, + scale: torch.Tensor, + bias: Optional[torch.Tensor] = None, + all_reduce: bool = False): + """forward.""" + raise NotImplementedError + + +class LinearBlockedF8Builder(ABC): + """linear BlockedF8 implementation builder.""" + + @staticmethod + @abstractmethod + def build(in_features: int, + out_features: int, + bias: bool = True, + dtype: torch.dtype = None): + """build.""" + raise NotImplementedError diff --git a/lmdeploy/pytorch/backends/cuda/blockedf8_modules.py b/lmdeploy/pytorch/backends/cuda/blockedf8_modules.py new file mode 100644 index 0000000000..8299ac2dfd --- /dev/null +++ b/lmdeploy/pytorch/backends/cuda/blockedf8_modules.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +import torch.distributed as dist + +from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import (blocked_gemm_fp8, + quant_fp8) + +from ..blockedf8_modules import LinearBlockedF8Builder, LinearBlockedF8Impl + + +class TritonLinearBlockedF8Impl(LinearBlockedF8Impl): + """triton linear blocked f8 implementation.""" + + def __init__(self, + in_features: int, + out_features: int, + block_size: int, + out_dtype: torch.dtype = torch.float16): + self.in_features = in_features + self.out_features = out_features + self.out_dtype = out_dtype + self.block_size = block_size + + def forward(self, + x, + weight: torch.Tensor, + scale: torch.Tensor, + bias: Optional[torch.Tensor] = None, + all_reduce: bool = False): + """forward.""" + x_shape = x.shape + x = x.flatten(0, -2) + input_quant, input_scale = quant_fp8(x, + self.block_size, + dtype=weight.dtype) + + out = blocked_gemm_fp8(input_quant, + input_scale, + weight.t(), + scale.t(), + out_dtype=x.dtype) + if bias is not None: + out += bias + + if all_reduce: + dist.all_reduce(out) + + out = out.unflatten(0, x_shape[:-1]) + return out + + +class TritonLinearBlockedF8Builder(LinearBlockedF8Builder): + """triton linear blocked f8 implementation builder.""" + + @staticmethod + def build(in_features: int, + out_features: int, + block_size: int = 128, + bias: bool = True, + dtype: torch.dtype = None): + """build.""" + return TritonLinearBlockedF8Impl(in_features, out_features, block_size, + dtype) diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py index b5f48fa557..a913ca82fb 100644 --- a/lmdeploy/pytorch/backends/cuda/moe.py +++ b/lmdeploy/pytorch/backends/cuda/moe.py @@ -5,11 +5,15 @@ import torch from lmdeploy.pytorch.kernels.cuda import fused_moe, fused_moe_w8a8 +from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import \ + fused_moe_blocked_fp8 +from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8 from lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import \ per_token_quant_int8 from lmdeploy.pytorch.models.q_modules import QTensor -from ..moe import (FusedMoEBuilder, FusedMoEImpl, FusedMoEW8A8Builder, +from ..moe import (FusedMoEBlockedF8Builder, FusedMoEBlockedF8Impl, + FusedMoEBuilder, FusedMoEImpl, FusedMoEW8A8Builder, FusedMoEW8A8Impl) @@ -168,3 +172,95 @@ def build(top_k: int, num_experts=num_experts, renormalize=renormalize, out_dtype=out_dtype) + + +class TritonFusedMoEBlockedF8Impl(FusedMoEBlockedF8Impl): + """triton fused moe blocked f8 implementation.""" + + def __init__(self, + top_k: int, + num_experts: int, + renormalize: bool = False, + block_size: int = 128, + out_dtype: torch.dtype = torch.float16): + self.num_experts = num_experts + self.top_k = top_k + self.renormalize = renormalize + self.block_size = block_size + self.out_dtype = out_dtype + + def update_weights(self, gate_up_weights: torch.Tensor, + down_weights: torch.Tensor, gate_up_scale: torch.Tensor, + down_scale: torch.Tensor): + gate_up_weights = gate_up_weights.transpose(1, + 2).contiguous().transpose( + 1, 2) + down_weights = down_weights.transpose(1, + 2).contiguous().transpose(1, 2) + return gate_up_weights, down_weights, gate_up_scale, down_scale + + def support_ep(self): + """support expert parallelism.""" + return True + + def ep_expert_list(self, world_size: int, rank: int): + """experts list of current rank.""" + num_experts = self.num_experts + expert_per_rank = (num_experts + world_size - 1) // world_size + first_expert = rank * expert_per_rank + last_expert = min(first_expert + expert_per_rank, num_experts) + return list(range(first_expert, last_expert)) + + def forward(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.LongTensor, + gate_up_weights: torch.Tensor, + gate_up_scale: torch.Tensor, + down_weights: torch.Tensor, + down_scale: torch.Tensor, + expert_list: List[int] = None): + """forward.""" + input_size = hidden_states.shape + hidden_states = hidden_states.flatten(0, -2) + input_quant, input_scale = quant_fp8(hidden_states, + self.block_size, + dtype=gate_up_weights.dtype) + + expert_offset = 0 + num_experts = None + if expert_list is not None and len(expert_list) != self.num_experts: + expert_offset = expert_list[0] + num_experts = self.num_experts + output = fused_moe_blocked_fp8(input_quant, + input_scale, + gate_up_weights, + gate_up_scale, + down_weights, + down_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + topk=self.top_k, + out_dtype=hidden_states.dtype, + expert_offset=expert_offset, + num_experts=num_experts, + renormalize=self.renormalize) + output = output.unflatten(0, input_size[:-1]) + return output + + +class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder): + """triton fused moe blocked f8 builder.""" + + @staticmethod + def build(top_k: int, + num_experts: int, + renormalize: bool = False, + block_size: int = 128, + out_dtype: torch.dtype = torch.float16): + """build from mlp.""" + return TritonFusedMoEBlockedF8Impl(top_k=top_k, + num_experts=num_experts, + renormalize=renormalize, + block_size=block_size, + out_dtype=out_dtype) diff --git a/lmdeploy/pytorch/backends/cuda/op_backend.py b/lmdeploy/pytorch/backends/cuda/op_backend.py index bc6063ec7b..7b2134aeef 100644 --- a/lmdeploy/pytorch/backends/cuda/op_backend.py +++ b/lmdeploy/pytorch/backends/cuda/op_backend.py @@ -59,6 +59,12 @@ def get_layer_impl_builder(cls, layer_type: OpType): elif layer_type == OpType.FusedMoEW8A8: from .moe import TritonFusedMoEW8A8Builder return TritonFusedMoEW8A8Builder + elif layer_type == OpType.FusedMoEBlockedF8: + from .moe import TritonFusedMoEBlockedF8Builder + return TritonFusedMoEBlockedF8Builder + elif layer_type == OpType.LinearBlockedF8: + from .blockedf8_modules import TritonLinearBlockedF8Builder + return TritonLinearBlockedF8Builder else: logger.debug( f'Op {layer_type} fallback to default implementation.') diff --git a/lmdeploy/pytorch/backends/moe.py b/lmdeploy/pytorch/backends/moe.py index b5946eeefe..4501e52c0b 100644 --- a/lmdeploy/pytorch/backends/moe.py +++ b/lmdeploy/pytorch/backends/moe.py @@ -105,3 +105,48 @@ def build(top_k: int, out_dtype: torch.dtype = torch.float16): """build from mlp.""" raise NotImplementedError + + +class FusedMoEBlockedF8Impl(ABC): + """fused moe blocked f8 implementation.""" + + def update_weights(self, gate_up_weights: torch.Tensor, + down_weights: torch.Tensor, gate_up_scale: torch.Tensor, + down_scale: torch.Tensor): + """update weights.""" + return gate_up_weights, down_weights, gate_up_scale, down_scale + + def support_ep(self): + """support expert parallelism.""" + return False + + def ep_expert_list(self, world_size: int, rank: int): + """experts list of current rank.""" + raise NotImplementedError('Not Implemented.') + + @abstractmethod + def forward(self, + hidden_states: torch.Tensor, + input_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.LongTensor, + gate_up_weights: torch.Tensor, + gate_up_scale: torch.Tensor, + down_weights: torch.Tensor, + down_scale: torch.Tensor, + expert_list: List[int] = None): + """forward.""" + raise NotImplementedError + + +class FusedMoEBlockedF8Builder(ABC): + """fused moe blocked f8 builder.""" + + @staticmethod + @abstractmethod + def build(top_k: int, + num_experts: int, + renormalize: bool = False, + out_dtype: torch.dtype = torch.float16): + """build from mlp.""" + raise NotImplementedError diff --git a/lmdeploy/pytorch/configurations/deepseek_v2.py b/lmdeploy/pytorch/configurations/deepseek_v2.py index d1f0844ad5..bf06ff0c33 100644 --- a/lmdeploy/pytorch/configurations/deepseek_v2.py +++ b/lmdeploy/pytorch/configurations/deepseek_v2.py @@ -9,7 +9,7 @@ class DeepseekV2ModelConfigBuilder(AutoModelConfigBuilder): @classmethod def condition(cls, hf_config): """config.""" - return hf_config.model_type == 'deepseek_v2' + return hf_config.model_type in ['deepseek_v3', 'deepseek_v2'] @classmethod def build(cls, hf_config, model_path: str = None, **kwargs): diff --git a/lmdeploy/pytorch/kernels/cuda/blocked_fp8_fused_moe.py b/lmdeploy/pytorch/kernels/cuda/blocked_fp8_fused_moe.py new file mode 100644 index 0000000000..4907d92ac5 --- /dev/null +++ b/lmdeploy/pytorch/kernels/cuda/blocked_fp8_fused_moe.py @@ -0,0 +1,344 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# modify from: https://github.com/vllm-project/vllm +import torch +import triton +import triton.language as tl + +from .activation import silu_and_mul +from .blocked_gemm_fp8 import quant_fp8 +from .fused_moe import _get_sorted_idx, _make_intermediate, _renormalize + + +def get_cuda_autotune_config(): + return [ + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + }, + num_stages=4, + num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + }, + num_stages=4, + num_warps=4), + ] + + +@triton.autotune( + configs=get_cuda_autotune_config(), + key=['N', 'K', 'M_NP2'], + warmup=10, + rep=25, +) +@triton.jit +def fused_moe_blocked_f8_kernel( + A, + A_scale, + B, + B_scale, + C, + SortedIdx, + ExpStart, + ExpEnd, + Weights, + N: tl.constexpr, + K: tl.constexpr, + group_ak: tl.constexpr, + group_bk: tl.constexpr, + group_bn: tl.constexpr, + stride_am: tl.constexpr, + stride_ak: tl.constexpr, + stride_asm, + stride_ask: tl.constexpr, + stride_be: tl.constexpr, + stride_bn: tl.constexpr, + stride_bk: tl.constexpr, + stride_bse: tl.constexpr, + stride_bsk: tl.constexpr, + stride_bsn: tl.constexpr, + stride_cm, + stride_cn: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + M_NP2: tl.constexpr, + ENABLE_WEIGHTS: tl.constexpr, + top_k: tl.constexpr, + expert_offset: tl.constexpr, + reindex_a: tl.constexpr, + reindex_c: tl.constexpr, +): + """fused moe kernel.""" + exp_id = tl.program_id(1) + pid = tl.program_id(0) + + exp_start = tl.load(ExpStart + exp_id + expert_offset) + exp_end = tl.load(ExpEnd + exp_id + expert_offset) + M = exp_end - exp_start + if M <= 0: + return + + num_pid_m = tl.cdiv(M_NP2, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + if GROUP_SIZE_M == 1: + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + if pid_m * BLOCK_SIZE_M >= M or pid_n * BLOCK_SIZE_N >= N: + return + + offs_sid = exp_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + mask_sid = offs_sid < exp_end + sid = tl.load(SortedIdx + offs_sid, mask=mask_sid, other=0) + + offs_k = tl.arange(0, BLOCK_SIZE_K) + if reindex_a: + offs_am = sid // top_k + else: + offs_am = offs_sid + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + as_ptrs = A_scale + offs_am + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), + BLOCK_SIZE_N) + + # deepseek has 160 experts, exp index would overflow int32 + exp_id = exp_id.to(tl.int64) + exp_off = stride_be * exp_id + b_ptrs = B + exp_off + (offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn) + + offs_bsn = pid_n * BLOCK_SIZE_N // group_bn + as_ptrs = A_scale + offs_am * stride_asm + bs_ptrs = B_scale + stride_bse * exp_id + offs_bsn * stride_bsn + + acc_scale = tl.load(as_ptrs) * tl.load(bs_ptrs) + acc_ratio = 1 / acc_scale + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # load scales + k_start = (k + 1) * BLOCK_SIZE_K + offs_ksa = k_start // group_ak + offs_ksb = k_start // group_bk + a_scale = tl.load(as_ptrs + offs_ksa * stride_ask, + mask=k_start < K, + other=1.0) + b_scale = tl.load(bs_ptrs + offs_ksb * stride_bsk, + mask=k_start < K, + other=1.0) + + # load ab + a = tl.load(a_ptrs, + mask=mask_sid[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) + + # mma + accumulator = tl.dot(a, b, acc=accumulator * acc_ratio[:, None]) + + # update scales and ratio + new_acc_scale = a_scale * b_scale + acc_ratio = acc_scale / new_acc_scale + acc_scale = new_acc_scale + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator * (acc_ratio * acc_scale)[:, None] + + if ENABLE_WEIGHTS: + weight = tl.load(Weights + sid, mask=mask_sid) + c = c * weight[:, None].to(c.dtype) + + c = c.to(C.dtype.element_ty) + + if reindex_c: + offs_cm = sid + else: + offs_cm = offs_sid + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_bn[None, :] + tl.store(c_ptrs, c, mask=mask_sid[:, None]) + + +def fused_moe_blocked_fp8_kernel_launcher( + A: torch.Tensor, + A_scale: torch.Tensor, + B: torch.Tensor, + B_scale: torch.Tensor, + C: torch.Tensor, + sorted_idx: torch.Tensor, + exp_start: torch.Tensor, + exp_end: torch.Tensor, + weights: torch.Tensor, + enable_weights: bool = False, + top_k: int = 1, + num_tokens: int = None, + expert_offset: int = 0, + reindex_a: bool = True, + reindex_c: bool = True, +): + """fused moe kernel launcher.""" + + if num_tokens is None: + num_tokens = A.size(0) + M_NP2 = triton.next_power_of_2(num_tokens) + M_NP2 = max(64, M_NP2) + E, N, K = B.shape + + assert A.dim() == 2 + assert A_scale.dim() == 2 + assert B.dim() == 3 + assert B_scale.dim() == 3 + + assert K % A_scale.size(1) == 0 + assert K % B_scale.size(2) == 0 + assert N % B_scale.size(1) == 0 + + group_ak = K // A_scale.size(1) + group_bk = K // B_scale.size(2) + group_bn = N // B_scale.size(1) + + def _grid_fn(META): + grid = (triton.cdiv(M_NP2, META['BLOCK_SIZE_M']) * + triton.cdiv(N, META['BLOCK_SIZE_N']), E) + return grid + + A = A.flatten(0, -2) + C = C.flatten(0, -2) + + BLOCK_SIZE_K = group_bk + GROUP_SIZE_M = 8 + grid = _grid_fn + fused_moe_blocked_f8_kernel[grid]( + A, + A_scale, + B, + B_scale, + C, + sorted_idx, + exp_start, + exp_end, + weights, + N=N, + K=K, + group_ak=group_ak, + group_bk=group_bk, + group_bn=group_bn, + stride_am=A.stride(0), + stride_ak=A.stride(1), + stride_asm=A_scale.stride(0), + stride_ask=A_scale.stride(1), + stride_be=B.stride(0), + stride_bn=B.stride(1), + stride_bk=B.stride(2), + stride_bse=B_scale.stride(0), + stride_bsn=B_scale.stride(1), + stride_bsk=B_scale.stride(2), + stride_cm=C.stride(0), + stride_cn=C.stride(1), + ENABLE_WEIGHTS=enable_weights, + top_k=top_k, + expert_offset=expert_offset, + reindex_a=reindex_a, + reindex_c=reindex_c, + M_NP2=M_NP2, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M, + ) + + +def fused_moe_blocked_fp8(input: torch.Tensor, + input_scale: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + topk: int, + out_dtype: torch.dtype = torch.float16, + expert_offset: int = 0, + num_experts: int = None, + renormalize: bool = False) -> torch.Tensor: + """fused moe.""" + device = input.device + M = input.size(0) + E, N, _ = w1.shape + if num_experts is None: + num_experts = E + full_exp = num_experts == E + group_size = input.size(-1) // input_scale.size(-1) + + topk_weights = _renormalize(topk_weights, renormalize) + sorted_idx, exp_start, exp_end = _get_sorted_idx(topk_ids, num_experts) + + intermediate_cache1 = _make_intermediate((M, topk, N), + dtype=out_dtype, + device=device, + zeros=not full_exp) + # gate and up + fused_moe_blocked_fp8_kernel_launcher( + input, + input_scale, + w1, + w1_scale, + intermediate_cache1, + sorted_idx=sorted_idx, + exp_start=exp_start, + exp_end=exp_end, + weights=topk_weights, + enable_weights=False, + top_k=topk, + num_tokens=M, + expert_offset=expert_offset, + reindex_a=True, + reindex_c=False, + ) + + # activate + intermediate_cache1 = intermediate_cache1.flatten(0, -2) + gate_cache = silu_and_mul(intermediate_cache1) + del intermediate_cache1 + gate_cache, gate_scale = quant_fp8(gate_cache, + group_size, + dtype=input.dtype) + + intermediate_cache2 = _make_intermediate((M, topk, w2.shape[1]), + dtype=out_dtype, + device=device, + zeros=not full_exp) + # down + fused_moe_blocked_fp8_kernel_launcher( + gate_cache, + gate_scale, + w2, + w2_scale, + intermediate_cache2, + sorted_idx=sorted_idx, + exp_start=exp_start, + exp_end=exp_end, + weights=topk_weights, + enable_weights=True, + top_k=1, + num_tokens=M, + expert_offset=expert_offset, + reindex_a=False, + reindex_c=True, + ) + + ret = intermediate_cache2.sum(dim=1) + return ret diff --git a/lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py b/lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py new file mode 100644 index 0000000000..9f992bcfef --- /dev/null +++ b/lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py @@ -0,0 +1,237 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import triton +import triton.language as tl +from torch import Tensor + + +@triton.jit +def _quant_fp8_kernel( + a_ptr, + out_ptr, + scale_ptr, + fp8_min: tl.constexpr, + fp8_max: tl.constexpr, + stride_am, + stride_ak: tl.constexpr, + stride_om, + stride_ok: tl.constexpr, + stride_sm, + stride_sg: tl.constexpr, + GROUP_SIZE: tl.constexpr, +): + """quant fp8 kernel.""" + group_id = tl.program_id(0) + m_id = tl.program_id(1) + + g_offs = group_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE) + + a_ptrs = a_ptr + m_id * stride_am + g_offs * stride_ak + o_ptrs = out_ptr + m_id * stride_om + g_offs * stride_ok + s_ptr = scale_ptr + m_id * stride_sm + group_id * stride_sg + + rfp8_max = 1 / fp8_max + + a = tl.load(a_ptrs).to(tl.float32) + scale = tl.max(tl.abs(a)) * rfp8_max + out = a / scale + + out = tl.clamp(out, fp8_min, fp8_max) + out = out.to(out_ptr.dtype.element_ty) + + tl.store(o_ptrs, out) + tl.store(s_ptr, scale) + + +def quant_fp8(A: Tensor, + group_size: int, + dtype: torch.dtype = torch.float8_e4m3fn): + """quant online.""" + assert A.dim() == 2 + M, K = A.shape + assert K % group_size == 0 + num_groups = K // group_size + + finfo = torch.finfo(dtype) + fmin = finfo.min + fmax = finfo.max + + out = torch.empty_like(A, dtype=dtype) + scales = A.new_empty(M, num_groups, dtype=torch.float32) + grid = (num_groups, M) + num_warps = 4 + num_stages = 1 + _quant_fp8_kernel[grid]( + A, + out, + scales, + fp8_min=fmin, + fp8_max=fmax, + stride_am=A.stride(0), + stride_ak=A.stride(1), + stride_om=out.stride(0), + stride_ok=out.stride(1), + stride_sm=scales.stride(0), + stride_sg=scales.stride(1), + GROUP_SIZE=group_size, + num_warps=num_warps, + num_stages=num_stages, + ) + + return out, scales + + +@triton.autotune(configs=[ + triton.Config({ + 'BLOCK_M': 64, + 'BLOCK_N': 128, + }, num_stages=3, num_warps=4), + triton.Config({ + 'BLOCK_M': 128, + 'BLOCK_N': 64, + }, num_stages=3, num_warps=4) +], + key=['N', 'K'], + warmup=5, + rep=10) +@triton.jit +def _gemm_fp8_kernel( + A, + a_scale_ptr, + B, + b_scale_ptr, + C, + M, + N: tl.constexpr, + K: tl.constexpr, + group_ak: tl.constexpr, + group_bk: tl.constexpr, + group_bn: tl.constexpr, + stride_am, + stride_ak: tl.constexpr, + stride_asm, + stride_ask: tl.constexpr, + stride_bk: tl.constexpr, + stride_bn: tl.constexpr, + stride_bsk: tl.constexpr, + stride_bsn: tl.constexpr, + stride_cm, + stride_cn: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, +): + """gemm fp8 kernel.""" + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N + offs_k = tl.arange(0, BLOCK_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + offs_bsn = pid_n * BLOCK_N // group_bn + as_ptrs = a_scale_ptr + offs_am * stride_asm + bs_ptrs = b_scale_ptr + offs_bsn * stride_bsn + + acc_scale = tl.load(as_ptrs) * tl.load(bs_ptrs) + acc_ratio = 1 / acc_scale + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + # load scales + k_start = (k + 1) * BLOCK_K + offs_ksa = k_start // group_ak + offs_ksb = k_start // group_bk + a_scale = tl.load(as_ptrs + offs_ksa * stride_ask, + mask=k_start < K, + other=1.0) + b_scale = tl.load(bs_ptrs + offs_ksb * stride_bsk, + mask=k_start < K, + other=1.0) + + # load ab + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) + + # mma + accumulator = tl.dot(a, b, acc=accumulator * acc_ratio[:, None]) + + # update scales and ratio + new_acc_scale = a_scale * b_scale + acc_ratio = acc_scale / new_acc_scale + acc_scale = new_acc_scale + + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + c = accumulator * (acc_ratio * acc_scale)[:, None] + + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def blocked_gemm_fp8(A: Tensor, + A_scale: Tensor, + B: Tensor, + B_scale: torch.Tensor, + out_dtype: torch.dtype = torch.float16): + """gemm fp8.""" + + def grid(META): + return (triton.cdiv(M, META['BLOCK_M']) * + triton.cdiv(N, META['BLOCK_N']), ) + + assert A.dim() == 2 + assert A_scale.dim() == 2 + assert B.dim() == 2 + assert B_scale.dim() == 2 + + M, K = A.shape + _, N = B.shape + + group_ak = triton.cdiv(K, A_scale.size(1)) + group_bk = triton.cdiv(K, B_scale.size(0)) + group_bn = triton.cdiv(N, B_scale.size(1)) + + C = A.new_empty(M, N, dtype=out_dtype) + + BLOCK_K = max(group_ak, group_bk) + + _gemm_fp8_kernel[grid]( + A, + A_scale, + B, + B_scale, + C, + M=M, + N=N, + K=K, + group_ak=group_ak, + group_bk=group_bk, + group_bn=group_bn, + stride_am=A.stride(0), + stride_ak=A.stride(1), + stride_asm=A_scale.stride(0), + stride_ask=A_scale.stride(1), + stride_bk=B.stride(0), + stride_bn=B.stride(1), + stride_bsk=B_scale.stride(0), + stride_bsn=B_scale.stride(1), + stride_cm=C.stride(0), + stride_cn=C.stride(1), + BLOCK_K=BLOCK_K, + GROUP_M=8, + ) + + return C diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py index c58d1d8d4e..b69ae6650d 100644 --- a/lmdeploy/pytorch/models/deepseek_v2.py +++ b/lmdeploy/pytorch/models/deepseek_v2.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist +import torch.nn.functional as F from torch import nn from lmdeploy.pytorch.distributed import get_world_rank @@ -81,7 +82,7 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None): super().__init__() - quantization_config = None + quantization_config = getattr(config, 'quantization_config', None) self.q_lora_rank = config.q_lora_rank self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads @@ -102,6 +103,7 @@ def __init__(self, dtype=dtype, device=device, is_tp=True, + quant_config=quantization_config, ) else: self.q_a_proj = build_colwise_linear( @@ -111,6 +113,7 @@ def __init__(self, dtype=dtype, device=device, is_tp=False, + quant_config=quantization_config, ) self.q_a_layernorm = RMSNorm(config.q_lora_rank, 1e-6, @@ -124,6 +127,7 @@ def __init__(self, dtype=dtype, device=device, is_tp=True, + quant_config=quantization_config, ) self.kv_a_proj_with_mqa = build_colwise_linear( @@ -133,6 +137,7 @@ def __init__(self, dtype=dtype, device=device, is_tp=False, + quant_config=quantization_config, ) self.kv_a_layernorm = RMSNorm(config.kv_lora_rank, 1e-6, @@ -176,6 +181,7 @@ def __init__(self, dtype=dtype, device=device, is_tp=True, + quant_config=quantization_config, ) def _q_proj(self, hidden_states, num_heads: int, nope_size: int, @@ -272,6 +278,104 @@ def forward( return attn_output +class MoEGate(nn.Module): + """Deepseek Gate.""" + + def __init__(self, + config: Any, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.scoring_func = config.scoring_func + self.alpha = config.aux_loss_alpha + self.seq_aux = config.seq_aux + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob + self.renormalize = self.top_k > 1 and self.norm_topk_prob + + # topk selection algorithm + self.norm_topk_prob = config.norm_topk_prob + self.gating_dim = config.hidden_size + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, self.gating_dim), + dtype=dtype, + device=device)) + if self.topk_method == 'noaux_tc': + self.e_score_correction_bias = nn.Parameter( + torch.empty((self.n_routed_experts, ), + dtype=dtype, + device=device)) + self.softmax_topk = SoftmaxTopK(self.top_k) + + def _compute_scores(self, logits: torch.Tensor): + """compute scores.""" + if self.scoring_func == 'softmax': + scores = logits.softmax(dim=-1, dtype=torch.float32) + elif self.scoring_func == 'sigmoid': + scores = logits.sigmoid() + else: + raise NotImplementedError('insupportable scoring function ' + f'for MoE gating: {self.scoring_func}') + return scores + + def forward(self, hidden_states: torch.Tensor): + """forward.""" + sequence_length, hidden_dim = hidden_states.shape + router_logits = F.linear(hidden_states, self.weight) + + if self.topk_method == 'greedy': + topk_weight, topk_idx = self.softmax_topk(router_logits) + elif self.topk_method == 'group_limited_greedy': + scores = self._compute_scores(router_logits) + grouped_logits = scores.unflatten(-1, (self.n_group, -1)) + group_scores = (grouped_logits.max(-1).values) + group_idx = torch.topk(group_scores, + k=self.topk_group, + dim=-1, + sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + group_mask = ~group_mask.bool()[..., None] + grouped_logits = grouped_logits.masked_fill(group_mask, 0.0) + scores = grouped_logits.flatten(1, 2) + topk_weight, topk_idx = self.softmax_topk(scores) + elif self.topk_method == 'noaux_tc': + scores = self._compute_scores(router_logits) + scores_for_choice = scores.view( + sequence_length, -1) + self.e_score_correction_bias[None] + group_scores = (scores_for_choice.view( + sequence_length, self.n_group, + -1).topk(2, dim=-1)[0].sum(dim=-1)) # [n, n_group] + group_idx = torch.topk(group_scores, + k=self.topk_group, + dim=-1, + sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = (group_mask.unsqueeze(-1).expand( + sequence_length, self.n_group, + self.n_routed_experts // self.n_group).reshape( + sequence_length, -1)) # [n, e] + tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), + 0.0) # [n, e] + _, topk_idx = torch.topk(tmp_scores, + k=self.top_k, + dim=-1, + sorted=False) + topk_weight = scores.gather(1, topk_idx) + else: + raise RuntimeError(f'Unsupported topk_method: {self.topk_method}') + if not self.renormalize: + topk_weight = topk_weight * self.routed_scaling_factor + return topk_weight, topk_idx + + class DeepseekV2MoE(nn.Module): """Deepseek v2 MoE.""" @@ -280,6 +384,7 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None): super().__init__() + quantization_config = getattr(config, 'quantization_config', None) self.hidden_dim = config.hidden_size self.ffn_dim = config.moe_intermediate_size self.num_experts = config.n_routed_experts @@ -291,16 +396,7 @@ def __init__(self, self.n_group = config.n_group self.topk_group = config.topk_group - self.gate = build_rowwise_linear( - self.hidden_dim, - self.num_experts, - bias=False, - dtype=dtype, - device=device, - is_tp=False, - ) - - self.softmax_topk = SoftmaxTopK(self.top_k) + self.gate = MoEGate(config, dtype=dtype, device=device) self.experts = build_fused_moe( self.hidden_dim, @@ -311,6 +407,7 @@ def __init__(self, dtype=dtype, device=device, all_reduce=False, + quant_config=quantization_config, ) self.shared_experts = None @@ -335,27 +432,8 @@ def forward(self, hidden_states: torch.Tensor): """forward.""" batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - router_logits = self.gate(hidden_states) + topk_weights, topk_ids = self.gate(hidden_states) - if self.topk_method == 'greedy': - topk_weights, topk_ids = self.softmax_topk(router_logits) - elif self.topk_method == 'group_limited_greedy': - grouped_logits = router_logits.unflatten(-1, (self.n_group, -1)) - group_scores = (grouped_logits.max(-1).values) - group_idx = torch.topk(group_scores, - k=self.topk_group, - dim=-1, - sorted=False)[1] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - group_mask = ~group_mask.bool()[..., None] - grouped_logits = grouped_logits.masked_fill(group_mask, 0.0) - router_logits = grouped_logits.flatten(1, 2) - topk_weights, topk_ids = self.softmax_topk(router_logits) - else: - raise RuntimeError(f'Unsupported topk_method: {self.topk_method}') - if not self.renormalize: - topk_weights = topk_weights * self.routed_scaling_factor out_states = self.experts( hidden_states, topk_weights, @@ -572,7 +650,6 @@ def forward( cos, sin = cos[0], sin[0] rotary_pos_emb = (cos, sin) for idx, decoder_layer in enumerate(self.layers): - past_key_value = past_key_values[idx] hidden_states, residual = decoder_layer( hidden_states, @@ -601,6 +678,8 @@ def __init__(self, device: torch.device = None): super().__init__() self.config = config + self.quantization_config = getattr(config, 'quantization_config', None) + self.dtype = dtype self.ctx_mgr = ctx_mgr self.model = DeepseekV2Model(config, dtype=dtype, device=device) # build lm_head @@ -609,6 +688,7 @@ def __init__(self, bias=False, dtype=dtype, device=device) + self._load_buffers = dict() def forward( self, @@ -692,40 +772,99 @@ def __update_pe(weight, head_dim: int, pe_dim_offset: int): weight = weight.flatten(0, 1) return weight + def __load_kcvc(name: str, weight: torch.Tensor): + """load kc and vc from weight.""" + config = self.config + v_head_dim = config.v_head_dim + qk_nope_head_dim = config.qk_nope_head_dim + w_kc, w_vc = weight.unflatten( + 0, (-1, qk_nope_head_dim + v_head_dim)).split( + [qk_nope_head_dim, v_head_dim], dim=1) + w_vc = w_vc.transpose(1, 2).contiguous() + kc_param_name = name.replace('.kv_b_proj', '.kc') + param_kc = params_dict[kc_param_name] + load_weight(param_kc, w_kc) + vc_param_name = name.replace('.kv_b_proj', '.vc') + param_vc = params_dict[vc_param_name] + load_weight(param_vc, w_vc) + + def __dequant_weight(weight: torch.Tensor, scale: torch.Tensor, + dtype: torch.dtype): + """dequant weight.""" + dim_w0, dim_w1 = weight.shape + dim_s0, dim_s1 = scale.shape + assert dim_w0 % dim_s0 == 0 + assert dim_w1 % dim_s1 == 0 + group0 = dim_w0 // dim_s0 + group1 = dim_w1 // dim_s1 + weight = weight.reshape(dim_s0, group0, dim_s1, group1) + scale = scale.reshape(dim_s0, 1, dim_s1, 1) + weight = weight.to(scale.dtype) * scale + weight = weight.to(dtype) + weight = weight.reshape(dim_w0, dim_w1) + return weight + + def __load_kcvc_blocked_fp8(name: str, loaded_weight: torch.Tensor): + """dequant weight.""" + if name.endswith('.weight'): + weight_name = name + scale_name = name.replace('.weight', '.scale') + elif name.endswith('.scale'): + weight_name = name.replace('.scale', '.weight') + scale_name = name + self._load_buffers[name] = loaded_weight + if (weight_name in self._load_buffers + and scale_name in self._load_buffers): + weight = self._load_buffers.pop(weight_name) + scale = self._load_buffers.pop(scale_name) + kc_param_name = weight_name.replace('.kv_b_proj', '.kc') + dtype = params_dict[kc_param_name].dtype + weight = __dequant_weight(weight, scale, dtype) + __load_kcvc(weight_name, weight) + for (mod_name, head_dim, pe_dim_offset) in update_pe_mapping: if mod_name not in name: continue - weight = __update_pe(loaded_weight, head_dim, pe_dim_offset) + if name.endswith('.scale'): + weight = loaded_weight + else: + weight = __update_pe(loaded_weight, head_dim, pe_dim_offset) param = params_dict[name] load_weight(param, weight) break else: if '.kv_b_proj' in name: - config = self.config - v_head_dim = config.v_head_dim - qk_nope_head_dim = config.qk_nope_head_dim - w_kc, w_vc = loaded_weight.unflatten( - 0, (-1, qk_nope_head_dim + v_head_dim)).split( - [qk_nope_head_dim, v_head_dim], dim=1) - w_vc = w_vc.transpose(1, 2).contiguous() - kc_param_name = name.replace('.kv_b_proj', '.kc') - param_kc = params_dict[kc_param_name] - load_weight(param_kc, w_kc) - vc_param_name = name.replace('.kv_b_proj', '.vc') - param_vc = params_dict[vc_param_name] - load_weight(param_vc, w_vc) + quantization_config = self.quantization_config + quant_method = None + if quantization_config is not None: + quant_method = quantization_config.get('quant_method') + + if quant_method == 'fp8': + # update blocked fp8 weight + __load_kcvc_blocked_fp8(name, loaded_weight) + else: + __load_kcvc(name, loaded_weight) else: param = params_dict[name] load_weight(param, loaded_weight) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): """load weights.""" + + def __skip_nextn(name, nextn_keys): + for nextn_key in nextn_keys: + if nextn_key in name: + return True + return False + stacked_params_mapping = [ # (param_name, shard_name, shard_id) ('.gate_up_proj', '.gate_proj', 0), ('.gate_up_proj', '.up_proj', 1), ] + scale_suffix = '.weight_scale_inv' + config = self.config qk_rope_head_dim = config.qk_rope_head_dim kv_lora_rank = config.kv_lora_rank @@ -747,6 +886,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): exp_id, 'down') expert_params_mapping += [gate_param, up_param, down_param] + num_hidden_layers = self.config.num_hidden_layers + + num_nextn_predict_layers = getattr(self.config, + 'num_nextn_predict_layers', 1) + nextn_keys = [ + f'.layers.{num_hidden_layers+i}' + for i in range(num_nextn_predict_layers) + ] + params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if 'rotary_emb.inv_freq' in name: @@ -754,8 +902,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name): continue + if '.layers' in name: + # skip nextn + if __skip_nextn(name, nextn_keys): + continue if self.config.tie_word_embeddings and 'lm_head.weight' in name: continue + if name.endswith(scale_suffix): + name = name[:-len(scale_suffix)] + '.scale' if '.experts' in name: self._load_weight_experts( name, diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index e7b460026a..c1b62736f7 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -82,6 +82,12 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v2.DeepseekV2ForCausalLM' }) +# deepseek-v3 +MODULE_MAP.update({ + 'DeepseekV3ForCausalLM': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v2.DeepseekV2ForCausalLM' +}) + # llava MODULE_MAP.update( { diff --git a/lmdeploy/pytorch/nn/linear.py b/lmdeploy/pytorch/nn/linear.py index bc2f4b3591..73d0ef918d 100644 --- a/lmdeploy/pytorch/nn/linear.py +++ b/lmdeploy/pytorch/nn/linear.py @@ -12,7 +12,7 @@ from ..backends import OpType, get_backend from ..backends.lora import AdapterInfo -from .utils import chunk_aligned, get_distribute_size +from .utils import chunk_aligned, div_up, get_distribute_size logger = get_logger('lmdeploy') @@ -152,6 +152,239 @@ def weight_loader_B(self, param: nn.Parameter, loaded_weight: torch.Tensor, param_r.copy_(loaded_weight.t()) +class BlockedF8Linear(nn.Module): + """blocked f8 linear.""" + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + fp8_dtype: torch.dtype = torch.float8_e4m3fn, + colwise: bool = True, + is_tp: bool = False, + all_reduce: bool = True, + ): + super().__init__() + if device is None: + device = torch.device('cpu') + if dtype is None: + dtype = torch.float16 + if is_tp: + in_features, out_features = self._get_io_features( + in_features, out_features, colwise) + impl_builder = get_backend().get_layer_impl_builder( + OpType.LinearBlockedF8) + self.impl = impl_builder.build(in_features, + out_features, + block_size=128, + bias=bias is not None, + dtype=dtype) + self.block_size = 128 + self.fp8_dtype = fp8_dtype + weight, scale, bias = self.create_weights(in_features, out_features, + bias, dtype, device) + weight = torch.nn.Parameter(weight, requires_grad=False) + weight.weight_loader = self.weight_loader + scale = torch.nn.Parameter(scale, requires_grad=False) + scale.weight_loader = self.weight_loader + if bias is not None: + bias = torch.nn.Parameter(bias, requires_grad=False) + bias.weight_loader = self.weight_loader + self.register_parameter('weight', weight) + self.register_parameter('scale', scale) + self.register_parameter('bias', bias) + + self.in_features = in_features + self.out_features = out_features + self.lora_adapters = nn.ModuleDict() + self.is_tp = is_tp + self.colwise = colwise + self.all_reduce = all_reduce + + def _get_io_features(self, in_features: int, out_features: int, + colwise: bool): + """get io features.""" + world_size, rank = get_world_rank() + if colwise: + out_features = get_distribute_size(out_features, world_size, rank) + else: + in_features = get_distribute_size(in_features, world_size, rank) + return in_features, out_features + + def _weight_loader_tp_colwise(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, rank: int, + world_size: int): + """weight loader for colwise linear.""" + weight = loaded_weight.chunk(world_size, 0)[rank] + return default_weight_loader(param, weight) + + def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, rank: int, + world_size: int): + """weight loader for rowwise linear.""" + if loaded_weight.dim() == 2: + weight = loaded_weight.chunk(world_size, 1)[rank] + return default_weight_loader(param, weight) + else: + # bias + if rank != 0: + loaded_weight = torch.zeros_like(loaded_weight) + return default_weight_loader(param, loaded_weight) + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor): + """weight loader.""" + if not self.is_tp: + return default_weight_loader(param, loaded_weight) + + world_size, rank = get_world_rank() + if self.colwise: + return self._weight_loader_tp_colwise(param, loaded_weight, rank, + world_size) + else: + return self._weight_loader_tp_rowwise(param, loaded_weight, rank, + world_size) + + def create_weights(self, in_features: int, out_features: int, bias: bool, + dtype: torch.dtype, device: torch.device): + """create weights.""" + weight = torch.empty((out_features, in_features), + dtype=self.fp8_dtype, + device=device) + scale = torch.empty( + (div_up(out_features, + self.block_size), div_up(in_features, self.block_size)), + dtype=torch.float32, + device=device) + if bias: + bias = torch.empty((out_features, ), dtype=dtype, device=device) + else: + bias = None + return weight, scale, bias + + def update_weights(self): + """update weights.""" + weight, scale, bias = self.impl.update_weights(self.weight, self.scale, + self.bias) + weight = torch.nn.Parameter(weight, requires_grad=False) + self.weight.weight_loader = self.weight_loader + scale = torch.nn.Parameter(scale, requires_grad=False) + self.scale.weight_loader = self.weight_loader + if bias is not None: + bias = torch.nn.Parameter(bias, requires_grad=False) + self.bias.weight_loader = self.weight_loader + self.register_parameter('weight', weight) + self.register_parameter('scale', scale) + self.register_parameter('bias', bias) + + def forward(self, x): + """forward of blocked fp8 linear.""" + all_reduce = False if self.colwise else self.is_tp + all_reduce = all_reduce and self.all_reduce + if len(self.lora_adapters) == 0: + return self.impl.forward(x, self.weight, self.scale, self.bias, + all_reduce) + + out = self.impl.forward(x, self.weight, self.scale, self.bias, False) + for lora_adapter in self.lora_adapters.values(): + out = lora_adapter(x, out) + if all_reduce: + dist.all_reduce(out) + return out + + +class MergedBlockedF8Linear(BlockedF8Linear): + """merged blocked fp8 linear.""" + + def __init__(self, + in_features: int, + all_out_features: List[int], + bias: bool, + fp8_dtype: torch.dtype = torch.float8_e4m3fn, + replicate: Optional[List[bool]] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + is_tp: bool = True, + out_names: Optional[List[int]] = None): + if replicate is None: + replicate = tuple(False for _ in all_out_features) + self.block_size = 128 + self.split_section = all_out_features + self.scale_split_section = [ + section // self.block_size for section in self.split_section + ] + all_out_features = self._update_all_out_features( + all_out_features, replicate) + self.all_out_features = all_out_features + self.replicate = replicate + if out_names is None: + out_names = torch.arange(len(self.all_out_features)).tolist() + assert len(out_names) == len(self.all_out_features) + self.out_names_map = dict( + (name, idx) for idx, name in enumerate(out_names)) + out_features = sum(all_out_features) + super().__init__(in_features, + out_features, + bias, + dtype, + device, + fp8_dtype=fp8_dtype, + colwise=True, + is_tp=is_tp) + self.weight.weight_loader = self.weight_loader + self.scale.weight_loader = self.weight_loader + self.weight.weight_spliter = self.weight_spliter + self.scale.weight_spliter = self.weight_spliter + if self.bias is not None: + self.bias.weight_loader = self.weight_loader + self.bias.weight_spliter = self.weight_spliter + + def _get_io_features(self, in_features: int, out_features: int, + colwise: bool): + """get io features.""" + return in_features, out_features + + def _update_all_out_features(self, all_out_features: List[int], + replicate: Optional[List[bool]]): + """update all out features.""" + world_size, rank = get_world_rank() + new_all_out_features = [] + for out_feat, rep in zip(all_out_features, replicate): + if rep: + new_all_out_features.append(out_feat) + new_out_feat = get_distribute_size(out_feat, world_size, rank) + new_all_out_features.append(new_out_feat) + return new_all_out_features + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, shard_id: Any): + """weight loader.""" + world_size, rank = get_world_rank() + shard_idx = self.out_names_map[shard_id] + if loaded_weight.dim() == 2 and loaded_weight.dtype == torch.float32: + all_out_features = [ + feats // self.block_size for feats in self.all_out_features + ] + param_w = param.data.split(all_out_features, 0)[shard_idx] + else: + param_w = param.data.split(self.all_out_features, 0)[shard_idx] + if not self.replicate[shard_idx]: + loaded_weight = loaded_weight.chunk(world_size, 0)[rank] + param_w.copy_(loaded_weight) + + def weight_spliter(self, loaded_weight: torch.Tensor): + """weight spliter.""" + if loaded_weight.dim() == 2 and loaded_weight.dtype == torch.float32: + return loaded_weight.split(self.scale_split_section, dim=0) + return loaded_weight.split(self.split_section, dim=0) + + def weight_spliter_lora_b(self, loaded_weight: torch.Tensor): + return loaded_weight.split(self.split_section, dim=0) + + class AwqLinear(nn.Module): """w4a16 linear.""" @@ -1223,6 +1456,25 @@ def build_linear(in_features: int, is_tp=is_tp, all_reduce=all_reduce, quant_dtype=quant_dtype) + elif quant_method == 'fp8': + fmt = quant_config.get('fmt', 'e4m3') + if fmt == 'e4m3': + fp8_dtype = torch.float8_e4m3fn + elif fmt == 'e5m2': + fp8_dtype = torch.float8_e5m2 + else: + raise TypeError(f'Unsupported fp8 fmt: {fmt}') + return BlockedF8Linear( + in_features, + out_features, + bias=bias, + fp8_dtype=fp8_dtype, + dtype=dtype, + device=device, + colwise=colwise, + is_tp=is_tp, + all_reduce=all_reduce, + ) else: raise RuntimeError(f'Unsupported quant method: {quant_method}') @@ -1322,6 +1574,24 @@ def build_merged_colwise_linear( is_tp=is_tp, out_names=out_names, quant_dtype=quant_dtype) + elif quant_method == 'fp8': + fmt = quant_config.get('fmt', 'e4m3') + if fmt == 'e4m3': + fp8_dtype = torch.float8_e4m3fn + elif fmt == 'e5m2': + fp8_dtype = torch.float8_e5m2 + else: + raise TypeError(f'Unsupported fp8 fmt: {fmt}') + return MergedBlockedF8Linear( + in_features=in_features, + all_out_features=all_out_features, + bias=bias, + fp8_dtype=fp8_dtype, + dtype=dtype, + device=device, + is_tp=is_tp, + out_names=out_names, + ) else: raise RuntimeError(f'Unsupported quant method: {quant_method}') diff --git a/lmdeploy/pytorch/nn/moe.py b/lmdeploy/pytorch/nn/moe.py index 5cda26bb15..4921825c9a 100644 --- a/lmdeploy/pytorch/nn/moe.py +++ b/lmdeploy/pytorch/nn/moe.py @@ -8,6 +8,7 @@ from lmdeploy.pytorch.distributed import get_world_rank from ..backends import OpType, get_backend +from .utils import div_up class SoftmaxTopK(nn.Module): @@ -336,6 +337,160 @@ def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, return ret +class LinearWeightsBlockedF8(LinearWeights): + """fused moe linear blocked fp8 weights.""" + + def __init__(self, + num_experts: int, + in_features: int, + out_features: int, + weight_type: str, + block_size: int, + dtype: torch.dtype, + device: torch.device, + expert_list: List[int] = None, + ep: bool = False): + super().__init__( + num_experts=num_experts, + in_features=in_features, + out_features=out_features, + weight_type=weight_type, + dtype=dtype, + device=device, + expert_list=expert_list, + ep=ep, + ) + self.block_size = block_size + scale = torch.empty((num_experts, div_up( + out_features, block_size), div_up(in_features, block_size)), + dtype=torch.float32, + device=device) + scale = torch.nn.Parameter(scale, requires_grad=False) + self.register_parameter('scale', scale) + + if self.ep: + self.scale.weight_loader = self.weight_loader_ep + else: + self.scale.weight_loader = self.weight_loader_scale_tp + + def update_weight(self, weight: torch.Tensor, scale: torch.Tensor): + """update weight.""" + super().update_weight(weight=weight) + weight_loader = self.scale.weight_loader + scale = torch.nn.Parameter(scale, requires_grad=False) + scale.weight_loader = weight_loader + self.register_parameter('scale', scale) + + def weight_loader_scale_tp(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, expert_id: int, + shard_id: str): + """weight loader scale tp.""" + world_size, rank = get_world_rank() + block_size = self.block_size + half_out = self.half_out // block_size + if shard_id == 'gate': + param_data = param.data[expert_id, :half_out] + weight = loaded_weight.chunk(world_size, dim=0)[rank] + elif shard_id == 'up': + param_data = param.data[expert_id, half_out:] + weight = loaded_weight.chunk(world_size, dim=0)[rank] + elif shard_id == 'down': + param_data = param.data[expert_id] + weight = loaded_weight.chunk(world_size, dim=1)[rank] + else: + raise RuntimeError(f'Unknown shard_id: {shard_id}') + param_data.copy_(weight) + + +class FusedMoEBlockedF8(nn.Module): + """fused moe blocked f8.""" + + def __init__(self, + hidden_dim: int, + ffn_dim: int, + num_experts: int, + top_k: int, + renormalize: bool = False, + fp8_dtype: torch.dtype = torch.float8_e4m3fn, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + all_reduce: bool = True, + enable_ep: bool = False): + super().__init__() + if device is None: + device = torch.device('cpu') + dtype = torch.float16 if dtype is None else dtype + self.block_size = 128 + impl_builder = get_backend().get_layer_impl_builder( + OpType.FusedMoEBlockedF8) + self.impl = impl_builder.build(top_k, + num_experts, + renormalize, + block_size=self.block_size, + out_dtype=dtype) + + enable_ep = enable_ep and self.impl.support_ep() + if enable_ep: + world_size, rank = get_world_rank() + expert_list = self.impl.ep_expert_list(world_size, rank) + num_experts = len(expert_list) + else: + hidden_dim, ffn_dim = _update_args(hidden_dim, ffn_dim) + expert_list = None + self.expert_list = expert_list + + self.gate_up = LinearWeightsBlockedF8(num_experts, + hidden_dim, + ffn_dim * 2, + weight_type='gate_up', + block_size=self.block_size, + dtype=fp8_dtype, + device=device, + expert_list=expert_list, + ep=enable_ep) + self.down = LinearWeightsBlockedF8( + num_experts, + ffn_dim, + hidden_dim, + weight_type='down', + block_size=self.block_size, + dtype=fp8_dtype, + device=device, + expert_list=expert_list, + ep=enable_ep, + ) + + self.hidden_dim = hidden_dim + self.ffn_dim = ffn_dim + self.num_experts = num_experts + self.dtype = dtype + self.device = device + world_size, _ = get_world_rank() + if world_size == 1: + all_reduce = False + self.all_reduce = all_reduce + + def update_weights(self): + """update weights.""" + (gate_up_weights, down_weights, gate_up_scale, + down_scale) = self.impl.update_weights(self.gate_up.weight, + self.down.weight, + self.gate_up.scale, + self.down.scale) + self.gate_up.update_weight(gate_up_weights, gate_up_scale) + self.down.update_weight(down_weights, down_scale) + + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.LongTensor): + ret = self.impl.forward(hidden_states, topk_weights, topk_ids, + self.gate_up.weight, self.gate_up.scale, + self.down.weight, self.down.scale, + self.expert_list) + if self.all_reduce: + dist.all_reduce(ret) + return ret + + def build_fused_moe( hidden_dim: int, ffn_dim: int, @@ -376,5 +531,25 @@ def build_fused_moe( all_reduce=all_reduce, enable_ep=enable_ep, ) + elif quant_method == 'fp8': + fmt = quant_config.get('fmt', 'e4m3') + if fmt == 'e4m3': + fp8_dtype = torch.float8_e4m3fn + elif fmt == 'e5m2': + fp8_dtype = torch.float8_e5m2 + else: + raise TypeError(f'Unsupported fp8 fmt: {fmt}') + return FusedMoEBlockedF8( + hidden_dim=hidden_dim, + ffn_dim=ffn_dim, + num_experts=num_experts, + top_k=top_k, + renormalize=renormalize, + fp8_dtype=fp8_dtype, + dtype=dtype, + device=device, + all_reduce=all_reduce, + enable_ep=enable_ep, + ) else: raise RuntimeError(f'Unsupported quant method: {quant_method}') diff --git a/tests/pytorch/kernel/test_fuse_moe_blocked_fp8.py b/tests/pytorch/kernel/test_fuse_moe_blocked_fp8.py new file mode 100644 index 0000000000..bb165658dd --- /dev/null +++ b/tests/pytorch/kernel/test_fuse_moe_blocked_fp8.py @@ -0,0 +1,231 @@ +import pytest +import torch + + +def _make_A(M, K, group_size, out_dtype, device='cuda'): + quant_A = torch.rand(M, + K // group_size, + group_size, + dtype=torch.float32, + device=device) + # -1 ~ 1 + quant_A = quant_A * 2 - 1 + # scaling abs max to fmax + finfo = torch.finfo(out_dtype) + fmax = finfo.max + scaling = fmax / quant_A.abs().amax(-1, keepdim=True) + quant_A *= scaling + quant_A = quant_A.to(out_dtype).to(torch.float32) + + # create scale and A + scale = torch.rand(M, K // group_size, dtype=torch.float32, device=device) + scale /= fmax + A = quant_A * scale[..., None] + + A = A.reshape(M, K) + quant_A = quant_A.reshape(M, K).to(out_dtype) + return A, quant_A, scale + + +def _make_B(E, K, N, group_size, out_dtype, device='cuda'): + quant_B = torch.rand(E, + N // group_size, + group_size, + K // group_size, + group_size, + dtype=torch.float32, + device=device) + quant_B = quant_B * 2 - 1 + + # scaling abs max to fmax + finfo = torch.finfo(out_dtype) + fmax = finfo.max + scaling = fmax / quant_B.abs().amax((2, 4), keepdim=True) + quant_B *= scaling + quant_B = quant_B.to(out_dtype).to(torch.float32) + + scale = torch.rand(E, + N // group_size, + 1, + K // group_size, + 1, + dtype=torch.float32, + device=device) + scale /= fmax + + B = quant_B * scale + + B = B.reshape(E, N, K) + quant_B = quant_B.reshape(E, N, K).to(out_dtype) + scale = scale.reshape(E, N // group_size, K // group_size) + return B, quant_B, scale + + +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, + reason='require device with cc>=9.0') +class TestFusedMoeBlockedFP8: + + @pytest.fixture + def dtype(self): + yield torch.float16 + + @pytest.fixture + def quant_dtype(self): + yield torch.float8_e4m3fn + + @pytest.fixture + def device(self): + yield torch.device('cuda') + + @pytest.fixture + def in_size(self): + yield 512 + + @pytest.fixture + def seq_len(seq_len): + yield 128 + + @pytest.fixture + def hidden_size(self): + yield 2048 + + @pytest.fixture + def out_size(self): + yield 1024 + + @pytest.fixture + def num_experts(self): + yield 4 + + @pytest.fixture + def top_k(self): + yield 2 + + @pytest.fixture + def group_size(self): + yield 128 + + @pytest.fixture + def renormalize(self): + yield True + + @pytest.fixture + def build_hidden_states(self, seq_len, in_size, group_size, quant_dtype, + device): + yield _make_A(seq_len, + in_size, + group_size=group_size, + out_dtype=quant_dtype, + device=device) + + @pytest.fixture + def hidden_states(self, build_hidden_states, dtype): + yield build_hidden_states[0].to(dtype) + + @pytest.fixture + def states_quanted(self, build_hidden_states): + yield build_hidden_states[1] + + @pytest.fixture + def states_scale(self, build_hidden_states): + yield build_hidden_states[2] + + @pytest.fixture + def build_w1(self, num_experts, hidden_size, in_size, group_size, + quant_dtype, device): + yield _make_B(num_experts, + in_size, + hidden_size, + group_size=group_size, + out_dtype=quant_dtype, + device=device) + + @pytest.fixture + def w1(self, build_w1, dtype): + yield build_w1[0].to(dtype) + + @pytest.fixture + def w1_quant(self, build_w1): + yield build_w1[1] + + @pytest.fixture + def w1_scale(self, build_w1): + yield build_w1[2] + + @pytest.fixture + def build_w2(self, num_experts, out_size, hidden_size, group_size, + quant_dtype, device): + yield _make_B(num_experts, + hidden_size // 2, + out_size, + group_size=group_size, + out_dtype=quant_dtype, + device=device) + + @pytest.fixture + def w2(self, build_w2, dtype): + yield build_w2[0].to(dtype) + + @pytest.fixture + def w2_quant(self, build_w2): + yield build_w2[1] + + @pytest.fixture + def w2_scale(self, build_w2): + yield build_w2[2] + + @pytest.fixture + def router_logits(self, seq_len, num_experts, dtype, device): + yield torch.rand(seq_len, num_experts, dtype=dtype, device=device) + + @pytest.fixture + def topk_logits(self, router_logits, top_k): + routing_weights = torch.softmax(router_logits, + dim=-1, + dtype=torch.float32) + yield torch.topk(routing_weights, top_k, dim=-1) + + @pytest.fixture + def topk_weights(self, topk_logits): + yield topk_logits[0] + + @pytest.fixture + def topk_idx(self, topk_logits): + yield topk_logits[1] + + @pytest.fixture + def gt(self, hidden_states, w1, w2, topk_weights, topk_idx, top_k, + renormalize): + from lmdeploy.pytorch.kernels.cuda.fused_moe import fused_moe + output = fused_moe(hidden_states, + w1, + w2, + topk_weights, + topk_idx, + topk=top_k, + renormalize=renormalize) + yield output + + @torch.inference_mode() + def test_fused_moe(self, states_quanted, states_scale, w1_quant, w1_scale, + w2_quant, w2_scale, topk_weights, topk_idx, top_k, + renormalize, gt): + from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import \ + fused_moe_blocked_fp8 + output = fused_moe_blocked_fp8(states_quanted, + states_scale, + w1_quant, + w1_scale, + w2_quant, + w2_scale, + topk_weights, + topk_idx, + topk=top_k, + renormalize=renormalize) + out_max = output.abs().max() + gt_max = gt.abs().max() + assert (out_max - gt_max).abs() / out_max < 0.05 + + norm_out = output / out_max + norm_gt = gt / gt_max + torch.testing.assert_close(norm_out, norm_gt, atol=0.05, rtol=1e-3) diff --git a/tests/pytorch/kernel/test_gemm_fp8.py b/tests/pytorch/kernel/test_gemm_fp8.py new file mode 100644 index 0000000000..242a2db581 --- /dev/null +++ b/tests/pytorch/kernel/test_gemm_fp8.py @@ -0,0 +1,193 @@ +import pytest +import torch + + +def _make_A(M, K, group_size, out_dtype): + quant_A = torch.rand(M, + K // group_size, + group_size, + dtype=torch.float32, + device='cuda') + # -1 ~ 1 + quant_A = quant_A * 2 - 1 + # scaling abs max to fmax + finfo = torch.finfo(out_dtype) + fmax = finfo.max + scaling = fmax / quant_A.abs().amax(-1, keepdim=True) + quant_A *= scaling + quant_A = quant_A.to(out_dtype).to(torch.float32) + + # create scale and A + scale = torch.rand(M, K // group_size, dtype=torch.float32, device='cuda') + scale /= fmax + A = quant_A * scale[..., None] + + A = A.reshape(M, K) + quant_A = quant_A.reshape(M, K).to(out_dtype) + return A, quant_A, scale + + +def _aligned_size(a, b): + return (a + b - 1) // b * b + + +def _make_B(K, N, group_size, out_dtype): + K_aligned = _aligned_size(K, group_size) + N_aligned = _aligned_size(N, group_size) + + quant_B = torch.rand(K_aligned // group_size, + group_size, + N_aligned // group_size, + group_size, + dtype=torch.float32, + device='cuda') + quant_B = quant_B * 2 - 1 + + # scaling abs max to fmax + finfo = torch.finfo(out_dtype) + fmax = finfo.max + scaling = fmax / quant_B.abs().amax((1, 3), keepdim=True) + quant_B *= scaling + quant_B = quant_B.to(out_dtype).to(torch.float32) + + scale = torch.rand(K_aligned // group_size, + 1, + N_aligned // group_size, + 1, + dtype=torch.float32, + device='cuda') + scale /= fmax + + B = quant_B * scale + + B = B.reshape(K_aligned, N_aligned)[:K, :N] + quant_B = quant_B.reshape(K_aligned, N_aligned).to(out_dtype)[:K, :N] + scale = scale.reshape(K_aligned // group_size, N_aligned // group_size) + return B, quant_B, scale + + +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, + reason='require device with cc>=9.0') +class TestQuantFP8: + + @pytest.fixture + def M(self): + yield 256 + + @pytest.fixture + def K(self): + yield 512 + + @pytest.fixture + def group_size(self): + yield 128 + + @pytest.fixture + def out_dtype(self): + yield torch.float8_e4m3fn + + @pytest.fixture + def build_A(self, M, K, group_size, out_dtype): + return _make_A(M, K, group_size, out_dtype) + + @pytest.fixture + def A(self, build_A): + return build_A[0] + + @pytest.fixture + def quant_A(self, build_A): + return build_A[1] + + @pytest.fixture + def scale(self, build_A): + return build_A[2] + + @pytest.fixture + def gt(self, quant_A, scale): + yield quant_A, scale + + def test_quant_fp8(self, A, group_size, out_dtype, gt): + from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8 + quant_A_gt, scale_gt = gt + + quant_A, scale = quant_fp8(A, group_size=group_size, dtype=out_dtype) + torch.testing.assert_close(scale, scale_gt) + diff = (quant_A.to(torch.float16) - quant_A_gt.to(torch.float16)).abs() + diff_count = (diff > 1e-5).count_nonzero() + assert diff_count / diff.numel() < 1e-4 + + +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, + reason='require device with cc>=9.0') +class TestGemmFP8: + + @pytest.fixture + def M(self): + yield 256 + + @pytest.fixture + def N(self): + # test non-aligned + yield 1024 + 64 + + @pytest.fixture + def K(self): + yield 512 + + @pytest.fixture + def group_size(self): + yield 128 + + @pytest.fixture + def quant_dtype(self): + yield torch.float8_e4m3fn + + @pytest.fixture + def out_dtype(self): + yield torch.float16 + + @pytest.fixture + def build_A(self, M, K, group_size, quant_dtype): + return _make_A(M, K, group_size, quant_dtype) + + @pytest.fixture + def A(self, build_A, out_dtype): + return build_A[0].to(out_dtype) + + @pytest.fixture + def quant_A(self, build_A): + return build_A[1] + + @pytest.fixture + def scale_A(self, build_A): + return build_A[2] + + @pytest.fixture + def build_B(self, K, N, group_size, quant_dtype): + return _make_B(K, N, group_size, quant_dtype) + + @pytest.fixture + def B(self, build_B, out_dtype): + return build_B[0].to(out_dtype) + + @pytest.fixture + def quant_B(self, build_B): + return build_B[1] + + @pytest.fixture + def scale_B(self, build_B): + return build_B[2] + + @pytest.fixture + def gt(self, A, B): + yield A @ B + + def test_gemm_fp8(self, quant_A, scale_A, quant_B, scale_B, out_dtype, gt): + from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import \ + blocked_gemm_fp8 + C = blocked_gemm_fp8(quant_A, + scale_A, + quant_B, + scale_B, + out_dtype=out_dtype) + torch.testing.assert_close(C, gt, atol=0.5, rtol=1e-4)