Skip to content

Commit

Permalink
Update reproduce conformertcm code
Browse files Browse the repository at this point in the history
  • Loading branch information
Hung Dinh Xuan committed Oct 25, 2024
1 parent c37c147 commit 7b96ffa
Show file tree
Hide file tree
Showing 3 changed files with 330 additions and 14 deletions.
5 changes: 5 additions & 0 deletions configs/data/asvspoof_conformertcm_reproduce.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_target_: src.data.asvspoof_conformertcm_reproduce_datamodule.ASVSpoofDataModule
data_dir: ${oc.env:ASVSPOOF_PATH}
batch_size: 20 # Needs to be divisible by the number of devices (e.g., if in a distributed setup)
num_workers: 10
pin_memory: True
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
# python train.py experiment=example

defaults:
- override /data: asvspoof
- override /data: asvspoof_conformertcm_reproduce
- override /model: xlsr_conformertcm_reproduce
- override /callbacks: xlsr_conformertcm_reproduce
- override /trainer: default

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["asvspoof", "xlsr_conformertcm_reproduce"]
tags: ["asvspoof_conformertcm_reproduce", "xlsr_conformertcm_reproduce"]

seed: 1234

Expand All @@ -27,19 +27,10 @@ model:
weight_decay: 0.0001
net: null
scheduler: null

data:
batch_size: 20
num_workers: 8
pin_memory: true
args:
padding_type: repeat
random_start: False
cut: 66800 # 66800, aasist-ssl is 64600


logger:
wandb:
tags: ${tags}
group: "asvspoof"
group: "asvspoof_conformertcm_reproduce"
aim:
experiment: "asvspoof"
experiment: "asvspoof_conformertcm_reproduce"
320 changes: 320 additions & 0 deletions src/data/asvspoof_conformertcm_reproduce_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Any, Dict, Optional
from lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset
from torch import Tensor
import librosa
import numpy as np
import os
from src.data.components.RawBoost import process_Rawboost_feature

'''
Hemlata Tak, Madhu Kamble, Jose Patino, Massimiliano Todisco, Nicholas Evans.
RawBoost: A Raw Data Boosting and Augmentation Method applied to Automatic Speaker Verification Anti-Spoofing.
In Proc. ICASSP 2022, pp:6382--6386.
'''

def pad(x, max_len):
x_len = x.shape[0]
if x_len >= max_len:
return x[:max_len]
num_repeats = int(max_len / x_len)+1
padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0]
return padded_x

class Dataset_ASVspoof2019_train(Dataset):
def __init__(self,args,list_IDs, labels, base_dir,algo):
'''self.list_IDs : list of strings (each string: utt key),
self.labels : dictionary (key: utt key, value: label integer)'''

self.list_IDs = list_IDs
self.labels = labels
self.base_dir = base_dir
self.algo=algo
self.args=args
self.cut=66800

def __len__(self):
return len(self.list_IDs)


def __getitem__(self, index):
utt_id = self.list_IDs[index]
X,fs = librosa.load(self.base_dir+utt_id+'.flac', sr=16000)
Y=process_Rawboost_feature(X,fs,self.args,self.algo)
X_pad= pad(Y,self.cut)
x_inp= Tensor(X_pad)
target = self.labels[utt_id]

return x_inp, target

class Dataset_ASVspoof2021_eval(Dataset):
def __init__(self, list_IDs, base_dir):
self.list_IDs = list_IDs
self.base_dir = base_dir
self.cut=66800

def __len__(self):
return len(self.list_IDs)

def __getitem__(self, index):
utt_id = self.list_IDs[index]
X, fs = librosa.load(self.base_dir+utt_id+'.flac', sr=16000)
X_pad = pad(X,self.cut)
x_inp = Tensor(X_pad)
return x_inp,utt_id

class Args:
def __init__(self):
self.database_path = '/your/path/to/data/ASVspoof_database/DF/'
self.protocols_path = '/data/hungdx/Datasets/protocols/database/'
self.batch_size = 20
self.num_epochs = 100
self.lr = 0.000001
self.weight_decay = 0.0001
self.loss = 'weighted_CCE'
self.seed = 1234
self.model_path = None
self.comment = None
self.track = 'DF'
self.eval_output = None
self.eval = False
self.is_eval = False
self.eval_part = 0
self.cudnn_deterministic_toggle = True
self.cudnn_benchmark_toggle = False
self.algo = 3
self.nBands = 5
self.minF = 20
self.maxF = 8000
self.minBW = 100
self.maxBW = 1000
self.minCoeff = 10
self.maxCoeff = 100
self.minG = 0
self.maxG = 0
self.minBiasLinNonLin = 5
self.maxBiasLinNonLin = 20
self.N_f = 5
self.P = 10
self.g_sd = 2
self.SNRmin = 10
self.SNRmax = 40

class ASVSpoofDataModule(LightningDataModule):
"""`LightningDataModule` for the ASVSpoof dataset.
The ASVspoof 2019 database for logical access is based upon a standard multi-speaker speech synthesis database called VCTK2.
Genuine speech is collected from 107 speakers (46 male, 61 female) and with no significant channel or background noise effects.
Spoofed speech is generated from the genuine data using a number of different spoofing algorithms.
The full dataset is partitioned into three subsets, the first for training, the second for development and the third for evaluation.
There is no speaker overlap across the three subsets regarding target speakers used in voice conversion or TTS adaptation.
A `LightningDataModule` implements 7 key methods:
```python
def prepare_data(self):
# Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP).
# Download data, pre-process, split, save to disk, etc...
def setup(self, stage):
# Things to do on every process in DDP.
# Load data, set variables, etc...
def train_dataloader(self):
# return train dataloader
def val_dataloader(self):
# return validation dataloader
def test_dataloader(self):
# return test dataloader
def predict_dataloader(self):
# return predict dataloader
def teardown(self, stage):
# Called on every process in DDP.
# Clean up after fit or test.
```
This allows you to share a full dataset without explaining how to download,
split, transform and process the data.
Read the docs:
https://lightning.ai/docs/pytorch/latest/data/datamodule.html
"""

