Skip to content

Commit

Permalink
change padding strategy to reduce time & add the option to ignore_cac…
Browse files Browse the repository at this point in the history
…hed in training (#297)
  • Loading branch information
khai-meetkai authored Dec 9, 2024
1 parent dc43ea2 commit 7f984ee
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 5 deletions.
45 changes: 40 additions & 5 deletions functionary/train/custom_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def read_dataset(model_path, data_args, training_args, tokenizer, ds_type):
pack_length = data_args.pack_length if data_args.pack_length > 0 else None

data_class_args = {
"ignore_cached": False,
"ignore_cached": data_args.ignore_cached,
"keep_assistant_prefix": False,
}
if data_args.packing:
Expand Down Expand Up @@ -197,6 +197,9 @@ def read_dataset(model_path, data_args, training_args, tokenizer, ds_type):
torch.distributed.barrier() # allow other ranks to execute

# All ranks will read the processed data from cached_path created by rank 0
data_class_args["ignore_cached"] = (
False # other processes will read from cached results
)
ds = data_class(None, tokenizer, **data_class_args)
if local_rank == 0:
if data_args.packing:
Expand Down Expand Up @@ -824,7 +827,7 @@ def __init__(
self.data_points = map_raw_data_to_input_dic(
raw_data=raw_data,
tokenizer=tokenizer,
padding="max_length",
padding="do_not_pad",
batch_size=batch_size,
keep_assistant_prefix=keep_assistant_prefix,
)
Expand All @@ -834,9 +837,41 @@ def __init__(

def __getitem__(self, i) -> Dict[str, torch.Tensor]:
dp = self.data_points[i]
result = {}
for key in dp:
result[key] = torch.tensor(dp[key])
input_ids, label_ids, attention_mask = (
dp["input_ids"],
dp["labels"],
dp["attention_mask"],
)
# asert all attention_mask == 1
assert sum(attention_mask) == len(attention_mask)
# padding to max_length
pad_length = self.tokenizer.model_max_length - len(dp["input_ids"])
if pad_length > 0:
if self.tokenizer.padding_side == "right":
input_ids = input_ids + [
self.tokenizer.pad_token_id for _ in range(pad_length)
]
label_ids = label_ids + [-100 for _ in range(pad_length)]
attention_mask = attention_mask + [0 for _ in range(pad_length)]
else:
input_ids = [
self.tokenizer.pad_token_id for _ in range(pad_length)
] + input_ids
label_ids = [-100 for _ in range(pad_length)] + label_ids
attention_mask = [0 for _ in range(pad_length)] + attention_mask

assert (
len(input_ids)
== len(label_ids)
== len(attention_mask)
== self.tokenizer.model_max_length
)

result = {
"input_ids": torch.tensor(input_ids),
"labels": torch.tensor(label_ids),
"attention_mask": torch.tensor(attention_mask),
}
return result


Expand Down
4 changes: 4 additions & 0 deletions functionary/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ class DataArguments:
default=False,
metadata={"help": "Whether to use lazy loading for the dataset or not"},
)
ignore_cached: bool = field(
default=False,
metadata={"help": "Whether to ignore cached tokenized data or not"},
)


@dataclass
Expand Down

0 comments on commit 7f984ee

Please sign in to comment.