diff --git a/optimum/exporters/ipex/cache_utils.py b/optimum/exporters/ipex/cache_utils.py old mode 100644 new mode 100755 index c8868716a4..20f3cc608b --- a/optimum/exporters/ipex/cache_utils.py +++ b/optimum/exporters/ipex/cache_utils.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch from intel_extension_for_pytorch.llm.modules import PagedAttention @@ -95,6 +95,79 @@ def __init__( for _ in range(self.num_hidden_layers) ] + def update_for_prefill( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + batch_size: int, + length_list: Optional[List], + ): + all_block_indices = [] + all_slot_offsets = [] + for i in range(batch_size): + num_blocks = (length_list[i] + self.block_size - 1) // self.block_size + for b_idx in range(num_blocks): + if self.block_tables[i][b_idx] == -1: + # need a free block + self.block_tables[i][b_idx] = self.free_blocks.pop(0) + + slots_range = torch.arange(length_list[i], device=key_states.device) + block_indices = slots_range // self.block_size + slot_offsets = slots_range % self.block_size + all_block_indices.append(self.block_tables[i][block_indices]) + all_slot_offsets.append(slot_offsets) + + all_block_indices = torch.cat(all_block_indices) + all_slot_offsets = torch.cat(all_slot_offsets) + slots_tensor = all_block_indices * self.block_size + all_slot_offsets + # Update the cache + PagedAttention.reshape_and_cache( + key_states, + value_states, + self.kv_cache[layer_idx][0], + self.kv_cache[layer_idx][1], + slots_tensor, + ) + + # Update the number of seen tokens + if layer_idx == self.num_hidden_layers - 1: + for i in range(batch_size): + self._seen_tokens[i] += length_list[i] + + def update_for_decode( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + batch_size: int, + ): + slots = [] + for i in range(batch_size): + start_block_idx = self._seen_tokens[i] // self.block_size + num_blocks = (self._seen_tokens[i] + self.block_size) // self.block_size + for b_idx in range(start_block_idx, num_blocks): + if self.block_tables[i][b_idx] == -1: + # need a free block + self.block_tables[i][b_idx] = self.free_blocks.pop(0) + block_idx = (self._seen_tokens[i]) // self.block_size + slot_offset_in_block = (self._seen_tokens[i]) % self.block_size + slots.append(self.block_tables[i][block_idx].item() * self.block_size + slot_offset_in_block) + + # Update the cache + PagedAttention.reshape_and_cache( + key_states, + value_states, + self.kv_cache[layer_idx][0], + self.kv_cache[layer_idx][1], + torch.tensor(slots, device=key_states.device), + ) + + # Update the number of seen tokens + if layer_idx == self.num_hidden_layers - 1: + for i in range(batch_size): + self._seen_tokens[i] += 1 + def update( self, key_states: torch.Tensor, @@ -102,7 +175,7 @@ def update( layer_idx: int, attention_mask: torch.Tensor, position_ids: torch.Tensor, - input_lens: torch.Tensor, + length_list: Optional[List], ) -> Tuple[torch.Tensor, torch.Tensor]: """ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. @@ -117,39 +190,14 @@ def update( Return: A tuple containing the updated key and value states. """ + batch_size = position_ids.shape[0] - slots = [] if self.get_seq_length() == 0: # prefill - num_slots = input_lens.tolist() + self.update_for_prefill(key_states, value_states, layer_idx, batch_size, length_list) else: # decode - num_slots = [1] * batch_size - for i in range(batch_size): - start_block_idx = self._seen_tokens[i] // self.block_size - num_blocks = (self._seen_tokens[i] + num_slots[i] + self.block_size - 1) // self.block_size - for b_idx in range(start_block_idx, num_blocks): - if self.block_tables[i][b_idx] == -1: - # need a free block - self.block_tables[i][b_idx] = self.free_blocks.pop(0) - for slot in range(num_slots[i]): - block_idx = (self._seen_tokens[i] + slot) // self.block_size - slot_offset_in_block = (self._seen_tokens[i] + slot) % self.block_size - slots.append(self.block_tables[i][block_idx].item() * self.block_size + slot_offset_in_block) - - # Update the cache - PagedAttention.reshape_and_cache( - key_states, - value_states, - self.kv_cache[layer_idx][0], - self.kv_cache[layer_idx][1], - torch.tensor(slots, device=key_states.device), - ) - - # Update the number of seen tokens - if layer_idx == self.num_hidden_layers - 1: - for i in range(batch_size): - self._seen_tokens[i] += num_slots[i] + self.update_for_decode(key_states, value_states, layer_idx, batch_size) return self.kv_cache[layer_idx][0], self.kv_cache[layer_idx][1] diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py old mode 100644 new mode 100755 index b8cfc5772e..b062528438 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -123,7 +123,7 @@ def _llama_model_forward( else: hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) input_lens = attention_mask.cumsum(-1)[:, -1] - + lens_list = input_lens.tolist() for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -137,6 +137,7 @@ def _llama_model_forward( use_cache=use_cache, position_embeddings=position_embeddings, input_lens=input_lens.int(), + lens_list=lens_list, ) hidden_states = layer_outputs[0] @@ -210,6 +211,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, input_lens: Optional[torch.Tensor] = None, + lens_list: Optional[List] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if past_key_value is None and kwargs.get("layer_past", None) is not None: @@ -227,15 +229,13 @@ def forward( if past_key_value is not None: key_cache, value_cache = past_key_value.update( - key, value, self.layer_idx, attention_mask, position_ids, input_lens + key, value, self.layer_idx, attention_mask, position_ids, lens_list ) attn_output = torch.empty_like(query) if past_len == 0: # prefill, remove padding - seq_len_tensor = torch.cat( - (torch.tensor([0], device=input_lens.device, dtype=torch.int), input_lens.cumsum(-1).int()) - ) + seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) varlen_attention( query.contiguous() if query.device.type == "xpu" else query, key.contiguous() if key.device.type == "xpu" else key,