Skip to content

Commit

Permalink
Support DeepseekV3 fp8 (#2967)
Browse files Browse the repository at this point in the history
* support moe w8a8

* refactor moe

* fix tp

* first

* done

* optimize

* stage =3

* pingpong

* skip on cc<9.0

* fused moe blocked gemm

* lint

* stages 4

* support newds

* new format

* remove scale=scale
  • Loading branch information
grimoire authored Jan 8, 2025
1 parent de2050d commit ac509e8
Show file tree
Hide file tree
Showing 15 changed files with 1,912 additions and 49 deletions.
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class OpType(Enum):
SoftmaxTopK = auto()
FusedMoE = auto()
FusedMoEW8A8 = auto()
LinearBlockedF8 = auto()
FusedMoEBlockedF8 = auto()


class OpsBackend(ABC):
Expand Down
39 changes: 39 additions & 0 deletions lmdeploy/pytorch/backends/blockedf8_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from typing import Optional

import torch


class LinearBlockedF8Impl(ABC):
"""linear BlockedF8 implementation api."""

def update_weights(self,
weight: torch.Tensor,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None):
"""update weights."""
return weight, scale, bias

@abstractmethod
def forward(self,
x,
weight: torch.Tensor,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False):
"""forward."""
raise NotImplementedError


class LinearBlockedF8Builder(ABC):
"""linear BlockedF8 implementation builder."""

@staticmethod
@abstractmethod
def build(in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None):
"""build."""
raise NotImplementedError
65 changes: 65 additions & 0 deletions lmdeploy/pytorch/backends/cuda/blockedf8_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch
import torch.distributed as dist

from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import (blocked_gemm_fp8,
quant_fp8)

from ..blockedf8_modules import LinearBlockedF8Builder, LinearBlockedF8Impl


class TritonLinearBlockedF8Impl(LinearBlockedF8Impl):
"""triton linear blocked f8 implementation."""

def __init__(self,
in_features: int,
out_features: int,
block_size: int,
out_dtype: torch.dtype = torch.float16):
self.in_features = in_features
self.out_features = out_features
self.out_dtype = out_dtype
self.block_size = block_size

def forward(self,
x,
weight: torch.Tensor,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False):
"""forward."""
x_shape = x.shape
x = x.flatten(0, -2)
input_quant, input_scale = quant_fp8(x,
self.block_size,
dtype=weight.dtype)

out = blocked_gemm_fp8(input_quant,
input_scale,
weight.t(),
scale.t(),
out_dtype=x.dtype)
if bias is not None:
out += bias

if all_reduce:
dist.all_reduce(out)

out = out.unflatten(0, x_shape[:-1])
return out


class TritonLinearBlockedF8Builder(LinearBlockedF8Builder):
"""triton linear blocked f8 implementation builder."""

@staticmethod
def build(in_features: int,
out_features: int,
block_size: int = 128,
bias: bool = True,
dtype: torch.dtype = None):
"""build."""
return TritonLinearBlockedF8Impl(in_features, out_features, block_size,
dtype)
98 changes: 97 additions & 1 deletion lmdeploy/pytorch/backends/cuda/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
import torch

from lmdeploy.pytorch.kernels.cuda import fused_moe, fused_moe_w8a8
from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import \
fused_moe_blocked_fp8
from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8
from lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import \
per_token_quant_int8
from lmdeploy.pytorch.models.q_modules import QTensor

from ..moe import (FusedMoEBuilder, FusedMoEImpl, FusedMoEW8A8Builder,
from ..moe import (FusedMoEBlockedF8Builder, FusedMoEBlockedF8Impl,
FusedMoEBuilder, FusedMoEImpl, FusedMoEW8A8Builder,
FusedMoEW8A8Impl)


Expand Down Expand Up @@ -168,3 +172,95 @@ def build(top_k: int,
num_experts=num_experts,
renormalize=renormalize,
out_dtype=out_dtype)


class TritonFusedMoEBlockedF8Impl(FusedMoEBlockedF8Impl):
"""triton fused moe blocked f8 implementation."""

def __init__(self,
top_k: int,
num_experts: int,
renormalize: bool = False,
block_size: int = 128,
out_dtype: torch.dtype = torch.float16):
self.num_experts = num_experts
self.top_k = top_k
self.renormalize = renormalize
self.block_size = block_size
self.out_dtype = out_dtype

def update_weights(self, gate_up_weights: torch.Tensor,
down_weights: torch.Tensor, gate_up_scale: torch.Tensor,
down_scale: torch.Tensor):
gate_up_weights = gate_up_weights.transpose(1,
2).contiguous().transpose(
1, 2)
down_weights = down_weights.transpose(1,
2).contiguous().transpose(1, 2)
return gate_up_weights, down_weights, gate_up_scale, down_scale

def support_ep(self):
"""support expert parallelism."""
return True

def ep_expert_list(self, world_size: int, rank: int):
"""experts list of current rank."""
num_experts = self.num_experts
expert_per_rank = (num_experts + world_size - 1) // world_size
first_expert = rank * expert_per_rank
last_expert = min(first_expert + expert_per_rank, num_experts)
return list(range(first_expert, last_expert))

def forward(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.LongTensor,
gate_up_weights: torch.Tensor,
gate_up_scale: torch.Tensor,
down_weights: torch.Tensor,
down_scale: torch.Tensor,
expert_list: List[int] = None):
"""forward."""
input_size = hidden_states.shape
hidden_states = hidden_states.flatten(0, -2)
input_quant, input_scale = quant_fp8(hidden_states,
self.block_size,
dtype=gate_up_weights.dtype)

expert_offset = 0
num_experts = None
if expert_list is not None and len(expert_list) != self.num_experts:
expert_offset = expert_list[0]
num_experts = self.num_experts
output = fused_moe_blocked_fp8(input_quant,
input_scale,
gate_up_weights,
gate_up_scale,
down_weights,
down_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
topk=self.top_k,
out_dtype=hidden_states.dtype,
expert_offset=expert_offset,
num_experts=num_experts,
renormalize=self.renormalize)
output = output.unflatten(0, input_size[:-1])
return output


class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder):
"""triton fused moe blocked f8 builder."""

@staticmethod
def build(top_k: int,
num_experts: int,
renormalize: bool = False,
block_size: int = 128,
out_dtype: torch.dtype = torch.float16):
"""build from mlp."""
return TritonFusedMoEBlockedF8Impl(top_k=top_k,
num_experts=num_experts,
renormalize=renormalize,
block_size=block_size,
out_dtype=out_dtype)
6 changes: 6 additions & 0 deletions lmdeploy/pytorch/backends/cuda/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def get_layer_impl_builder(cls, layer_type: OpType):
elif layer_type == OpType.FusedMoEW8A8:
from .moe import TritonFusedMoEW8A8Builder
return TritonFusedMoEW8A8Builder
elif layer_type == OpType.FusedMoEBlockedF8:
from .moe import TritonFusedMoEBlockedF8Builder
return TritonFusedMoEBlockedF8Builder
elif layer_type == OpType.LinearBlockedF8:
from .blockedf8_modules import TritonLinearBlockedF8Builder
return TritonLinearBlockedF8Builder
else:
logger.debug(
f'Op {layer_type} fallback to default implementation.')
Expand Down
45 changes: 45 additions & 0 deletions lmdeploy/pytorch/backends/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,48 @@ def build(top_k: int,
out_dtype: torch.dtype = torch.float16):
"""build from mlp."""
raise NotImplementedError


class FusedMoEBlockedF8Impl(ABC):
"""fused moe blocked f8 implementation."""

def update_weights(self, gate_up_weights: torch.Tensor,
down_weights: torch.Tensor, gate_up_scale: torch.Tensor,
down_scale: torch.Tensor):
"""update weights."""
return gate_up_weights, down_weights, gate_up_scale, down_scale

def support_ep(self):
"""support expert parallelism."""
return False

def ep_expert_list(self, world_size: int, rank: int):
"""experts list of current rank."""
raise NotImplementedError('Not Implemented.')

@abstractmethod
def forward(self,
hidden_states: torch.Tensor,
input_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.LongTensor,
gate_up_weights: torch.Tensor,
gate_up_scale: torch.Tensor,
down_weights: torch.Tensor,
down_scale: torch.Tensor,
expert_list: List[int] = None):
"""forward."""
raise NotImplementedError


class FusedMoEBlockedF8Builder(ABC):
"""fused moe blocked f8 builder."""

@staticmethod
@abstractmethod
def build(top_k: int,
num_experts: int,
renormalize: bool = False,
out_dtype: torch.dtype = torch.float16):
"""build from mlp."""
raise NotImplementedError
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/configurations/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class DeepseekV2ModelConfigBuilder(AutoModelConfigBuilder):
@classmethod
def condition(cls, hf_config):
"""config."""
return hf_config.model_type == 'deepseek_v2'
return hf_config.model_type in ['deepseek_v3', 'deepseek_v2']

@classmethod
def build(cls, hf_config, model_path: str = None, **kwargs):
Expand Down
Loading

0 comments on commit ac509e8

Please sign in to comment.