Skip to content

Commit

Permalink
Nit updates to Conformer (#100)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Oct 11, 2023
1 parent 8a210c5 commit c07448a
Showing 1 changed file with 10 additions and 16 deletions.
26 changes: 10 additions & 16 deletions src/fairseq2/models/conformer/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional
from typing import Literal, Optional

from torch import Tensor
from torch.nn import GLU, BatchNorm1d, Conv1d, Module, SiLU
Expand All @@ -23,7 +23,6 @@ class ConformerConvolution(Module):
pointwise_conv1: Conv1d
pointwise_conv1_activation: GLU
depthwise_conv: Conv1d
depthwise_kernel_size: int
causal_depthwise_conv: bool
batch_norm: Optional[BatchNorm1d]
layer_norm: Optional[LayerNorm]
Expand All @@ -36,7 +35,7 @@ def __init__(
depthwise_kernel_size: int,
*,
causal_depthwise_conv: bool = False,
norm_type: str = "batch_norm",
norm_type: Literal["batch_norm", "layer_norm"] = "batch_norm",
depthwise_activation: Optional[Module] = None,
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
Expand All @@ -47,10 +46,10 @@ def __init__(
:param depthwise_kernel_size:
The kernel size of the depthwise convolution.
:param causal_depthwise_conv:
If True, uses a causal depthwise convolution similar to that described in
Section 2.1 of :cite:t:`https://doi.org/10.48550/arxiv.1609.03499`.
If ``True``, uses a causal depthwise convolution similar to that
described in Section 2.1 of :cite:t:`https://doi.org/10.48550/arxiv.1609.03499`.
:param norm_type:
The type of norm layer applied after the depthwise convolution.
The type of normalization to apply after the depthwise convolution.
:param depthwise_activation:
The activation to apply to outputs of the depthwise convolution. If
``None``, :func:`~torch.nn.SiLU` (a.k.a. swish) will be used.
Expand All @@ -77,15 +76,14 @@ def __init__(
model_dim,
model_dim,
depthwise_kernel_size,
# We preserve the sequence length regardless of the kernel size.
padding="same" if not causal_depthwise_conv else 0,
# We want to perform depthwise convolution.
groups=model_dim,
bias=False,
device=device,
dtype=dtype,
)
self.depthwise_kernel_size = depthwise_kernel_size

self.causal_depthwise_conv = causal_depthwise_conv

if norm_type not in ("batch_norm", "layer_norm"):
Expand All @@ -111,12 +109,7 @@ def __init__(
self.depthwise_activation = depthwise_activation

self.pointwise_conv2 = Conv1d(
model_dim,
model_dim,
kernel_size=1,
bias=False,
device=device,
dtype=dtype,
model_dim, model_dim, kernel_size=1, bias=False, device=device, dtype=dtype
)

def forward(self, seqs: Tensor, padding_mask: Optional[Tensor]) -> Tensor:
Expand Down Expand Up @@ -145,9 +138,9 @@ def forward(self, seqs: Tensor, padding_mask: Optional[Tensor]) -> Tensor:
# (N, 2 * M, S) -> (N, M, S)
seqs = self.pointwise_conv1_activation(seqs)

# Pad the sequence entirely on the left in the case of a causal convolution.
# Pad the sequence entirely on the left in case of a causal convolution.
if self.causal_depthwise_conv:
seqs = pad(seqs, (self.depthwise_kernel_size - 1, 0))
seqs = pad(seqs, (self.depthwise_conv.kernel_size[0] - 1, 0))

# (N, M, S) -> (N, M, S)
seqs = self.depthwise_conv(seqs)
Expand All @@ -156,6 +149,7 @@ def forward(self, seqs: Tensor, padding_mask: Optional[Tensor]) -> Tensor:
seqs = self.batch_norm(seqs)
else:
assert self.layer_norm is not None

# (N, M, S) -> (N, S, M)
seqs = seqs.transpose(1, 2)

Expand Down

0 comments on commit c07448a

Please sign in to comment.