def __init__(
self,
data_dir: str = "data/",
batch_size: int = 64,
num_workers: int = 0,
pin_memory: bool = False,
) -> None:
"""Initialize a `ASVSpoofDataModule`.
:param data_dir: The data directory. Defaults to `"data/"`.
:param batch_size: The batch size. Defaults to `64`.
:param num_workers: The number of workers. Defaults to `0`.
:param pin_memory: Whether to pin memory. Defaults to `False`.
"""
super().__init__()

# this line allows to access init params with 'self.hparams' attribute
# also ensures init params will be stored in ckpt
self.save_hyperparameters(logger=False)

self.data_train: Optional[Dataset] = None
self.data_val: Optional[Dataset] = None
self.data_test: Optional[Dataset] = None

self.batch_size_per_device = batch_size
self.data_dir = data_dir

@property
def num_classes(self) -> int:
"""Get the number of classes.
:return: The number of ASVSpoof classes (2).
"""
return 2

def prepare_data(self) -> None:
pass

def setup(self, stage: Optional[str] = None) -> None:
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
`trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after
`self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to
`self.setup()` once the data is prepared and available for use.
:param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
"""
# Divide batch size by the number of devices.
if self.trainer is not None:
if self.hparams.batch_size % self.trainer.world_size != 0:
raise RuntimeError(
f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
)
self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size

# load and split datasets only if not loaded already
if not self.data_train and not self.data_val and not self.data_test:

# define train dataloader
args = Args()
args.database_path = self.data_dir
track = args.track
prefix_2021 = 'ASVspoof2021.{}'.format(track)

d_label_trn,file_train = self.genSpoof_list( dir_meta = os.path.join(args.protocols_path+'ASVspoof_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt'),is_train=True,is_eval=False)
d_label_dev,file_dev = self.genSpoof_list( dir_meta = os.path.join(args.protocols_path+'ASVspoof_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt'),is_train=False,is_eval=False)
file_eval = self.genSpoof_list( dir_meta = os.path.join(args.protocols_path+'ASVspoof_{}_cm_protocols/{}.cm.eval.trl.txt'.format(track,prefix_2021)),is_train=False,is_eval=True)

self.data_train = Dataset_ASVspoof2019_train(args,list_IDs = file_train,labels = d_label_trn,base_dir = os.path.join(args.database_path+'ASVspoof2019_LA_train/'),algo=args.algo)
self.data_val = Dataset_ASVspoof2019_train(args,list_IDs = file_dev,labels = d_label_dev,base_dir = os.path.join(args.database_path+'ASVspoof2019_LA_dev/'),algo=args.algo)
self.data_test = Dataset_ASVspoof2021_eval(list_IDs = file_eval,base_dir = os.path.join(args.database_path+'ASVspoof2021_{}_eval/'.format(args.track)))


def train_dataloader(self) -> DataLoader[Any]:
"""Create and return the train dataloader.
:return: The train dataloader.
"""
return DataLoader(
dataset=self.data_train,
batch_size=self.batch_size_per_device,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=True,
drop_last=True,
)

def val_dataloader(self) -> DataLoader[Any]:
"""Create and return the validation dataloader.
:return: The validation dataloader.
"""
return DataLoader(
dataset=self.data_val,
batch_size=self.batch_size_per_device,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
)

def test_dataloader(self) -> DataLoader[Any]:
"""Create and return the test dataloader.
:return: The test dataloader.
"""
return DataLoader(
dataset=self.data_test,
batch_size=self.batch_size_per_device,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
)

def teardown(self, stage: Optional[str] = None) -> None:
"""Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`,
`trainer.test()`, and `trainer.predict()`.
:param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
Defaults to ``None``.
"""
pass

def state_dict(self) -> Dict[Any, Any]:
"""Called when saving a checkpoint. Implement to generate and save the datamodule state.
:return: A dictionary containing the datamodule state that you want to save.
"""
return {}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Called when loading a checkpoint. Implement to reload datamodule state given datamodule
`state_dict()`.
:param state_dict: The datamodule state returned by `self.state_dict()`.
"""
pass

def genSpoof_list(self, dir_meta, is_train=False, is_eval=False):
"""
This function is from the following source: https://github.com/TakHemlata/SSL_Anti-spoofing/blob/main/data_utils_SSL.py#L17
Official source: https://arxiv.org/abs/2202.12233
Automatic speaker verification spoofing and deepfake detection using wav2vec 2.0 and data augmentation
"""
d_meta = {}
file_list=[]
with open(dir_meta, 'r') as f:
l_meta = f.readlines()

if (is_train):
for line in l_meta:
_,key,_,_,label = line.strip().split()

file_list.append(key)
d_meta[key] = 1 if label == 'bonafide' else 0
return d_meta,file_list

elif(is_eval):
for line in l_meta:
key= line.strip()
file_list.append(key)
return file_list
else:
for line in l_meta:
_,key,_,_,label = line.strip().split()

file_list.append(key)
d_meta[key] = 1 if label == 'bonafide' else 0
return d_meta,file_list

if __name__ == "__main__":
_ = ASVSpoofDataModule()

0 comments on commit 7b96ffa

Please sign in to comment.