Skip to content

Commit

Permalink
remove debugging stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Jul 19, 2016
1 parent a1515a8 commit bfada02
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
12 changes: 11 additions & 1 deletion nn_with_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import os

def train(data, mode):
'''
Expand Down Expand Up @@ -57,6 +58,12 @@ def train(data, mode):
combined_rnn.add(Dense(1))
combined_rnn.compile('adam', 'mae')

try:
weights_path = os.path.join('weights', 'combinedrnn-progress.h5')
combined_rnn.load_weights(weights_path)
except IOError:
print 'Pre-trained weights not found'

print 'Training:'
try:
combined_rnn.fit([X_jet_train, X_photon_train],
Expand All @@ -66,14 +73,17 @@ def train(data, mode):
},
callbacks = [
EarlyStopping(verbose=True, patience=10, monitor='val_loss'),
ModelCheckpoint('./models/combinedrnn-progress',
ModelCheckpoint(weights_path,
monitor='val_loss', verbose=True, save_best_only=True)
],
nb_epoch=30, validation_split = 0.2)

except KeyboardInterrupt:
print 'Training ended early.'

# -- load best weights back into the net
combined_rnn.load_weights(weights_path)

return combined_rnn

def test(net, data):
Expand Down
1 change: 0 additions & 1 deletion pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def sha(s):
# # evaluate performance on the test set
yhat = test(net, data)

print yhat
# # -- plot performance by mode
if mode == 'regression':
plot_regression(yhat, data)
Expand Down
2 changes: 1 addition & 1 deletion plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues):
# in each class)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
_plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix')
plt.savefig('confusion2.pdf')
plt.savefig('confusion.pdf')


def plot_regression(yhat, data):
Expand Down

0 comments on commit bfada02

Please sign in to comment.