generated from ashleve/lightning-hydra-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Hung Dinh Xuan
committed
Oct 25, 2024
1 parent
c37c147
commit 7b96ffa
Showing
3 changed files
with
330 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
_target_: src.data.asvspoof_conformertcm_reproduce_datamodule.ASVSpoofDataModule | ||
data_dir: ${oc.env:ASVSPOOF_PATH} | ||
batch_size: 20 # Needs to be divisible by the number of devices (e.g., if in a distributed setup) | ||
num_workers: 10 | ||
pin_memory: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,320 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
from typing import Any, Dict, Optional | ||
from lightning import LightningDataModule | ||
from torch.utils.data import DataLoader, Dataset | ||
from torch import Tensor | ||
import librosa | ||
import numpy as np | ||
import os | ||
from src.data.components.RawBoost import process_Rawboost_feature | ||
|
||
''' | ||
Hemlata Tak, Madhu Kamble, Jose Patino, Massimiliano Todisco, Nicholas Evans. | ||
RawBoost: A Raw Data Boosting and Augmentation Method applied to Automatic Speaker Verification Anti-Spoofing. | ||
In Proc. ICASSP 2022, pp:6382--6386. | ||
''' | ||
|
||
def pad(x, max_len): | ||
x_len = x.shape[0] | ||
if x_len >= max_len: | ||
return x[:max_len] | ||
num_repeats = int(max_len / x_len)+1 | ||
padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0] | ||
return padded_x | ||
|
||
class Dataset_ASVspoof2019_train(Dataset): | ||
def __init__(self,args,list_IDs, labels, base_dir,algo): | ||
'''self.list_IDs : list of strings (each string: utt key), | ||
self.labels : dictionary (key: utt key, value: label integer)''' | ||
|
||
self.list_IDs = list_IDs | ||
self.labels = labels | ||
self.base_dir = base_dir | ||
self.algo=algo | ||
self.args=args | ||
self.cut=66800 | ||
|
||
def __len__(self): | ||
return len(self.list_IDs) | ||
|
||
|
||
def __getitem__(self, index): | ||
utt_id = self.list_IDs[index] | ||
X,fs = librosa.load(self.base_dir+utt_id+'.flac', sr=16000) | ||
Y=process_Rawboost_feature(X,fs,self.args,self.algo) | ||
X_pad= pad(Y,self.cut) | ||
x_inp= Tensor(X_pad) | ||
target = self.labels[utt_id] | ||
|
||
return x_inp, target | ||
|
||
class Dataset_ASVspoof2021_eval(Dataset): | ||
def __init__(self, list_IDs, base_dir): | ||
self.list_IDs = list_IDs | ||
self.base_dir = base_dir | ||
self.cut=66800 | ||
|
||
def __len__(self): | ||
return len(self.list_IDs) | ||
|
||
def __getitem__(self, index): | ||
utt_id = self.list_IDs[index] | ||
X, fs = librosa.load(self.base_dir+utt_id+'.flac', sr=16000) | ||
X_pad = pad(X,self.cut) | ||
x_inp = Tensor(X_pad) | ||
return x_inp,utt_id | ||
|
||
class Args: | ||
def __init__(self): | ||
self.database_path = '/your/path/to/data/ASVspoof_database/DF/' | ||
self.protocols_path = '/data/hungdx/Datasets/protocols/database/' | ||
self.batch_size = 20 | ||
self.num_epochs = 100 | ||
self.lr = 0.000001 | ||
self.weight_decay = 0.0001 | ||
self.loss = 'weighted_CCE' | ||
self.seed = 1234 | ||
self.model_path = None | ||
self.comment = None | ||
self.track = 'DF' | ||
self.eval_output = None | ||
self.eval = False | ||
self.is_eval = False | ||
self.eval_part = 0 | ||
self.cudnn_deterministic_toggle = True | ||
self.cudnn_benchmark_toggle = False | ||
self.algo = 3 | ||
self.nBands = 5 | ||
self.minF = 20 | ||
self.maxF = 8000 | ||
self.minBW = 100 | ||
self.maxBW = 1000 | ||
self.minCoeff = 10 | ||
self.maxCoeff = 100 | ||
self.minG = 0 | ||
self.maxG = 0 | ||
self.minBiasLinNonLin = 5 | ||
self.maxBiasLinNonLin = 20 | ||
self.N_f = 5 | ||
self.P = 10 | ||
self.g_sd = 2 | ||
self.SNRmin = 10 | ||
self.SNRmax = 40 | ||
|
||
class ASVSpoofDataModule(LightningDataModule): | ||
"""`LightningDataModule` for the ASVSpoof dataset. | ||
The ASVspoof 2019 database for logical access is based upon a standard multi-speaker speech synthesis database called VCTK2. | ||
Genuine speech is collected from 107 speakers (46 male, 61 female) and with no significant channel or background noise effects. | ||
Spoofed speech is generated from the genuine data using a number of different spoofing algorithms. | ||
The full dataset is partitioned into three subsets, the first for training, the second for development and the third for evaluation. | ||
There is no speaker overlap across the three subsets regarding target speakers used in voice conversion or TTS adaptation. | ||
A `LightningDataModule` implements 7 key methods: | ||
```python | ||
def prepare_data(self): | ||
# Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP). | ||
# Download data, pre-process, split, save to disk, etc... | ||
def setup(self, stage): | ||
# Things to do on every process in DDP. | ||
# Load data, set variables, etc... | ||
def train_dataloader(self): | ||
# return train dataloader | ||
def val_dataloader(self): | ||
# return validation dataloader | ||
def test_dataloader(self): | ||
# return test dataloader | ||
def predict_dataloader(self): | ||
# return predict dataloader | ||
def teardown(self, stage): | ||
# Called on every process in DDP. | ||
# Clean up after fit or test. | ||
``` | ||
This allows you to share a full dataset without explaining how to download, | ||
split, transform and process the data. | ||
Read the docs: | ||
https://lightning.ai/docs/pytorch/latest/data/datamodule.html | ||
""" | ||
|
||
def __init__( | ||
self, | ||
data_dir: str = "data/", | ||
batch_size: int = 64, | ||
num_workers: int = 0, | ||
pin_memory: bool = False, | ||
) -> None: | ||
"""Initialize a `ASVSpoofDataModule`. | ||
:param data_dir: The data directory. Defaults to `"data/"`. | ||
:param batch_size: The batch size. Defaults to `64`. | ||
:param num_workers: The number of workers. Defaults to `0`. | ||
:param pin_memory: Whether to pin memory. Defaults to `False`. | ||
""" | ||
super().__init__() | ||
|
||
# this line allows to access init params with 'self.hparams' attribute | ||
# also ensures init params will be stored in ckpt | ||
self.save_hyperparameters(logger=False) | ||
|
||
self.data_train: Optional[Dataset] = None | ||
self.data_val: Optional[Dataset] = None | ||
self.data_test: Optional[Dataset] = None | ||
|
||
self.batch_size_per_device = batch_size | ||
self.data_dir = data_dir | ||
|
||
@property | ||
def num_classes(self) -> int: | ||
"""Get the number of classes. | ||
:return: The number of ASVSpoof classes (2). | ||
""" | ||
return 2 | ||
|
||
def prepare_data(self) -> None: | ||
pass | ||
|
||
def setup(self, stage: Optional[str] = None) -> None: | ||
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. | ||
This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and | ||
`trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after | ||
`self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to | ||
`self.setup()` once the data is prepared and available for use. | ||
:param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``. | ||
""" | ||
# Divide batch size by the number of devices. | ||
if self.trainer is not None: | ||
if self.hparams.batch_size % self.trainer.world_size != 0: | ||
raise RuntimeError( | ||
f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})." | ||
) | ||
self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size | ||
|
||
# load and split datasets only if not loaded already | ||
if not self.data_train and not self.data_val and not self.data_test: | ||
|
||
# define train dataloader | ||
args = Args() | ||
args.database_path = self.data_dir | ||
track = args.track | ||
prefix_2021 = 'ASVspoof2021.{}'.format(track) | ||
|
||
d_label_trn,file_train = self.genSpoof_list( dir_meta = os.path.join(args.protocols_path+'ASVspoof_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt'),is_train=True,is_eval=False) | ||
d_label_dev,file_dev = self.genSpoof_list( dir_meta = os.path.join(args.protocols_path+'ASVspoof_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt'),is_train=False,is_eval=False) | ||
file_eval = self.genSpoof_list( dir_meta = os.path.join(args.protocols_path+'ASVspoof_{}_cm_protocols/{}.cm.eval.trl.txt'.format(track,prefix_2021)),is_train=False,is_eval=True) | ||
|
||
self.data_train = Dataset_ASVspoof2019_train(args,list_IDs = file_train,labels = d_label_trn,base_dir = os.path.join(args.database_path+'ASVspoof2019_LA_train/'),algo=args.algo) | ||
self.data_val = Dataset_ASVspoof2019_train(args,list_IDs = file_dev,labels = d_label_dev,base_dir = os.path.join(args.database_path+'ASVspoof2019_LA_dev/'),algo=args.algo) | ||
self.data_test = Dataset_ASVspoof2021_eval(list_IDs = file_eval,base_dir = os.path.join(args.database_path+'ASVspoof2021_{}_eval/'.format(args.track))) | ||
|
||
|
||
def train_dataloader(self) -> DataLoader[Any]: | ||
"""Create and return the train dataloader. | ||
:return: The train dataloader. | ||
""" | ||
return DataLoader( | ||
dataset=self.data_train, | ||
batch_size=self.batch_size_per_device, | ||
num_workers=self.hparams.num_workers, | ||
pin_memory=self.hparams.pin_memory, | ||
shuffle=True, | ||
drop_last=True, | ||
) | ||
|
||
def val_dataloader(self) -> DataLoader[Any]: | ||
"""Create and return the validation dataloader. | ||
:return: The validation dataloader. | ||
""" | ||
return DataLoader( | ||
dataset=self.data_val, | ||
batch_size=self.batch_size_per_device, | ||
num_workers=self.hparams.num_workers, | ||
pin_memory=self.hparams.pin_memory, | ||
shuffle=False, | ||
) | ||
|
||
def test_dataloader(self) -> DataLoader[Any]: | ||
"""Create and return the test dataloader. | ||
:return: The test dataloader. | ||
""" | ||
return DataLoader( | ||
dataset=self.data_test, | ||
batch_size=self.batch_size_per_device, | ||
num_workers=self.hparams.num_workers, | ||
pin_memory=self.hparams.pin_memory, | ||
shuffle=False, | ||
) | ||
|
||
def teardown(self, stage: Optional[str] = None) -> None: | ||
"""Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`, | ||
`trainer.test()`, and `trainer.predict()`. | ||
:param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. | ||
Defaults to ``None``. | ||
""" | ||
pass | ||
|
||
def state_dict(self) -> Dict[Any, Any]: | ||
"""Called when saving a checkpoint. Implement to generate and save the datamodule state. | ||
:return: A dictionary containing the datamodule state that you want to save. | ||
""" | ||
return {} | ||
|
||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | ||
"""Called when loading a checkpoint. Implement to reload datamodule state given datamodule | ||
`state_dict()`. | ||
:param state_dict: The datamodule state returned by `self.state_dict()`. | ||
""" | ||
pass | ||
|
||
def genSpoof_list(self, dir_meta, is_train=False, is_eval=False): | ||
""" | ||
This function is from the following source: https://github.com/TakHemlata/SSL_Anti-spoofing/blob/main/data_utils_SSL.py#L17 | ||
Official source: https://arxiv.org/abs/2202.12233 | ||
Automatic speaker verification spoofing and deepfake detection using wav2vec 2.0 and data augmentation | ||
""" | ||
d_meta = {} | ||
file_list=[] | ||
with open(dir_meta, 'r') as f: | ||
l_meta = f.readlines() | ||
|
||
if (is_train): | ||
for line in l_meta: | ||
_,key,_,_,label = line.strip().split() | ||
|
||
file_list.append(key) | ||
d_meta[key] = 1 if label == 'bonafide' else 0 | ||
return d_meta,file_list | ||
|
||
elif(is_eval): | ||
for line in l_meta: | ||
key= line.strip() | ||
file_list.append(key) | ||
return file_list | ||
else: | ||
for line in l_meta: | ||
_,key,_,_,label = line.strip().split() | ||
|
||
file_list.append(key) | ||
d_meta[key] = 1 if label == 'bonafide' else 0 | ||
return d_meta,file_list | ||
|
||
if __name__ == "__main__": | ||
_ = ASVSpoofDataModule() |