-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_preprocessing.py
49 lines (40 loc) · 1.36 KB
/
train_preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""
Preprocessing
"""
from glob import glob
import os
import pickle
import sys
import yaml
from tqdm import tqdm
import numpy as np
from lib.dataset.io import read_index_binary_file_64bits
from lib.dataset.io import write_binary_encoded_smiles
from lib.dataset.utils import build_vocabulary, save_vocabulary
if __name__ == "__main__":
hparams = yaml.load(open(sys.argv[1]), Loader=yaml.FullLoader)
smiles_path = hparams["smiles_path"]
smiles = pickle.load(open(smiles_path, "rb"))
std_smiles = [std_smi for _, _, std_smi in smiles]
vocabulary, tokenizer = build_vocabulary(std_smiles)
save_vocabulary(hparams["vocabulary"], vocabulary)
index = {}
for fname in tqdm(glob(os.path.join(hparams["pairs_path"], "*.index.dat")), ascii=True):
index.update(read_index_binary_file_64bits(fname))
index_np = -np.ones((max(index.keys()) + 1, 4), dtype=np.int64)
for k in index:
k = int(k)
fname, pos = index[k]
index_np[k] = [int(x) for x in fname.split("_")] + [pos]
np.save(hparams["smiles_id_to_data"], index_np)
# create cache for encoded smiles
# Takes up to 1,5 hours
write_binary_encoded_smiles(
smiles,
hparams["encoded_smiles"],
hparams["smiles_id_to_encoded_smiles"],
tokenizer,
vocabulary,
hparams["max_sequence_length"],
use_pbar=True,
)