Skip to content

Commit

Permalink
Update code to reproduce conformer and conformertcm
Browse files Browse the repository at this point in the history
  • Loading branch information
hungdinhxuan committed Oct 25, 2024
1 parent 130901d commit c37c147
Show file tree
Hide file tree
Showing 18 changed files with 1,023 additions and 3 deletions.
24 changes: 24 additions & 0 deletions configs/callbacks/xlsr_conformertcm_reproduce.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
defaults:
- model_checkpoint
- early_stopping
- rich_progress_bar
- _self_

model_checkpoint:
dirpath: ${paths.output_dir}/checkpoints
filename: "epoch_{epoch:03d}"
monitor: "val/loss"
mode: "min"
save_last: True
auto_insert_metric_name: False
save_top_k: 5 # save k best models (determined by above metric)
save_weights_only: True
verbose: True

early_stopping:
monitor: "val/loss"
patience: 7
mode: "min"

model_summary:
max_depth: -1
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ tags: ["asvspoof", "xlsr_conformertcm_baseline"]
seed: 1234

trainer:
min_epochs: 50
max_epochs: 70
gradient_clip_val: 0.0
accelerator: cuda
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
- override /data: asvspoof
- override /model: xlsr_conformertcm_reproduce
- override /callbacks: xlsr_conformertcm_reproduce
- override /trainer: default

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["asvspoof", "xlsr_conformertcm_reproduce"]

seed: 1234

trainer:
max_epochs: -1 # -1 for infinite until early stopping is triggered
gradient_clip_val: 0.0
accelerator: cuda

model:
optimizer:
lr: 0.000001
weight_decay: 0.0001
net: null
scheduler: null

data:
batch_size: 20
num_workers: 8
pin_memory: true
args:
padding_type: repeat
random_start: False
cut: 66800 # 66800, aasist-ssl is 64600

logger:
wandb:
tags: ${tags}
group: "asvspoof"
aim:
experiment: "asvspoof"
21 changes: 21 additions & 0 deletions configs/model/xlsr_conformertcm_reproduce.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
_target_: src.models.xlsr_conformertcm_reproduce_module.XLSRConformerTCMLitModule

optimizer:
_target_: torch.optim.Adam
_partial_: true
lr: 0.000001
weight_decay: 0.0001

scheduler: null

args:
conformer:
emb_size: 144
heads: 4
kernel_size: 31
n_encoders: 4

ssl_pretrained_path: ${oc.env:XLSR_PRETRAINED_MODEL_PATH}
cross_entropy_weight: [0.1, 0.9] # weight for cross entropy loss 0.1 for spoof and 0.9 for bonafide
# compile model for faster training with pytorch 2.0
compile: false
3 changes: 3 additions & 0 deletions configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,6 @@ ckpt_path: null

# seed for random number generators in pytorch, numpy and python.random
seed: null

# config.yaml
model_averaging: null
2 changes: 2 additions & 0 deletions src/data/asvspoof_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(self,args,list_IDs, labels, base_dir,algo):

# Sampling rate and cut-off
print('args:',args)
print("Algo:",algo)
self.fs = args.get('sampling_rate', 16000) if args is not None else 16000
self.cut = args.get('cut', 64600) if args is not None else 64600
self.padding_type = args.get('padding_type', 'zero') if args is not None else 'zero'
Expand Down Expand Up @@ -207,6 +208,7 @@ def train_dataloader(self) -> DataLoader[Any]:
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=True,
drop_last=True
)

def val_dataloader(self) -> DataLoader[Any]:
Expand Down
Empty file.
245 changes: 245 additions & 0 deletions src/models/components/conformer_reproduce/conformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
import math
import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange
from einops.layers.torch import Rearrange

# helper functions

def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

def calc_same_padding(kernel_size):
pad = kernel_size // 2
return (pad, pad - (kernel_size + 1) % 2)

# helper classes

class Swish(nn.Module):
def forward(self, x):
return x * x.sigmoid()

class GLU(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim

def forward(self, x):
out, gate = x.chunk(2, dim=self.dim)
return out * gate.sigmoid()

class DepthWiseConv1d(nn.Module):
def __init__(self, chan_in, chan_out, kernel_size, padding):
super().__init__()
self.padding = padding
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in)

def forward(self, x):
x = F.pad(x, self.padding)
return self.conv(x)

# attention, feedforward, and conv module

class Scale(nn.Module):
def __init__(self, scale, fn):
super().__init__()
self.fn = fn
self.scale = scale

