-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
An error occurred when I used flash_decoding_chunkllama in run_chunkllama_100k.py #17
Comments
Hi ! Does it work for Llama2-7B and Llama3-8b based on your environment? |
When I use replace_with_chunkllama from chunkllama_attn_replace.py based on Llama2-7B、Llama3-8B、Llama2-70B、LLama3-70B, it runs normally. However, when I use replace_with_chunkllama from flash_decoding_chunkllama, it doesn't work. |
Llama2-7B is ok. The error was caused by the use of GQA, which resulted in inconsistent calculation of the head dim in flash_decoding_chunkllama.py and the headdim in modeling_llama.py. |
Thank you so much for letting me know! I will update the code to support GQA😊 |
So sorry for the late response. I was too busy in the past two weeks. |
I used flash decoding in run_chunkllama_100k.py like
from chunkllama_attn_replace import replace_with_chunkllama
from flash_decoding_chunkllama import replace_with_chunkllama
the model is Llama2-70B or Llama3-70B
transformers==4.40.1
torch==2.2.1
the Error is
File "/mnt/ChunkLlama/flash_decoding_chunkllama.py", line 327, in forward
key_cache[:, kv_seq_len - key_states.shape[-2]:kv_seq_len, :, :] = key_states.transpose(1, 2)
RuntimeError: The expanded size of the tensor (1024) must match the existing size (128) at non-singleton dimension 3. Target sizes: [1, 26141, 8, 1024]. Tensor sizes: [26141, 8, 128]
By the way, I added "**kwargs," in LlamaModel_forward and replaced "if self._attn_implementation:" by "if self.config._attn_implementation == "flash_attention_2":" .
The text was updated successfully, but these errors were encountered: