diff --git a/deepsurv/deep_surv.py b/deepsurv/deep_surv.py index a57c23e..3801d5b 100644 --- a/deepsurv/deep_surv.py +++ b/deepsurv/deep_surv.py @@ -204,7 +204,13 @@ def _get_loss_updates(self, loss, self.params, **kwargs ) - # Store last update function + # If the model was loaded from file, reload params + if self.restored_update_params: + for p, value in zip(updates.keys(), self.restored_update_params): + p.set_value(value) + self.restored_update_params = None + + # Store last update function to be later saved self.updates = updates return loss, updates @@ -363,9 +369,6 @@ def train(self, """ # @TODO? Should these be managed by the logger => then you can do logger.getMetrics - # train_loss = [] - # train_ci = [] - x_train, e_train, t_train = self.prepare_data(train_data) # Set Standardization layer offset and scale to training data mean and std @@ -374,8 +377,6 @@ def train(self, self.scale = x_train.std(axis = 0) if valid_data: - # valid_loss = [] - # valid_ci = [] x_valid, e_valid, t_valid = self.prepare_data(valid_data) # Initialize Metrics @@ -474,24 +475,16 @@ def save_model(self, filename, weights_file = None): if weights_file: self.save_weights(weights_file) - # # @TODO need to reimplement with it working - # # @TODO need to add save_model - # def load_model(self, params): - # """ - # Loads the network's parameters from a previously saved state. - - # Parameters: - # params: a list of parameters in same order as network.params - # """ - # lasagne.layers.set_all_param_values(self.network, params, trainable=True) - def save_weights(self,filename): def save_list_by_idx(group, lst): for (idx, param) in enumerate(lst): group.create_dataset(str(idx), data=param) weights_out = lasagne.layers.get_all_param_values(self.network, trainable=False) - updates_out = [p.get_value() for p in self.updates.keys()] + if self.updates: + updates_out = [p.get_value() for p in self.updates.keys()] + else: + raise Exception("Model has not been trained: no params to save!") # Store all of the parameters in an hd5f file # We store the parameter under the index in the list @@ -527,8 +520,7 @@ def sort_params_by_idx(params): trainable=False) sorted_updates_in = sort_params_by_idx(updates_in) - for p, value in zip(self.updates.keys(), sorted_updates_in): - p.set_value(value) + self.restored_update_params = sorted_updates_in def risk(self,deterministic = False): """