From cf7ae82ce0cfebf6c5805f641463bcc8d58e297f Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 8 Jan 2025 16:19:29 +0800 Subject: [PATCH] [Model] Reuse RoPE positions for Deepseek-v2 model (#3084) This PR updates the Deepseek-v2 model implementation with the updated RoPE position arrays. Prior to this PR, we will query the RoPE positions for every single layer, while in fact these arrays can be reused and thus only one query is sufficient. --- .../model/deepseek_v2/deepseek_v2_model.py | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py b/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py index c2cecc3621..e144dd77db 100644 --- a/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py +++ b/python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py @@ -229,7 +229,13 @@ def __init__(self, config: DeepseekV2Config): self.softmax_scale = self.softmax_scale * mscale * mscale self.rotary_emb = DeepseekV2YarnRotaryEmbedding(config) - def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + def forward( + self, + hidden_states: Tensor, + paged_kv_cache: PagedKVCache, + layer_id: int, + query_positions: Tensor, + ): b, s, _ = hidden_states.shape if self.q_lora_rank is None: @@ -260,7 +266,7 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: kv, [self.qk_nope_head_dim], axis=-1 ) # (b, s, num_heads, qk_nope_head_dim), (b, s, num_heads, v_head_dim) - q_pe, k_pe = self.rotary_emb(q_pe, k_pe, paged_kv_cache.get_query_positions(s)) + q_pe, k_pe = self.rotary_emb(q_pe, k_pe, query_positions) @T.prim_func def inplace_q(var_q: T.handle, var_pe: T.handle): @@ -471,9 +477,15 @@ def _set(layer, hint): self.tensor_parallel_shards = config.tensor_parallel_shards _set_tp() - def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + def forward( + self, + hidden_states: Tensor, + paged_kv_cache: PagedKVCache, + layer_id: int, + query_positions: Tensor, + ): out = self.input_layernorm(hidden_states) - out = self.self_attn(out, paged_kv_cache, layer_id) + out = self.self_attn(out, paged_kv_cache, layer_id, query_positions) hidden_states = self._apply_residual(out, residual=hidden_states) out = self.post_attention_layernorm(hidden_states) out = self.mlp(out) # type: ignore[operator] @@ -499,8 +511,10 @@ def __init__(self, config: DeepseekV2Config): def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): hidden_states = inputs + print(f"inputs.shape = {inputs.shape}") + query_positions = paged_kv_cache.get_query_positions(inputs.shape[0] * inputs.shape[1]) for layer_id, layer in enumerate(self.layers): - hidden_states = layer(hidden_states, paged_kv_cache, layer_id) + hidden_states = layer(hidden_states, paged_kv_cache, layer_id, query_positions) hidden_states = self.norm(hidden_states) return hidden_states