Skip to content

Commit

Permalink
Merge branch 'main' into concat
Browse files Browse the repository at this point in the history
  • Loading branch information
am831 authored Oct 9, 2023
2 parents 62ffb53 + 4812198 commit a5d0185
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 1 deletion.
9 changes: 9 additions & 0 deletions bibliography.bib
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ @misc{https://doi.org/10.48550/arxiv.1607.06450
primaryClass={stat.ML}
}

@misc{https://doi.org/10.48550/arxiv.1609.03499,
title={WaveNet: A Generative Model for Raw Audio},
author={Aaron van den Oord and Sander Dieleman and Heiga Zen and Karen Simonyan and Oriol Vinyals and Alex Graves and Nal Kalchbrenner and Andrew Senior and Koray Kavukcuoglu},
year={2016},
eprint={1609.03499},
archivePrefix={arXiv},
primaryClass={cs.SD}
}

@misc{https://doi.org/10.48550/arxiv.1706.03762,
doi = {10.48550/ARXIV.1706.03762},
url = {https://arxiv.org/abs/1706.03762},
Expand Down
15 changes: 14 additions & 1 deletion src/fairseq2/models/conformer/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from torch import Tensor
from torch.nn import GLU, BatchNorm1d, Conv1d, Module, SiLU
from torch.nn.functional import pad

from fairseq2.nn.normalization import LayerNorm, StandardLayerNorm
from fairseq2.nn.utils.mask import apply_padding_mask
Expand All @@ -22,6 +23,8 @@ class ConformerConvolution(Module):
pointwise_conv1: Conv1d
pointwise_conv1_activation: GLU
depthwise_conv: Conv1d
depthwise_kernel_size: int
causal_depthwise_conv: bool
batch_norm: Optional[BatchNorm1d]
layer_norm: Optional[LayerNorm]
depthwise_activation: Module
Expand All @@ -32,6 +35,7 @@ def __init__(
model_dim: int,
depthwise_kernel_size: int,
*,
causal_depthwise_conv: bool = False,
norm_type: str = "batch_norm",
depthwise_activation: Optional[Module] = None,
device: Optional[Device] = None,
Expand All @@ -42,6 +46,9 @@ def __init__(
The dimensionality of the model.
:param depthwise_kernel_size:
The kernel size of the depthwise convolution.
:param causal_depthwise_conv:
If True, uses a causal depthwise convolution similar to that described in
Section 2.1 of :cite:t:`https://doi.org/10.48550/arxiv.1609.03499`.
:param norm_type:
The type of norm layer applied after the depthwise convolution.
:param depthwise_activation:
Expand Down Expand Up @@ -71,13 +78,15 @@ def __init__(
model_dim,
depthwise_kernel_size,
# We preserve the sequence length regardless of the kernel size.
padding="same",
padding="same" if not causal_depthwise_conv else 0,
# We want to perform depthwise convolution.
groups=model_dim,
bias=False,
device=device,
dtype=dtype,
)
self.depthwise_kernel_size = depthwise_kernel_size
self.causal_depthwise_conv = causal_depthwise_conv

if norm_type not in ("batch_norm", "layer_norm"):
raise ValueError(
Expand Down Expand Up @@ -136,6 +145,10 @@ def forward(self, seqs: Tensor, padding_mask: Optional[Tensor]) -> Tensor:
# (N, 2 * M, S) -> (N, M, S)
seqs = self.pointwise_conv1_activation(seqs)

# Pad the sequence entirely on the left in the case of a causal convolution.
if self.causal_depthwise_conv:
seqs = pad(seqs, (self.depthwise_kernel_size - 1, 0))

# (N, M, S) -> (N, M, S)
seqs = self.depthwise_conv(seqs)

Expand Down
2 changes: 2 additions & 0 deletions src/fairseq2/models/w2vbert/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def _encoder_600m() -> Wav2Vec2EncoderConfig:
layer_drop_p=0.0,
norm_order=TransformerNormOrder.POST,
depthwise_conv_kernel_size=31,
causal_depthwise_conv=False,
conv_norm_type="batch_norm",
)

Expand Down Expand Up @@ -75,6 +76,7 @@ def _encoder_300m() -> Wav2Vec2EncoderConfig:
layer_drop_p=0.0,
norm_order=TransformerNormOrder.POST,
depthwise_conv_kernel_size=31,
causal_depthwise_conv=False,
conv_norm_type="batch_norm",
)

Expand Down
6 changes: 6 additions & 0 deletions src/fairseq2/models/wav2vec2/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ class Wav2Vec2EncoderConfig:
depthwise_conv_kernel_size: int
"""The kernel size of depthwise convolutions in Conformer blocks."""

causal_depthwise_conv: bool
"""If True, uses a causal depthwise convolution similar to that described in
Section 2.1 of :cite:t:`https://doi.org/10.48550/arxiv.1609.03499`."""

conv_norm_type: Literal["batch_norm", "layer_norm"]
"""The type of normalization to use in the Conformer convolution module."""

Expand Down Expand Up @@ -173,6 +177,7 @@ def _encoder_base() -> Wav2Vec2EncoderConfig:
layer_drop_p=0.05,
norm_order=TransformerNormOrder.POST,
depthwise_conv_kernel_size=0,
causal_depthwise_conv=False,
conv_norm_type="batch_norm",
)

Expand Down Expand Up @@ -315,6 +320,7 @@ def build_conformer_block(self) -> TransformerEncoderLayer:
conv = ConformerConvolution(
self.config.model_dim,
self.config.depthwise_conv_kernel_size,
causal_depthwise_conv=self.config.causal_depthwise_conv,
norm_type=self.config.conv_norm_type,
device=self.device,
dtype=self.dtype,
Expand Down

0 comments on commit a5d0185

Please sign in to comment.