-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval_scrolls.py
183 lines (155 loc) · 7.1 KB
/
eval_scrolls.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import os
import argparse
import json
import shutil
import evaluate
from datasets import load_dataset
from huggingface_hub import hf_hub_download
DATASETS = [
"narrative_qa",
"qasper",
"summ_screen_fd",
"gov_report",
"qmsum",
"contract_nli",
"quality",
"quality_hard",
]
# usage: python eval_scrolls.py --split validation --dataset_name DATASET_NAME --predictions PREDICTIONS_JSON --metrics_output_dir METRICS_OUTPUT_DIR
def main(args, raise_on_errors=False):
"""
If raise_on_errors is True, raises ValueError on verification errors (after dumping the error descriptions).
Otherwise, exists with an error code
"""
predictions = args.predictions
dataset_name = args.dataset_name
verify_only = args.verify_only
if not verify_only:
# Loading metrics
scrolls_metric = evaluate.load(download_metric(), dataset_name)
# Downloading and loading the dataset from the hub
load_dataset_kwargs = {
"path": "tau/scrolls",
"name": dataset_name if dataset_name != "quality_hard" else "quality",
"hard_only": None if dataset_name != "quality_hard" else True,
"data_files": {"test": args.test_data_file} if args.test_data_file is not None else None,
}
if args.cache_dir is not None:
load_dataset_kwargs["cache_dir"] = args.cache_dir
load_dataset_kwargs["split"] = args.split
seq2seq_dataset = load_dataset(**load_dataset_kwargs)
if not verify_only:
assert all(
example["output"] is not None for example in seq2seq_dataset
), "Make sure to load data with gold outputs"
# Prepare reference
untokenized_dataset = drop_duplicates_in_input(seq2seq_dataset)
id_to_labels = {instance["id"]: instance["outputs"] for instance in untokenized_dataset}
# Prepare predictions
if isinstance(predictions, str):
with open(predictions) as f:
id_to_pred = json.load(f)
else:
id_to_pred = predictions
# Special handling for quality_hard to prevent redundant keys error
if dataset_name == "quality_hard":
quality_dataset = load_dataset(path="tau/scrolls", name="quality", split=load_dataset_kwargs["split"])
for id_ in quality_dataset["id"]:
if id_ not in id_to_labels:
id_to_pred.pop(id_, None)
# Check for format errors
errors, details = verify(id_to_pred, id_to_labels)
out_file_path = get_metrics_filename(args.metrics_output_dir, dataset_name)
os.makedirs(args.metrics_output_dir, exist_ok=True)
if len(errors) == 0 and not verify_only:
# Compute metrics
metrics = scrolls_metric.compute(**scrolls_metric.convert_from_map_format(id_to_pred, id_to_labels))
with open(out_file_path, mode="w") as f:
json.dump(metrics, f, indent=4)
if args.internal_call:
return metrics
else:
print(json.dumps(metrics, indent=4))
elif len(errors) > 0:
# Output errors
errors_msg = errors[0] if len(errors) == 1 else " ".join(f"{i}: {err}" for i, err in enumerate(errors))
print(json.dumps(errors, indent=4))
print(f"See details in: {out_file_path}")
with open(out_file_path, mode="w") as f:
json.dump({"errors": errors, "details": details}, f, indent=4)
if raise_on_errors:
raise ValueError(f"Failed to evaluate due to: {errors_msg}")
exit(os.EX_DATAERR)
def download_metric():
scrolls_metric_path = hf_hub_download(repo_id="tau/scrolls", filename="metrics/scrolls.py", repo_type="dataset")
updated_scrolls_metric_path = (
os.path.dirname(scrolls_metric_path) + os.path.basename(scrolls_metric_path).replace(".", "_") + ".py"
)
shutil.copy(scrolls_metric_path, updated_scrolls_metric_path)
return updated_scrolls_metric_path
# Copied from baselines/src/utils/duplicates.py
def drop_duplicates_in_input(untokenized_dataset):
indices_to_keep = []
id_to_idx = {}
outputs = []
for i, (id_, output) in enumerate(zip(untokenized_dataset["id"], untokenized_dataset["output"])):
if id_ in id_to_idx:
outputs[id_to_idx[id_]].append(output)
continue
indices_to_keep.append(i)
id_to_idx[id_] = len(outputs)
outputs.append([output])
untokenized_dataset = untokenized_dataset.select(indices_to_keep).flatten_indices()
untokenized_dataset = untokenized_dataset.remove_columns("output")
untokenized_dataset = untokenized_dataset.add_column("outputs", outputs)
return untokenized_dataset
def get_metrics_filename(outdir, dataset_name):
return os.path.join(outdir, f"{dataset_name}_metrics.json")
def verify(id_to_pred, id_to_labels):
errors = []
details = {"missing_keys": [], "redundant_keys": []}
if not isinstance(id_to_pred, dict):
errors.append('The predictions must be saved a JSON object: {"id1": "prediction1", "id2": "prediction2", ...}')
else:
if not all(isinstance(key, str) for key in id_to_pred.keys()):
errors.append("All keys of the predictions dictionary must be strings")
if not all(isinstance(value, str) for value in id_to_pred.values()):
errors.append("All values of the predictions dictionary must be strings")
if len(errors) == 0:
predictions_keys, reference_keys = set(id_to_pred.keys()), set(id_to_labels.keys())
missing_keys = reference_keys - predictions_keys
redundant_keys = predictions_keys - reference_keys
if len(missing_keys) > 0:
details["missing_keys"] = list(missing_keys)
errors.append(f"There are missing example IDs.")
else:
del details["missing_keys"]
if len(redundant_keys) > 0:
details["redundant_keys"] = list(redundant_keys)
errors.append(f"There are redundant example IDs.")
else:
del details["redundant_keys"]
return errors, details
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluate SCROLLS predictions per dataset")
parser.add_argument(
"--predictions", type=str, help="Path to the predictions file or the actual predictions", required=True
)
parser.add_argument("--dataset_name", type=str, help="Name of the dataset", choices=DATASETS, required=True)
parser.add_argument("--metrics_output_dir", type=str, help="Directory of the output metrics file", required=True)
parser.add_argument("--split", type=str, help="The split to evaluate on", default="test")
parser.add_argument("--internal_call", type=str, help="For internal use", default=False)
parser.add_argument(
"--test_data_file", type=str, help="Defining the path to the test file containing the answers", default=None
)
parser.add_argument(
"--cache_dir", type=str, help="Cache dir for the dataset download", default=None, required=False
)
parser.add_argument(
"--verify_only",
action="store_true",
help="Don't evaluate, just verify that the format and ids are correct.",
default=False,
)
args = parser.parse_args()
main(args)