Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model.py: Adding DecoderState class to enable simultaneously call decoder.inference #176

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 72 additions & 47 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,41 @@ def inference(self, x):

return outputs

class DecoderState():
""" Initializes attention rnn states, decoder rnn states, attention
weights, attention cumulative weights, attention context, stores memory
and stores processed memory
PARAMS
------
memory: Encoder outputs
mask: Mask for padded data if training, expects None for inference
"""
def __init__(self, memory, mask,
attention_rnn_dim, decoder_rnn_dim,
encoder_embedding_dim, attention_layer):
B = memory.size(0)
MAX_TIME = memory.size(1)

self.attention_hidden = Variable(memory.data.new(
B, attention_rnn_dim).zero_())
self.attention_cell = Variable(memory.data.new(
B, attention_rnn_dim).zero_())

self.decoder_hidden = Variable(memory.data.new(
B, decoder_rnn_dim).zero_())
self.decoder_cell = Variable(memory.data.new(
B, decoder_rnn_dim).zero_())

self.attention_weights = Variable(memory.data.new(
B, MAX_TIME).zero_())
self.attention_weights_cum = Variable(memory.data.new(
B, MAX_TIME).zero_())
self.attention_context = Variable(memory.data.new(
B, encoder_embedding_dim).zero_())

self.memory = memory
self.processed_memory = attention_layer.memory_layer(memory)
self.mask = mask

class Decoder(nn.Module):
def __init__(self, hparams):
Expand Down Expand Up @@ -263,30 +298,15 @@ def initialize_decoder_states(self, memory, mask):
------
memory: Encoder outputs
mask: Mask for padded data if training, expects None for inference
"""
B = memory.size(0)
MAX_TIME = memory.size(1)

self.attention_hidden = Variable(memory.data.new(
B, self.attention_rnn_dim).zero_())
self.attention_cell = Variable(memory.data.new(
B, self.attention_rnn_dim).zero_())

self.decoder_hidden = Variable(memory.data.new(
B, self.decoder_rnn_dim).zero_())
self.decoder_cell = Variable(memory.data.new(
B, self.decoder_rnn_dim).zero_())

self.attention_weights = Variable(memory.data.new(
B, MAX_TIME).zero_())
self.attention_weights_cum = Variable(memory.data.new(
B, MAX_TIME).zero_())
self.attention_context = Variable(memory.data.new(
B, self.encoder_embedding_dim).zero_())

self.memory = memory
self.processed_memory = self.attention_layer.memory_layer(memory)
self.mask = mask
RETURNS
-------
decoder_state
"""
decoder_state = DecoderState(memory, mask,
self.attention_rnn_dim, self.decoder_rnn_dim,
self.encoder_embedding_dim, self.attention_layer)
return decoder_state

def parse_decoder_inputs(self, decoder_inputs):
""" Prepares decoder inputs, i.e. mel outputs
Expand Down Expand Up @@ -337,46 +357,51 @@ def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):

return mel_outputs, gate_outputs, alignments

def decode(self, decoder_input):
def decode(self, decoder_input, decoder_state):
""" Decoder step using stored states, attention and memory
PARAMS
------
decoder_input: previous mel output

decoder_state: decoder states

RETURNS
-------
mel_output:
gate_output: gate output energies
attention_weights:
"""
cell_input = torch.cat((decoder_input, self.attention_context), -1)
self.attention_hidden, self.attention_cell = self.attention_rnn(
cell_input, (self.attention_hidden, self.attention_cell))
self.attention_hidden = F.dropout(
self.attention_hidden, self.p_attention_dropout, self.training)
cell_input = torch.cat((decoder_input, decoder_state.attention_context), -1)
decoder_state.attention_hidden, decoder_state.attention_cell = self.attention_rnn(
cell_input, (decoder_state.attention_hidden, decoder_state.attention_cell))
decoder_state.attention_hidden = F.dropout(
decoder_state.attention_hidden, self.p_attention_dropout, self.training)
decoder_state.attention_cell = F.dropout(
decoder_state.attention_cell, self.p_attention_dropout, self.training)

attention_weights_cat = torch.cat(
(self.attention_weights.unsqueeze(1),
self.attention_weights_cum.unsqueeze(1)), dim=1)
self.attention_context, self.attention_weights = self.attention_layer(
self.attention_hidden, self.memory, self.processed_memory,
attention_weights_cat, self.mask)
(decoder_state.attention_weights.unsqueeze(1),
decoder_state.attention_weights_cum.unsqueeze(1)), dim=1)
decoder_state.attention_context, decoder_state.attention_weights = self.attention_layer(
decoder_state.attention_hidden, decoder_state.memory, decoder_state.processed_memory,
attention_weights_cat, decoder_state.mask)

self.attention_weights_cum += self.attention_weights
decoder_state.attention_weights_cum += decoder_state.attention_weights
decoder_input = torch.cat(
(self.attention_hidden, self.attention_context), -1)
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
decoder_input, (self.decoder_hidden, self.decoder_cell))
self.decoder_hidden = F.dropout(
self.decoder_hidden, self.p_decoder_dropout, self.training)
(decoder_state.attention_hidden, decoder_state.attention_context), -1)
decoder_state.decoder_hidden, decoder_state.decoder_cell = self.decoder_rnn(
decoder_input, (decoder_state.decoder_hidden, decoder_state.decoder_cell))
decoder_state.decoder_hidden = F.dropout(
decoder_state.decoder_hidden, self.p_decoder_dropout, self.training)
decoder_state.decoder_cell = F.dropout(
decoder_state.decoder_cell, self.p_decoder_dropout, self.training)

decoder_hidden_attention_context = torch.cat(
(self.decoder_hidden, self.attention_context), dim=1)
(decoder_state.decoder_hidden, decoder_state.attention_context), dim=1)
decoder_output = self.linear_projection(
decoder_hidden_attention_context)

gate_prediction = self.gate_layer(decoder_hidden_attention_context)
return decoder_output, gate_prediction, self.attention_weights
return decoder_output, gate_prediction, decoder_state.attention_weights

def forward(self, memory, decoder_inputs, memory_lengths):
""" Decoder forward pass for training
Expand All @@ -398,14 +423,14 @@ def forward(self, memory, decoder_inputs, memory_lengths):
decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
decoder_inputs = self.prenet(decoder_inputs)

self.initialize_decoder_states(
decoder_state = self.initialize_decoder_states(
memory, mask=~get_mask_from_lengths(memory_lengths))

mel_outputs, gate_outputs, alignments = [], [], []
while len(mel_outputs) < decoder_inputs.size(0) - 1:
decoder_input = decoder_inputs[len(mel_outputs)]
mel_output, gate_output, attention_weights = self.decode(
decoder_input)
decoder_input, decoder_state)
mel_outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output.squeeze()]
alignments += [attention_weights]
Expand All @@ -429,12 +454,12 @@ def inference(self, memory):
"""
decoder_input = self.get_go_frame(memory)

self.initialize_decoder_states(memory, mask=None)
decoder_state = self.initialize_decoder_states(memory, mask=None)

mel_outputs, gate_outputs, alignments = [], [], []
while True:
decoder_input = self.prenet(decoder_input)
mel_output, gate_output, alignment = self.decode(decoder_input)
mel_output, gate_output, alignment = self.decode(decoder_input, decoder_state)

mel_outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output]
Expand Down