From 7f984ee21db98c5fdfac91d3bea32f3aec59c750 Mon Sep 17 00:00:00 2001 From: khai-meetkai <117131523+khai-meetkai@users.noreply.github.com> Date: Mon, 9 Dec 2024 13:52:24 +0700 Subject: [PATCH] change padding strategy to reduce time & add the option to ignore_cached in training (#297) --- functionary/train/custom_datasets.py | 45 ++++++++++++++++++++++++---- functionary/train/train.py | 4 +++ 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/functionary/train/custom_datasets.py b/functionary/train/custom_datasets.py index a820f5c..294cabd 100644 --- a/functionary/train/custom_datasets.py +++ b/functionary/train/custom_datasets.py @@ -147,7 +147,7 @@ def read_dataset(model_path, data_args, training_args, tokenizer, ds_type): pack_length = data_args.pack_length if data_args.pack_length > 0 else None data_class_args = { - "ignore_cached": False, + "ignore_cached": data_args.ignore_cached, "keep_assistant_prefix": False, } if data_args.packing: @@ -197,6 +197,9 @@ def read_dataset(model_path, data_args, training_args, tokenizer, ds_type): torch.distributed.barrier() # allow other ranks to execute # All ranks will read the processed data from cached_path created by rank 0 + data_class_args["ignore_cached"] = ( + False # other processes will read from cached results + ) ds = data_class(None, tokenizer, **data_class_args) if local_rank == 0: if data_args.packing: @@ -824,7 +827,7 @@ def __init__( self.data_points = map_raw_data_to_input_dic( raw_data=raw_data, tokenizer=tokenizer, - padding="max_length", + padding="do_not_pad", batch_size=batch_size, keep_assistant_prefix=keep_assistant_prefix, ) @@ -834,9 +837,41 @@ def __init__( def __getitem__(self, i) -> Dict[str, torch.Tensor]: dp = self.data_points[i] - result = {} - for key in dp: - result[key] = torch.tensor(dp[key]) + input_ids, label_ids, attention_mask = ( + dp["input_ids"], + dp["labels"], + dp["attention_mask"], + ) + # asert all attention_mask == 1 + assert sum(attention_mask) == len(attention_mask) + # padding to max_length + pad_length = self.tokenizer.model_max_length - len(dp["input_ids"]) + if pad_length > 0: + if self.tokenizer.padding_side == "right": + input_ids = input_ids + [ + self.tokenizer.pad_token_id for _ in range(pad_length) + ] + label_ids = label_ids + [-100 for _ in range(pad_length)] + attention_mask = attention_mask + [0 for _ in range(pad_length)] + else: + input_ids = [ + self.tokenizer.pad_token_id for _ in range(pad_length) + ] + input_ids + label_ids = [-100 for _ in range(pad_length)] + label_ids + attention_mask = [0 for _ in range(pad_length)] + attention_mask + + assert ( + len(input_ids) + == len(label_ids) + == len(attention_mask) + == self.tokenizer.model_max_length + ) + + result = { + "input_ids": torch.tensor(input_ids), + "labels": torch.tensor(label_ids), + "attention_mask": torch.tensor(attention_mask), + } return result diff --git a/functionary/train/train.py b/functionary/train/train.py index 0fd319f..ee95e20 100644 --- a/functionary/train/train.py +++ b/functionary/train/train.py @@ -100,6 +100,10 @@ class DataArguments: default=False, metadata={"help": "Whether to use lazy loading for the dataset or not"}, ) + ignore_cached: bool = field( + default=False, + metadata={"help": "Whether to ignore cached tokenized data or not"}, + ) @dataclass