-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathrun_llava.py
107 lines (83 loc) · 3.85 KB
/
run_llava.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import os
import random
import numpy as np
from tqdm import tqdm
from datasets import load_dataset, concatenate_datasets
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from argparse import ArgumentParser
from utils.data_utils import load_yaml, construct_prompt, save_json, process_single_sample, CAT_SHORT2LONG
from utils.model_utils import call_llava_engine_df, llava_image_processor
from utils.eval_utils import parse_multi_choice_response, parse_open_response
def run_model(args, samples, model, call_model_engine_fn=None, tokenizer=None, processor=None):
out_samples = dict()
with torch.no_grad():
for sample in tqdm(samples):
response = call_model_engine_fn(args, sample, model, tokenizer, processor)
if sample['question_type'] == 'multiple-choice':
pred_ans = parse_multi_choice_response(response, sample['all_choices'], sample['index2ans'])
else: # open question
pred_ans = response
out_samples[sample['id']] = pred_ans
return out_samples
def set_seed(seed_value):
"""
Set the seed for PyTorch (both CPU and CUDA), Python, and NumPy for reproducible results.
:param seed_value: An integer value to be used as the seed.
"""
torch.manual_seed(seed_value)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value) # For multi-GPU setups
random.seed(seed_value)
np.random.seed(seed_value)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def main():
parser = ArgumentParser()
parser.add_argument('--output_path', type=str, default='llava1.5_13b_val.json',
help='name of saved json')
parser.add_argument('--config_path', type=str, default="configs/llava1.5.yaml")
parser.add_argument('--data_path', type=str, default="MMMU/MMMU") # hf dataset path.
parser.add_argument('--model_path', type=str, default="liuhaotian/llava-v1.5-13b")
parser.add_argument('--split', type=str, default='validation')
parser.add_argument('--seed', type=int, default=42)
args = parser.parse_args()
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
set_seed(args.seed)
print('llava_initializing...')
processor = None
call_model_engine = call_llava_engine_df
vis_process_func = llava_image_processor
# load config and process to one value
args.config = load_yaml(args.config_path)
for key, value in args.config.items():
if key != 'eval_params' and type(value) == list:
assert len(value) == 1, 'key {} has more than one value'.format(key)
args.config[key] = value[0]
# run for each subject
sub_dataset_list = []
for subject in CAT_SHORT2LONG.values():
sub_dataset = load_dataset(args.data_path, subject, split=args.split)
sub_dataset_list.append(sub_dataset)
# merge all dataset
dataset = concatenate_datasets(sub_dataset_list)
# load model
model_name = get_model_name_from_path(args.model_path)
tokenizer, model, vis_processors, _ = load_pretrained_model(args.model_path, None,
model_name)
samples = []
for sample in dataset:
sample = process_single_sample(sample)
sample = construct_prompt(sample, args.config)
if sample['image']:
sample['image'] = vis_process_func(sample['image'], vis_processors).to(device)
samples.append(sample)
# run ex
out_samples = run_model(args, samples, model, call_model_engine, tokenizer, processor)
save_json(args.output_path, out_samples)
# metric_dict.update({"num_example": len(out_samples)})
# save_json(save_result_path, metric_dict)
if __name__ == '__main__':
main()