Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

finetune error #13

Open
MarsMeng1994 opened this issue Apr 30, 2024 · 2 comments
Open

finetune error #13

MarsMeng1994 opened this issue Apr 30, 2024 · 2 comments

Comments

@MarsMeng1994
Copy link

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:04<00:00, 2.06s/it] Loading data... Num of training samples: 5405 5405 Formatting inputs...Skip in lazy mode /cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/accelerate/accelerator.py:436: FutureWarning: Passing the following arguments to Acceleratoris deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches']). Please pass anaccelerate.DataLoaderConfiguration instead: dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False) warnings.warn( Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher. 0%| | 0/1600 [00:00<?, ?it/s]use_cache=Trueis incompatible with gradient checkpointing. Settinguse_cache=False... /cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants. warnings.warn( Traceback (most recent call last): File "train_chunkllama_16k.py", line 292, in <module> train() File "train_chunkllama_16k.py", line 286, in train trainer.train() File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/transformers/trainer.py", line 1539, in train return inner_training_loop( File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/transformers/trainer.py", line 1869, in _inner_training_loop tr_loss_step = self.training_step(model, inputs) File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/transformers/trainer.py", line 2772, in training_step loss = self.compute_loss(model, inputs) File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/transformers/trainer.py", line 2795, in compute_loss outputs = model(**inputs) File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/accelerate/utils/operations.py", line 825, in forward return model_forward(*args, **kwargs) File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/accelerate/utils/operations.py", line 813, in __call__ return convert_to_fp32(self.model_forward(*args, **kwargs)) File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast return func(*args, **kwargs) File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 1183, in forward outputs = self.model( File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 1060, in forward layer_outputs = self._gradient_checkpointing_func( File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/torch/_compile.py", line 24, in inner return torch._dynamo.disable(fn, recursive)(*args, **kwargs) File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn return fn(*args, **kwargs) File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/torch/_dynamo/external_utils.py", line 17, in inner return fn(*args, **kwargs) File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/torch/utils/checkpoint.py", line 451, in checkpoint return CheckpointFunction.apply(function, preserve, *args) File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/torch/autograd/function.py", line 539, in apply return super().apply(*args, **kwargs) # type: ignore[misc] File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/torch/utils/checkpoint.py", line 230, in forward outputs = run_function(*args) File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 798, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/cpfs01/shared/public/msm/workspace/software/miniconda3/envs/tmp/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 704, in forward cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) ValueError: too many values to unpack (expected 2)

but i think my env is right, cause i have successful run the needle test

my env:
torch 2.1.2
torchmetrics 1.3.0.post0
transformers 4.37.2
transformers-stream-generator 0.0.4
flash-attn 2.5.6

@ChenxinAn-fdu
Copy link
Contributor

ChenxinAn-fdu commented Apr 30, 2024

Please add attn_implementation="flash_attention_2" when loading the model Line 265

@ChenxinAn-fdu
Copy link
Contributor

This error is caused by LlamaFlashAttention2.forward not being correctly replaced by the new forward function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants