-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathrun_eval_prm_rlhflow.py
85 lines (74 loc) · 2.92 KB
/
run_eval_prm_rlhflow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# example code to evaluate the `RLHFlow/Llama3.1-8B-PRM-Mistral-Data` PRM
"""
Suppose you have launch an vllm server, e.g., through:
```
vllm serve \
RLHFlow/Llama3.1-8B-PRM-Mistral-Data \
--served-model-name Llama3.1-8B-PRM-Mistral-Data \
--port 8000 \
--tensor-parallel-size 8 \
--dtype auto \
--api-key token-abc123 \
--enable-prefix-caching
```
Then you can run the following script to evaluate the model.
"""
import os
import numpy as np
import json
from tqdm import tqdm
from multiprocessing import Pool
from openai import OpenAI
from datasets import load_dataset
def main():
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="token-abc123",
)
os.makedirs('outputs/Llama3.1-8B-PRM-Mistral-Data', exist_ok=True)
def single_process(d):
steps = d['steps']
messages = []
for sdx, step in enumerate(steps):
if sdx == 0:
messages.append({'role': 'user', 'content': d['problem'] + '\n\n' + step})
else:
messages.append({'role': 'user', 'content': step})
completion = client.chat.completions.create(
model='Llama3.1-8B-PRM-Mistral-Data',
messages=messages,
n=1,
temperature=0.,
max_tokens=1,
)
judgment = completion.choices[0].message.content.strip().lower().startswith('+')
if not judgment:
return sdx
messages.append({'role': 'assistant', 'content': '+'})
return -1
configs = ['gsm8k', 'math', 'olympiadbench', 'omnimath']
for config in configs:
input_data = load_dataset('/cpfs01/user/zhengchujie.zcj/hf_datasets/Qwen/ProcessBench', split=config)
with Pool(32) as p:
predictions = list(tqdm(p.imap(single_process, input_data), total=len(input_data),
desc=f'Processing {config}', dynamic_ncols=True))
res_data = []
for idx, d in enumerate(input_data):
new_d = d.copy()
new_d['prediction'] = predictions[idx]
new_d['match'] = predictions[idx] == d['label']
res_data.append(new_d)
data1 = [e for e in res_data if e['label'] != -1]
data2 = [e for e in res_data if e['label'] == -1]
with open(f'outputs/Llama3.1-8B-PRM-Mistral-Data/{config}_error.jsonl', 'w') as f:
for e in data1:
f.write(json.dumps(e) + '\n')
with open(f'outputs/Llama3.1-8B-PRM-Mistral-Data/{config}_correct.jsonl', 'w') as f:
for e in data2:
f.write(json.dumps(e) + '\n')
acc1 = np.mean([e['match'] for e in data1]) * 100
acc2 = np.mean([e['match'] for e in data2]) * 100
f1 = 2 * acc1 * acc2 / (acc1 + acc2)
print(f'{config} error acc: {acc1:.1f}, correct acc: {acc2:.1f}, f1: {f1:.1f}')
if __name__ == '__main__':
main()