Skip to content

Commit

Permalink
add tokenization to readme
Browse files Browse the repository at this point in the history
  • Loading branch information
luciaquirke committed Aug 13, 2024
1 parent b8e6901 commit 33d0939
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 124 deletions.
62 changes: 39 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,40 @@ pip install tokengrams
```

# Usage
## Preparing data

Tokengrams builds indices from on-disk corpora of either u16 or u32 tokens, supporting a maximum vocabulary size of 2^32. In practice, however, vocabulary size is limited by the length of the largest word size vector the machine can allocate in memory.

Corpora with vocabulary sizes smaller than 2^16 must use u16 tokens.
Use a dataset formatted as either u16 or u32 tokens, or prepare one from a HuggingFace dataset.

## Building an index
```python
from tokengrams import MemmapIndex
# Get pre-tokenized dataset
from huggingface_hub import HfApi, hf_hub_download

# Get a dataset formatted as u16 tokens
hf_hub_download(
repo_id="EleutherAI/pile-standard-pythia-preshuffled",
repo_type="dataset",
filename="document-00000-of-00020.bin",
local_dir="."
)
```
```python
# Tokenize HF dataset
from tokengrams import tokenize_hf_dataset
from datasets import load_dataset
from transformers import AutoTokenizer

tokenize_hf_dataset(
dataset=load_dataset("wikitext", "wikitext-103-raw-v1"),
tokenizer=AutoTokenizer.from_pretrained("EleutherAI/pythia-160m"),
output_path="wikitext.bin",
text_key="text",
append_eod=True,
workers=1,
)
```

## Building an index
```python
from tokengrams import MemmapIndex

# Create a new index from an on-disk corpus of u16 tokens and save it to a .idx file.
# Set verbose to true to include a progress bar for the index sort.
Expand All @@ -52,7 +69,7 @@ print(index.count(tokenizer.encode("hello world")))
index = MemmapIndex(
"document-00000-of-00020.bin",
"document-00000-of-00020.idx",
vocab=2**17
vocab=2**16
)
```

Expand Down Expand Up @@ -99,28 +116,27 @@ Many systems struggle with memory mapping extremely large tables (e.g. 40 billio

```python
from tokengrams import ShardedMemmapIndex
from pathlib import Path
from huggingface_hub import HfApi, hf_hub_download

# Get sharded corpus of u16 tokens
repo_id = "EleutherAI/pile-standard-pythia-preshuffled"
repo_type = "dataset"

bin_files = [
file for file in HfApi().list_repo_files(repo_id, repo_type=repo_type)
if file.endswith('.bin')
]

for file in bin_files:
hf_hub_download(repo_id=repo_id, repo_type=repo_type, filename=file, local_dir=".")

