Skip to content

Commit

Permalink
Add saving and loading models
Browse files Browse the repository at this point in the history
  • Loading branch information
jaredleekatzman committed Mar 29, 2017
1 parent cf1aeda commit a0e4ede
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 20 deletions.
125 changes: 108 additions & 17 deletions deepsurv/deep_surv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import lasagne
import numpy
import time
import json
import h5py

import theano
import theano.tensor as T
Expand All @@ -16,7 +18,7 @@ def __init__(self, n_in,
learning_rate, hidden_layers_sizes = None,
lr_decay = 0.0, momentum = 0.9,
L2_reg = 0.0, L1_reg = 0.0,
activation = lasagne.nonlinearities.rectify,
activation = "rectify",
dropout = None,
batch_norm = False,
standardize = False,
Expand Down Expand Up @@ -59,9 +61,14 @@ def __init__(self, n_in,
shared_axes = 0)
self.standardize = standardize

if activation == 'rectify':
activation_fn = lasagne.nonlinearities.rectify
else:
raise IllegalArgumentException("Unknown activation function: %s" % activation)

# Construct Neural Network
for n_layer in (hidden_layers_sizes or []):
if activation == lasagne.nonlinearities.rectify:
if activation_fn == lasagne.nonlinearities.rectify:
W_init = lasagne.init.GlorotUniform()
else:
# TODO: implement other initializations
Expand All @@ -70,7 +77,7 @@ def __init__(self, n_in,

network = lasagne.layers.DenseLayer(
network, num_units = n_layer,
nonlinearity = activation,
nonlinearity = activation_fn,
W = W_init
)

Expand All @@ -95,7 +102,21 @@ def __init__(self, n_in,
# Relevant Functions
self.partial_hazard = T.exp(self.risk(deterministic = True)) # e^h(x)

# Set Hyper-parameters:
# Store and set needed Hyper-parameters:
self.hyperparams = {
'n_in': n_in,
'learning_rate': learning_rate,
'hidden_layers_sizes': hidden_layers_sizes,
'lr_decay': lr_decay,
'momentum': momentum,
'L2_reg': L2_reg,
'L1_reg': L1_reg,
'activation': activation,
'dropout': dropout,
'batch_norm': batch_norm,
'standardize': standardize
}

self.n_in = n_in
self.learning_rate = learning_rate
self.lr_decay = lr_decay
Expand Down Expand Up @@ -183,6 +204,9 @@ def _get_loss_updates(self,
loss, self.params, **kwargs
)

# Store last update function
self.updates = updates

return loss, updates

def _get_train_valid_fn(self,
Expand Down Expand Up @@ -373,9 +397,6 @@ def train(self,

start = time.time()
for epoch in range(n_epochs):
if logger and (epoch % validation_frequency == 0):
logger.print_progress_bar(epoch, n_epochs)

# Power-Learning Rate Decay
lr = self.learning_rate / (1 + epoch * self.lr_decay)

Expand Down Expand Up @@ -415,6 +436,9 @@ def train(self,
# best_params_idx = epoch
best_validation_loss = validation_loss

if logger and (epoch % validation_frequency == 0):
logger.print_progress_bar(epoch, n_epochs, loss)

if patience <= epoch:
break

Expand All @@ -440,16 +464,71 @@ def train(self,

return logger.history

# @TODO need to reimplement with it working
# @TODO need to add save_model
def load_model(self, params):
"""
Loads the network's parameters from a previously saved state.
Parameters:
params: a list of parameters in same order as network.params
"""
lasagne.layers.set_all_param_values(self.network, params, trainable=True)
def to_json(self):
return json.dumps(self.hyperparams)

def save_model(self, filename, weights_file = None):
with open(filename, 'w') as fp:
fp.write(self.to_json())

if weights_file:
self.save_weights(weights_file)

# # @TODO need to reimplement with it working
# # @TODO need to add save_model
# def load_model(self, params):
# """
# Loads the network's parameters from a previously saved state.

# Parameters:
# params: a list of parameters in same order as network.params
# """
# lasagne.layers.set_all_param_values(self.network, params, trainable=True)

def save_weights(self,filename):
def save_list_by_idx(group, lst):
for (idx, param) in enumerate(lst):
group.create_dataset(str(idx), data=param)

weights_out = lasagne.layers.get_all_param_values(self.network, trainable=False)
updates_out = [p.get_value() for p in self.updates.keys()]

# Store all of the parameters in an hd5f file
# We store the parameter under the index in the list
# so that when we read it later, we can construct the list of
# parameters in the same order they were saved
with h5py.File(filename, 'w') as f_out:
weights_grp = f_out.create_group('weights')
save_list_by_idx(weights_grp, weights_out)

updates_grp = f_out.create_group('updates')
save_list_by_idx(updates_grp, updates_out)

def load_weights(self, filename):
def load_all_keys(fp):
results = []
for key in fp:
dataset = fp[key][:]
results.append((int(key), dataset))
return results

def sort_params_by_idx(params):
return [param for (idx, param) in sorted(params,
key=lambda param: param[0])]

# Load all of the parameters
with h5py.File(filename, 'r') as f_in:
weights_in = load_all_keys(f_in['weights'])
updates_in = load_all_keys(f_in['updates'])

# Sort them according to the idx to ensure they are set correctly
sorted_weights_in = sort_params_by_idx(weights_in)
lasagne.layers.set_all_param_values(self.network, sorted_weights_in,
trainable=False)

sorted_updates_in = sort_params_by_idx(updates_in)
for p, value in zip(self.updates.keys(), sorted_updates_in):
p.set_value(value)

def risk(self,deterministic = False):
"""
Expand Down Expand Up @@ -554,3 +633,15 @@ def plot_risk_surface(self, data, i = 0, j = 1,
plt.ylabel('$x_{%d}$' % j, fontsize=18)

return fig

def load_model_from_json(model_fp, weights_fp = None):
with open(model_fp, 'r') as fp:
json_model = fp.read()
hyperparams = json.loads(json_model)

model = DeepSurv(**hyperparams)

if weights_fp:
model.load_weights(weights_fp)

return model
8 changes: 5 additions & 3 deletions deepsurv/deepsurv_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ def __init__(self):
def logMessage(self,message):
self.logger.info(message)

def print_progress_bar(self, step, max_steps, bar_length = 50, char = '*'):
def print_progress_bar(self, step, max_steps, loss = None, bar_length = 25, char = '*', ):
progress_length = int(bar_length * step / max_steps)
progress_bar = [char] * (progress_length) + [' '] * (bar_length - progress_length)
self.logger.info("Training step %d/%d |" % (step, max_steps)
+ ''.join(progress_bar) + "|")
message = "Training step %d/%d |" % (step, max_steps) + ''.join(progress_bar) + "|"
if loss:
message += " - loss: %.4f" % loss
self.logger.info(message)


class TensorboardLogger(DeepSurvLogger):
Expand Down

0 comments on commit a0e4ede

Please sign in to comment.