From f171b3938e95a5f99f7514ea099c8486c7edcfe2 Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 16 Feb 2024 00:20:46 +0700 Subject: [PATCH 1/2] update_to_same_lanta --- .../llama_thai_tokenizer/constants.py | 7 +++++++ .../tokenizers/spm_trainer.py | 13 ++++++++----- src/model/pyproject.toml | 1 + .../scripts/llama_thai_tokenizer/merge_tokenizer.py | 12 ++++++------ 4 files changed, 22 insertions(+), 11 deletions(-) create mode 100644 src/model/openthaigpt_pretraining_model/llama_thai_tokenizer/constants.py diff --git a/src/model/openthaigpt_pretraining_model/llama_thai_tokenizer/constants.py b/src/model/openthaigpt_pretraining_model/llama_thai_tokenizer/constants.py new file mode 100644 index 0000000..1961444 --- /dev/null +++ b/src/model/openthaigpt_pretraining_model/llama_thai_tokenizer/constants.py @@ -0,0 +1,7 @@ +import os + + +FILE_DIR = os.path.dirname(__file__) +LLAMA_TOKENIZER_DIR = "decapoda-research/llama-7b-hf" +THAI_SP_MODEL_DIR = f"{FILE_DIR}/thai_tokenizer/trained_bpe.model" +OUTPUT_HF_DIR = f"{FILE_DIR}/merged_tokenizer_hf" diff --git a/src/model/openthaigpt_pretraining_model/tokenizers/spm_trainer.py b/src/model/openthaigpt_pretraining_model/tokenizers/spm_trainer.py index 6c2e89e..4178d26 100644 --- a/src/model/openthaigpt_pretraining_model/tokenizers/spm_trainer.py +++ b/src/model/openthaigpt_pretraining_model/tokenizers/spm_trainer.py @@ -12,6 +12,7 @@ EOS_TOKEN = "" BOS_TOKEN = "" UNK_TOKEN = "" +USER_DEFINED_SYMBOLS = [] SPM_MODE = "spm" BPE_MODE = "bpe" @@ -65,6 +66,8 @@ def load_local_dataset(data_type, local_path): streaming=True, ) + # text_dataset = Dataset.from_dict(text_dataset[:int(len(text_dataset) * .5)]) + return text_dataset @@ -73,7 +76,7 @@ def train_tokenizer( vocab_size: int, num_docs: Optional[Union[str, int]] = None, num_proc: Optional[int] = os.cpu_count(), - streaming: bool = True, + is_slurm: bool = False, load_dataset_path: str = "oscar", load_dataset_name: str = "unshuffled_deduplicated_th", load_dataset_local_path: Optional[str] = None, @@ -89,7 +92,7 @@ def train_tokenizer( vocab_size (int): The size of the vocabulary to use when training the tokenizer. num_docs (int, optional): The number of documents to use from the input dataset. num_proc (int, optional): The number of CPU cores to use when training the tokenizer. Defaults to the number of available CPU cores. - streaming (bool, optional): Whether the code is running on a Slurm cluster. Defaults to False. + is_slurm (bool, optional): Whether the code is running on a Slurm cluster. Defaults to False. load_dataset_path (str, optional): The name of the Hugging Face dataset to load. Defaults to "oscar". load_dataset_name (str, optional): The name of the dataset split to use. Defaults to "unshuffled_deduplicated_th". load_dataset_local_path (str, optional): The path to a local directory containing the input data. If specified, the Hugging Face dataset is not used. Defaults to None. @@ -103,12 +106,12 @@ def train_tokenizer( KeyError(f"mode mush be {SPM_MODE} or {BPE_MODE}") if load_dataset_local_path is None: - if streaming: + if not is_slurm: text_dataset = load_dataset( path=load_dataset_path, name=load_dataset_name, split="train", - streaming=streaming, + streaming=not is_slurm, ) num_docs = len(text_dataset) if num_docs is None else num_docs @@ -149,7 +152,7 @@ def train_tokenizer( ), model_prefix=output_path + "/spm_tokenizer", vocab_size=vocab_size, - user_defined_symbols=[], + user_defined_symbols=USER_DEFINED_SYMBOLS, num_threads=num_proc, train_extremely_large_corpus=large_corpus, model_type=mode, diff --git a/src/model/pyproject.toml b/src/model/pyproject.toml index 72cb775..3052d51 100644 --- a/src/model/pyproject.toml +++ b/src/model/pyproject.toml @@ -31,4 +31,5 @@ dependencies = [ "peft>=0.3.0", "scipy<2.0.0", "tensorboard==2.*", + "nlpo3>=1.3.0", ] \ No newline at end of file diff --git a/src/model/scripts/llama_thai_tokenizer/merge_tokenizer.py b/src/model/scripts/llama_thai_tokenizer/merge_tokenizer.py index ab63dce..d5aa8e5 100644 --- a/src/model/scripts/llama_thai_tokenizer/merge_tokenizer.py +++ b/src/model/scripts/llama_thai_tokenizer/merge_tokenizer.py @@ -25,10 +25,10 @@ help="path to llama tokenizer", ) parser.add_argument( - "--thai_sp_path", + "--sp_path", type=str, default=THAI_SP_MODEL_DIR, - help="path to Thai tokenizer", + help="path to tokenizer to merge", ) parser.add_argument( "--output_path", @@ -38,16 +38,16 @@ ) args = parser.parse_args() - # call merge function - tokenizer = merge(args.llama_path, args.thai_sp_path, get_spm_tokenizer=True) + + tokenizer = merge(args.llama_path, args.sp_path, get_spm_tokenizer=True) os.makedirs(args.output_path, exist_ok=True) with open(args.output_path + "/spm_tokenizer.model", "wb") as f: f.write(tokenizer.SerializeToString()) tokenizer = LlamaTokenizer(vocab_file=args.output_path + "/spm_tokenizer.model") - # change special tokens + tokenizer.eos_token = EOS_TOKEN tokenizer.bos_token = BOS_TOKEN tokenizer.unk_token = UNK_TOKEN - # save model + tokenizer.save_pretrained(args.output_path) From 3094982060b4cf53b49f1c9742e510cbcd207f6b Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 16 Feb 2024 00:34:30 +0700 Subject: [PATCH 2/2] update_hf_trainer --- src/model/scripts/hf_trainer/train.py | 37 ++++++++++++--------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/src/model/scripts/hf_trainer/train.py b/src/model/scripts/hf_trainer/train.py index 288368c..cf46b47 100644 --- a/src/model/scripts/hf_trainer/train.py +++ b/src/model/scripts/hf_trainer/train.py @@ -7,7 +7,6 @@ from datasets import load_from_disk from torch.utils.data import IterableDataset - import os import random @@ -21,10 +20,10 @@ class ModelArguments: @dataclass class DataArguments: - data_path: List[str] = field( + data_path: Optional[List[str]] = field( default_factory=list, metadata={"help": "Path to the tokenized data."} ) - data_weights: List[float] = field(default_factory=list) + data_weights: Optional[List[float]] = field(default_factory=list) train_split: Optional[str] = field(default="train") eval_split: Optional[str] = field(default="eval") @@ -42,18 +41,6 @@ class TrainingArguments(transformers.TrainingArguments): ) -class DataCollatorForSupervisedDataset(object): - """Collate examples for supervised fine-tuning.""" - - def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: - input_ids = [instance["input_ids"] for instance in instances] - input_ids = torch.tensor(input_ids) # type: ignore - return { - "input_ids": input_ids, # type: ignore - "labels": input_ids, # type: ignore - } - - class CombinedDataset(IterableDataset): def __init__(self, datasets, seed, weights=None): self._seed = seed @@ -89,6 +76,18 @@ def __next__(self): return next(dataset) +class DataCollatorForSupervisedDataset(object): + """Collate examples for supervised fine-tuning.""" + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids = [instance["input_ids"] for instance in instances] + input_ids = torch.tensor(input_ids) # type: ignore + return { + "input_ids": input_ids, # type: ignore + "labels": input_ids, # type: ignore + } + + def load_dataset(paths, weights, split, seed=42): datasets = [] for path in paths: @@ -120,7 +119,7 @@ def train(): ) model_args, data_args, training_args = parser.parse_args_into_dataclasses() - model = transformers.LlamaForCausalLM.from_pretrained( + model = transformers.AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, ) @@ -128,7 +127,7 @@ def train(): if model_args.tokenizer_name_or_path is None: model_args.tokenizer_name_or_path = model_args.model_name_or_path - tokenizer = transformers.LlamaTokenizer.from_pretrained( + tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.tokenizer_name_or_path, cache_dir=training_args.cache_dir, model_max_length=training_args.model_max_length, @@ -136,12 +135,10 @@ def train(): use_fast=False, ) - # if tokenizer is not None and model.vocab_size != len(tokenizer): - # model.resize_token_embeddings(len(tokenizer)) - data_module = make_supervised_data_module( data_args=data_args, seed=training_args.data_seed ) + trainer = Trainer( model=model, tokenizer=tokenizer, args=training_args, **data_module )