Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support uint32 in tokenization util
Browse files Browse the repository at this point in the history
luciaquirke committed Aug 12, 2024
1 parent fe1f6cf commit a6d8b42
Showing 1 changed file with 23 additions and 17 deletions.
40 changes: 23 additions & 17 deletions tokengrams/utils/tokenize_hf_dataset.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,15 @@
import multiprocessing as mp
from tqdm import tqdm

def get_vocab_size(tokenizer: AutoTokenizer) -> int:
"""Get the vocab size of the tokenizer."""
if hasattr(tokenizer, 'vocab_size'):
return tokenizer.vocab_size # type: ignore
elif hasattr(tokenizer, 'get_vocab'):
return len(tokenizer.get_vocab()) # type: ignore
else:
return len(tokenizer) # type: ignore

def get_dataset_iterator(data: Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict], batch_size: int):
"""Get an iterator for the dataset, handling different dataset types."""
if isinstance(data, IterableDataset):
@@ -43,9 +52,10 @@ def tokenize_and_write_mmap(
batch_size: int = 1000,
buffer_size: int = 10_000_000,
eos_token: int | None = None,
num_workers: int = 4
num_workers: int = 4,
dtype: np.dtype = np.dtype(np.uint16)
):
mmap = np.memmap(f'{output_prefix}.bin', dtype=np.uint16, mode='w+', shape=(buffer_size,))
mmap = np.memmap(f'{output_prefix}.bin', dtype=dtype, mode='w+', shape=(buffer_size,))

total_tokens = 0
pool = mp.Pool(num_workers)
@@ -56,7 +66,7 @@ def tokenize_and_write_mmap(
new_tokens = pool.map(tokenize_batch, tokenize_args)[0]

if total_tokens + len(new_tokens) > mmap.shape[0]:
mmap = np.memmap(f'{output_prefix}.bin', dtype=np.uint16, mode='r+', shape=(mmap.shape[0] * 2,))
mmap = np.memmap(f'{output_prefix}.bin', dtype=dtype, mode='r+', shape=(mmap.shape[0] * 2,))

mmap[total_tokens:total_tokens + len(new_tokens)] = new_tokens
total_tokens += len(new_tokens)
@@ -67,9 +77,9 @@ def tokenize_and_write_mmap(

# Resize mmap to actual size
with open(f'{output_prefix}.bin', 'r+b') as f:
f.truncate(total_tokens * np.uint16().itemsize)
f.truncate(total_tokens * dtype.itemsize)

mmap = np.memmap(f'{output_prefix}.bin', dtype=np.uint16, mode='r+', shape=(total_tokens,))
mmap = np.memmap(f'{output_prefix}.bin', dtype=dtype, mode='r+', shape=(total_tokens,))
mmap.flush()

pbar.close()
@@ -93,7 +103,7 @@ def get_args(input_args=None):
"--split",
type=str,
default="train",
help="Split of the Hugging Face dataset",
help="Hugging Face dataset split",
)
group.add_argument(
"--stream",
@@ -109,7 +119,7 @@ def get_args(input_args=None):
group.add_argument(
"--append-eod",
action="store_true",
help="Append an <eod> token to the end of a document.",
help="Append an <eod> token to the end of each sequence.",
)
group = parser.add_argument_group(title="output data")
group.add_argument(
@@ -123,12 +133,6 @@ def get_args(input_args=None):
"--workers", type=int, default=1, help="Number of worker processes to launch"
)
args = parser.parse_args(input_args)
args.keep_empty = False

# some default/dummy values for the tokenizer
args.rank = 0
args.make_vocab_size_divisible_by = 128
args.model_parallel_size = 1

return args

@@ -139,8 +143,11 @@ def main(input_args=None):
# Get tokenizer
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name) # type: ignore
eos_token = tokenizer.eos_token_id if args.append_eod else None # type: ignore
vocab_size = get_vocab_size(tokenizer)
if vocab_size > 2**32:
raise ValueError(f"Tokenizer vocab size {vocab_size} is too large for uint32")

# Get data
# Get dataset iterator
os.makedirs('.cache', exist_ok=True)
dataset = load_dataset(
args.dataset_name,
@@ -149,7 +156,6 @@ def main(input_args=None):
split=args.split,
streaming=args.stream,
)

data = get_dataset_iterator(dataset, batch_size)

# Tokenize and save as memory-mapped array
@@ -159,9 +165,9 @@ def main(input_args=None):
args.output_prefix,
eos_token=eos_token,
batch_size=batch_size,
num_workers=args.workers
num_workers=args.workers,
dtype=np.dtype(np.uint16 if vocab_size < 2**16 else np.uint32)
)

print(f"{total_tokens} tokens saved as memory-mapped array in {args.output_prefix}.bin")

if __name__ == "__main__":

0 comments on commit a6d8b42

Please sign in to comment.