Skip to content

Commit

Permalink
Fixed bug with setting updates
Browse files Browse the repository at this point in the history
  • Loading branch information
jaredleekatzman committed Mar 29, 2017
1 parent a0e4ede commit 69591f3
Showing 1 changed file with 12 additions and 20 deletions.
32 changes: 12 additions & 20 deletions deepsurv/deep_surv.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,13 @@ def _get_loss_updates(self,
loss, self.params, **kwargs
)

# Store last update function
# If the model was loaded from file, reload params
if self.restored_update_params:
for p, value in zip(updates.keys(), self.restored_update_params):
p.set_value(value)
self.restored_update_params = None

# Store last update function to be later saved
self.updates = updates

return loss, updates
Expand Down Expand Up @@ -363,9 +369,6 @@ def train(self,
"""

# @TODO? Should these be managed by the logger => then you can do logger.getMetrics
# train_loss = []
# train_ci = []

x_train, e_train, t_train = self.prepare_data(train_data)

# Set Standardization layer offset and scale to training data mean and std
Expand All @@ -374,8 +377,6 @@ def train(self,
self.scale = x_train.std(axis = 0)

if valid_data:
# valid_loss = []
# valid_ci = []
x_valid, e_valid, t_valid = self.prepare_data(valid_data)

# Initialize Metrics
Expand Down Expand Up @@ -474,24 +475,16 @@ def save_model(self, filename, weights_file = None):
if weights_file:
self.save_weights(weights_file)

# # @TODO need to reimplement with it working
# # @TODO need to add save_model
# def load_model(self, params):
# """
# Loads the network's parameters from a previously saved state.

# Parameters:
# params: a list of parameters in same order as network.params
# """
# lasagne.layers.set_all_param_values(self.network, params, trainable=True)

def save_weights(self,filename):
def save_list_by_idx(group, lst):
for (idx, param) in enumerate(lst):
group.create_dataset(str(idx), data=param)

weights_out = lasagne.layers.get_all_param_values(self.network, trainable=False)
updates_out = [p.get_value() for p in self.updates.keys()]
if self.updates:
updates_out = [p.get_value() for p in self.updates.keys()]
else:
raise Exception("Model has not been trained: no params to save!")

# Store all of the parameters in an hd5f file
# We store the parameter under the index in the list
Expand Down Expand Up @@ -527,8 +520,7 @@ def sort_params_by_idx(params):
trainable=False)

sorted_updates_in = sort_params_by_idx(updates_in)
for p, value in zip(self.updates.keys(), sorted_updates_in):
p.set_value(value)
self.restored_update_params = sorted_updates_in

def risk(self,deterministic = False):
"""
Expand Down

0 comments on commit 69591f3

Please sign in to comment.