forked from cgarciae/nanoGPT-jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsample.py
82 lines (74 loc) · 2.99 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""
Sample from a trained model
"""
import os
import pickle
import tiktoken
from model import GPTConfig, GPT
from pathlib import Path
from flax import serialization
import jax.numpy as jnp
import jax
import orbax.checkpoint as orbax
from utils import print_compiling
# -----------------------------------------------------------------------------
out_dir = 'out'
start = "\n" # or "<|endoftext|>" or whatever you like
num_samples = 10 # number of samples to draw
max_new_tokens = 500 # number of tokens generated in each sample
temperature = 0.8 # higher temperature (up to 1) is more random, lower (down to 0) means more greedy
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
dtype = 'bfloat16' # 'float32' or 'bfloat16' or 'float16'
compile = False # use PyTorch 2.0 to compile the model to be faster
exec(open('configurator.py').read()) # overrides from command line or config file
# -----------------------------------------------------------------------------
checkpoint_path = Path(out_dir, 'checkpoint')
checkpoint_manager = orbax.CheckpointManager(
checkpoint_path,
checkpointers=orbax.Checkpointer(orbax.PyTreeCheckpointHandler()),
)
latest_step = checkpoint_manager.latest_step()
assert latest_step is not None, "No checkpoint found in out_dir!"
# model
checkpoint = checkpoint_manager.restore(latest_step, items=None)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
params = checkpoint['state']['params']
# look for the meta pickle in case it is available in the dataset folder
load_meta = False
if 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these...
meta_path = Path('data', checkpoint['config']['dataset'], 'meta.pkl')
load_meta = meta_path.exists()
if load_meta:
print(f"Loading meta from {meta_path}...")
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
# TODO want to make this more general to arbitrary encoder/decoder schemes
stoi, itos = meta['stoi'], meta['itos']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
else:
# ok let's assume gpt-2 encodings by default
print("No meta.pkl found, assuming GPT-2 encodings...")
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
decode = lambda l: enc.decode(l)
# encode the beginning of the prompt
start_ids = encode(start)
x = jnp.array(start_ids, dtype=jnp.int32)[None]
key = jax.random.PRNGKey(seed)
@jax.jit
@print_compiling
def _sample(params, key, tokens) -> jax.Array:
return model.generate(
key, params, tokens, max_new_tokens=max_new_tokens, top_k=top_k, temperature=temperature)
def sample(params, key, tokens) -> str:
tokens = _sample(params, key, tokens)
return decode(tokens[0])
# run generation
for k in range(num_samples):
step_key = jax.random.fold_in(key, k)
sample_str = sample(params, step_key, x)
print(sample_str)
print('---------------')