diff --git a/bulbea/learn/models/ann.py b/bulbea/learn/models/ann.py index cc881c3a..939f7ceb 100644 --- a/bulbea/learn/models/ann.py +++ b/bulbea/learn/models/ann.py @@ -4,7 +4,6 @@ from keras.models import Sequential from keras.layers import recurrent from keras.layers import core - from bulbea.learn.models import Supervised class ANN(Supervised): @@ -24,16 +23,14 @@ def __init__(self, sizes, optimizer = 'rmsprop'): self.model = Sequential() self.model.add(cell( - input_dim = sizes[0], - output_dim = sizes[1], - return_sequences = True + units=sizes[1],return_sequences=True )) for i in range(2, len(sizes) - 1): self.model.add(cell(sizes[i], return_sequences = False)) self.model.add(core.Dropout(dropout)) - self.model.add(core.Dense(output_dim = sizes[-1])) + self.model.add(core.Dense(sizes[-1])) self.model.add(core.Activation(activation)) self.model.compile(loss = loss, optimizer = optimizer)