Skip to content

Commit

Permalink
Update varlist creation and passing
Browse files Browse the repository at this point in the history
  • Loading branch information
mickypaganini committed Jun 24, 2016
1 parent 06e6101 commit 66b4958
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
16 changes: 10 additions & 6 deletions data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ def build_X(events, phrase):
Returns:
output_array: a numpy array containing data only pertaining to the related branches
'''
sliced_events = events[[key for key in events.keys() if key.startswith(phrase)]].as_matrix()
return sliced_events
branch_names = [key for key in events.keys() if key.startswith(phrase)]
sliced_events = events[varlist].as_matrix()
return sliced_events, branch_names

def read_in(class_files_dict):
'''
Expand Down Expand Up @@ -70,6 +71,9 @@ def read_in(class_files_dict):
X_muons: ndarray [n_ev, n_muon_feat] containing muon related branches
y: ndarray [n_ev, 1] containing the truth labels
w: ndarray [n_ev, 1] containing EventWeights
jet_branches + photon_branches + muon_branches = list of strings that concatenates the individual
lists of variables for each particle type, e.g.:
['Jet_Px', 'Jet_E', 'Muon_ID', 'Photon_Px']
'''

#convert files to pd data frames, assign key to y, concat all files
Expand All @@ -83,14 +87,14 @@ def read_in(class_files_dict):
all_events = pd.concat([all_events, df], ignore_index=True)

#slice related branches
X_jets = build_X(all_events, 'Jet')
X_photons = build_X(all_events, 'Photon')
X_muons = build_X(all_events, 'Muon')
X_jets, jet_branches = build_X(all_events, 'Jet')
X_photons, photon_branches = build_X(all_events, 'Photon')
X_muons, muon_branches = build_X(all_events, 'Muon')

#transform string labels to integer classes
le = LabelEncoder()
y = le.fit_transform(all_events['y'].values)

w = all_events['EventWeight'].values

return X_jets, X_photons, X_muons, y, w
return X_jets, X_photons, X_muons, y, w, jet_branches + photon_branches + muon_branches
7 changes: 3 additions & 4 deletions pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,14 @@ def main(json_config):
class_files_dict = json.load(open(json_config))

# -- transform ROOT files into standard ML format (ndarrays)
X_jets, X_photons, X_muons, y, w = read_in(class_files_dict)
X_jets, X_photons, X_muons, y, w, varlist = read_in(class_files_dict)

# -- shuffle, split samples into train and test set, scale features
X_jets_train, X_jets_test, \
X_photons_train, X_photons_test, \
X_muons_train, X_muons_test, \
y_train, y_test, \
w_train, w_test, \
variables = shuffle_split_scale(X_jets, X_photons, X_muons, y, w)
w_train, w_test = shuffle_split_scale(X_jets, X_photons, X_muons, y, w)

# -- plot distributions:
# this should produce weighted histograms of the input distributions for all variables
Expand All @@ -50,7 +49,7 @@ def main(json_config):
X_muons_train, X_muons_test,
y_train, y_test,
w_train, w_test,
variables
varlist
)

# -- train
Expand Down

0 comments on commit 66b4958

Please sign in to comment.