From bfada0214f8bc4cf23689f761cda9479de79bf71 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 19 Jul 2016 16:27:46 +0000 Subject: [PATCH] remove debugging stuff --- nn_with_modes.py | 12 +++++++++++- pipeline.py | 1 - plotting.py | 2 +- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/nn_with_modes.py b/nn_with_modes.py index 667098b..c963ff3 100644 --- a/nn_with_modes.py +++ b/nn_with_modes.py @@ -5,6 +5,7 @@ import numpy as np import matplotlib import matplotlib.pyplot as plt +import os def train(data, mode): ''' @@ -57,6 +58,12 @@ def train(data, mode): combined_rnn.add(Dense(1)) combined_rnn.compile('adam', 'mae') + try: + weights_path = os.path.join('weights', 'combinedrnn-progress.h5') + combined_rnn.load_weights(weights_path) + except IOError: + print 'Pre-trained weights not found' + print 'Training:' try: combined_rnn.fit([X_jet_train, X_photon_train], @@ -66,7 +73,7 @@ def train(data, mode): }, callbacks = [ EarlyStopping(verbose=True, patience=10, monitor='val_loss'), - ModelCheckpoint('./models/combinedrnn-progress', + ModelCheckpoint(weights_path, monitor='val_loss', verbose=True, save_best_only=True) ], nb_epoch=30, validation_split = 0.2) @@ -74,6 +81,9 @@ def train(data, mode): except KeyboardInterrupt: print 'Training ended early.' + # -- load best weights back into the net + combined_rnn.load_weights(weights_path) + return combined_rnn def test(net, data): diff --git a/pipeline.py b/pipeline.py index 93bad46..ed191d2 100644 --- a/pipeline.py +++ b/pipeline.py @@ -107,7 +107,6 @@ def sha(s): # # evaluate performance on the test set yhat = test(net, data) - print yhat # # -- plot performance by mode if mode == 'regression': plot_regression(yhat, data) diff --git a/plotting.py b/plotting.py index a5ce87d..5e7cad0 100644 --- a/plotting.py +++ b/plotting.py @@ -158,7 +158,7 @@ def _plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues): # in each class) cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] _plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix') - plt.savefig('confusion2.pdf') + plt.savefig('confusion.pdf') def plot_regression(yhat, data):