Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keras newbie: Why doesn't loss decrease when fine-tuning LLaMA 3.2 3B on TPU v3-8 with Keras? #2042

Open
aeltorio opened this issue Jan 14, 2025 · 7 comments

Comments

@aeltorio
Copy link

I'm opening this issue after asking the question in the Discussions @innat wrote that it is better to open a ticket here… So this is an issue !

I'm trying to fine-tune Meta's LLaMA 3.2 3B Instruct model using Keras on TPU v3-8 (Kaggle). While the code runs without errors, the loss remains constant during training.

Context:

  • My wife, a history and geography teacher, writes short evaluations for her students every three months. Each evaluation includes three indicators: attitude in class (0-10), personal work (0-10), and participation in class (0-10), along with the trimestrial period (1, 2, or 3) and the mean score (0-20). It wrapped her evaluation in a HF dataset (french)
  • I successfully fine-tuned the same model with the same dataset using HuggingFace's SFTTrainer and Unsloth GPU optimizations (working model here) and the notebook used for fine-tuning.
  • The model works well, and a demo runs on HF Spaces at this link it is very slow on CPU but if you have access to a ZeroGPU just duplicate the space and it is really fast.
  • The current notebook designed for running on Kaggle with Google TPU v3x8 is available here.

Technical Setup:

  • TPU v3-8 on Kaggle
  • tensorflow==2.16.2
  • keras==3.0.5
  • Base model: meta-llama/Llama-3.2-3B-Instruct

Here's the relevant code copied from full notebook:

# TPU Setup
devices = keras.distribution.list_devices()
device_mesh = keras.distribution.DeviceMesh(
        (1, 8),
        ["batch", "model"],
        devices=keras.distribution.list_devices())
layout_map = keras_hub.models.Llama3Backbone.get_layout_map(device_mesh) # default layout_map
distrib = keras.distribution.ModelParallel(layout_map=layout_map, batch_dim_name="batch")
keras.distribution.set_distribution(distrib)

# Model initialization
llama_model = keras_hub.models.Llama3CausalLM.from_preset("hf://meta-llama/Llama-3.2-3B-Instruct")

llama_model.backbone.enable_lora(rank=8)

llama_model.preprocessor.sequence_length = 256
optimizer = keras.optimizers.AdamW(
    learning_rate=2e-4, 
    weight_decay=0.01,
)
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])


# dataset adaptation
# sample dataset content
# multi_turn_dataset is an instance of HF datasets.DatasetDict
# multi_turn_dataset['train'][42]['text']
# "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 July 2024\n\nVous êtes une IA assistant les enseignants d'histoire-géographie en rédigeant à leur place une appréciation personnalisée pour leur élève en fonction de ses performances. Votre appréciation doit être en français formel et impersonnel. Votre appréciation doit être bienveillante, constructive, et aider l'élève à comprendre ses points forts et les axes d'amélioration. Votre appréciation doit comporter de 8 à 250 caractères. Votre appréciation ne doit jamais comporter les valeurs des notes. <|eot_id|><|start_header_id|>user<|end_header_id|>\n\nVeuillez rédiger une appréciation en moins de 250 caractères pour le premier trimestre pour cet élève qui a eu 14.0 de moyenne, j'ai évalué son comportement à 4.8/10, sa participation à 3.1/10 et son travail à 3.9/10. Les notes ne doivent pas apparaître dans l'appréciation.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nL'ensemble est correct mais vous pourrez certainement progresser au second trimestre en faisant davantage d'efforts de participation et dans le travail personnel. Attention également à ne pas se déconcentrer en cours.<|eot_id|>"

llama_model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
    #jit_compile="auto",
    #auto_scale_loss=True,
)

# Training
llama_model.fit(multi_turn_dataset['train'].with_format("tf")['text'], 
                validation_data=multi_turn_dataset['validation'].with_format("tf")['text'], 
                epochs=3,
                verbose="auto",)

Issue:
The loss stays constant at ~2.9 throughout training with no signs of learning.

Questions:

  1. Why isn't the model learning despite using similar parameters to my successful HuggingFace implementation?
  2. Are there specific considerations when fine-tuning LLMs with Keras on TPU?
  3. Is my loss function appropriate for this task?

Thank you for your help!

@harshaljanjani
Copy link

Thanks for posting the issue @aeltorio!
The issue might be in how the raw text is direct passed into model.fit(). Looking at the training part of the code:

llama_model.fit(
multi_turn_dataset['train'].with_format("tf")['text'], 
validation_data=multi_turn_dataset['validation'].with_format("tf")['text'], 
epochs=epochs, 
verbose="auto"
)

In general, models don't accept raw text as input, they need proper tokenization and preprocessing to work. So tokenization and properly shifted targets would be key for the task you're trying to take up (or in general for any LM task), and that might also explain the constant loss. I'm linking these amazing HF resources for you to play around with, they should help you get the job done!

