-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrun_metasupervised.py
149 lines (130 loc) · 4.75 KB
/
run_metasupervised.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Copyright 2023 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import logging
import os
import time
import hydra
import torch
from neptune.integrations.python_logger import NeptuneHandler
from omegaconf import DictConfig, open_dict
from meta.evaluation import run_metasupervised_evaluation
from meta.tools.logger import get_logger_from_config
from meta.utils import (
display_info,
get_current_git_commit_hash,
set_seeds,
)
logging.basicConfig(level="NOTSET", format="%(message)s", datefmt="[%X]")
log = logging.getLogger("rich")
@hydra.main(
config_path="config", config_name="metasupervised", version_base="1.2"
) # n.b. config_name can be overridden by config-name command line flag
def run_metasupervised(cfg: DictConfig) -> None: # noqa: CCR001
display_info(cfg)
set_seeds(cfg.seed)
task = hydra.utils.instantiate(cfg.task)
log.info("Setting up task (splitting datasets etc.)")
# load 0-shot data if using auxillary zero-shot predictions
load_zero_shot = cfg.surrogate.model_config.aux_pred
task.setup_datasets(load_zero_shot=load_zero_shot)
commit_hash = get_current_git_commit_hash()
log.info(f"Commit hash: {commit_hash}")
with open_dict(cfg):
cfg.commit_hash = commit_hash
neptune_tags = [
str(cfg.task.task_name),
str(cfg.surrogate.name),
str(type(task)), # task.task_type,
commit_hash,
]
logger = get_logger_from_config(cfg, file_system=None, neptune_tags=neptune_tags)
if cfg.logging.type == "neptune":
log.addHandler(NeptuneHandler(run=logger.run)) # type: ignore
log.info("Instantiating surrogate")
surrogate = hydra.utils.instantiate(cfg.surrogate, _recursive_=True)
run_name = "".join(cfg.logging.tags)
surrogate.set_dir(
os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"results/",
run_name,
str(cfg.seed),
)
)
if task.has_metadata:
log.info("Providing surrogate with task metadata")
surrogate.set_metadata(task.metadata)
# val_data should be determiend automatically in the evaluation function,
# but to get it directly:
# val_data = task.data_splits.get("validation", None)
if cfg.evaluate_first:
log.info("Evaluating surrogate at initialisation")
init_metrics, _ = run_metasupervised_evaluation(task, surrogate)
log.info(
"Init metrics: {}".format(
"\t".join(f"{k}: {v:.3f}" for k, v in init_metrics.items())
)
)
logger.write(init_metrics, label="init_metrics", timestep=0)
if cfg.exit_after_first_eval:
assert cfg.evaluate_first, "Must evaluate first to exit after first eval."
del surrogate
del task
torch.cuda.empty_cache()
gc.collect()
log.info("Run metasupervised exiting after first evaluation.")
exit()
log.info(
"Dataset sizes: "
+ " ".join(
[
f"{split}: {str(len(task.data_splits[split]))}"
for split in ["train", "validation"]
if split in task.data_splits
]
),
)
log.info("Training surrogate")
t0 = time.time()
eval_func = lambda surrogate_in: run_metasupervised_evaluation(
task, surrogate_in # noqa: F821
)
surrogate.fit(
task.data_splits["train"], cfg.seed, logger=logger, eval_func=eval_func
)
if cfg.evaluate_end:
t1 = time.time()
log.info("Evaluating surrogate at end")
train_end_metrics, _ = run_metasupervised_evaluation(
task, surrogate
)
train_end_metrics.update(surrogate.get_training_summary_metrics())
train_end_metrics.update(task.data_summary())
train_end_metrics.update({"train_time": t1 - t0})
log.info(
"Metrics: {}".format(
"\t".join(f"{k}: {v:.4f}" for k, v in train_end_metrics.items())
)
)
logger.write(train_end_metrics, label="end_metrics")
logger.close()
surrogate.cleanup()
del surrogate
del task
torch.cuda.empty_cache()
gc.collect()
log.info("Run metasupervised completed successfully")
if __name__ == "__main__":
run_metasupervised()