-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathpipeline.py
143 lines (128 loc) · 5.15 KB
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import json
from data_processing import read_in, shuffle_split_scale
import pandautils as pup
import cPickle
from plotting import plot_inputs
import utils
import logging
#from plotting import plot_inputs, plot_performance
#from nn_model import train, test
def main(json_config, exclude_vars):
'''
Args:
-----
json_config: a JSON file, containing a dictionary that links the names of the different
classes in the classification problem to the paths of the ROOT files
associated with each class; for example:
{
"ttbar" :
[
"/path/to/file1.root",
"/path/to/file2.root",
],
"qcd" :
[
"/path/to/file3.root",
"/path/to/file4.root",
],
...
}
exclude_vars: list of strings of names of branches not to be used for training
Saves:
------
'processed_data.h5': dictionary with processed ndarrays (X, y, w) for all particles for training and testing
'''
logger = logging.getLogger('Main')
# -- load in the JSON file
logger.info('Loading JSON config')
class_files_dict = json.load(open(json_config))
# -- hash the config dictionary to check if the pickled data exists
from hashlib import md5
def sha(s):
'''Get a unique identifier for an object'''
m = md5()
m.update(s.__repr__())
return m.hexdigest()[:5]
# -- if the pickle exists, use it
try:
data = cPickle.load(open('processed_data_' + sha(class_files_dict) + '.pkl', 'rb'))
logger.info('Preprocessed data found in pickle')
X_jets_train = data['X_jets_train']
X_jets_test = data['X_jets_test']
X_photons_train = data['X_photons_train']
X_photons_test = data['X_photons_test']
X_muons_train = data['X_muons_train']
X_muons_test = data['X_muons_test']
y_train = data['y_train']
y_test = data['y_test']
w_train = data['w_train']
w_test = data['w_test']
varlist = data['varlist']
# -- otherwise, process the new data
except IOError:
logger.info('Preprocessed data not found')
logger.info('Processing data')
# -- transform ROOT files into standard ML format (ndarrays)
X_jets, X_photons, X_muons, y, w, varlist = read_in(class_files_dict, exclude_vars)
# -- shuffle, split samples into train and test set, scale features
X_jets_train, X_jets_test, \
X_photons_train, X_photons_test, \
X_muons_train, X_muons_test, \
y_train, y_test, \
w_train, w_test = shuffle_split_scale(X_jets, X_photons, X_muons, y, w)
# -- save out to pickle
logger.info('Saving processed data to pickle')
cPickle.dump({
'X_jets_train' : X_jets_train,
'X_jets_test' : X_jets_test,
'X_photons_train' : X_photons_train,
'X_photons_test' : X_photons_test,
'X_muons_train' : X_muons_train,
'X_muons_test' : X_muons_test,
'y_train' : y_train,
'y_test' : y_test,
'w_train' : w_train,
'w_test' : w_test,
'varlist' : varlist
},
open('processed_data_' + sha(class_files_dict) + '.pkl', 'wb'),
protocol=cPickle.HIGHEST_PROTOCOL)
# -- plot distributions:
'''
This should produce normed, weighted histograms of the input distributions for all variables
The train and test distributions should be shown for every class
Plots should be saved out a pdf with informative names
'''
logger.info('Plotting input distributions')
plot_inputs(
X_jets_train, X_jets_test,
X_photons_train, X_photons_test,
X_muons_train, X_muons_test,
y_train, y_test,
w_train, w_test,
varlist
)
# # -- train
# # design a Keras NN with three RNN streams (jets, photons, muons)
# # combine the outputs and process them through a bunch of FF layers
# # use a validation split of 20%
# # save out the weights to hdf5 and the model to yaml
# net = train(X_jets_train, X_photons_train, X_muons_train, y_train, w_train)
# # -- test
# # evaluate performance on the test set
# yhat = test(net, X_jets_test, X_photons_test, X_muons_test, y_test, w_test)
# # -- plot performance
# # produce ROC curves to evaluate performance
# # save them out to pdf
# plot_performance(yhat, y_test, w_test)
if __name__ == '__main__':
import sys
import argparse
utils.configure_logging()
# -- read in arguments
parser = argparse.ArgumentParser()
parser.add_argument('config', help="JSON file that specifies classes and corresponding ROOT files' paths")
parser.add_argument('--exclude', help="names of branches to exclude from training", nargs="*", default=[])
args = parser.parse_args()
# -- pass arguments to main
sys.exit(main(args.config, args.exclude))