Skip to content

Commit

Permalink
Merge pull request #35 from N3PDF/fix_undo_flatten_bug
Browse files Browse the repository at this point in the history
fix undo_flatten inerting weights bug
  • Loading branch information
RoyStegeman authored Jan 4, 2021
2 parents 6195544 + b95e985 commit 3572879
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/evolutionary_keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from tensorflow.keras.callbacks import History
from tensorflow.keras.models import Model
from tensorflow.python.keras import callbacks as callbacks_module
from tensorflow.keras import callbacks as callbacks_module

import evolutionary_keras.optimizers as Evolutionary_Optimizers

Expand Down
10 changes: 4 additions & 6 deletions src/evolutionary_keras/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,7 @@ def flatten(self):
# The first values of 'self.length_flat_layer' is set to 0 which is helpful in determining
# the range of weights in the function 'undo_flatten'.
flattened_weights = []
self.length_flat_layer = []
self.length_flat_layer.append(0)
self.length_flat_layer = [0]
for weight in self.model.trainable_weights:
a = np.reshape(compatibility_numpy(weight), [-1])
flattened_weights.append(a)
Expand All @@ -353,10 +352,9 @@ def undo_flatten(self, flattened_weights):
"""
new_weights = []
for i, layer_shape in enumerate(self.shape):
flat_layer = flattened_weights[
self.length_flat_layer[i] : self.length_flat_layer[i]
+ self.length_flat_layer[i + 1]
]
start_index = sum(self.length_flat_layer[: i + 1])
end_index = start_index + self.length_flat_layer[i + 1]
flat_layer = flattened_weights[start_index:end_index]
new_weights.append(np.reshape(flat_layer, layer_shape))

ordered_names = [weight.name for layer in self.model.layers for weight in layer.weights]
Expand Down

0 comments on commit 3572879

Please sign in to comment.