-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfinalize_scib.py
65 lines (55 loc) · 2.69 KB
/
finalize_scib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import shutil
from evaluator import recalculate_results
import numpy as np
from main import reset_random_seeds, load_data
import argparse
import os
import yaml
def run_scib(tmp, seed, adata):
if not os.path.exists(os.path.join(tmp, seed,"evaluation_metrics.csv")):
print(os.path.join(tmp, seed,))
reset_random_seeds(int(seed))
# recalculate scib-metrics.
if os.path.exists(os.path.join(tmp, seed, "embedding.npz")):
print(f"Calculate @ {os.path.join(tmp, seed)}")
emb = np.load(os.path.join(tmp, seed, "embedding.npz"))['arr_0']
results = recalculate_results(adata, emb, 12)
results.to_csv(os.path.join(tmp, seed, 'evaluation_metrics.csv'), index=None)
parser = argparse.ArgumentParser(description='Re-Evaluate with SCIB')
# "/local/home/tomap/scAugmentBench/augmentation-ablation-vol9/ImmHuman"
parser.add_argument('--dname_root', default='./', type=str,
help='Where to evaluate.')
parser.add_argument('--dataset', default='immune', type=str,
help='Where to evaluate.')
parser.add_argument('--project_directory', default='immune', type=str,
help='Where to evaluate.')
args = parser.parse_args()
dname_root = args.dname_root
with open(f"{args.project_directory}/conf/data/{args.dataset}.yaml") as stream:
cfg_data = yaml.safe_load(stream)
with open(f"{args.project_directory}/conf/augmentation/base.yaml") as stream:
cfg_aug = yaml.safe_load(stream)
cfg = {}
cfg['data'] = cfg_data
cfg['data']['n_hvgs'] = 4000
cfg["data"]["holdout_batch"] = None
cfg['augmentation'] = cfg_aug
train, val, adata = load_data(cfg)
for mname in os.listdir(dname_root):
for param in os.listdir(os.path.join(dname_root, mname)):
try:
for param2 in os.listdir(os.path.join(dname_root, mname, param)):
for seed in os.listdir(os.path.join(dname_root, mname, param, param2)):
tmp = os.path.join(dname_root, mname, param, param2)
run_scib(tmp, seed, adata)
except:
try:
for param2 in os.listdir(os.path.join(dname_root, mname, param)):
for param3 in os.listdir(os.path.join(dname_root, mname, param, param2)):
for seed in os.listdir(os.path.join(dname_root, mname, param, param2, param3)):
tmp = os.path.join(dname_root, mname, param, param2, param3)
run_scib(tmp, seed, adata)
except:
for seed in os.listdir(os.path.join(dname_root, mname, param)):
tmp = os.path.join(dname_root, mname, param)
run_scib(tmp, seed, adata)