-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Keras newbie: Why doesn't loss decrease when fine-tuning LLaMA 3.2 3B on TPU v3-8 with Keras? #2042
Comments
Thanks for posting the issue @aeltorio! llama_model.fit(
multi_turn_dataset['train'].with_format("tf")['text'],
validation_data=multi_turn_dataset['validation'].with_format("tf")['text'],
epochs=epochs,
verbose="auto"
) In general, models don't accept raw text as input, they need proper tokenization and preprocessing to work. So tokenization and properly shifted targets would be key for the task you're trying to take up (or in general for any LM task), and that might also explain the constant loss. I'm linking these amazing HF resources for you to play around with, they should help you get the job done!
|
@harshaljanjani I fact in my code I already tokenized the input in the from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(f"{source_model}", chat_template=llama31_template)
def formatting_prompts_func(messages):
convos = messages["conversation"]
texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
tokenized = [tokenizer.apply_chat_template(convo, tokenize = True, add_generation_prompt = False) for convo in convos]
return { "text" : texts, "tokenized": tokenized }
pass
multi_turn_dataset = multi_turn_dataset.map(
formatting_prompts_func,
batched=True,
) But when I tried to use it with: llama_model.fit(multi_turn_dataset['train'].with_format("tf")['tokenized'],
validation_data=multi_turn_dataset['validation'].with_format("tf")['tokenized'],
epochs=epochs,
verbose="auto",) I get an error about the conversion:
|
TPU probably doesn't support Ragged tensor. |
@innat , thank you for your help. |
Hello @aeltorio, I believe @innat might be on the right track too. Main fixes that you could try in the implementation: a) Preprocessing changes: # BEFORE (problematic):
llama_model.fit(multi_turn_dataset['train'].with_format("tf")['tokenized'], ...)
# AFTER (in the direction to fix the issue):
def process_sequence(tokens):
# convert ragged to fixed-size by padding/truncating to `max_length`
input_ids = np.array(tokens[:max_length])
if len(input_ids) < max_length:
padding = np.zeros(max_length - len(input_ids), dtype=np.int32)
input_ids = np.concatenate([input_ids, padding]) b) Add attention masking (the attention mask makes it sure that the model ignores padding during computation): # try adding attention mask generation to handle padding properly
attention_mask = (input_ids != tokenizer.pad_token_id).astype(np.int32) c) Created proper shifted labels for causal LM: # create labels by shifting the input sequence left by one position
labels = np.roll(input_ids, -1)
labels[-1] = tokenizer.pad_token_id d) Modify the loss function to respect padding: def compute_loss(labels, logits, attention_mask):
loss = keras.losses.sparse_categorical_crossentropy(
labels, logits, from_logits=True
)
mask = tf.cast(attention_mask, loss.dtype)
loss = loss * mask
return tf.reduce_sum(loss) / tf.reduce_sum(mask) I could be missing something obviously, but these are the changes you'd need to make in the direction of fixing the issue; I haven't really tried to reproduce the outputs due to lack of HF access to the models. Note regarding TPU and RaggedTensor support:
You could obviously improve upon the padding to fixed length or bucketing by sequence length, but this is to give you an idea that the fixes it would take to solve this issue must convert these variable-length sequences ( Resources for RaggedTensors:
|
In jax, I think ragged tensor is not supported also. |
I'm opening this issue after asking the question in the Discussions @innat wrote that it is better to open a ticket here… So this is an issue !
I'm trying to fine-tune Meta's LLaMA 3.2 3B Instruct model using Keras on TPU v3-8 (Kaggle). While the code runs without errors, the loss remains constant during training.
Context:
Technical Setup:
Here's the relevant code copied from full notebook:
Issue:
The loss stays constant at ~2.9 throughout training with no signs of learning.
Questions:
Thank you for your help!
The text was updated successfully, but these errors were encountered: