Skip to content

Commit

Permalink
support uint32 in tokenization util
Browse files Browse the repository at this point in the history
  • Loading branch information
luciaquirke committed Aug 12, 2024
1 parent fe1f6cf commit 37a6fd5
Showing 1 changed file with 33 additions and 26 deletions.
59 changes: 33 additions & 26 deletions tokengrams/utils/tokenize_hf_dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
import numpy as np
from typing import Generator
from typing import Union, Generator
from argparse import ArgumentParser
from datasets import load_dataset, Dataset, DatasetDict, IterableDataset, IterableDatasetDict, concatenate_datasets
from transformers import AutoTokenizer
from typing import Union
import numpy as np
import os
import multiprocessing as mp

from datasets import load_dataset, Dataset, DatasetDict, IterableDataset, IterableDatasetDict, concatenate_datasets
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from tqdm import tqdm


def get_vocab_size(tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast) -> int:
"""Get the vocab size of the tokenizer."""
if hasattr(tokenizer, 'vocab_size'):
return tokenizer.vocab_size
elif hasattr(tokenizer, 'get_vocab'):
return len(tokenizer.get_vocab())
else:
return len(tokenizer)

def get_dataset_iterator(data: Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict], batch_size: int):
"""Get an iterator for the dataset, handling different dataset types."""
if isinstance(data, IterableDataset):
Expand All @@ -18,7 +28,7 @@ def get_dataset_iterator(data: Union[Dataset, DatasetDict, IterableDataset, Iter
for i in range(0, len(data), batch_size)
)
elif isinstance(data, DatasetDict) or isinstance(data, IterableDatasetDict):
# Concatenate all available splits in the DatasetDict
# Concatenate all available splits
concatenated_dataset = concatenate_datasets(list(data.values()))
return concatenated_dataset.iter(batch_size=batch_size)
else:
Expand All @@ -37,15 +47,16 @@ def tokenize_batch(args):

def tokenize_and_write_mmap(
data: Generator,
tokenizer: AutoTokenizer,
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
output_prefix: str,
text_key: str = "text",
batch_size: int = 1000,
buffer_size: int = 10_000_000,
eos_token: int | None = None,
num_workers: int = 4
num_workers: int = 4,
dtype: np.dtype = np.dtype(np.uint16)
):
mmap = np.memmap(f'{output_prefix}.bin', dtype=np.uint16, mode='w+', shape=(buffer_size,))
mmap = np.memmap(f'{output_prefix}.bin', dtype=dtype, mode='w+', shape=(buffer_size,))

total_tokens = 0
pool = mp.Pool(num_workers)
Expand All @@ -56,7 +67,7 @@ def tokenize_and_write_mmap(
new_tokens = pool.map(tokenize_batch, tokenize_args)[0]

if total_tokens + len(new_tokens) > mmap.shape[0]:
mmap = np.memmap(f'{output_prefix}.bin', dtype=np.uint16, mode='r+', shape=(mmap.shape[0] * 2,))
mmap = np.memmap(f'{output_prefix}.bin', dtype=dtype, mode='r+', shape=(mmap.shape[0] * 2,))

mmap[total_tokens:total_tokens + len(new_tokens)] = new_tokens
total_tokens += len(new_tokens)
Expand All @@ -67,9 +78,9 @@ def tokenize_and_write_mmap(

# Resize mmap to actual size
with open(f'{output_prefix}.bin', 'r+b') as f:
f.truncate(total_tokens * np.uint16().itemsize)
f.truncate(total_tokens * dtype.itemsize)

mmap = np.memmap(f'{output_prefix}.bin', dtype=np.uint16, mode='r+', shape=(total_tokens,))
mmap = np.memmap(f'{output_prefix}.bin', dtype=dtype, mode='r+', shape=(total_tokens,))
mmap.flush()

pbar.close()
Expand All @@ -93,7 +104,7 @@ def get_args(input_args=None):
"--split",
type=str,
default="train",
help="Split of the Hugging Face dataset",
help="Hugging Face dataset split",
)
group.add_argument(
"--stream",
Expand All @@ -109,7 +120,7 @@ def get_args(input_args=None):
group.add_argument(
"--append-eod",
action="store_true",
help="Append an <eod> token to the end of a document.",
help="Append an <eod> token to the end of each sequence.",
)
group = parser.add_argument_group(title="output data")
group.add_argument(
Expand All @@ -123,12 +134,6 @@ def get_args(input_args=None):
"--workers", type=int, default=1, help="Number of worker processes to launch"
)
args = parser.parse_args(input_args)
args.keep_empty = False

# some default/dummy values for the tokenizer
args.rank = 0
args.make_vocab_size_divisible_by = 128
args.model_parallel_size = 1

return args

Expand All @@ -137,10 +142,13 @@ def main(input_args=None):
batch_size = 10_000

# Get tokenizer
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name) # type: ignore
eos_token = tokenizer.eos_token_id if args.append_eod else None # type: ignore
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
eos_token = tokenizer.eos_token_id if args.append_eod else None
vocab_size = get_vocab_size(tokenizer)
if vocab_size > 2**32:
raise ValueError(f"Tokenizer vocab size {vocab_size} is too large for uint32")

# Get data
# Get dataset iterator
os.makedirs('.cache', exist_ok=True)
dataset = load_dataset(
args.dataset_name,
Expand All @@ -149,7 +157,6 @@ def main(input_args=None):
split=args.split,
streaming=args.stream,
)

data = get_dataset_iterator(dataset, batch_size)

# Tokenize and save as memory-mapped array
Expand All @@ -159,9 +166,9 @@ def main(input_args=None):
args.output_prefix,
eos_token=eos_token,
batch_size=batch_size,
num_workers=args.workers
num_workers=args.workers,
dtype=np.dtype(np.uint16 if vocab_size < 2**16 else np.uint32)
)

print(f"{total_tokens} tokens saved as memory-mapped array in {args.output_prefix}.bin")

if __name__ == "__main__":
Expand Down

0 comments on commit 37a6fd5

Please sign in to comment.