diff --git a/models/codec/ns3_codec/facodec.py b/models/codec/ns3_codec/facodec.py index 87f661bd..715456dc 100644 --- a/models/codec/ns3_codec/facodec.py +++ b/models/codec/ns3_codec/facodec.py @@ -412,10 +412,18 @@ def quantize(self, x, n_quantizers=None): f0_input = x # (B, d, T) f0_quantizer = self.quantizer[0] out, q, commit, quantized = f0_quantizer(f0_input, n_quantizers=n_quantizers) - outs += out - qs.append(q) - quantized_buf.append(quantized.sum(0)) - commit_loss.append(commit) + if outs == 0: + outs += out + qs.append(q) + quantized_buf.append(quantized.sum(0)) + commit_loss.append(commit) + else: + out = out[:, :, : outs.size(2)] + q = q[:, :, : outs.size(2)] + quantized = quantized[:, :, :, : outs.size(2)] + qs.append(q) + quantized_buf.append(quantized.sum(0)) + commit_loss.append(commit) # phone phone_input = x @@ -423,6 +431,9 @@ def quantize(self, x, n_quantizers=None): out, q, commit, quantized = phone_quantizer( phone_input, n_quantizers=n_quantizers ) + out = out[:, :, : outs.size(2)] + q = q[:, :, : outs.size(2)] + quantized = quantized[:, :, :, : outs.size(2)] outs += out qs.append(q) quantized_buf.append(quantized.sum(0)) @@ -431,10 +442,15 @@ def quantize(self, x, n_quantizers=None): # residual if self.vq_num_q_r > 0: residual_quantizer = self.quantizer[2] + if x.shape != quantized_buf[0].shape: + x = x[:, :, : quantized_buf[0].size(2)] residual_input = x - (quantized_buf[0] + quantized_buf[1]).detach() out, q, commit, quantized = residual_quantizer( residual_input, n_quantizers=n_quantizers ) + out = out[:, :, : outs.size(2)] + q = q[:, :, : outs.size(2)] + quantized = quantized[:, :, :, : outs.size(2)] outs += out qs.append(q) quantized_buf.append(quantized.sum(0)) # [L, B, C, T] -> [B, C, T] @@ -694,7 +710,6 @@ def forward( speaker_embedding, use_residual_code=False, ): - x = 0 x_p = 0 @@ -713,7 +728,6 @@ def forward( x = x + x_c if use_residual_code: - x_r = 0 for i in range(self.vq_num_q_r): x_r = x_r + self.residual_embs[i]( @@ -732,7 +746,6 @@ def forward( return x def vq2emb(self, vq, speaker_embedding, use_residual=True): - out = 0 x_t = 0 @@ -1034,10 +1047,18 @@ def quantize(self, x, prosody_feature, n_quantizers=None): f0_input = f0_input.transpose(1, 2) f0_quantizer = self.quantizer[0] out, q, commit, quantized = f0_quantizer(f0_input, n_quantizers=n_quantizers) - outs += out - qs.append(q) - quantized_buf.append(quantized.sum(0)) - commit_loss.append(commit) + if outs == 0: + outs += out + qs.append(q) + quantized_buf.append(quantized.sum(0)) + commit_loss.append(commit) + else: + out = out[:, :, : outs.size(2)] + q = q[:, :, : outs.size(2)] + quantized = quantized[:, :, :, : outs.size(2)] + qs.append(q) + quantized_buf.append(quantized.sum(0)) + commit_loss.append(commit) # phone phone_input = x @@ -1045,6 +1066,9 @@ def quantize(self, x, prosody_feature, n_quantizers=None): out, q, commit, quantized = phone_quantizer( phone_input, n_quantizers=n_quantizers ) + out = out[:, :, : outs.size(2)] + q = q[:, :, : outs.size(2)] + quantized = quantized[:, :, :, : outs.size(2)] outs += out qs.append(q) quantized_buf.append(quantized.sum(0)) @@ -1053,10 +1077,15 @@ def quantize(self, x, prosody_feature, n_quantizers=None): # residual if self.vq_num_q_r > 0: residual_quantizer = self.quantizer[2] + if x.shape != quantized_buf[0].shape: + x = x[:, :, : quantized_buf[0].size(2)] residual_input = x - (quantized_buf[0] + quantized_buf[1]).detach() out, q, commit, quantized = residual_quantizer( residual_input, n_quantizers=n_quantizers ) + out = out[:, :, : outs.size(2)] + q = q[:, :, : outs.size(2)] + quantized = quantized[:, :, :, : outs.size(2)] outs += out qs.append(q) quantized_buf.append(quantized.sum(0)) # [L, B, C, T] -> [B, C, T]