def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)

def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)

class Attention(nn.Module):
# Head Token attention: https://arxiv.org/pdf/2210.05958.pdf
def __init__(self, dim, heads=8, dim_head=64, qkv_bias=False, dropout=0., proj_drop=0.):
super().__init__()
self.num_heads = heads
inner_dim = dim_head * heads
self.scale = dim_head ** -0.5

self.qkv = nn.Linear(dim, inner_dim * 3, bias=qkv_bias)

self.attn_drop = nn.Dropout(dropout)
self.proj = nn.Linear(inner_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

self.act = nn.GELU()
self.ht_proj = nn.Linear(dim_head, dim,bias=True)
self.ht_norm = nn.LayerNorm(dim_head)
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_heads, dim))

def forward(self, x, mask=None):
B, N, C = x.shape

# head token
head_pos = self.pos_embed.expand(x.shape[0], -1, -1)
x_ = x.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
x_ = x_.mean(dim=2) # now the shape is [B, h, 1, d//h]
x_ = self.ht_proj(x_).reshape(B, -1, self.num_heads, C // self.num_heads)
x_ = self.act(self.ht_norm(x_)).flatten(2)
x_ = x_ + head_pos
x = torch.cat([x, x_], dim=1)

# normal mhsa
qkv = self.qkv(x).reshape(B, N+self.num_heads, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)

attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
# attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B, N+self.num_heads, C)
x = self.proj(x)

# merge head tokens into cls token
cls, patch, ht = torch.split(x, [1, N-1, self.num_heads], dim=1)
cls = cls + torch.mean(ht, dim=1, keepdim=True) + torch.mean(patch, dim=1, keepdim=True)
x = torch.cat([cls, patch], dim=1)

x = self.proj_drop(x)

return x, attn


class FeedForward(nn.Module):
def __init__(
self,
dim,
mult = 4,
dropout = 0.
):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim),
nn.Dropout(dropout)
)

def forward(self, x):
return self.net(x)

class ConformerConvModule(nn.Module):
def __init__(
self,
dim,
causal = False,
expansion_factor = 2,
kernel_size = 31,
dropout = 0.
):
super().__init__()

inner_dim = dim * expansion_factor
padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)

self.net = nn.Sequential(
nn.LayerNorm(dim),
Rearrange('b n c -> b c n'),
nn.Conv1d(dim, inner_dim * 2, 1),
GLU(dim=1),
DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding),
nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(),
Swish(),
nn.Conv1d(inner_dim, dim, 1),
Rearrange('b c n -> b n c'),
nn.Dropout(dropout)
)

def forward(self, x):
return self.net(x)

# Conformer Block

class ConformerBlock(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
ff_mult = 4,
conv_expansion_factor = 2,
conv_kernel_size = 31,
attn_dropout = 0.,
ff_dropout = 0.,
conv_dropout = 0.,
conv_causal = False
):
super().__init__()
self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
self.attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)
self.conv = ConformerConvModule(dim = dim, causal = conv_causal, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout)
self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)

self.attn = PreNorm(dim, self.attn)
self.ff1 = Scale(0.5, PreNorm(dim, self.ff1))
self.ff2 = Scale(0.5, PreNorm(dim, self.ff2))

self.post_norm = nn.LayerNorm(dim)

def forward(self, x, mask = None):
x = self.ff1(x) + x
attn_x, attn_weight = self.attn(x, mask = mask)
x = attn_x + x
x = self.conv(x) + x
x = self.ff2(x) + x
x = self.post_norm(x)
return x, attn_weight

# Conformer

class Conformer(nn.Module):
def __init__(
self,
dim,
*,
depth,
dim_head = 64,
heads = 8,
ff_mult = 4,
conv_expansion_factor = 2,
conv_kernel_size = 31,
attn_dropout = 0.,
ff_dropout = 0.,
conv_dropout = 0.,
conv_causal = False
):
super().__init__()
self.dim = dim
self.layers = nn.ModuleList([])

for _ in range(depth):
self.layers.append(ConformerBlock(
dim = dim,
dim_head = dim_head,
heads = heads,
ff_mult = ff_mult,
conv_expansion_factor = conv_expansion_factor,
conv_kernel_size = conv_kernel_size,
conv_causal = conv_causal

))

def forward(self, x):

for block in self.layers:
x = block(x)

return x
Loading

0 comments on commit c37c147

Please sign in to comment.