# Build sharded index
files = [
(file, f'{file.rstrip('.bin')}.idx'),
bin_files = []
for file in HfApi().list_repo_files(repo_id, repo_type="dataset"):
if file.endswith('.bin'):
bin_files.append(file)
hf_hub_download(repo_id, repo_type="dataset", filename=file, local_dir=".")

paths = [
(file, f'{file.rstrip('.bin')}.idx')
for file in bin_files
]
index = ShardedMemmapIndex.build(files, vocab=2**17, verbose=True)

index = ShardedMemmapIndex.build(paths, vocab=2**16, verbose=True)
```
### Tokens

Tokengrams builds indices from on-disk corpora of either u16 or u32 tokens, supporting a maximum vocabulary size of 2^32. In practice, however, vocabulary size is limited by the length of the largest word size vector the machine can allocate in memory.

Corpora with vocabulary sizes smaller than 2^16 must use u16 tokens.

## Performance

Expand Down
2 changes: 1 addition & 1 deletion tokengrams/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
ShardedMemmapIndex,
)

from .utils import tokenize_hf_dataset
from .utils.tokenize_hf_dataset import tokenize_hf_dataset
138 changes: 38 additions & 100 deletions tokengrams/utils/tokenize_hf_dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,41 @@
import os
from argparse import ArgumentParser
import multiprocessing as mp
from typing import Union, Generator
from pathlib import Path

import numpy as np
from datasets import load_dataset, Dataset, DatasetDict, IterableDataset, IterableDatasetDict, concatenate_datasets
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict, concatenate_datasets
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from tqdm import tqdm


def tokenize_hf_dataset(
dataset: Dataset | DatasetDict | IterableDataset | IterableDatasetDict,
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
output_path: Path,
text_key="text",
append_eod: bool = False,
workers: int = 1,
):
batch_size = 10_000
eos_token = tokenizer.eos_token_id if append_eod else None
vocab_size = get_vocab_size(tokenizer)
if vocab_size > 2**32:
raise ValueError(f"Tokenizer vocab size {vocab_size} is too large for uint32")

data = get_dataset_iterator(dataset, batch_size)

# Tokenize and save as memory-mapped array
total_tokens = tokenize_and_write_mmap(
data,
tokenizer,
output_path,
eos_token=eos_token,
text_key=text_key,
num_workers=workers,
dtype=np.dtype(np.uint16 if vocab_size < 2**16 else np.uint32)
)
print(f"{total_tokens} tokens saved to {output_path}")

def get_vocab_size(tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast) -> int:
"""Get the vocab size of the tokenizer."""
if hasattr(tokenizer, 'vocab_size'):
Expand Down Expand Up @@ -48,15 +75,14 @@ def tokenize_batch(args):
def tokenize_and_write_mmap(
data: Generator,
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
output_prefix: str,
output_path: Path,
text_key: str = "text",
batch_size: int = 1000,
buffer_size: int = 10_000_000,
eos_token: int | None = None,
num_workers: int = 4,
dtype: np.dtype = np.dtype(np.uint16)
):
mmap = np.memmap(f'{output_prefix}.bin', dtype=dtype, mode='w+', shape=(buffer_size,))
mmap = np.memmap(output_path, dtype=dtype, mode='w+', shape=(buffer_size,))

total_tokens = 0
pool = mp.Pool(num_workers)
Expand All @@ -67,109 +93,21 @@ def tokenize_and_write_mmap(
new_tokens = pool.map(tokenize_batch, tokenize_args)[0]

if total_tokens + len(new_tokens) > mmap.shape[0]:
mmap = np.memmap(f'{output_prefix}.bin', dtype=dtype, mode='r+', shape=(mmap.shape[0] * 2,))
mmap = np.memmap(output_path, dtype=dtype, mode='r+', shape=(mmap.shape[0] * 2,))

mmap[total_tokens:total_tokens + len(new_tokens)] = new_tokens
total_tokens += len(new_tokens)
pbar.update(batch_size)
pbar.update(len(batch))

pool.close()
pool.join()

# Resize mmap to actual size
with open(f'{output_prefix}.bin', 'r+b') as f:
with open(output_path, 'r+b') as f:
f.truncate(total_tokens * dtype.itemsize)

mmap = np.memmap(f'{output_prefix}.bin', dtype=dtype, mode='r+', shape=(total_tokens,))
mmap = np.memmap(output_path, dtype=dtype, mode='r+', shape=(total_tokens,))
mmap.flush()

pbar.close()
return total_tokens

def get_args(input_args=None):
parser = ArgumentParser()
group = parser.add_argument_group(title="input data")
group.add_argument(
"--dataset_name",
type=str,
required=True,
help="Name of the Hugging Face dataset to use",
)
group.add_argument(
"--config_name",
type=str,
default=None
)
group.add_argument(
"--split",
type=str,
default="train",
help="Hugging Face dataset split",
)
group.add_argument(
"--stream",
action="store_true",
)
group = parser.add_argument_group(title="tokenizer")
group.add_argument(
"--tokenizer_name",
type=str,
required=True,
help="Name or path of the pre-trained tokenizer to use",
)
group.add_argument(
"--append-eod",
action="store_true",
help="Append an <eod> token to the end of each sequence.",
)
group = parser.add_argument_group(title="output data")
group.add_argument(
"--output-prefix",
type=str,
required=True,
help="Path to binary output file without suffix",
)
group = parser.add_argument_group(title="runtime")
group.add_argument(
"--workers", type=int, default=1, help="Number of worker processes to launch"
)
args = parser.parse_args(input_args)

return args

def main(input_args=None):
args = get_args(input_args)
batch_size = 10_000

# Get tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
eos_token = tokenizer.eos_token_id if args.append_eod else None
vocab_size = get_vocab_size(tokenizer)
if vocab_size > 2**32:
raise ValueError(f"Tokenizer vocab size {vocab_size} is too large for uint32")

# Get dataset iterator
os.makedirs('.cache', exist_ok=True)
dataset = load_dataset(
args.dataset_name,
args.config_name,
cache_dir=os.path.join(os.getcwd(), '.cache'),
split=args.split,
streaming=args.stream,
)
data = get_dataset_iterator(dataset, batch_size)

# Tokenize and save as memory-mapped array
total_tokens = tokenize_and_write_mmap(
data,
tokenizer,
args.output_prefix,
eos_token=eos_token,
batch_size=batch_size,
num_workers=args.workers,
dtype=np.dtype(np.uint16 if vocab_size < 2**16 else np.uint32)
)
print(f"{total_tokens} tokens saved as memory-mapped array in {args.output_prefix}.bin")

if __name__ == "__main__":
main()
return total_tokens

0 comments on commit 33d0939

Please sign in to comment.