flash_attention_2 2.7.2.post1 seems to crash when using torch.compile
and DataCollatorWithFlattening
#35588
Labels
torch.compile
and DataCollatorWithFlattening
#35588
System Info
transformers
version: 4.47.1Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
update to latest flash attention version (as the time of writing 2.7.2). this should be torch.compile compatible as described in https://github.com/Dao-AILab/flash-attention
load a model with fa2 (tested with opt and qwen)
use trainer with
DataCollatorWithFlattening
and train.this causes a crash with the following stacktrace:
the code works fine when not using compile.
the code doesn't crash when using compile but not using
DataCollatorWithFlattening
.when using compile and not using
DataCollatorWithFlattening
I am getting the following graph break with qwen2.5Expected behavior
the training shouldn't crash.
The text was updated successfully, but these errors were encountered: