diff --git a/src/fairseq2/models/conformer/convolution.py b/src/fairseq2/models/conformer/convolution.py index 9b59ce6fd..5dbcde11e 100644 --- a/src/fairseq2/models/conformer/convolution.py +++ b/src/fairseq2/models/conformer/convolution.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional +from typing import Literal, Optional from torch import Tensor from torch.nn import GLU, BatchNorm1d, Conv1d, Module, SiLU @@ -23,7 +23,6 @@ class ConformerConvolution(Module): pointwise_conv1: Conv1d pointwise_conv1_activation: GLU depthwise_conv: Conv1d - depthwise_kernel_size: int causal_depthwise_conv: bool batch_norm: Optional[BatchNorm1d] layer_norm: Optional[LayerNorm] @@ -36,7 +35,7 @@ def __init__( depthwise_kernel_size: int, *, causal_depthwise_conv: bool = False, - norm_type: str = "batch_norm", + norm_type: Literal["batch_norm", "layer_norm"] = "batch_norm", depthwise_activation: Optional[Module] = None, device: Optional[Device] = None, dtype: Optional[DataType] = None, @@ -47,10 +46,10 @@ def __init__( :param depthwise_kernel_size: The kernel size of the depthwise convolution. :param causal_depthwise_conv: - If True, uses a causal depthwise convolution similar to that described in - Section 2.1 of :cite:t:`https://doi.org/10.48550/arxiv.1609.03499`. + If ``True``, uses a causal depthwise convolution similar to that + described in Section 2.1 of :cite:t:`https://doi.org/10.48550/arxiv.1609.03499`. :param norm_type: - The type of norm layer applied after the depthwise convolution. + The type of normalization to apply after the depthwise convolution. :param depthwise_activation: The activation to apply to outputs of the depthwise convolution. If ``None``, :func:`~torch.nn.SiLU` (a.k.a. swish) will be used. @@ -77,7 +76,6 @@ def __init__( model_dim, model_dim, depthwise_kernel_size, - # We preserve the sequence length regardless of the kernel size. padding="same" if not causal_depthwise_conv else 0, # We want to perform depthwise convolution. groups=model_dim, @@ -85,7 +83,7 @@ def __init__( device=device, dtype=dtype, ) - self.depthwise_kernel_size = depthwise_kernel_size + self.causal_depthwise_conv = causal_depthwise_conv if norm_type not in ("batch_norm", "layer_norm"): @@ -111,12 +109,7 @@ def __init__( self.depthwise_activation = depthwise_activation self.pointwise_conv2 = Conv1d( - model_dim, - model_dim, - kernel_size=1, - bias=False, - device=device, - dtype=dtype, + model_dim, model_dim, kernel_size=1, bias=False, device=device, dtype=dtype ) def forward(self, seqs: Tensor, padding_mask: Optional[Tensor]) -> Tensor: @@ -145,9 +138,9 @@ def forward(self, seqs: Tensor, padding_mask: Optional[Tensor]) -> Tensor: # (N, 2 * M, S) -> (N, M, S) seqs = self.pointwise_conv1_activation(seqs) - # Pad the sequence entirely on the left in the case of a causal convolution. + # Pad the sequence entirely on the left in case of a causal convolution. if self.causal_depthwise_conv: - seqs = pad(seqs, (self.depthwise_kernel_size - 1, 0)) + seqs = pad(seqs, (self.depthwise_conv.kernel_size[0] - 1, 0)) # (N, M, S) -> (N, M, S) seqs = self.depthwise_conv(seqs) @@ -156,6 +149,7 @@ def forward(self, seqs: Tensor, padding_mask: Optional[Tensor]) -> Tensor: seqs = self.batch_norm(seqs) else: assert self.layer_norm is not None + # (N, M, S) -> (N, S, M) seqs = seqs.transpose(1, 2)