-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathpipeline.py
134 lines (114 loc) · 5.04 KB
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import json
import numpy as np
import pandautils as pup
import cPickle
import logging
import deepdish.io as io
from data_processing import read_in, shuffle_split_scale, padding
from plotting import plot_performance
from nets import nn_with_modes
import utils
def main(json_config, tree_name, model_name, mode):
'''
Args:
-----
json_config: path to a JSON file, containing a dictionary that links the names of the different
classes in the classification problem to the paths of the ROOT files
associated with each class; for example:
{
"ttbar" :
[
"/path/to/file1.root",
"/path/to/file2.root",
],
"qcd" :
[
"/path/to/file3.root",
"/path/to/file4.root",
],
...
}
tree_name: string, name of the tree that contains the correct branches
Saves:
------
'processed_data_<hash>.pkl': dictionary with processed ndarrays (X, y, w) for all particles for training and testing
'''
logger = logging.getLogger('Main')
# -- load in the JSON file
logger.info('Loading information from ' + json_config)
config = utils.load_config(json_config) # check config has expected structure
class_files_dict = config['classes']
particles_dict = config['particles']
# -- hash the config dictionary to check if the pickled data exists
from hashlib import md5
def sha(s):
'''Get a unique identifier for an object'''
m = md5()
m.update(s.__repr__())
return m.hexdigest()[:5]
#-- if the pickle exists, use it
pickle_name = 'processed_data_' + sha(config) + '_' + mode + '.pkl'
try:
logger.info('Attempting to read from {}'.format(pickle_name))
data = cPickle.load(open(pickle_name, 'rb'))
logger.info('Pre-processed data found and loaded from pickle')
# -- otherwise, process the new data
except IOError:
logger.info('Pre-processed data not found in {}'.format(pickle_name))
logger.info('Processing data')
# -- transform ROOT files into standard ML format (ndarrays)
X, y, w, le = read_in(class_files_dict, tree_name, particles_dict, mode)
# -- shuffle, split samples into train and test set, scale features
data = shuffle_split_scale(X, y, w)
data.update({
'varlist' : [
branch
for particle_info in particles_dict.values()
for branch in particle_info['branches']
],
'LabelEncoder' : le
})
# -- plot distributions:
'''
This should produce normed, weighted histograms of the input distributions for all variables
The train and test distributions should be shown for every class
Plots should be saved out a pdf with informative names
'''
logger.info('Saving input distributions in ./plots/')
plot_inputs(data, particles_dict)
logger.info('Padding')
for key in data:
if ((key.startswith('X_')) and ('event' not in key)): # no padding for `event` matrix
data[key] = padding(data[key], particles_dict[key.split('_')[1]]['max_length'])
# ^ assuming naming convention: X_<particle>_train, X_<particle>_test
# -- save out to pickle
logger.info('Saving processed data to {}'.format(pickle_name))
cPickle.dump(data,
open(pickle_name, 'wb'),
protocol=cPickle.HIGHEST_PROTOCOL)
# -- plot distributions:
# -- train
# design a Keras NN with three RNN streams (jets, photons, muons)
# combine the outputs and process them through a bunch of FF layers
# use a validation split of 20%
# save out the weights to hdf5 and the model to json
net = nn_with_modes.train(data, model_name, mode)
yhat = nn_with_modes.test(net, data, model_name)
# -- plot performance by mode
plot_performance(yhat, data, model_name, mode)
# --------------------------------------------------------------
if __name__ == '__main__':
import sys
import argparse
utils.configure_logging()
# -- read in arguments
parser = argparse.ArgumentParser()
parser.add_argument('config', help="path to JSON file that specifies classes and corresponding ROOT files' paths")
parser.add_argument('model_name', help="name of the set from particular network")
parser.add_argument('mode', help="classification or regression")
parser.add_argument('--tree', help="name of the tree to open in the ntuples", default='CollectionTree')
args = parser.parse_args()
if args.mode != 'classification' and args.mode != 'regression':
raise ValueError('Mode must be classification or regression')
# -- pass arguments to main
sys.exit(main(args.config, args.tree, args.model_name, args.mode))