@aeltorio
Copy link
Author

@harshaljanjani
Thanks for your help,
I tried with this Kaggle Notebook to use the tokenized version…

I fact in my code I already tokenized the input in the tokenized column with:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(f"{source_model}", chat_template=llama31_template)
def formatting_prompts_func(messages):
    convos = messages["conversation"]
    texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
    tokenized = [tokenizer.apply_chat_template(convo, tokenize = True, add_generation_prompt = False) for convo in convos]
    return { "text" : texts, "tokenized": tokenized }
pass

multi_turn_dataset = multi_turn_dataset.map(
    formatting_prompts_func,
    batched=True,
)

But when I tried to use it with:

llama_model.fit(multi_turn_dataset['train'].with_format("tf")['tokenized'], 
                validation_data=multi_turn_dataset['validation'].with_format("tf")['tokenized'], 
                epochs=epochs,
                verbose="auto",)

I get an error about the conversion:

Failed to convert elements of tf.RaggedTensor(values=Tensor("RaggedFromVariant/RaggedTensorFromVariant:1", shape=(None,), dtype=int64), row_splits=Tensor("RaggedFromVariant/RaggedTensorFromVariant:0", shape=(None,), dtype=int32)) to Tensor. Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.
    
    Arguments received by Llama3Tokenizer.call():
      • inputs=tf.RaggedTensor(values=Tensor("RaggedFromVariant/RaggedTensorFromVariant:1", shape=(None,), dtype=int64), row_splits=Tensor("RaggedFromVariant/RaggedTensorFromVariant:0", shape=(None,), dtype=int32))
      • args=<class 'inspect._empty'>
      • training=None
      • kwargs=<class 'inspect._empty'>

@innat
Copy link

innat commented Jan 17, 2025

TPU probably doesn't support Ragged tensor.

@aeltorio
Copy link
Author

@innat , thank you for your help.
According to https://www.tensorflow.org/api_docs/python/tf/dtypes both int32 and int64 are supported dtype…
So I don't understand message Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.

@harshaljanjani
Copy link

Hello @aeltorio, I believe @innat might be on the right track too.
I'll break down the key changes and explain the TPU/RaggedTensor situation in a list of a few steps you might try. The core issue was the direct use of RaggedTensors in the HuggingFace tokenizer with TPU. TPUs don't support RaggedTensors; they require fixed-size tensors (tensorflow/tensorflow#27170 (comment)). What you're talking about is supported dtypes, which isn't the issue at hand here. RaggedTensor support is also avoided in the latest Keras versions (keras-team/keras#20290 (comment)), so you may understand why it's widely ignored.

Main fixes that you could try in the implementation:

a) Preprocessing changes:

# BEFORE (problematic):
llama_model.fit(multi_turn_dataset['train'].with_format("tf")['tokenized'], ...)

# AFTER (in the direction to fix the issue):
def process_sequence(tokens):
    # convert ragged to fixed-size by padding/truncating to `max_length`
    input_ids = np.array(tokens[:max_length])
    if len(input_ids) < max_length:
        padding = np.zeros(max_length - len(input_ids), dtype=np.int32)
        input_ids = np.concatenate([input_ids, padding])

b) Add attention masking (the attention mask makes it sure that the model ignores padding during computation):

# try adding attention mask generation to handle padding properly
attention_mask = (input_ids != tokenizer.pad_token_id).astype(np.int32)

c) Created proper shifted labels for causal LM:

# create labels by shifting the input sequence left by one position
labels = np.roll(input_ids, -1)
labels[-1] = tokenizer.pad_token_id

d) Modify the loss function to respect padding:

def compute_loss(labels, logits, attention_mask):
    loss = keras.losses.sparse_categorical_crossentropy(
        labels, logits, from_logits=True
    )
    mask = tf.cast(attention_mask, loss.dtype)
    loss = loss * mask
    return tf.reduce_sum(loss) / tf.reduce_sum(mask)

I could be missing something obviously, but these are the changes you'd need to make in the direction of fixing the issue; I haven't really tried to reproduce the outputs due to lack of HF access to the models.

Note regarding TPU and RaggedTensor support:

  • TPUs are designed for batch processing with consistent shapes and hence don't support RaggedTensors.
  • TPU hardware accelerators require predictable memory access patterns.
  • TPUs use the XLA compiler, which requires static shapes during compilation.

You could obviously improve upon the padding to fixed length or bucketing by sequence length, but this is to give you an idea that the fixes it would take to solve this issue must convert these variable-length sequences (RaggedTensors) into fixed-length tensors with padding and attention masks, which is the standard approach for training language models on TPU.

Resources for RaggedTensors:

@innat
Copy link

innat commented Jan 18, 2025

@innat
Copy link

innat commented Jan 18, 2025

In jax, I think ragged tensor is not supported also.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants