Skip to content

Commit

Permalink
Add variable_multi_view_collate_fn
Browse files Browse the repository at this point in the history
  • Loading branch information
hungdinhxuan committed Nov 1, 2024
1 parent b585d10 commit 775d3ef
Show file tree
Hide file tree
Showing 8 changed files with 896 additions and 24 deletions.
23 changes: 23 additions & 0 deletions configs/callbacks/default_loss_w_only.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
defaults:
- model_checkpoint
#- early_stopping
- rich_progress_bar
- _self_

model_checkpoint:
dirpath: ${paths.output_dir}/checkpoints
filename: "epoch_{epoch:03d}"
monitor: "val/loss"
mode: "min"
save_last: True
auto_insert_metric_name: False
save_top_k: 5 # save k best models (determined by above metric)
save_weights_only: True

early_stopping:
monitor: "val/loss"
patience: 20
mode: "min"

model_summary:
max_depth: -1
50 changes: 50 additions & 0 deletions configs/experiment/aasistssl_multiview_conf-2-var.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
- override /data: asvspoof_multiview
- override /model: xlsr_aasist_var_multiview
- override /callbacks: default_loss
- override /trainer: default

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["asvspoof_multiview", "xlsr_aasist_var_multiview"]

seed: 1234

trainer:
max_epochs: 100
gradient_clip_val: 0.0
accelerator: cuda

model:
optimizer:
lr: 0.000001
weight_decay: 0.0001
net: null
scheduler: null
compile: true

data:
batch_size: 14
num_workers: 8
pin_memory: true
args:
padding_type: repeat
random_start: False
is_variable_multi_view: True
top_k: 4
min_duration: 16000
max_duration: 64600


logger:
wandb:
tags: ${tags}
group: "asvspoof_multiview"
aim:
experiment: "asvspoof_multiview"
16 changes: 16 additions & 0 deletions configs/model/xlsr_aasist_var_multiview.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
_target_: src.models.aasistssl_var_multiview_module.AASISTSSLLitModule

optimizer:
_target_: torch.optim.Adam
_partial_: true
lr: 0.0001
weight_decay: 0.0001

scheduler: null

net: null

ssl_pretrained_path: ${oc.env:XLSR_PRETRAINED_MODEL_PATH}
cross_entropy_weight: [0.1, 0.9] # weight for cross entropy loss 0.1 for spoof and 0.9 for bonafide
# compile model for faster training with pytorch 2.0
compile: true
8 changes: 4 additions & 4 deletions notebooks/eval.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"eer: 9.326796296235376\tthreshold: -2.986328125\n",
"eer: 8.955265192535776\tthreshold: -2.849609375\n",
"\n",
"0.09326796296235375\n"
"0.08955265192535777\n"
]
}
],
Expand Down Expand Up @@ -51,7 +51,7 @@
" print(out_data)\n",
" return eer_cm\n",
"\n",
"print(eval_to_score_file(\"/data/hungdx/Lightning-hydra/logs/eval/itw_xlsr_aasist_multiview_conf-2_epoch15_3s.txt\", \"/dataa/Datasets/in_the_wild.txt\"))"
"print(eval_to_score_file(\"/data/hungdx/Lightning-hydra/logs/eval/in_the_wild_xlsr_aasist_multiview_conf-2_4s_var.txt\", \"/dataa/Datasets/in_the_wild.txt\"))"
]
}
],
Expand Down
427 changes: 411 additions & 16 deletions notebooks/test.ipynb

Large diffs are not rendered by default.

31 changes: 27 additions & 4 deletions src/data/asvspoof_multiview_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import copy
from src.data.components.RawBoost import process_Rawboost_feature
from src.data.components.dataio import load_audio, pad
from src.data.components.collate_fn import multi_view_collate_fn
from src.data.components.collate_fn import multi_view_collate_fn, variable_multi_view_collate_fn
'''
Hemlata Tak, Madhu Kamble, Jose Patino, Massimiliano Todisco, Nicholas Evans.
RawBoost: A Raw Data Boosting and Augmentation Method applied to Automatic Speaker Verification Anti-Spoofing.
Expand Down Expand Up @@ -174,6 +174,29 @@ def __init__(
self.batch_size_per_device = batch_size
self.data_dir = data_dir
self.args = args
self.is_variable_multi_view = args.get('is_variable_multi_view', False) if args is not None else False
if self.is_variable_multi_view:
print('Using variable multi-view collate function')
self.top_k = self.args.get('top_k', 4)
self.min_duration = self.args.get('min_duration', 16000)
self.max_duration = self.args.get('max_duration', 64000)
self.collate_fn = lambda x: variable_multi_view_collate_fn(
x,
self.top_k,
self.min_duration,
self.max_duration,
self.args.sample_rate,
self.args.padding_type,
self.args.random_start
)
else:
self.collate_fn = lambda x: multi_view_collate_fn(
x,
self.args.views,
self.args.sample_rate,
self.args.padding_type,
self.args.random_start
)
self.protocols_path = self.args.get('protocols_path', '/data/hungdx/Datasets/protocols/database/') if self.args is not None else '/data/hungdx/Datasets/protocols/database/'

@property
Expand Down Expand Up @@ -243,10 +266,10 @@ def train_dataloader(self) -> DataLoader[Any]:
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=True,
collate_fn=lambda x: multi_view_collate_fn(x, self.args.views, self.args.sample_rate, self.args.padding_type, self.args.random_start),
collate_fn=self.collate_fn,
drop_last=True,
persistent_workers=True
)
)

def val_dataloader(self) -> DataLoader[Any]:
"""Create and return the validation dataloader.
Expand All @@ -259,7 +282,7 @@ def val_dataloader(self) -> DataLoader[Any]:
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
collate_fn=lambda x: multi_view_collate_fn(x, self.args.views, self.args.sample_rate, self.args.padding_type, self.args.random_start),
collate_fn=self.collate_fn,
)

def test_dataloader(self) -> DataLoader[Any]:
Expand Down
66 changes: 66 additions & 0 deletions src/data/components/collate_fn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,26 @@
import torch
from src.data.components.dataio import pad
import numpy as np

def multi_view_collate_fn(batch, views=[1, 2, 3, 4], sample_rate=16000, padding_type='zero', random_start=True):
'''
Collate function to pad each sample in a batch to multiple views
:param batch: list of tuples (x, label)
:param views: list of views to pad each sample to
:param sample_rate: sample rate of the audio
:param padding_type: padding type to use
:param random_start: whether to randomly start the sample
:return: dictionary with keys as views and values as tuples of padded sequences and labels
Example:
batch = [([1, 2, 3], 0), ([1, 2, 3, 4], 1)]
multi_view_collate_fn(batch, views=[1, 2], sample_rate=16000)
Output:
{
1: (tensor([[1, 2, 3], [1, 2, 3, 4]]), tensor([0, 1])),
2: (tensor([[1, 2, 3, 0], [1, 2, 3, 4]]), tensor([0, 1]))
}
'''
view_batches = {view: [] for view in views}

# Process each sample in the batch
Expand All @@ -22,4 +41,51 @@ def multi_view_collate_fn(batch, views=[1, 2, 3, 4], sample_rate=16000, padding_
labels = torch.tensor(labels, dtype=torch.long)
view_batches[view] = (padded_sequences, labels)

return view_batches

def variable_multi_view_collate_fn(batch, top_k=4, min_duration=16000, max_duration=64000, sample_rate=16000, padding_type='zero', random_start=True):
'''
Collate function to pad each sample in a batch to multiple views with variable duration
:param batch: list of tuples (x, label)
:param top_k: number of views to pad each sample to
:param min_duration: minimum duration of the audio
:param max_duration: maximum duration of the audio
:param sample_rate: sample rate of the audio
:param padding_type: padding type to use
:param random_start: whether to randomly start the sample
:return: dictionary with keys as views and values as tuples of padded sequences and labels
Example:
batch = [([1, 2, 3], 0), ([1, 2, 3, 4], 1)]
variable_multi_view_collate_fn(batch, top_k=2, min_duration=16000, max_duration=32000, sample_rate=16000)
Output:
{
1: (tensor([[1, 2, 3], [1, 2, 3, 4]]), tensor([0, 1])),
2: (tensor([[1, 2, 3, 0], [1, 2, 3, 4]]), tensor([0, 1]))
}
'''
# Duration of each view should be picked from a range of min_duration to max_duration by a uniform distribution
# Duration in seconds for each view
durations = np.random.uniform(min_duration, max_duration, top_k).astype(int)
# Ensure unique durations to avoid key collisions
views = np.unique(durations)
view_batches = {view: [] for view in views}
# Process each sample in the batch
for x, label in batch:
# Pad each sample for each view
for view in views:
view_length = view
x_view = pad(x, padding_type=padding_type, max_len=view_length, random_start=random_start)
# Check if x_view is Tensor or numpy array and convert to Tensor if necessary
if not torch.is_tensor(x_view):
x_view = torch.from_numpy(x_view)
view_batches[view].append((x_view, label))

# Convert lists to tensors
for view in views:
sequences, labels = zip(*view_batches[view])
padded_sequences = torch.stack(sequences)
labels = torch.tensor(labels, dtype=torch.long)
view_batches[view] = (padded_sequences, labels)

return view_batches
Loading

0 comments on commit 775d3ef

Please sign in to comment.