Skip to content

Commit

Permalink
update_hf_trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
boss-chanon committed Feb 15, 2024
1 parent f171b39 commit 3094982
Showing 1 changed file with 17 additions and 20 deletions.
37 changes: 17 additions & 20 deletions src/model/scripts/hf_trainer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from datasets import load_from_disk
from torch.utils.data import IterableDataset


import os
import random

Expand All @@ -21,10 +20,10 @@ class ModelArguments:

@dataclass
class DataArguments:
data_path: List[str] = field(
data_path: Optional[List[str]] = field(
default_factory=list, metadata={"help": "Path to the tokenized data."}
)
data_weights: List[float] = field(default_factory=list)
data_weights: Optional[List[float]] = field(default_factory=list)
train_split: Optional[str] = field(default="train")
eval_split: Optional[str] = field(default="eval")

Expand All @@ -42,18 +41,6 @@ class TrainingArguments(transformers.TrainingArguments):
)


class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""

def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids = [instance["input_ids"] for instance in instances]
input_ids = torch.tensor(input_ids) # type: ignore
return {
"input_ids": input_ids, # type: ignore
"labels": input_ids, # type: ignore
}


class CombinedDataset(IterableDataset):
def __init__(self, datasets, seed, weights=None):
self._seed = seed
Expand Down Expand Up @@ -89,6 +76,18 @@ def __next__(self):
return next(dataset)


class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""

def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids = [instance["input_ids"] for instance in instances]
input_ids = torch.tensor(input_ids) # type: ignore
return {
"input_ids": input_ids, # type: ignore
"labels": input_ids, # type: ignore
}


def load_dataset(paths, weights, split, seed=42):
datasets = []
for path in paths:
Expand Down Expand Up @@ -120,28 +119,26 @@ def train():
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

model = transformers.LlamaForCausalLM.from_pretrained(
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
)

if model_args.tokenizer_name_or_path is None:
model_args.tokenizer_name_or_path = model_args.model_name_or_path

tokenizer = transformers.LlamaTokenizer.from_pretrained(
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.tokenizer_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)

# if tokenizer is not None and model.vocab_size != len(tokenizer):
# model.resize_token_embeddings(len(tokenizer))

data_module = make_supervised_data_module(
data_args=data_args, seed=training_args.data_seed
)

trainer = Trainer(
model=model, tokenizer=tokenizer, args=training_args, **data_module
)
Expand Down

0 comments on commit 3094982

Please sign in to comment.