Skip to content

Commit

Permalink
[Model] Reuse RoPE positions for Deepseek-v2 model (#3084)
Browse files Browse the repository at this point in the history
This PR updates the Deepseek-v2 model implementation with the updated
RoPE position arrays. Prior to this PR, we will query the RoPE positions
for every single layer, while in fact these arrays can be reused and
thus only one query is sufficient.
  • Loading branch information
MasterJH5574 authored Jan 8, 2025
1 parent bf70bea commit cf7ae82
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,13 @@ def __init__(self, config: DeepseekV2Config):
self.softmax_scale = self.softmax_scale * mscale * mscale
self.rotary_emb = DeepseekV2YarnRotaryEmbedding(config)

def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
def forward(
self,
hidden_states: Tensor,
paged_kv_cache: PagedKVCache,
layer_id: int,
query_positions: Tensor,
):
b, s, _ = hidden_states.shape

if self.q_lora_rank is None:
Expand Down Expand Up @@ -260,7 +266,7 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id:
kv, [self.qk_nope_head_dim], axis=-1
) # (b, s, num_heads, qk_nope_head_dim), (b, s, num_heads, v_head_dim)

q_pe, k_pe = self.rotary_emb(q_pe, k_pe, paged_kv_cache.get_query_positions(s))
q_pe, k_pe = self.rotary_emb(q_pe, k_pe, query_positions)

@T.prim_func
def inplace_q(var_q: T.handle, var_pe: T.handle):
Expand Down Expand Up @@ -471,9 +477,15 @@ def _set(layer, hint):
self.tensor_parallel_shards = config.tensor_parallel_shards
_set_tp()

def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
def forward(
self,
hidden_states: Tensor,
paged_kv_cache: PagedKVCache,
layer_id: int,
query_positions: Tensor,
):
out = self.input_layernorm(hidden_states)
out = self.self_attn(out, paged_kv_cache, layer_id)
out = self.self_attn(out, paged_kv_cache, layer_id, query_positions)
hidden_states = self._apply_residual(out, residual=hidden_states)
out = self.post_attention_layernorm(hidden_states)
out = self.mlp(out) # type: ignore[operator]
Expand All @@ -499,8 +511,10 @@ def __init__(self, config: DeepseekV2Config):

def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):
hidden_states = inputs
print(f"inputs.shape = {inputs.shape}")
query_positions = paged_kv_cache.get_query_positions(inputs.shape[0] * inputs.shape[1])
for layer_id, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, paged_kv_cache, layer_id)
hidden_states = layer(hidden_states, paged_kv_cache, layer_id, query_positions)
hidden_states = self.norm(hidden_states)
return hidden_states

Expand Down

0 comments on commit cf7ae82

Please sign in to comment.