Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

An error occurred when I used flash_decoding_chunkllama in run_chunkllama_100k.py #17

Open
smilelite opened this issue May 31, 2024 · 5 comments

Comments

@smilelite
Copy link

I used flash decoding in run_chunkllama_100k.py like

from chunkllama_attn_replace import replace_with_chunkllama

from flash_decoding_chunkllama import replace_with_chunkllama

the model is Llama2-70B or Llama3-70B
transformers==4.40.1
torch==2.2.1

the Error is
File "/mnt/ChunkLlama/flash_decoding_chunkllama.py", line 327, in forward
key_cache[:, kv_seq_len - key_states.shape[-2]:kv_seq_len, :, :] = key_states.transpose(1, 2)
RuntimeError: The expanded size of the tensor (1024) must match the existing size (128) at non-singleton dimension 3. Target sizes: [1, 26141, 8, 1024]. Tensor sizes: [26141, 8, 128]

By the way, I added "**kwargs," in LlamaModel_forward and replaced "if self._attn_implementation:" by "if self.config._attn_implementation == "flash_attention_2":" .

@ChenxinAn-fdu
Copy link
Contributor

ChenxinAn-fdu commented Jun 3, 2024

Hi ! Does it work for Llama2-7B and Llama3-8b based on your environment?

@smilelite
Copy link
Author

When I use replace_with_chunkllama from chunkllama_attn_replace.py based on Llama2-7B、Llama3-8B、Llama2-70B、LLama3-70B, it runs normally. However, when I use replace_with_chunkllama from flash_decoding_chunkllama, it doesn't work.

@smilelite
Copy link
Author

Llama2-7B is ok. The error was caused by the use of GQA, which resulted in inconsistent calculation of the head dim in flash_decoding_chunkllama.py and the headdim in modeling_llama.py.

@ChenxinAn-fdu
Copy link
Contributor

Thank you so much for letting me know! I will update the code to support GQA😊

@ChenxinAn-fdu
Copy link
Contributor

So sorry for the late response. I was too busy in the past two weeks.
The code works well with Llama3 now. Plz try it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants