From f0aa8dafb56cb9a509cc1f78f366e0ea588f3598 Mon Sep 17 00:00:00 2001 From: Hung Dinh Xuan Date: Sat, 28 Sep 2024 00:50:53 +0900 Subject: [PATCH 1/8] Update cyberaicup datamodule --- configs/data/asvspoof.yaml | 8 +- configs/data/asvspoof_reproduce.yaml | 5 + configs/data/cyberaicup.yaml | 10 + configs/experiment/test_aasist.yaml | 4 +- configs/experiment/test_ssl.yaml | 4 +- src/data/asvspoof_datamodule.py | 95 +-- src/data/components/augwrapper.py | 1142 ++++++++++++++++++++++++++ src/data/components/baseloader.py | 50 ++ src/data/components/dataio.py | 142 ++++ src/data/cyberaicup_datamodule.py | 430 ++++------ 10 files changed, 1550 insertions(+), 340 deletions(-) create mode 100644 configs/data/asvspoof_reproduce.yaml create mode 100644 configs/data/cyberaicup.yaml create mode 100644 src/data/components/augwrapper.py create mode 100644 src/data/components/baseloader.py create mode 100644 src/data/components/dataio.py diff --git a/configs/data/asvspoof.yaml b/configs/data/asvspoof.yaml index 1f8680c..ab2c23f 100644 --- a/configs/data/asvspoof.yaml +++ b/configs/data/asvspoof.yaml @@ -1,5 +1,11 @@ -_target_: src.data.asvspoof_aasistssl_reproduce_datamodule.ASVSpoofDataModule +_target_: src.data.asvspoof_datamodule.ASVSpoofDataModule data_dir: ${oc.env:ASVSPOOF_PATH} batch_size: 16 # Needs to be divisible by the number of devices (e.g., if in a distributed setup) num_workers: 4 pin_memory: False +args: + # The sampling rate of the audio files + sampling_rate: 16000 + cut: 64000 + padding_type: zero + random_start: True \ No newline at end of file diff --git a/configs/data/asvspoof_reproduce.yaml b/configs/data/asvspoof_reproduce.yaml new file mode 100644 index 0000000..97814e2 --- /dev/null +++ b/configs/data/asvspoof_reproduce.yaml @@ -0,0 +1,5 @@ +_target_: src.data.asvspoof_aasistssl_reproduce_datamodule.ASVSpoofDataModule +data_dir: ${oc.env:ASVSPOOF_PATH} +batch_size: 16 # Needs to be divisible by the number of devices (e.g., if in a distributed setup) +num_workers: 4 +pin_memory: False \ No newline at end of file diff --git a/configs/data/cyberaicup.yaml b/configs/data/cyberaicup.yaml new file mode 100644 index 0000000..28aea53 --- /dev/null +++ b/configs/data/cyberaicup.yaml @@ -0,0 +1,10 @@ +_target_: src.data.asvspoof_aasistssl_reproduce_datamodule.ASVSpoofDataModule +data_dir: ${oc.env:ASVSPOOF_PATH} +batch_size: 16 # Needs to be divisible by the number of devices (e.g., if in a distributed setup) +num_workers: 4 +pin_memory: False +args: + # The number of seconds to consider for each audio file + duration: 3 + # The sampling rate of the audio files + sampling_rate: 16000 \ No newline at end of file diff --git a/configs/experiment/test_aasist.yaml b/configs/experiment/test_aasist.yaml index f1781ee..5f1061f 100644 --- a/configs/experiment/test_aasist.yaml +++ b/configs/experiment/test_aasist.yaml @@ -4,7 +4,7 @@ # python train.py experiment=example defaults: - - override /data: asvspoof + - override /data: asvspoof_reproduce - override /model: aasist - override /callbacks: default - override /trainer: default @@ -12,7 +12,7 @@ defaults: # all parameters below will be merged with parameters from default configurations set above # this allows you to overwrite only specified parameters -tags: ["asvspoof", "aasist"] +tags: ["asvspoof_reproduce", "aasist"] seed: 12345 diff --git a/configs/experiment/test_ssl.yaml b/configs/experiment/test_ssl.yaml index f702251..366c766 100644 --- a/configs/experiment/test_ssl.yaml +++ b/configs/experiment/test_ssl.yaml @@ -4,7 +4,7 @@ # python train.py experiment=example defaults: - - override /data: asvspoof + - override /data: asvspoof_reproduce - override /model: xlsr_aasist - override /callbacks: default - override /trainer: default @@ -12,7 +12,7 @@ defaults: # all parameters below will be merged with parameters from default configurations set above # this allows you to overwrite only specified parameters -tags: ["asvspoof", "xlsr_aasist"] +tags: ["asvspoof_reproduce", "xlsr_aasist"] seed: 12345 diff --git a/src/data/asvspoof_datamodule.py b/src/data/asvspoof_datamodule.py index bdb1d4d..dac0538 100644 --- a/src/data/asvspoof_datamodule.py +++ b/src/data/asvspoof_datamodule.py @@ -10,22 +10,13 @@ from scipy import signal import copy from src.data.components.RawBoost import process_Rawboost_feature - +from src.data.components.dataio import load_audio, pad ''' Hemlata Tak, Madhu Kamble, Jose Patino, Massimiliano Todisco, Nicholas Evans. RawBoost: A Raw Data Boosting and Augmentation Method applied to Automatic Speaker Verification Anti-Spoofing. In Proc. ICASSP 2022, pp:6382--6386. ''' -def pad(x, max_len=64600): - x_len = x.shape[0] - if x_len >= max_len: - return x[:max_len] - # need to pad - num_repeats = int(max_len / x_len)+1 - padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0] - return padded_x - class Dataset_ASVspoof2019_train(Dataset): def __init__(self,args,list_IDs, labels, base_dir,algo): '''self.list_IDs : list of strings (each string: utt key), @@ -36,7 +27,13 @@ def __init__(self,args,list_IDs, labels, base_dir,algo): self.base_dir = base_dir self.algo=algo self.args=args - self.cut=64600 # take ~4 sec audio (64600 samples) + + # Sampling rate and cut-off + print('args:',args) + self.fs = args.get('sampling_rate', 16000) if args is not None else 16000 + self.cut = args.get('cut', 64600) if args is not None else 64600 + self.padding_type = args.get('padding_type', 'zero') if args is not None else 'zero' + self.random_start = args.get('random_start', False) if args is not None else False def __len__(self): return len(self.list_IDs) @@ -44,66 +41,35 @@ def __len__(self): def __getitem__(self, index): utt_id = self.list_IDs[index] - X,fs = librosa.load(self.base_dir+utt_id+'.flac', sr=16000) + X,fs = librosa.load(self.base_dir+utt_id+'.flac', sr=self.fs) Y=process_Rawboost_feature(X,fs,self.args,self.algo) - X_pad= pad(Y,self.cut) + X_pad= pad(Y, self.padding_type, self.cut, self.random_start) x_inp= Tensor(X_pad) target = self.labels[utt_id] return x_inp, target class Dataset_ASVspoof2021_eval(Dataset): - def __init__(self, list_IDs, base_dir): + def __init__(self, args, list_IDs, base_dir): self.list_IDs = list_IDs self.base_dir = base_dir - self.cut=64600 # take ~4 sec audio (64600 samples) + + # Sampling rate and cut-off + self.fs = args.get('sampling_rate', 16000) if args is not None else 16000 + self.cut = args.get('cut', 64600) if args is not None else 64600 + self.padding_type = args.get('padding_type', 'zero') if args is not None else 'zero' + self.random_start = args.get('random_start', False) if args is not None else False def __len__(self): return len(self.list_IDs) def __getitem__(self, index): utt_id = self.list_IDs[index] - X, fs = librosa.load(self.base_dir+utt_id+'.flac', sr=16000) - X_pad = pad(X,self.cut) + X, fs = librosa.load(self.base_dir+utt_id+'.flac', sr=self.fs) + X_pad = pad(X,self.padding_type, self.cut, self.random_start) x_inp = Tensor(X_pad) return x_inp,utt_id -class Args: - def __init__(self): - self.database_path = '/your/path/to/data/ASVspoof_database/DF/' - self.protocols_path = '/data/hungdx/Datasets/protocols/database/' - self.batch_size = 14 - self.num_epochs = 100 - self.lr = 0.000001 - self.weight_decay = 0.0001 - self.loss = 'weighted_CCE' - self.seed = 1234 - self.model_path = None - self.comment = None - self.track = 'DF' - self.eval_output = None - self.eval = False - self.is_eval = False - self.eval_part = 0 - self.cudnn_deterministic_toggle = True - self.cudnn_benchmark_toggle = False - self.algo = 3 - self.nBands = 5 - self.minF = 20 - self.maxF = 8000 - self.minBW = 100 - self.maxBW = 1000 - self.minCoeff = 10 - self.maxCoeff = 100 - self.minG = 0 - self.maxG = 0 - self.minBiasLinNonLin = 5 - self.maxBiasLinNonLin = 20 - self.N_f = 5 - self.P = 10 - self.g_sd = 2 - self.SNRmin = 10 - self.SNRmax = 40 class ASVSpoofDataModule(LightningDataModule): """`LightningDataModule` for the ASVSpoof dataset. @@ -155,6 +121,7 @@ def __init__( batch_size: int = 64, num_workers: int = 0, pin_memory: bool = False, + args: Optional[Dict[str, Any]] = None, ) -> None: """Initialize a `ASVSpoofDataModule`. @@ -175,6 +142,7 @@ def __init__( self.batch_size_per_device = batch_size self.data_dir = data_dir + self.args = args @property def num_classes(self) -> int: @@ -209,18 +177,21 @@ def setup(self, stage: Optional[str] = None) -> None: if not self.data_train and not self.data_val and not self.data_test: # define train dataloader - args = Args() - args.database_path = self.data_dir - track = args.track + + self.database_path = self.data_dir + track = 'DF' + prefix_2021 = 'ASVspoof2021.{}'.format(track) + self.protocols_path = '/data/hungdx/Datasets/protocols/database/' + self.algo = self.args.get('algo', -1) if self.args is not None else -1 - d_label_trn,file_train = self.genSpoof_list( dir_meta = os.path.join(args.protocols_path+'ASVspoof_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt'),is_train=True,is_eval=False) - d_label_dev,file_dev = self.genSpoof_list( dir_meta = os.path.join(args.protocols_path+'ASVspoof_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt'),is_train=False,is_eval=False) - file_eval = self.genSpoof_list( dir_meta = os.path.join(args.protocols_path+'ASVspoof_{}_cm_protocols/{}.cm.eval.trl.txt'.format(track,prefix_2021)),is_train=False,is_eval=True) + d_label_trn,file_train = self.genSpoof_list( dir_meta = os.path.join(self.protocols_path+'ASVspoof_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt'),is_train=True,is_eval=False) + d_label_dev,file_dev = self.genSpoof_list( dir_meta = os.path.join(self.protocols_path+'ASVspoof_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt'),is_train=False,is_eval=False) + file_eval = self.genSpoof_list( dir_meta = os.path.join(self.protocols_path+'ASVspoof_{}_cm_protocols/{}.cm.eval.trl.txt'.format(track,prefix_2021)),is_train=False,is_eval=True) - self.data_train = Dataset_ASVspoof2019_train(args,list_IDs = file_train,labels = d_label_trn,base_dir = os.path.join(args.database_path+'ASVspoof2019_LA_train/'),algo=args.algo) - self.data_val = Dataset_ASVspoof2019_train(args,list_IDs = file_dev,labels = d_label_dev,base_dir = os.path.join(args.database_path+'ASVspoof2019_LA_dev/'),algo=args.algo) - self.data_test = Dataset_ASVspoof2021_eval(list_IDs = file_eval,base_dir = os.path.join(args.database_path+'ASVspoof2021_{}_eval/'.format(args.track))) + self.data_train = Dataset_ASVspoof2019_train(self.args,list_IDs = file_train,labels = d_label_trn,base_dir = os.path.join(self.database_path+'ASVspoof2019_LA_train/'),algo=self.algo) + self.data_val = Dataset_ASVspoof2019_train(self.args,list_IDs = file_dev,labels = d_label_dev,base_dir = os.path.join(self.database_path+'ASVspoof2019_LA_dev/'),algo=self.algo) + self.data_test = Dataset_ASVspoof2021_eval(self.args, list_IDs = file_eval,base_dir = os.path.join(self.database_path+'ASVspoof2021_{}_eval/'.format(track))) def train_dataloader(self) -> DataLoader[Any]: diff --git a/src/data/components/augwrapper.py b/src/data/components/augwrapper.py new file mode 100644 index 0000000..0110014 --- /dev/null +++ b/src/data/components/augwrapper.py @@ -0,0 +1,1142 @@ +import os +import numpy as np +import torch +import torchaudio +from torch import Tensor +import librosa +from datautils.RawBoost import ISD_additive_noise, LnL_convolutive_noise, SSI_additive_noise, normWav +from datautils.audio_augmentor import BackgroundNoiseAugmentor, PitchAugmentor, ReverbAugmentor, SpeedAugmentor, VolumeAugmentor, TelephoneEncodingAugmentor, GaussianAugmentor, CopyPasteAugmentor, BaseAugmentor, TimeMaskingAugmentor, FrequencyMaskingAugmentor, MaskingAugmentor, TimeSwapAugmentor, FrequencySwapAugmentor, SwappingAugmentor, LinearFilterAugmentor, BandpassAugmentor +from datautils.audio_augmentor.utils import pydub_to_librosa, librosa_to_pydub +import soundfile as sf +import random + +SUPPORTED_AUGMENTATION = [ + 'background_noise_5_15', 'pitch_1', 'volume_10', 'reverb_1', 'speed_01', 'telephone_g722', 'gaussian_1', 'gaussian_2', 'gaussian_2_5', 'gaussian_3', + 'RawBoostdf', 'RawBoost12', 'copy_paste_80', 'copy_paste_r', 'time_masking', 'masking', 'time_swap', + 'freq_swap', 'swapping', 'frequency_masking', 'linear_filter', 'mp32flac', 'ogg2flac', 'nonspeechtrim', + 'bandpass_0_4000', 'griffinlim_downsample', 'lowpass_hifigan_asvspoof5', 'lowpass_hifigan', 'librosa_downsample'] + + +def audio_transform(filepath: str, aug_type: BaseAugmentor, config: dict, online: bool = False): + """ + filepath: str, input audio file path + aug_type: BaseAugmentor, augmentation type object + config: dict, configuration dictionary + online: bool, if True, return the augmented audio waveform, else save the augmented audio file + """ + at = aug_type(config) + at.load(filepath) + at.transform() + if online: + audio = at.augmented_audio + return pydub_to_librosa(audio) + else: + at.save() + + +def background_noise_5_15(x, args, sr=16000, audio_path=None): + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + args.input_path = os.path.dirname(audio_path) + aug_audio_path = os.path.join(aug_dir, 'background_noise', utt_id + '.wav') + args.output_path = os.path.join(aug_dir, 'background_noise') + args.out_format = 'wav' + config = { + "aug_type": "background_noise", + "output_path": args.output_path, + "out_format": args.out_format, + "noise_path": args.noise_path, + "min_SNR_dB": 5, + "max_SNR_dB": 15 + } + if (args.online_aug): + waveform = audio_transform( + filepath=audio_path, aug_type=BackgroundNoiseAugmentor, config=config, online=True) + # waveform,_ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + if os.path.exists(aug_audio_path): + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + audio_transform( + filepath=audio_path, aug_type=BackgroundNoiseAugmentor, config=config, online=False) + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + + +def pitch_1(x, args, sr=16000, audio_path=None): + """ + Augment the audio with pitch shift of -1 to 1 + """ + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_audio_path = os.path.join(aug_dir, 'pitch', utt_id + '.wav') + args.output_path = os.path.join(aug_dir, 'pitch') + args.out_format = 'wav' + args.input_path = os.path.dirname(audio_path) + + config = { + "aug_type": "pitch", + "output_path": args.output_path, + "out_format": args.out_format, + "min_pitch_shift": -1, + "max_pitch_shift": 1 + } + if (args.online_aug): + waveform = audio_transform( + filepath=audio_path, aug_type=PitchAugmentor, config=config, online=True) + # waveform,_ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + if os.path.exists(aug_audio_path): + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + audio_transform(filepath=audio_path, + aug_type=PitchAugmentor, config=config, online=False) + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + + +def volume_10(x, args, sr=16000, audio_path=None): + """ + Augment the audio with volume change of -10 to 10 dBFS + """ + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_audio_path = os.path.join(aug_dir, 'volume', utt_id + '.wav') + args.output_path = os.path.join(aug_dir, 'volume') + args.out_format = 'wav' + args.input_path = os.path.dirname(audio_path) + + config = { + "aug_type": "volume", + "output_path": args.output_path, + "out_format": args.out_format, + "min_volume_dBFS": -10, + "max_volume_dBFS": 10 + } + if (args.online_aug): + waveform = audio_transform( + filepath=audio_path, aug_type=VolumeAugmentor, config=config, online=True) + # waveform,_ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + if os.path.exists(aug_audio_path): + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + audio_transform( + filepath=audio_path, aug_type=VolumeAugmentor, config=config, online=False) + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + + +def reverb_1(x, args, sr=16000, audio_path=None): + """ + Augment the audio with reverb effect + """ + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_audio_path = os.path.join(aug_dir, 'reverb', utt_id + '.wav') + args.output_path = os.path.join(aug_dir, 'reverb') + args.out_format = 'wav' + args.input_path = os.path.dirname(audio_path) + config = { + "aug_type": "reverb", + "output_path": args.output_path, + "out_format": args.out_format, + "rir_path": args.rir_path, + } + if (args.online_aug): + waveform = audio_transform( + filepath=audio_path, aug_type=ReverbAugmentor, config=config, online=True) + # waveform,_ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + if os.path.exists(aug_audio_path): + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + audio_transform( + filepath=audio_path, aug_type=ReverbAugmentor, config=config, online=False) + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + + +def speed_01(x, args, sr=16000, audio_path=None): + """ + Augment the audio with speed change of 0.9 to 1.1 + """ + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_audio_path = os.path.join(aug_dir, 'speed', utt_id + '.wav') + args.output_path = os.path.join(aug_dir, 'speed') + args.out_format = 'wav' + args.input_path = os.path.dirname(audio_path) + config = { + "aug_type": "speed", + "output_path": args.output_path, + "out_format": args.out_format, + "min_speed_factor": 0.9, + "max_speed_factor": 1.1 + } + if (args.online_aug): + waveform = audio_transform( + filepath=audio_path, aug_type=SpeedAugmentor, config=config, online=True) + # waveform,_ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + if os.path.exists(aug_audio_path): + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + audio_transform(filepath=audio_path, + aug_type=SpeedAugmentor, config=config, online=False) + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + + +def telephone_g722(x, args, sr=16000, audio_path=None): + """ + Augment the audio with telephone encoding g722 and bandpass filter + """ + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_audio_path = os.path.join(aug_dir, 'telephone', utt_id + '.wav') + args.output_path = os.path.join(aug_dir, 'telephone') + args.out_format = 'wav' + args.input_path = os.path.dirname(audio_path) + config = { + "aug_type": "telephone", + "output_path": args.output_path, + "out_format": args.out_format, + "encoding": "g722", + "bandpass": { + "lowpass": "3400", + "highpass": "400" + } + } + if (args.online_aug): + waveform = audio_transform( + filepath=audio_path, aug_type=TelephoneEncodingAugmentor, config=config, online=True) + # waveform,_ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + if os.path.exists(aug_audio_path): + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + audio_transform( + filepath=audio_path, aug_type=TelephoneEncodingAugmentor, config=config, online=False) + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + + +def bandpass_0_4000(x, args, sr=16000, audio_path=None): + """ + Augment the audio with bandpass filter of 0 to 4000 Hz + """ + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_audio_path = os.path.join(aug_dir, 'bandpass', utt_id + '.wav') + args.output_path = os.path.join(aug_dir, 'bandpass') + args.out_format = 'wav' + args.input_path = os.path.dirname(audio_path) + config = { + "aug_type": "bandpass", + "output_path": args.output_path, + "out_format": args.out_format, + "lowpass": "4000", + "highpass": "0" + } + if (args.online_aug): + waveform = audio_transform( + filepath=audio_path, aug_type=BandpassAugmentor, config=config, online=True) + # waveform,_ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + if os.path.exists(aug_audio_path): + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + audio_transform( + filepath=audio_path, aug_type=BandpassAugmentor, config=config, online=False) + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + + +def gaussian_1(x, args, sr=16000, audio_path=None): + """ + Augment the audio with gaussian noise in the range of 0.001 to 0.015 + """ + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_audio_path = os.path.join(aug_dir, 'gaussian_noise', utt_id + '.wav') + args.output_path = os.path.join(aug_dir, 'gaussian_noise') + args.out_format = 'wav' + args.input_path = os.path.dirname(audio_path) + config = { + "aug_type": "guassian_noise", + "output_path": args.output_path, + "out_format": args.out_format, + "min_amplitude": 0.001, + "max_amplitude": 0.015 + } + if (args.online_aug): + waveform = audio_transform( + filepath=audio_path, aug_type=GaussianAugmentor, config=config, online=True) + # waveform,_ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + if os.path.exists(aug_audio_path): + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + audio_transform( + filepath=audio_path, aug_type=GaussianAugmentor, config=config, online=False) + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + + +def gaussian_2(x, args, sr=16000, audio_path=None): + """ + Augment the audio with gaussian noise in the range of 0.001 to 0.015 + """ + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_audio_path = os.path.join(aug_dir, 'gaussian_noise', utt_id + '.wav') + args.output_path = os.path.join(aug_dir, 'gaussian_noise') + args.out_format = 'wav' + args.input_path = os.path.dirname(audio_path) + config = { + "aug_type": "guassian_noise", + "output_path": args.output_path, + "out_format": args.out_format, + "min_amplitude": 0.00001, + "max_amplitude": 0.00015 + } + if (args.online_aug): + waveform = audio_transform( + filepath=audio_path, aug_type=GaussianAugmentor, config=config, online=True) + # waveform,_ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + if os.path.exists(aug_audio_path): + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + audio_transform( + filepath=audio_path, aug_type=GaussianAugmentor, config=config, online=False) + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + + +def gaussian_2_5(x, args, sr=16000, audio_path=None): + """ + Augment the audio with gaussian noise in the range of 0.001 to 0.015 + """ + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_audio_path = os.path.join(aug_dir, 'gaussian_noise', utt_id + '.wav') + args.output_path = os.path.join(aug_dir, 'gaussian_noise') + args.out_format = 'wav' + args.input_path = os.path.dirname(audio_path) + config = { + "aug_type": "guassian_noise", + "output_path": args.output_path, + "out_format": args.out_format, + "min_amplitude": 0.000002, + "max_amplitude": 0.00003 + } + if (args.online_aug): + waveform = audio_transform( + filepath=audio_path, aug_type=GaussianAugmentor, config=config, online=True) + # waveform,_ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + if os.path.exists(aug_audio_path): + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + audio_transform( + filepath=audio_path, aug_type=GaussianAugmentor, config=config, online=False) + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + + +def gaussian_3(x, args, sr=16000, audio_path=None): + """ + Augment the audio with gaussian noise in the range of 0.001 to 0.015 + """ + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_audio_path = os.path.join(aug_dir, 'gaussian_noise', utt_id + '.wav') + args.output_path = os.path.join(aug_dir, 'gaussian_noise') + args.out_format = 'wav' + args.input_path = os.path.dirname(audio_path) + config = { + "aug_type": "guassian_noise", + "output_path": args.output_path, + "out_format": args.out_format, + "min_amplitude": 0.000001, + "max_amplitude": 0.000015 + } + if (args.online_aug): + waveform = audio_transform( + filepath=audio_path, aug_type=GaussianAugmentor, config=config, online=True) + # waveform,_ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + if os.path.exists(aug_audio_path): + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + audio_transform( + filepath=audio_path, aug_type=GaussianAugmentor, config=config, online=False) + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + + +def copy_paste_r(x, args, sr=16000, audio_path=None): + """ + Augment the audio with copy paste of 80% of the audio, frame size is 800 samples + """ + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_audio_path = os.path.join(aug_dir, 'copy_paste', utt_id + '.wav') + args.output_path = os.path.join(aug_dir, 'copy_paste') + args.out_format = 'wav' + args.input_path = os.path.dirname(audio_path) + config = { + "aug_type": "copy_paste", + "output_path": args.output_path, + "out_format": args.out_format, + "frame_size": 0, + "shuffle_ratio": random.uniform(0.3, 1) + } + if (args.online_aug): + waveform = audio_transform( + filepath=audio_path, aug_type=CopyPasteAugmentor, config=config, online=True) + # waveform,_ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + if os.path.exists(aug_audio_path): + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + audio_transform( + filepath=audio_path, aug_type=CopyPasteAugmentor, config=config, online=False) + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + + +def copy_paste_80(x, args, sr=16000, audio_path=None): + """ + Augment the audio with copy paste of 80% of the audio, frame size is 800 samples + """ + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_audio_path = os.path.join(aug_dir, 'copy_paste', utt_id + '.wav') + args.output_path = os.path.join(aug_dir, 'copy_paste') + args.out_format = 'wav' + args.input_path = os.path.dirname(audio_path) + config = { + "aug_type": "copy_paste", + "output_path": args.output_path, + "out_format": args.out_format, + "frame_size": 800, + "shuffle_ratio": 0.8 + } + if (args.online_aug): + waveform = audio_transform( + filepath=audio_path, aug_type=CopyPasteAugmentor, config=config, online=True) + # waveform,_ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + if os.path.exists(aug_audio_path): + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + audio_transform( + filepath=audio_path, aug_type=CopyPasteAugmentor, config=config, online=False) + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + + +def masking(x, args, sr=16000, audio_path=None): + """ + Apply frequency and time masking to the audio file. + """ + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_audio_path = os.path.join(aug_dir, 'masking', utt_id + '.wav') + args.output_path = os.path.join(aug_dir, 'masking') + args.out_format = 'wav' + args.input_path = os.path.dirname(audio_path) + + config = { + "aug_type": "masking", + "output_path": args.output_path, + "out_format": args.out_format, + "F": 27, + "num_frequency_masks": 2, + "T": 100, + "p": 1.0, + "num_time_masks": 2 + } + + if args.online_aug: + waveform = audio_transform( + filepath=audio_path, aug_type=MaskingAugmentor, config=config, online=True) + return waveform + else: + if os.path.exists(aug_audio_path): + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + augmentor = MaskingAugmentor(config) + augmentor.load(audio_path) + augmentor.transform() + augmentor.save(aug_audio_path) + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + + +def time_swap(x, args, sr=16000, audio_path=None): + """ + Apply time swapping to the audio file. + """ + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_audio_path = os.path.join(aug_dir, 'time_swap', utt_id + '.wav') + args.output_path = os.path.join(aug_dir, 'time_swap') + args.out_format = 'wav' + args.input_path = os.path.dirname(audio_path) + + config = { + "aug_type": "time_swap", + "output_path": args.output_path, + "out_format": args.out_format, + "T": 40, + "num_swaps": 1 + } + + if args.online_aug: + waveform = audio_transform( + filepath=audio_path, aug_type=TimeSwapAugmentor, config=config, online=True) + return waveform + else: + if os.path.exists(aug_audio_path): + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + augmentor = TimeSwapAugmentor(config) + augmentor.load(audio_path) + augmentor.transform() + augmentor.save() + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + + +def freq_swap(x, args, sr=16000, audio_path=None): + """ + Apply frequency swapping to the audio file. + """ + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_audio_path = os.path.join(aug_dir, 'freq_swap', utt_id + '.wav') + args.output_path = os.path.join(aug_dir, 'freq_swap') + args.out_format = 'wav' + args.input_path = os.path.dirname(audio_path) + + config = { + "aug_type": "freq_swap", + "output_path": args.output_path, + "out_format": args.out_format, + "F": 7, + "num_swaps": 1 + } + + if args.online_aug: + waveform = audio_transform( + filepath=audio_path, aug_type=FrequencySwapAugmentor, config=config, online=True) + return waveform + else: + if os.path.exists(aug_audio_path): + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + augmentor = FrequencySwapAugmentor(config) + augmentor.load(audio_path) + augmentor.transform() + augmentor.save() + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + + +def swapping(x, args, sr=16000, audio_path=None): + """ + Apply time and frequency swapping to the audio file. + """ + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_audio_path = os.path.join(aug_dir, 'swapping', utt_id + '.wav') + args.output_path = os.path.join(aug_dir, 'swapping') + args.out_format = 'wav' + args.input_path = os.path.dirname(audio_path) + + config = { + "aug_type": "swapping", + "output_path": args.output_path, + "out_format": args.out_format, + "T": 40, + "F": 7 + } + + if args.online_aug: + waveform = audio_transform( + filepath=audio_path, aug_type=SwappingAugmentor, config=config, online=True) + return waveform + else: + if os.path.exists(aug_audio_path): + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + augmentor = SwappingAugmentor(config) + augmentor.load(audio_path) + augmentor.transform() + augmentor.save() + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + + +def linear_filter(x, args, sr=16000, audio_path=None): + """ + Apply linear filter augmentation to the audio file. + """ + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_audio_path = os.path.join(aug_dir, 'linear_filter', utt_id + '.wav') + args.output_path = os.path.join(aug_dir, 'linear_filter') + args.out_format = 'wav' + args.input_path = os.path.dirname(audio_path) + + config = { + "aug_type": "linear_filter", + "output_path": args.output_path, + "out_format": args.out_format, + "db_range": [-6, 6], + "n_band": [3, 6], + "min_bw": 6 + } + + if args.online_aug: + waveform = audio_transform( + filepath=audio_path, aug_type=LinearFilterAugmentor, config=config, online=True) + return waveform + else: + if os.path.exists(aug_audio_path): + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + augmentor = LinearFilterAugmentor(config) + augmentor.load(audio_path) + augmentor.transform() + augmentor.save() + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform +# --------------Offline only augmentation algorithms---------------------------## + + +def time_masking(x, args, sr=16000, audio_path=None): + """ + Apply time masking to the audio file. + """ + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_audio_path = os.path.join(aug_dir, 'time_masking', utt_id + '.wav') + args.output_path = os.path.join(aug_dir, 'time_masking') + args.out_format = 'wav' + args.input_path = os.path.dirname(audio_path) + + config = { + "aug_type": "time_masking", + "output_path": args.output_path, + "out_format": args.out_format, + "T": 100, + "p": 1.0, + "num_masks": 2 + } + if os.path.exists(aug_audio_path): + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + # augmentor = TimeMaskingAugmentor(config) + # augmentor.load(audio_path) + # augmentor.transform() + # augmentor.save() + # waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + waveform, _ = librosa.load(audio_path, sr=sr, mono=True) + return waveform + + +def frequency_masking(x, args, sr=16000, audio_path=None): + """ + Apply frequency masking to the audio file. + """ + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_audio_path = os.path.join( + aug_dir, 'frequency_masking', utt_id + '.wav') + args.output_path = os.path.join(aug_dir, 'frequency_masking') + args.out_format = 'wav' + args.input_path = os.path.dirname(audio_path) + + config = { + "aug_type": "frequency_masking", + "output_path": args.output_path, + "out_format": args.out_format, + "F": 27, + "num_masks": 2 + } + # this aug do not support online augmentation + if os.path.exists(aug_audio_path): + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + # augmentor = FrequencyMaskingAugmentor(config) + # augmentor.load(audio_path) + # augmentor.transform() + # augmentor.save() + # waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + waveform, _ = librosa.load(audio_path, sr=sr, mono=True) + return waveform + + +def mp32flac(x, args, sr=16000, audio_path=None): + """ + Augment the audio with mp3 codec + This codec is only available offline + """ + # print('mp32flac: audio_path:', audio_path) + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_path = os.path.join(aug_dir, 'mp32flac', utt_id + '_from_mp3.flac') + # check if the augmented file exists + if (os.path.exists(aug_path)): + waveform, _ = librosa.load(aug_path, sr=sr, mono=True) + return waveform + else: + # Alert ERROR and stop the process + # print("Error: mp3 to flac conversion is only available offline\n Please convert to mp3 your data before running the script") + # using ffmpeg to convert original audio to mp3 + os.system( + f'ffmpeg -loglevel quiet -i {audio_path} {aug_path.replace(".flac", ".mp3")} -y') + # using ffmpeg to convert mp3 to flac + os.system( + f'ffmpeg -loglevel quiet -i {aug_path.replace(".flac", ".mp3")} {aug_path} -y') + # remove the mp3 file + os.system(f'rm {aug_path.replace(".flac", ".mp3")}') + waveform, _ = librosa.load(aug_path, sr=sr, mono=True) + return waveform + +## Speex codec +def speex2flac_high_band_varied_bitrate(x, args, sr=16000, audio_path=None): + """ + Augment the audio with speex codec + This codec is only available offline + """ + # print('speex2flac: audio_path:', audio_path) + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_folder_path = os.path.join(aug_dir, 'speex2flac_high_band_varied_bitrate') + + os.makedirs(aug_folder_path, exist_ok=True) + + aug_path = os.path.join(aug_folder_path, utt_id + '_from_speex_.flac') + # check if the augmented file exists + if (os.path.exists(aug_path)): + waveform, _ = librosa.load(aug_path, sr=sr, mono=True) + return waveform + else: + # Choice a random bitrate between 1.5 and 128 kbs + bitrate = random.randrange(1500, 128000) / 1000 + spx_aug_path = aug_path.replace(".flac", ".ogg") + os.system( + f'ffmpeg -loglevel quiet -i {audio_path} -acodec libspeex -b:a {bitrate}k {spx_aug_path} -y') + # using ffmpeg to convert speex to flac + os.system( + f'ffmpeg -loglevel quiet -i {spx_aug_path} {aug_path} -y') + # remove the speex file + os.system(f'rm {spx_aug_path}') + waveform, _ = librosa.load(aug_path, sr=sr, mono=True) + return waveform + +def speex2flac_low_band_varied_bitrate(x, args, sr=16000, audio_path=None): + """ + Augment the audio with speex codec + This codec is only available offline + """ + # print('speex2flac: audio_path:', audio_path) + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_folder_path = os.path.join(aug_dir, 'speex2flac_low_band_varied_bitrate') + downsample_sr = 8000 + upsample_sr = 16000 + + os.makedirs(aug_folder_path, exist_ok=True) + + aug_path = os.path.join(aug_folder_path, utt_id + '_from_speex_.flac') + # check if the augmented file exists + if (os.path.exists(aug_path)): + waveform, _ = librosa.load(aug_path, sr=sr, mono=True) + return waveform + else: + # Choice a random bitrate between 5.75 and 34.20 + bitrate = random.randrange(1500, 128000) / 1000 + + # Apply the speex codec + spx_aug_path = aug_path.replace(".flac", ".ogg") + os.system( + f'ffmpeg -loglevel quiet -i {audio_path} -ar {downsample_sr} -acodec libspeex -b:a {bitrate}k {spx_aug_path} -y') + os.system( + f'ffmpeg -loglevel quiet -i {spx_aug_path} -ar {upsample_sr} {aug_path} -y') + + # remove the speex file + os.system(f'rm {spx_aug_path}') + waveform, _ = librosa.load(aug_path, sr=sr, mono=True) + return waveform + + +### END Speex codec + +### OPUS codec +def opus2flac_high_band_varied_bitrate(x, args, sr=16000, audio_path=None): + """ + Augment the audio with opus codec + This codec is only available offline + """ + # print('opus2flac: audio_path:', audio_path) + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_folder_path = os.path.join(aug_dir, 'opus2flac_high_band_varied_bitrate') + upsample_sr = 16000 + os.makedirs(aug_folder_path, exist_ok=True) + + aug_path = os.path.join(aug_folder_path, utt_id + '_from_opus_.flac') + # check if the augmented file exists + if (os.path.exists(aug_path)): + waveform, _ = librosa.load(aug_path, sr=sr, mono=True) + return waveform + else: + # Choice a random bitrate between 1.5 and 128 kbs + bitrate = random.randrange(1500, 128000) / 1000 + opus_aug_path = aug_path.replace(".flac", ".opus") + os.system( + f'ffmpeg -loglevel quiet -i {audio_path} -acodec libopus -b:a {bitrate}k {opus_aug_path} -y') + # using ffmpeg to convert opus to flac + os.system( + f'ffmpeg -loglevel quiet -i {opus_aug_path} -ar {upsample_sr} -ac 1 -sample_fmt s16 -c:a flac {aug_path} -y') + # remove the opus file + os.system(f'rm {opus_aug_path}') + waveform, _ = librosa.load(aug_path, sr=sr, mono=True) + return waveform + +def opus2flac_low_band_varied_bitrate(x, args, sr=16000, audio_path=None): + """ + Augment the audio with opus codec + This codec is only available offline + """ + # print('opus2flac: audio_path:', audio_path) + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_folder_path = os.path.join(aug_dir, 'opus2flac_low_band_varied_bitrate') + downsample_sr = 8000 + upsample_sr = 16000 + + os.makedirs(aug_folder_path, exist_ok=True) + + aug_path = os.path.join(aug_folder_path, utt_id + '_from_opus_.flac') + # check if the augmented file exists + if (os.path.exists(aug_path)): + waveform, _ = librosa.load(aug_path, sr=sr, mono=True) + return waveform + else: + # Choice a random bitrate between 5.75 and 34.20 + bitrate = random.randrange(1500, 128000) / 1000 + + # Apply the opus codec + opus_aug_path = aug_path.replace(".flac", ".opus") + os.system( + f'ffmpeg -loglevel quiet -i {audio_path} -ar {downsample_sr} -acodec libopus -b:a {bitrate}k {opus_aug_path} -y') + os.system( + f'ffmpeg -loglevel quiet -i {opus_aug_path} -ar {upsample_sr} -ac 1 -sample_fmt s16 -c:a flac {aug_path} -y') + + # remove the opus file + os.system(f'rm {opus_aug_path}') + waveform, _ = librosa.load(aug_path, sr=sr, mono=True) + return waveform +### END OPUS codec + +def ogg2flac(x, args, sr=16000, audio_path=None): + """ + Augment the audio with ogg codec + This codec is only available offline + """ + # print('ogg2flac: audio_path:', audio_path) + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_path = os.path.join(aug_dir, 'ogg2flac', utt_id + '_from_ogg.flac') + # check if the augmented file exists + if (os.path.exists(aug_path)): + waveform, _ = librosa.load(aug_path, sr=sr, mono=True) + return waveform + else: + # Alert ERROR and stop the process + # print("Error: ogg to flac conversion is only available offline\n Please convert to ogg your data before running the script") + # using ffmpeg to convert original audio to ogg + os.system( + f'ffmpeg -loglevel quiet -i {audio_path} {aug_path.replace(".flac", ".ogg")} -y') + # using ffmpeg to convert ogg to flac + os.system( + f'ffmpeg -loglevel quiet -i {aug_path.replace(".flac", ".ogg")} {aug_path} -y') + # remove the ogg file + os.system(f'rm {aug_path.replace(".flac", ".ogg")}') + waveform, _ = librosa.load(aug_path, sr=sr, mono=True) + return waveform + + +def nonspeechtrim(x, args, sr=16000, audio_path=None): + """ + Augment the audio with non-speech trimming beginning and end + This augmentation is only available offline + """ + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_path = os.path.join(aug_dir, 'nonspeechtrim', utt_id + '.wav') + + if (os.path.exists(aug_path)): + waveform, _ = librosa.load(aug_path, sr=sr, mono=True) + return waveform + else: + # Alert ERROR and stop the process + # print("Error: nonspeech trimming is only available offline\n Please convert to oog your data before running the script") + y, _ = librosa.load(audio_path, sr=sr, mono=True) + waveform = librosa.effects.trim(y, top_db=10)[0] + sf.write(aug_path, waveform, sr, subtype='PCM_16') + return waveform + + +def griffinlim_downsample(x, args, sr=16000, audio_path=None): + """ + Augment the audio with neural codec + This codec is only available offline + """ + # print('ogg2flac: audio_path:', audio_path) + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_path = os.path.join(aug_dir, 'griffinlim_downsample', utt_id + '.wav') + # check if the augmented file exists + if (os.path.exists(aug_path)): + waveform, _ = librosa.load(aug_path, sr=sr, mono=True) + return waveform + else: + y, sr = librosa.load(audio_path, sr=sr, mono=True) + # downsample to 8000 + y = librosa.resample(y, orig_sr=sr, target_sr=8000) + y = librosa.resample(y, orig_sr=8000, target_sr=16000) + # Griffin-Lim + # print("D1") + S = np.abs(librosa.stft(y)) + # print("D2") + y_inv = librosa.griffinlim(S) + sf.write(aug_path, y_inv, sr, subtype='PCM_16') + return y_inv + + +def librosa_downsample(x, args, sr=16000, audio_path=None): + """ + Augment the audio with neural codec + This codec is only available offline + """ + # print('ogg2flac: audio_path:', audio_path) + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_path = os.path.join(aug_dir, 'librosa_downsample', utt_id + '.wav') + # check if the augmented file exists + if (os.path.exists(aug_path)): + waveform, _ = librosa.load(aug_path, sr=sr, mono=True) + return waveform + else: + y, sr = librosa.load(audio_path, sr=sr, mono=True) + # downsample to 8000 + y = librosa.resample(y, orig_sr=sr, target_sr=8000) + y = librosa.resample(y, orig_sr=8000, target_sr=16000) + sf.write(aug_path, y, sr, subtype='PCM_16') + return y + + +def lowpass_hifigan_asvspoof5(x, args, sr=16000, audio_path=None): + """ + Augment the audio with neural codec + This codec is only available offline + """ + # print('ogg2flac: audio_path:', audio_path) + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_path = os.path.join( + aug_dir, 'lowpass_hifigan_asvspoof5', utt_id + '.wav') + # check if the augmented file exists + if (os.path.exists(aug_path)): + waveform, _ = librosa.load(aug_path, sr=sr, mono=True) + return waveform + else: + # just load the original audio + return x + + +def lowpass_hifigan(x, args, sr=16000, audio_path=None): + """ + Augment the audio with neural codec + This codec is only available offline + """ + # print('ogg2flac: audio_path:', audio_path) + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_path = os.path.join(aug_dir, 'lowpass_hifigan', utt_id + '.wav') + # check if the augmented file exists + if (os.path.exists(aug_path)): + waveform, _ = librosa.load(aug_path, sr=sr, mono=True) + return waveform + else: + # just load the original audio + return x + +# --------------RawBoost data augmentation algorithms---------------------------## + + +def RawBoost12(x, args, sr=16000, audio_path=None): + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_audio_path = os.path.join(aug_dir, 'RawBoost12', utt_id + '.wav') + if args.online_aug: + return process_Rawboost_feature(x, sr, args, algo=5) + else: + # check if the augmented file exists + if (os.path.exists(aug_audio_path)): + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + waveform = process_Rawboost_feature(x, sr, args, algo=5) + # save the augmented file,waveform in np array + sf.write(aug_audio_path, waveform, sr, subtype='PCM_16') + return waveform + + +def RawBoostdf(x, args, sr=16000, audio_path=None): + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_audio_path = os.path.join(aug_dir, 'RawBoostdf', utt_id + '.wav') + if args.online_aug: + return process_Rawboost_feature(x, sr, args, algo=3) + else: + # check if the augmented file exists + if (os.path.exists(aug_audio_path)): + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + waveform = process_Rawboost_feature(x, sr, args, algo=3) + # save the augmented file,waveform in np array + sf.write(aug_audio_path, waveform, sr, subtype='PCM_16') + return waveform + + +def RawBoostFull(x, args, sr=16000, audio_path=None): + ''' + Apply all 3 algo. together in series (1+2+3) + ''' + aug_dir = args.aug_dir + utt_id = os.path.basename(audio_path).split('.')[0] + aug_audio_path = os.path.join(aug_dir, 'RawBoostFull', utt_id + '.wav') + if args.online_aug: + return process_Rawboost_feature(x, sr, args, algo=4) + else: + # check if the augmented file exists + if (os.path.exists(aug_audio_path)): + waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True) + return waveform + else: + waveform = process_Rawboost_feature(x, sr, args, algo=4) + # save the augmented file,waveform in np array + sf.write(aug_audio_path, waveform, sr, subtype='PCM_16') + return waveform + + +def process_Rawboost_feature(feature, sr, args, algo): + + # Data process by Convolutive noise (1st algo) + if algo == 1: + + feature = LnL_convolutive_noise(feature, args.N_f, args.nBands, args.minF, args.maxF, args.minBW, args.maxBW, + args.minCoeff, args.maxCoeff, args.minG, args.maxG, args.minBiasLinNonLin, args.maxBiasLinNonLin, sr) + + # Data process by Impulsive noise (2nd algo) + elif algo == 2: + + feature = ISD_additive_noise(feature, args.P, args.g_sd) + + # Data process by coloured additive noise (3rd algo) + elif algo == 3: + + feature = SSI_additive_noise(feature, args.SNRmin, args.SNRmax, args.nBands, args.minF, + args.maxF, args.minBW, args.maxBW, args.minCoeff, args.maxCoeff, args.minG, args.maxG, sr) + + # Data process by all 3 algo. together in series (1+2+3) + elif algo == 4: + + feature = LnL_convolutive_noise(feature, args.N_f, args.nBands, args.minF, args.maxF, args.minBW, args.maxBW, + args.minCoeff, args.maxCoeff, args.minG, args.maxG, args.minBiasLinNonLin, args.maxBiasLinNonLin, sr) + feature = ISD_additive_noise(feature, args.P, args.g_sd) + feature = SSI_additive_noise(feature, args.SNRmin, args.SNRmax, args.nBands, args.minF, + args.maxF, args.minBW, args.maxBW, args.minCoeff, args.maxCoeff, args.minG, args.maxG, sr) + + # Data process by 1st two algo. together in series (1+2) + elif algo == 5: + + feature = LnL_convolutive_noise(feature, args.N_f, args.nBands, args.minF, args.maxF, args.minBW, args.maxBW, + args.minCoeff, args.maxCoeff, args.minG, args.maxG, args.minBiasLinNonLin, args.maxBiasLinNonLin, sr) + feature = ISD_additive_noise(feature, args.P, args.g_sd) + + # Data process by 1st and 3rd algo. together in series (1+3) + elif algo == 6: + + feature = LnL_convolutive_noise(feature, args.N_f, args.nBands, args.minF, args.maxF, args.minBW, args.maxBW, + args.minCoeff, args.maxCoeff, args.minG, args.maxG, args.minBiasLinNonLin, args.maxBiasLinNonLin, sr) + feature = SSI_additive_noise(feature, args.SNRmin, args.SNRmax, args.nBands, args.minF, + args.maxF, args.minBW, args.maxBW, args.minCoeff, args.maxCoeff, args.minG, args.maxG, sr) + + # Data process by 2nd and 3rd algo. together in series (2+3) + elif algo == 7: + + feature = ISD_additive_noise(feature, args.P, args.g_sd) + feature = SSI_additive_noise(feature, args.SNRmin, args.SNRmax, args.nBands, args.minF, + args.maxF, args.minBW, args.maxBW, args.minCoeff, args.maxCoeff, args.minG, args.maxG, sr) + + # Data process by 1st two algo. together in Parallel (1||2) + elif algo == 8: + + feature1 = LnL_convolutive_noise(feature, args.N_f, args.nBands, args.minF, args.maxF, args.minBW, args.maxBW, + args.minCoeff, args.maxCoeff, args.minG, args.maxG, args.minBiasLinNonLin, args.maxBiasLinNonLin, sr) + feature2 = ISD_additive_noise(feature, args.P, args.g_sd) + + feature_para = feature1+feature2 + feature = normWav(feature_para, 0) # normalized resultant waveform + + # original data without Rawboost processing + else: + + feature = feature + + return feature diff --git a/src/data/components/baseloader.py b/src/data/components/baseloader.py new file mode 100644 index 0000000..de96c38 --- /dev/null +++ b/src/data/components/baseloader.py @@ -0,0 +1,50 @@ +import os +from torch.utils.data import Dataset + + +class Dataset_base(Dataset): + def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[], + augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2, + trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None, + aug_dir=None, online_aug=False, repeat_pad=True, is_train=True): + """ + Args: + list_IDs (string): Path to the .lst file with real audio filenames. + vocoders (list): list of vocoder names. + augmentation_methods (list): List of augmentation methods to apply. + """ + self.args = args + self.list_IDs = list_IDs + self.labels = labels + self.base_dir = base_dir + self.bonafide_dir = os.path.join(base_dir, 'bonafide') + self.vocoded_dir = os.path.join(base_dir, 'vocoded') + self.algo = algo + self.vocoders = vocoders + print("vocoders:", vocoders) + + self.augmentation_methods = augmentation_methods + if len(augmentation_methods) < 1: + # using default augmentation method RawBoostWrapper12 + # self.augmentation_methods = ["RawBoost12"] + print("No augmentation method provided") + self.eval_augment = eval_augment + self.num_additional_real = num_additional_real + self.num_additional_spoof = num_additional_spoof + self.trim_length = trim_length + self.sample_rate = wav_samp_rate + + self.args.noise_path = noise_path + self.args.rir_path = rir_path + self.args.aug_dir = aug_dir + self.args.online_aug = online_aug + self.repeat_pad = repeat_pad + self.is_train = is_train + + def __len__(self): + return len(self.list_IDs) + + def __getitem__(self, idx): + # to be implemented in child classes + pass + diff --git a/src/data/components/dataio.py b/src/data/components/dataio.py new file mode 100644 index 0000000..c7c6dcf --- /dev/null +++ b/src/data/components/dataio.py @@ -0,0 +1,142 @@ +import os +import numpy as np +import librosa +import soundfile as sf +import torch +import torchaudio + + +def load_audio(file_path: str, sr: int=16000) -> np.ndarray: + ''' + Load audio file + file_path: path to the audio file + sr: sampling rate, default 16000 + ''' + if not os.path.exists(file_path): + raise FileNotFoundError(f"File {file_path} not found") + audio, _ = librosa.load(file_path, sr=sr) + return audio + +def load_torchaudio(file_path: str, sr: int=16000) -> torch.Tensor: + ''' + Load audio file + file_path: path to the audio file + sr: sampling rate, default 16000 + ''' + if not os.path.exists(file_path): + raise FileNotFoundError(f"File {file_path} not found") + audio, sample_rate = torchaudio.load(file_path) + if sample_rate != sr: + raise ValueError(f"Sample rate mismatch: {sample_rate} != {sr}") + return audio + +def save_audio(file_path: str, audio: np.ndarray, sr: int=16000): + ''' + Save audio file + file_path: path to save the audio file + audio: audio signal + sr: sampling rate, default 16000 + ''' + sf.write(file_path, audio, sr, subtype='PCM_16') + +def npwav2torch(waveform: np.ndarray) -> torch.Tensor: + ''' + Convert numpy array to torch tensor + waveform: audio signal + ''' + return torch.from_numpy(waveform).float() + +def pad(x:np.ndarray, padding_type:str='zero', max_len=64000, random_start=False) -> np.ndarray: + ''' + pad audio signal to max_len + x: audio signal + padding_type: 'zero' or 'repeat' when len(X) < max_len, default 'zero' + zero: pad with zeros + repeat: repeat the signal until it reaches max_len + max_len: max length of the audio, default 64000 + random_start: if True, randomly choose the start point of the audio + ''' + x_len = x.shape[0] + padded_x = None + if max_len == 0: + # no padding + padded_x = x + elif max_len > 0: + if x_len >= max_len: + if random_start: + start = np.random.randint(0, x_len - max_len+1) + padded_x = x[start:start + max_len] + # logger.debug("padded_x1: {}".format(padded_x.shape)) + else: + padded_x = x[:max_len] + # logger.debug("padded_x2: {}".format(padded_x.shape)) + else: + if random_start: + # keep at least half of the signal + start = np.random.randint(0, int((x_len+1)/2)) + x_new = x[start:] + else: + x_new = x + + if padding_type == "repeat": + num_repeats = int(max_len / len(x_new)) + 1 + padded_x = np.tile(x_new, (1, num_repeats))[:, :max_len][0] + + elif padding_type == "zero": + padded_x = np.zeros(max_len) + padded_x[:len(x_new)] = x_new + + else: + raise ValueError("max_len must be >= 0") + # logger.debug("padded_x: {}".format(padded_x.shape)) + return padded_x + + +def pad_tensor(x: torch.Tensor, padding_type: str = 'zero', max_len: int = 64000, random_start: bool = False) -> torch.Tensor: + ''' + Pad audio signal to max_len. + + Args: + x: audio signal + padding_type: 'zero' or 'repeat' when len(X) < max_len, default 'zero' + zero: pad with zeros + repeat: repeat the signal until it reaches max_len + max_len: max length of the audio, default 64000 + random_start: if True, randomly choose the start point of the audio + + Returns: + padded_x: Padded audio signal + ''' + x_len = x.shape[0] + padded_x = None + + if max_len == 0: + # no padding + padded_x = x + elif max_len > 0: + if x_len >= max_len: + if random_start: + start = torch.randint(0, x_len - max_len + 1, (1,)).item() + padded_x = x[start:start + max_len] + else: + padded_x = x[:max_len] + else: + if random_start: + start = torch.randint(0, max_len - x_len + 1, (1,)).item() + if padding_type == "repeat": + num_repeats = (max_len // x_len) + 1 + padded_x = x.repeat(num_repeats)[start:start + max_len] + elif padding_type == "zero": + padded_x = torch.zeros(max_len, dtype=x.dtype) + padded_x[start:start + x_len] = x + else: + if padding_type == "repeat": + num_repeats = (max_len // x_len) + 1 + padded_x = x.repeat(num_repeats)[:max_len] + elif padding_type == "zero": + padded_x = torch.zeros(max_len, dtype=x.dtype) + padded_x[:x_len] = x + else: + raise ValueError("max_len must be >= 0") + + return padded_x \ No newline at end of file diff --git a/src/data/cyberaicup_datamodule.py b/src/data/cyberaicup_datamodule.py index cf906d5..a1b817f 100644 --- a/src/data/cyberaicup_datamodule.py +++ b/src/data/cyberaicup_datamodule.py @@ -9,250 +9,100 @@ import os from scipy import signal import copy - -''' - Hemlata Tak, Madhu Kamble, Jose Patino, Massimiliano Todisco, Nicholas Evans. - RawBoost: A Raw Data Boosting and Augmentation Method applied to Automatic Speaker Verification Anti-Spoofing. - In Proc. ICASSP 2022, pp:6382--6386. -''' - -def randRange(x1, x2, integer): - y = np.random.uniform(low=x1, high=x2, size=(1,)) - if integer: - y = int(y) - return y - -def normWav(x,always): - if always: - x = x/np.amax(abs(x)) - elif np.amax(abs(x)) > 1: - x = x/np.amax(abs(x)) - return x - -def genNotchCoeffs(nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,fs): - b = 1 - for i in range(0, nBands): - fc = randRange(minF,maxF,0) - bw = randRange(minBW,maxBW,0) - c = randRange(minCoeff,maxCoeff,1) - - if c/2 == int(c/2): - c = c + 1 - f1 = fc - bw/2 - f2 = fc + bw/2 - if f1 <= 0: - f1 = 1/1000 - if f2 >= fs/2: - f2 = fs/2-1/1000 - b = np.convolve(signal.firwin(c, [float(f1), float(f2)], window='hamming', fs=fs),b) - - G = randRange(minG,maxG,0) - _, h = signal.freqz(b, 1, fs=fs) - b = pow(10, G/20)*b/np.amax(abs(h)) - return b - - -def filterFIR(x,b): - N = b.shape[0] + 1 - xpad = np.pad(x, (0, N), 'constant') - y = signal.lfilter(b, 1, xpad) - y = y[int(N/2):int(y.shape[0]-N/2)] - return y - -# Linear and non-linear convolutive noise -def LnL_convolutive_noise(x,N_f,nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,minBiasLinNonLin,maxBiasLinNonLin,fs): - y = [0] * x.shape[0] - for i in range(0, N_f): - if i == 1: - minG = minG-minBiasLinNonLin - maxG = maxG-maxBiasLinNonLin - b = genNotchCoeffs(nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,fs) - y = y + filterFIR(np.power(x, (i+1)), b) - y = y - np.mean(y) - y = normWav(y,0) - return y - - -# Impulsive signal dependent noise -def ISD_additive_noise(x, P, g_sd): - beta = randRange(0, P, 0) - - y = copy.deepcopy(x) - x_len = x.shape[0] - n = int(x_len*(beta/100)) - p = np.random.permutation(x_len)[:n] - f_r= np.multiply(((2*np.random.rand(p.shape[0]))-1),((2*np.random.rand(p.shape[0]))-1)) - r = g_sd * x[p] * f_r - y[p] = x[p] + r - y = normWav(y,0) - return y - - -# Stationary signal independent noise - -def SSI_additive_noise(x,SNRmin,SNRmax,nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,fs): - noise = np.random.normal(0, 1, x.shape[0]) - b = genNotchCoeffs(nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,fs) - noise = filterFIR(noise, b) - noise = normWav(noise,1) - SNR = randRange(SNRmin, SNRmax, 0) - noise = noise / np.linalg.norm(noise,2) * np.linalg.norm(x,2) / 10.0**(0.05 * SNR) - x = x + noise - return x - -#--------------RawBoost data augmentation algorithms---------------------------## -def process_Rawboost_feature(feature, sr,args,algo): - - # Data process by Convolutive noise (1st algo) - if algo==1: - - feature =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) - - # Data process by Impulsive noise (2nd algo) - elif algo==2: - - feature=ISD_additive_noise(feature, args.P, args.g_sd) - - # Data process by coloured additive noise (3rd algo) - elif algo==3: - - feature=SSI_additive_noise(feature,args.SNRmin,args.SNRmax,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,sr) - - # Data process by all 3 algo. together in series (1+2+3) - elif algo==4: - - feature =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW, - args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) - feature=ISD_additive_noise(feature, args.P, args.g_sd) - feature=SSI_additive_noise(feature,args.SNRmin,args.SNRmax,args.nBands,args.minF, - args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,sr) - - # Data process by 1st two algo. together in series (1+2) - elif algo==5: - - feature =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW, - args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) - feature=ISD_additive_noise(feature, args.P, args.g_sd) - - - # Data process by 1st and 3rd algo. together in series (1+3) - elif algo==6: - - feature =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW, - args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) - feature=SSI_additive_noise(feature,args.SNRmin,args.SNRmax,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,sr) - - # Data process by 2nd and 3rd algo. together in series (2+3) - elif algo==7: - - feature=ISD_additive_noise(feature, args.P, args.g_sd) - feature=SSI_additive_noise(feature,args.SNRmin,args.SNRmax,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,sr) - - # Data process by 1st two algo. together in Parallel (1||2) - elif algo==8: - - feature1 =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW, - args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) - feature2=ISD_additive_noise(feature, args.P, args.g_sd) - - feature_para=feature1+feature2 - feature=normWav(feature_para,0) #normalized resultant waveform - - # original data without Rawboost processing - else: +import random +from src.data.components.RawBoost import process_Rawboost_feature +from src.data.components.dataio import load_audio, pad +from src.data.components.baseloader import Dataset_base + +# augwrapper +from src.data.components.augwrapper import SUPPORTED_AUGMENTATION + +# dynamic import of augmentation methods +for aug in SUPPORTED_AUGMENTATION: + exec(f"from src.data.components.augwrapper import {aug}") + +class Dataset_for(Dataset_base): + def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[], + augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2, + trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None, + aug_dir=None, online_aug=False, repeat_pad=True, is_train=True): + super(Dataset_for, self).__init__(args, list_IDs, labels, base_dir, algo, vocoders, + augmentation_methods, eval_augment, num_additional_real, num_additional_spoof, + trim_length, wav_samp_rate, noise_path, rir_path, + aug_dir, online_aug, repeat_pad, is_train) + print("train repeat_pad:", repeat_pad) + #self.repeat_pad = False # temporary fix for the padding issue + def __getitem__(self, idx): + utt_id = self.list_IDs[idx] + filepath = os.path.join(self.base_dir, utt_id) + X = load_audio(filepath, self.sample_rate) - feature=feature - - return feature - - -def pad(x, max_len=64600): - x_len = x.shape[0] - if x_len >= max_len: - return x[:max_len] - # need to pad - num_repeats = int(max_len / x_len)+1 - padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0] - return padded_x - -class Dataset_ASVspoof2019_train(Dataset): - def __init__(self,args,list_IDs, labels, base_dir,algo): - '''self.list_IDs : list of strings (each string: utt key), - self.labels : dictionary (key: utt key, value: label integer)''' - - self.list_IDs = list_IDs - self.labels = labels - self.base_dir = base_dir - self.algo=algo - self.args=args - self.cut=64600 # take ~4 sec audio (64600 samples) - - def __len__(self): - return len(self.list_IDs) - - - def __getitem__(self, index): - utt_id = self.list_IDs[index] - X,fs = librosa.load(self.base_dir+utt_id+'.flac', sr=16000) - Y=process_Rawboost_feature(X,fs,self.args,self.algo) - X_pad= pad(Y,self.cut) + # apply augmentation + # randomly choose an augmentation method + if self.is_train: + augmethod_index = random.choice(range(len(self.augmentation_methods))) if len(self.augmentation_methods) > 0 else -1 + if augmethod_index >= 0: + X = globals()[self.augmentation_methods[augmethod_index]](X, self.args, self.sample_rate, + audio_path = filepath) + + X_pad= pad(X,padding_type="repeat" if self.repeat_pad else "zero", max_len=self.trim_length, random_start=True) x_inp= Tensor(X_pad) target = self.labels[utt_id] - - return x_inp, target - -class Dataset_ASVspoof2021_eval(Dataset): - def __init__(self, list_IDs, base_dir): - self.list_IDs = list_IDs - self.base_dir = base_dir - self.cut=64600 # take ~4 sec audio (64600 samples) - - def __len__(self): - return len(self.list_IDs) - + return idx, x_inp, target + +class Dataset_for_dev(Dataset_base): + def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[], + augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2, + trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None, + aug_dir=None, online_aug=False, repeat_pad=False, is_train=True): + super(Dataset_for_dev, self).__init__(args, list_IDs, labels, base_dir, algo, vocoders, + augmentation_methods, eval_augment, num_additional_real, num_additional_spoof, + trim_length, wav_samp_rate, noise_path, rir_path, + aug_dir, online_aug, repeat_pad, is_train) + #repeat_pad = False + print("dev repeat_pad:", repeat_pad) + + if repeat_pad: + self.padding_type = "repeat" + else: + self.padding_type = "zero" + def __getitem__(self, index): utt_id = self.list_IDs[index] - X, fs = librosa.load(self.base_dir+utt_id+'.flac', sr=16000) - X_pad = pad(X,self.cut) + X, fs = librosa.load(self.base_dir + "/" + utt_id, sr=16000) + X_pad = pad(X,self.padding_type,self.trim_length, random_start=True) x_inp = Tensor(X_pad) - return x_inp,utt_id - -class Args: - def __init__(self): - self.database_path = '/your/path/to/data/ASVspoof_database/DF/' - self.protocols_path = '/data/hungdx/Datasets/protocols/database/' - self.batch_size = 14 - self.num_epochs = 100 - self.lr = 0.000001 - self.weight_decay = 0.0001 - self.loss = 'weighted_CCE' - self.seed = 1234 - self.model_path = None - self.comment = None - self.track = 'DF' - self.eval_output = None - self.eval = False - self.is_eval = False - self.eval_part = 0 - self.cudnn_deterministic_toggle = True - self.cudnn_benchmark_toggle = False - self.algo = 3 - self.nBands = 5 - self.minF = 20 - self.maxF = 8000 - self.minBW = 100 - self.maxBW = 1000 - self.minCoeff = 10 - self.maxCoeff = 100 - self.minG = 0 - self.maxG = 0 - self.minBiasLinNonLin = 5 - self.maxBiasLinNonLin = 20 - self.N_f = 5 - self.P = 10 - self.g_sd = 2 - self.SNRmin = 10 - self.SNRmax = 40 + target = self.labels[utt_id] + return index, x_inp, target + +class Dataset_for_eval(Dataset_base): + def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[], + augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2, + trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None, + aug_dir=None, online_aug=False, repeat_pad=True, is_train=True, enable_chunking=False + ): + super(Dataset_for_eval, self).__init__(args, list_IDs, labels, base_dir, algo, vocoders, + augmentation_methods, eval_augment, num_additional_real, num_additional_spoof, + trim_length, wav_samp_rate, noise_path, rir_path, + aug_dir, online_aug, repeat_pad, is_train) + self.enable_chunking = enable_chunking + if repeat_pad: + self.padding_type = "repeat" + else: + self.padding_type = "zero" + + def __getitem__(self, idx): + utt_id = self.list_IDs[idx] + filepath = os.path.join(self.base_dir, utt_id) + X, _ = librosa.load(filepath, sr=16000) + # apply augmentation at inference time + if self.eval_augment is not None: + # print("eval_augment:", self.eval_augment) + X = globals()[self.eval_augment](X, self.args, self.sample_rate, audio_path = filepath) + if not self.enable_chunking: + X= pad(X,padding_type=self.padding_type,max_len=self.trim_length, random_start=True) + x_inp = Tensor(X) + return x_inp, utt_id class CyberAiCupDataModule(LightningDataModule): """`LightningDataModule` for the ASVSpoof dataset. @@ -304,6 +154,11 @@ def __init__( batch_size: int = 64, num_workers: int = 0, pin_memory: bool = False, + args: Optional[Dict[str, Any]] = None, + list_IDs: list = [], labels: list = [], base_dir: str = '', algo: int = 5, vocoders: list = [], + augmentation_methods: list = [], eval_augment: Optional[str] = None, num_additional_real: int = 2, num_additional_spoof: int = 2, + trim_length: int = 64000, wav_samp_rate: int = 16000, noise_path: Optional[str] = None, rir_path: Optional[str] = None, + aug_dir: Optional[str] = None, online_aug: bool = False, repeat_pad: bool = True, is_train: bool = True, enable_chunking: bool = False ) -> None: """Initialize a `ASVSpoofDataModule`. @@ -324,6 +179,7 @@ def __init__( self.batch_size_per_device = batch_size self.data_dir = data_dir + self.args = args @property def num_classes(self) -> int: @@ -358,18 +214,39 @@ def setup(self, stage: Optional[str] = None) -> None: if not self.data_train and not self.data_val and not self.data_test: # define train dataloader - args = Args() - args.database_path = self.data_dir - track = args.track - prefix_2021 = 'ASVspoof2021.{}'.format(track) - - d_label_trn,file_train = self.genSpoof_list( dir_meta = os.path.join(args.protocols_path+'ASVspoof_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt'),is_train=True,is_eval=False) - d_label_dev,file_dev = self.genSpoof_list( dir_meta = os.path.join(args.protocols_path+'ASVspoof_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt'),is_train=False,is_eval=False) - file_eval = self.genSpoof_list( dir_meta = os.path.join(args.protocols_path+'ASVspoof_{}_cm_protocols/{}.cm.eval.trl.txt'.format(track,prefix_2021)),is_train=False,is_eval=True) + + d_label_trn, file_train = self.genList(dir_meta=os.path.join( + self.data_dir), is_train=True, is_eval=False, is_dev=False) + + # if 'portion' in config['data']: + # idx = range(len(file_train)) + # idx = np.random.choice( + # idx, int(len(file_train)*config['data']['portion']), replace=False) + # file_train = [file_train[i] for i in idx] + # if len(d_label_trn) > 0: # the case of train without label + # d_label_trn = {k: d_label_trn[k] for k in file_train} - self.data_train = Dataset_ASVspoof2019_train(args,list_IDs = file_train,labels = d_label_trn,base_dir = os.path.join(args.database_path+'ASVspoof2019_LA_train/'),algo=args.algo) - self.data_val = Dataset_ASVspoof2019_train(args,list_IDs = file_dev,labels = d_label_dev,base_dir = os.path.join(args.database_path+'ASVspoof2019_LA_dev/'),algo=args.algo) - self.data_test = Dataset_ASVspoof2021_eval(list_IDs = file_eval,base_dir = os.path.join(args.database_path+'ASVspoof2021_{}_eval/'.format(args.track))) + d_label_dev, file_dev = self.genList(dir_meta=os.path.join(self.data_dir), is_train=False, is_eval=False, is_dev=True) + + # if 'portion' in config['data']: + # idx = range(len(file_dev)) + # idx = np.random.choice( + # idx, int(len(file_dev)*config['data']['portion']), replace=False) + # file_dev = [file_dev[i] for i in idx] + # if len(d_label_dev) > 0: # the case of train without label + # d_label_dev = {k: d_label_dev[k] for k in file_dev} + + self.data_train = Dataset_for(args, list_IDs=file_train, labels=d_label_trn, + base_dir=args.database_path+'/', algo=args.algo, repeat_pad=is_repeat_pad, **config['data']['kwargs']) + + self.data_val = Dataset_for_dev(args, list_IDs=file_dev, labels=d_label_dev, + base_dir=args.database_path+'/', algo=args.algo, repeat_pad=is_repeat_pad, is_train=False, **config['data']['kwargs']) + + self.data_test = Dataset_for_eval(args, list_IDs=file_eval, labels=None, + base_dir=args.database_path+'/', algo=args.algo, repeat_pad=is_repeat_pad, **config['data']['kwargs'], + enable_chunking=config.get( + 'eval_chunking', False) + ) def train_dataloader(self) -> DataLoader[Any]: @@ -435,37 +312,44 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """ pass - def genSpoof_list(self, dir_meta, is_train=False, is_eval=False): - """ - This function is from the following source: https://github.com/TakHemlata/SSL_Anti-spoofing/blob/main/data_utils_SSL.py#L17 - Official source: https://arxiv.org/abs/2202.12233 - Automatic speaker verification spoofing and deepfake detection using wav2vec 2.0 and data augmentation - """ + def genList(self, dir_meta, is_train=False, is_eval=False, is_dev=False): + # bonafide: 1, spoof: 0 d_meta = {} file_list=[] - with open(dir_meta, 'r') as f: - l_meta = f.readlines() + protocol = os.path.join(dir_meta, "protocol.txt") if (is_train): + with open(protocol, 'r') as f: + l_meta = f.readlines() for line in l_meta: - _,key,_,_,label = line.strip().split() - - file_list.append(key) - d_meta[key] = 1 if label == 'bonafide' else 0 - return d_meta,file_list - - elif(is_eval): + utt, subset, label = line.strip().split() + if subset == 'train': + file_list.append(utt) + d_meta[utt] = 1 if label == 'bonafide' else 0 + + return d_meta, file_list + if (is_dev): + with open(protocol, 'r') as f: + l_meta = f.readlines() for line in l_meta: - key= line.strip() - file_list.append(key) - return file_list - else: + utt, subset, label = line.strip().split() + if subset == 'dev': + file_list.append(utt) + d_meta[utt] = 1 if label == 'bonafide' else 0 + return d_meta, file_list + + if (is_eval): + # no eval protocol yet + with open(protocol, 'r') as f: + l_meta = f.readlines() for line in l_meta: - _,key,_,_,label = line.strip().split() - - file_list.append(key) - d_meta[key] = 1 if label == 'bonafide' else 0 - return d_meta,file_list + utt, subset, label = line.strip().split() + if subset == 'dev': + file_list.append(utt) + file_list.append(utt) + d_meta[utt] = 1 if label == 'bonafide' else 0 + # return d_meta, file_list + return d_meta, file_list if __name__ == "__main__": _ = CyberAiCupDataModule() From d2fd24b85d7b1f171079c123957c3eb2045c9f8b Mon Sep 17 00:00:00 2001 From: Hung Dinh Xuan Date: Sat, 28 Sep 2024 01:20:52 +0900 Subject: [PATCH 2/8] Define wavlm+conformertcm, xlsr+conformertcm, wavlm+vib, xlsr+vib --- .gitignore | 3 + configs/logger/comet.yaml | 2 +- requirements.txt | 10 +- src/data/scl_datamodule.py | 263 +----- src/models/components/WavLM/__intit__.py | 1 + src/models/components/WavLM/fe.py | 126 +++ src/models/components/WavLM/model.py | 772 ++++++++++++++++ src/models/components/WavLM/modules.py | 827 ++++++++++++++++++ .../components/conformer_tcm/__init__.py | 0 .../components/conformer_tcm/conformer.py | 331 +++++++ src/models/components/conformer_tcm/model.py | 56 ++ src/models/components/loss_metrics.py | 598 +++++++++++++ .../components/wavlmbase_conformertcm.py | 147 ++++ src/models/components/wavlmbase_vib.py | 281 ++++++ src/models/components/xlsr_conformertcm.py | 206 +++++ src/models/components/xlsr_vib.py | 224 +++++ 16 files changed, 3618 insertions(+), 229 deletions(-) create mode 100644 src/models/components/WavLM/__intit__.py create mode 100644 src/models/components/WavLM/fe.py create mode 100644 src/models/components/WavLM/model.py create mode 100644 src/models/components/WavLM/modules.py create mode 100644 src/models/components/conformer_tcm/__init__.py create mode 100644 src/models/components/conformer_tcm/conformer.py create mode 100644 src/models/components/conformer_tcm/model.py create mode 100644 src/models/components/loss_metrics.py create mode 100644 src/models/components/wavlmbase_conformertcm.py create mode 100644 src/models/components/wavlmbase_vib.py create mode 100644 src/models/components/xlsr_conformertcm.py create mode 100644 src/models/components/xlsr_vib.py diff --git a/.gitignore b/.gitignore index 04a0648..29ec039 100644 --- a/.gitignore +++ b/.gitignore @@ -152,3 +152,6 @@ configs/local/default.yaml # Aim logging .aim + +# Neptune logging +.neptune \ No newline at end of file diff --git a/configs/logger/comet.yaml b/configs/logger/comet.yaml index e078927..c35ea06 100644 --- a/configs/logger/comet.yaml +++ b/configs/logger/comet.yaml @@ -4,7 +4,7 @@ comet: _target_: lightning.pytorch.loggers.comet.CometLogger api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable save_dir: "${paths.output_dir}" - project_name: "lightning-hydra-template" + project_name: ${oc.env:COMET_PROJECT_NAME} rest_api_key: null # experiment_name: "" experiment_key: null # set to resume experiment diff --git a/requirements.txt b/requirements.txt index 73bf981..84a785d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,11 +10,11 @@ hydra-colorlog==1.2.0 hydra-optuna-sweeper==1.2.0 # --------- loggers --------- # -# wandb -# neptune-client -# mlflow -# comet-ml -# aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550 +wandb +neptune-client +mlflow +comet-ml +aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550 # --------- others --------- # rootutils # standardizing the project root setup diff --git a/src/data/scl_datamodule.py b/src/data/scl_datamodule.py index e69e35d..76e1137 100644 --- a/src/data/scl_datamodule.py +++ b/src/data/scl_datamodule.py @@ -12,216 +12,8 @@ import random from src.core_scripts.data_io import wav_augmentation as nii_wav_aug from src.core_scripts.data_io import wav_tools as nii_wav_tools -''' - Hemlata Tak, Madhu Kamble, Jose Patino, Massimiliano Todisco, Nicholas Evans. - RawBoost: A Raw Data Boosting and Augmentation Method applied to Automatic Speaker Verification Anti-Spoofing. - In Proc. ICASSP 2022, pp:6382--6386. -''' - -def randRange(x1, x2, integer): - y = np.random.uniform(low=x1, high=x2, size=(1,)) - if integer: - y = int(y) - return y - -def normWav(x,always): - if always: - x = x/np.amax(abs(x)) - elif np.amax(abs(x)) > 1: - x = x/np.amax(abs(x)) - return x - -def genNotchCoeffs(nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,fs): - b = 1 - for i in range(0, nBands): - fc = randRange(minF,maxF,0) - bw = randRange(minBW,maxBW,0) - c = randRange(minCoeff,maxCoeff,1) - - if c/2 == int(c/2): - c = c + 1 - f1 = fc - bw/2 - f2 = fc + bw/2 - if f1 <= 0: - f1 = 1/1000 - if f2 >= fs/2: - f2 = fs/2-1/1000 - b = np.convolve(signal.firwin(c, [float(f1), float(f2)], window='hamming', fs=fs),b) - - G = randRange(minG,maxG,0) - _, h = signal.freqz(b, 1, fs=fs) - b = pow(10, G/20)*b/np.amax(abs(h)) - return b - - -def filterFIR(x,b): - N = b.shape[0] + 1 - xpad = np.pad(x, (0, N), 'constant') - y = signal.lfilter(b, 1, xpad) - y = y[int(N/2):int(y.shape[0]-N/2)] - return y - -# Linear and non-linear convolutive noise -def LnL_convolutive_noise(x,N_f,nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,minBiasLinNonLin,maxBiasLinNonLin,fs): - y = [0] * x.shape[0] - for i in range(0, N_f): - if i == 1: - minG = minG-minBiasLinNonLin - maxG = maxG-maxBiasLinNonLin - b = genNotchCoeffs(nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,fs) - y = y + filterFIR(np.power(x, (i+1)), b) - y = y - np.mean(y) - y = normWav(y,0) - return y - - -# Impulsive signal dependent noise -def ISD_additive_noise(x, P, g_sd): - beta = randRange(0, P, 0) - - y = copy.deepcopy(x) - x_len = x.shape[0] - n = int(x_len*(beta/100)) - p = np.random.permutation(x_len)[:n] - f_r= np.multiply(((2*np.random.rand(p.shape[0]))-1),((2*np.random.rand(p.shape[0]))-1)) - r = g_sd * x[p] * f_r - y[p] = x[p] + r - y = normWav(y,0) - return y - - -# Stationary signal independent noise - -def SSI_additive_noise(x,SNRmin,SNRmax,nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,fs): - noise = np.random.normal(0, 1, x.shape[0]) - b = genNotchCoeffs(nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,fs) - noise = filterFIR(noise, b) - noise = normWav(noise,1) - SNR = randRange(SNRmin, SNRmax, 0) - noise = noise / np.linalg.norm(noise,2) * np.linalg.norm(x,2) / 10.0**(0.05 * SNR) - x = x + noise - return x - -#--------------RawBoost data augmentation algorithms---------------------------## -def process_Rawboost_feature(feature, sr,args,algo): - - # Data process by Convolutive noise (1st algo) - if algo==1: - - feature =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) - - # Data process by Impulsive noise (2nd algo) - elif algo==2: - - feature=ISD_additive_noise(feature, args.P, args.g_sd) - - # Data process by coloured additive noise (3rd algo) - elif algo==3: - - feature=SSI_additive_noise(feature,args.SNRmin,args.SNRmax,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,sr) - - # Data process by all 3 algo. together in series (1+2+3) - elif algo==4: - - feature =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW, - args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) - feature=ISD_additive_noise(feature, args.P, args.g_sd) - feature=SSI_additive_noise(feature,args.SNRmin,args.SNRmax,args.nBands,args.minF, - args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,sr) - - # Data process by 1st two algo. together in series (1+2) - elif algo==5: - - feature =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW, - args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) - feature=ISD_additive_noise(feature, args.P, args.g_sd) - - - # Data process by 1st and 3rd algo. together in series (1+3) - elif algo==6: - - feature =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW, - args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) - feature=SSI_additive_noise(feature,args.SNRmin,args.SNRmax,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,sr) - - # Data process by 2nd and 3rd algo. together in series (2+3) - elif algo==7: - - feature=ISD_additive_noise(feature, args.P, args.g_sd) - feature=SSI_additive_noise(feature,args.SNRmin,args.SNRmax,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,sr) - - # Data process by 1st two algo. together in Parallel (1||2) - elif algo==8: - - feature1 =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW, - args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) - feature2=ISD_additive_noise(feature, args.P, args.g_sd) - - feature_para=feature1+feature2 - feature=normWav(feature_para,0) #normalized resultant waveform - - # original data without Rawboost processing - else: - - feature=feature - - return feature - - -def pad(x, max_len=64600): - x_len = x.shape[0] - if x_len >= max_len: - return x[:max_len] - # need to pad - num_repeats = int(max_len / x_len)+1 - padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0] - return padded_x - -class Dataset_base(Dataset): - def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[], - augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2, - trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None, - aug_dir=None, online_aug=False, repeat_pad=True, is_train=True): - """ - Args: - list_IDs (string): Path to the .lst file with real audio filenames. - vocoders (list): list of vocoder names. - augmentation_methods (list): List of augmentation methods to apply. - """ - self.args = args - self.list_IDs = list_IDs - self.labels = labels - self.base_dir = base_dir - self.bonafide_dir = os.path.join(base_dir, 'bonafide') - self.vocoded_dir = os.path.join(base_dir, 'vocoded') - self.algo = algo - self.vocoders = vocoders - print("vocoders:", vocoders) - - self.augmentation_methods = augmentation_methods - if len(augmentation_methods) < 1: - # using default augmentation method RawBoostWrapper12 - # self.augmentation_methods = ["RawBoost12"] - print("No augmentation method provided") - self.eval_augment = eval_augment - self.num_additional_real = num_additional_real - self.num_additional_spoof = num_additional_spoof - self.trim_length = trim_length - self.sample_rate = wav_samp_rate - - self.args.noise_path = noise_path - self.args.rir_path = rir_path - self.args.aug_dir = aug_dir - self.args.online_aug = online_aug - self.repeat_pad = repeat_pad - self.is_train = is_train - - def __len__(self): - return len(self.list_IDs) - - def __getitem__(self, idx): - # to be implemented in child classes - pass +from src.data.components.dataio import load_audio, pad +from src.data.components.baseloader import Dataset_base class Dataset_for(Dataset_base): def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[], @@ -340,7 +132,7 @@ def __init__(self): self.SNRmin = 10 self.SNRmax = 40 -class ASVSpoofDataModule(LightningDataModule): +class SclNormalDataModule(LightningDataModule): """`LightningDataModule` for the ASVSpoof dataset. The ASVspoof 2019 database for logical access is based upon a standard multi-speaker speech synthesis database called VCTK2. @@ -390,6 +182,11 @@ def __init__( batch_size: int = 64, num_workers: int = 0, pin_memory: bool = False, + args: Optional[Dict[str, Any]] = None, + list_IDs: list = [], labels: list = [], base_dir: str = '', algo: int = 5, vocoders: list = [], + augmentation_methods: list = [], eval_augment: Optional[str] = None, num_additional_real: int = 2, num_additional_spoof: int = 2, + trim_length: int = 64000, wav_samp_rate: int = 16000, noise_path: Optional[str] = None, rir_path: Optional[str] = None, + aug_dir: Optional[str] = None, online_aug: bool = False, repeat_pad: bool = True, is_train: bool = True, enable_chunking: bool = False ) -> None: """Initialize a `ASVSpoofDataModule`. @@ -444,18 +241,38 @@ def setup(self, stage: Optional[str] = None) -> None: if not self.data_train and not self.data_val and not self.data_test: # define train dataloader - args = Args() - args.database_path = self.data_dir - track = args.track - prefix_2021 = 'ASVspoof2021.{}'.format(track) - - d_label_trn,file_train = self.genSpoof_list( dir_meta = os.path.join(args.protocols_path+'ASVspoof_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt'),is_train=True,is_eval=False) - d_label_dev,file_dev = self.genSpoof_list( dir_meta = os.path.join(args.protocols_path+'ASVspoof_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt'),is_train=False,is_eval=False) - file_eval = self.genSpoof_list( dir_meta = os.path.join(args.protocols_path+'ASVspoof_{}_cm_protocols/{}.cm.eval.trl.txt'.format(track,prefix_2021)),is_train=False,is_eval=True) + + d_label_trn, file_train = self.genList(dir_meta=os.path.join( + self.data_dir), is_train=True, is_eval=False, is_dev=False) + + # if 'portion' in config['data']: + # idx = range(len(file_train)) + # idx = np.random.choice( + # idx, int(len(file_train)*config['data']['portion']), replace=False) + # file_train = [file_train[i] for i in idx] + # if len(d_label_trn) > 0: # the case of train without label + # d_label_trn = {k: d_label_trn[k] for k in file_train} - self.data_train = Dataset_ASVspoof2019_train(args,list_IDs = file_train,labels = d_label_trn,base_dir = os.path.join(args.database_path+'ASVspoof2019_LA_train/'),algo=args.algo) - self.data_val = Dataset_ASVspoof2019_train(args,list_IDs = file_dev,labels = d_label_dev,base_dir = os.path.join(args.database_path+'ASVspoof2019_LA_dev/'),algo=args.algo) - self.data_test = Dataset_ASVspoof2021_eval(list_IDs = file_eval,base_dir = os.path.join(args.database_path+'ASVspoof2021_{}_eval/'.format(args.track))) + d_label_dev, file_dev = self.genList(dir_meta=os.path.join(self.data_dir), is_train=False, is_eval=False, is_dev=True) + + # if 'portion' in config['data']: + # idx = range(len(file_dev)) + # idx = np.random.choice( + # idx, int(len(file_dev)*config['data']['portion']), replace=False) + # file_dev = [file_dev[i] for i in idx] + # if len(d_label_dev) > 0: # the case of train without label + # d_label_dev = {k: d_label_dev[k] for k in file_dev} + + self.data_train = Dataset_for(args, list_IDs=file_train, labels=d_label_trn, + base_dir=args.database_path+'/', algo=args.algo, repeat_pad=is_repeat_pad, **config['data']['kwargs']) + + self.data_val = Dataset_for_dev(args, list_IDs=file_dev, labels=d_label_dev, + base_dir=args.database_path+'/', algo=args.algo, repeat_pad=is_repeat_pad, is_train=False, **config['data']['kwargs']) + + self.data_test = Dataset_for_eval(args, list_IDs=file_eval, labels=None, + base_dir=args.database_path+'/', algo=args.algo, repeat_pad=is_repeat_pad, **config['data']['kwargs'], + enable_chunking=config.get( + 'eval_chunking', False)) def train_dataloader(self) -> DataLoader[Any]: @@ -554,4 +371,4 @@ def genSpoof_list(self, dir_meta, is_train=False, is_eval=False): return d_meta,file_list if __name__ == "__main__": - _ = ASVSpoofDataModule() + _ = SclNormalDataModule() \ No newline at end of file diff --git a/src/models/components/WavLM/__intit__.py b/src/models/components/WavLM/__intit__.py new file mode 100644 index 0000000..04e876e --- /dev/null +++ b/src/models/components/WavLM/__intit__.py @@ -0,0 +1 @@ +from .model import WavLM, WavLMConfig \ No newline at end of file diff --git a/src/models/components/WavLM/fe.py b/src/models/components/WavLM/fe.py new file mode 100644 index 0000000..b7cd555 --- /dev/null +++ b/src/models/components/WavLM/fe.py @@ -0,0 +1,126 @@ +from .model import WavLM, WavLMConfig +import torch +from torch import nn +import torch.nn.functional as F + +class WavLMFe(nn.Module): + def __init__(self, **kwargs): + super(WavLMFe, self).__init__() + # load the pre-trained checkpoints + checkpoint_path = kwargs.get( + 'checkpoint_path', 'model/pretrained/WavLM-Base.pt') + # 'last', 'all', 'weighted_sum' + checkpoint = torch.load(checkpoint_path) + cfg = WavLMConfig(checkpoint['cfg']) + + self.extract_mode = kwargs.get('extract_mode', 'last') + self.cfg = cfg + self.num_layers = cfg.encoder_layers + 1 + self.model = WavLM(cfg) + self.model.load_state_dict(checkpoint['model']) + print('Loaded pre-trained WavLM from', checkpoint_path) + self.out_dim = cfg.encoder_embed_dim + + if self.extract_mode == 'weighted_sum': + # initialize the weights for weighted sum + print('Initializing weights for weighted sum') + self.weights = nn.Parameter(torch.ones( + self.num_layers), requires_grad=True) + elif self.extract_mode == 'attentive_merging': + print('Initializing weights for attentive merging') + # Initialize components for attentive merging + self.W_sq = nn.Parameter(torch.randn(cfg.encoder_embed_dim, 1)) + self.W_ex1 = nn.Parameter(torch.randn(self.num_layers, self.num_layers // 2)) + self.W_ex2 = nn.Parameter(torch.randn(self.num_layers // 2, self.num_layers)) + self.W_L1 = nn.Parameter(torch.randn(self.num_layers * cfg.encoder_embed_dim, (self.num_layers * cfg.encoder_embed_dim) // 4)) + self.W_L2 = nn.Parameter(torch.randn((self.num_layers * cfg.encoder_embed_dim) // 4, (self.num_layers * cfg.encoder_embed_dim) // 4)) + self.W_L3 = nn.Parameter(torch.randn((self.num_layers * cfg.encoder_embed_dim) // 4, cfg.encoder_embed_dim)) + + + def forward(self, wav_input_16khz): + if self.cfg.normalize: + wav_input_16khz = torch.nn.functional.layer_norm(wav_input_16khz, wav_input_16khz.shape) + if self.extract_mode == 'weighted_sum': + return self.forward_with_intermediate(wav_input_16khz) + elif self.extract_mode == 'attentive_merging': + return self.forward_with_attentive_merging(wav_input_16khz) + else: + return self.model.extract_features(wav_input_16khz)[0] + + def forward_with_intermediate(self, wav_input_16khz): + rep, layer_results = self.model.extract_features( + wav_input_16khz, output_layer=self.model.cfg.encoder_layers, ret_layer_results=True)[0] + layer_reps = [x.transpose(0, 1) for x, _ in layer_results] + + weighted_layer_reps = [x * w for x, w in zip(layer_reps, self.weights)] + return torch.sum(torch.stack(weighted_layer_reps), dim=0) + + def forward_with_attentive_merging(self, wav_input_16khz): + rep, layer_results = self.model.extract_features( + wav_input_16khz, output_layer=self.model.cfg.encoder_layers, ret_layer_results=True)[0] + layer_reps = [x.transpose(0, 1) for x, _ in layer_results] + #print("🐍 File: WavLM/fe.py | Line: 62 | forward_with_attentive_merging ~ layer_reps", len(layer_reps)) + + # Stack all hidden embeddings + X = torch.stack(layer_reps, dim=-1) # Shape: (B, T, H, L) + #print("🐍 File: WavLM/fe.py | Line: 65 | forward_with_attentive_merging ~ X", X.shape) + + # Step 1: Average across the time dimension + X_avg = torch.mean(X, dim=1) # Shape: (B, H, L) + #print("🐍 File: WavLM/fe.py | Line: 68 | forward_with_attentive_merging ~ X_avg", X_avg.shape) + + # Step 2: Fully connected layer with SWISH activation + X_avg_transposed = X_avg.transpose(1, 2) # Shape: (B, L, H) + x_sq = torch.matmul(X_avg_transposed, self.W_sq) # (B, L, H) * (H, 1) Shape: (B, L, 1) + x_sq = torch.nn.functional.silu(x_sq).squeeze(-1) # Shape: (B, L) + #print("🐍 File: WavLM/fe.py | Line: 76 | forward_with_attentive_merging ~ x_sq",x_sq.shape) + + + # Step 3: Obtain attentive weights + x_attW = torch.sigmoid(torch.matmul(x_sq, self.W_ex1)) # Shape: (B, H, L//2) + x_attW = torch.matmul(x_attW, self.W_ex2) # Shape: (B, H, L) + + # Step 4: Apply attentive weights + x_attW = x_attW.unsqueeze(1).unsqueeze(1) # Shape: (B, 1, 1, L) + X_att = X * x_attW # Shape: (B, T, H, L) + + # Step 5: Concatenate all embeddings and merge using 3-layer linear projection network + X_att = X_att.view(X_att.size(0), X_att.size(1), -1) # Shape: (B, T, H * L) + #print("🐍 File: WavLM/fe.py | Line: 89 | forward_with_attentive_merging ~ X_att",X_att.shape) + X_attM = F.linear(X_att, self.W_L1.transpose(0, 1)) + X_attM = F.linear(F.silu(X_attM), self.W_L2) + X_attM = F.linear(F.silu(X_attM), self.W_L3.transpose(0, 1)) # Shape: (B, T, H) + + return X_attM + + +# if __name__ == '__main__': + +# fe = WavLMFe(extract_mode='attentive_merging', checkpoint_path='/datad/pretrained/WavLM-Base.pt') + +# input_tensor = torch.randn(3, 16000) +# out = fe(input_tensor) +# print(out.shape) +# # loss +# ce_loss = nn.CrossEntropyLoss() +# torch.autograd.set_detect_anomaly(True) # for debugging +# # forward +# output = fe(input_tensor) # torch.Size([1, 49, 768]) +# print(output.shape) +# # convert output to 2D tensor +# output = output.view(output.size(0), -1) + +# # full connected layer +# fc = nn.Linear(49*1024, 2) + +# final_output = fc(output) + +# print(final_output.shape) + +# loss = ce_loss(final_output, torch.tensor([1])) + +# print(loss) + +# # loss + +# loss.backward() \ No newline at end of file diff --git a/src/models/components/WavLM/model.py b/src/models/components/WavLM/model.py new file mode 100644 index 0000000..c579ef2 --- /dev/null +++ b/src/models/components/WavLM/model.py @@ -0,0 +1,772 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import logging +from typing import List, Optional, Tuple + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import LayerNorm +from .modules import ( + Fp32GroupNorm, + Fp32LayerNorm, + GradMultiply, + MultiheadAttention, + SamePad, + init_bert_params, + get_activation_fn, + TransposeLast, + GLU_Linear, +) + +logger = logging.getLogger(__name__) + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint( + mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + + return mask + + +class WavLMConfig: + def __init__(self, cfg=None): + # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True) + self.extractor_mode: str = "default" + self.encoder_layers: int = 12 # num encoder layers in the transformer + + self.encoder_embed_dim: int = 768 # encoder embedding dimension + # encoder embedding dimension for FFN + self.encoder_ffn_embed_dim: int = 3072 + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + # apply layernorm first in the transformer + self.layer_norm_first: bool = False + # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...] + self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" + self.conv_bias: bool = False # include bias in conv encoder + # multiply feature extractor var grads by this + self.feature_grad_mult: float = 1.0 + + # normalize input to have 0 mean and unit variance during training + self.normalize: bool = False + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + # dropout probability for attention weights + self.attention_dropout: float = 0.1 + # dropout probability after activation in FFN + self.activation_dropout: float = 0.0 + # probability of dropping a tarnsformer layer + self.encoder_layerdrop: float = 0.0 + # dropout to apply to the input (after feat extr) + self.dropout_input: float = 0.0 + # dropout to apply to the features (after feat extr) + self.dropout_features: float = 0.0 + + # masking + self.mask_length: int = 10 # mask length + self.mask_prob: float = 0.65 # probability of replacing a token with mask + self.mask_selection: str = "static" # how to choose mask length + # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh + self.mask_other: float = 0 + self.no_mask_overlap: bool = False # whether to allow masks to overlap + # min space between spans (if no overlap is enabled) + self.mask_min_space: int = 1 + + # channel masking + # length of the mask for features (channels) + self.mask_channel_length: int = 10 + # probability of replacing a feature with 0 + self.mask_channel_prob: float = 0.0 + # how to choose mask length for channel masking + self.mask_channel_selection: str = "static" + # secondary mask argument (used for more complex distributions), see help in compute_mask_indices + self.mask_channel_other: float = 0 + # whether to allow channel masks to overlap + self.no_mask_channel_overlap: bool = False + # min space between spans (if no overlap is enabled) + self.mask_channel_min_space: int = 1 + + # positional embeddings + # number of filters for convolutional positional embeddings + self.conv_pos: int = 128 + # number of groups for convolutional positional embedding + self.conv_pos_groups: int = 16 + + # relative position embedding + # apply relative position embedding + self.relative_position_embedding: bool = False + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class WavLM(nn.Module): + def __init__( + self, + cfg: WavLMConfig, + ) -> None: + super().__init__() + logger.info(f"WavLM Config: {cfg.__dict__}") + + self.cfg = cfg + feature_enc_layers = eval(cfg.conv_feature_layers) + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + + self.mask_emb = nn.Parameter( + torch.FloatTensor(cfg.encoder_embed_dim).uniform_() + ) + + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + def apply_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def forward_padding_mask( + self, features: torch.Tensor, padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ret_layer_results: bool = False, + ): + + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + + if mask: + x, mask_indices = self.apply_mask( + features, padding_mask + ) + else: + x = features + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + layer=None if output_layer is None else output_layer - 1 + ) + + res = {"x": x, "padding_mask": padding_mask, + "features": features, "layer_results": layer_results} + + feature = res["features"] if ret_conv else res["x"] + if ret_layer_results: + feature = (feature, res["layer_results"]) + return feature, res["padding_mask"] + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + conv_type: str = "default" + ): + super().__init__() + + assert mode in {"default", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert ( + is_layer_norm and is_group_norm + ) == False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=True), + TransposeLast(), + ), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + Fp32GroupNorm(dim, dim, affine=True), + nn.GELU(), + ) + else: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + self.conv_type = conv_type + if self.conv_type == "default": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + elif self.conv_type == "conv2d": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + + self.conv_layers.append( + torch.nn.Conv2d(in_d, dim, k, stride) + ) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + elif self.conv_type == "custom": + in_d = 1 + idim = 80 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + self.conv_layers.append( + torch.nn.Conv2d(in_d, dim, k, stride, padding=1) + ) + self.conv_layers.append( + torch.nn.LayerNorm([dim, idim]) + ) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + if (i + 1) % 2 == 0: + self.conv_layers.append( + torch.nn.MaxPool2d(2, stride=2, ceil_mode=True) + ) + idim = int(math.ceil(idim / 2)) + else: + pass + + def forward(self, x, mask=None): + + # BxT -> BxCxT + x = x.unsqueeze(1) + if self.conv_type == "custom": + for conv in self.conv_layers: + if isinstance(conv, nn.LayerNorm): + x = x.transpose(1, 2) + x = conv(x).transpose(1, 2) + else: + x = conv(x) + x = x.transpose(2, 3).contiguous() + x = x.view(x.size(0), -1, x.size(-1)) + else: + for conv in self.conv_layers: + x = conv(x) + if self.conv_type == "conv2d": + b, c, t, f = x.size() + x = x.transpose(2, 3).contiguous().view(b, c * f, t) + return x + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / + (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm( + self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential( + self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + if hasattr(args, "relative_position_embedding"): + self.relative_position_embedding = args.relative_position_embedding + self.num_buckets = args.num_buckets + self.max_distance = args.max_distance + else: + self.relative_position_embedding = False + self.num_buckets = 0 + self.max_distance = 0 + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + has_relative_attention_bias=( + self.relative_position_embedding and i == 0), + num_buckets=self.num_buckets, + max_distance=self.max_distance, + gru_rel_pos=args.gru_rel_pos, + ) + for i in range(args.encoder_layers) + ] + ) + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + def forward(self, x, padding_mask=None, streaming_mask=None, layer=None): + x, layer_results = self.extract_features( + x, padding_mask, streaming_mask, layer) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None): + + if padding_mask is not None: + x[padding_mask] = 0 + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x = x + x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + z = None + if tgt_layer is not None: + layer_results.append((x, z)) + r = None + pos_bias = None + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, + self_attn_mask=streaming_mask, pos_bias=pos_bias) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + has_relative_attention_bias: bool = False, + num_buckets: int = 0, + max_distance: int = 0, + rescale_init: bool = False, + gru_rel_pos: bool = False, + ) -> None: + + super().__init__() + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_name = activation_fn + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + has_relative_attention_bias=has_relative_attention_bias, + num_buckets=num_buckets, + max_distance=max_distance, + rescale_init=rescale_init, + gru_rel_pos=gru_rel_pos, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + + if self.activation_name == "glu": + self.fc1 = GLU_Linear(self.embedding_dim, + ffn_embedding_dim, "swish") + else: + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + pos_bias=None + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + + x = self.dropout1(x) + x = residual + x + + x = self.self_attn_layer_norm(x) + + residual = x + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + x = self.final_layer_norm(x) + + return x, attn, pos_bias diff --git a/src/models/components/WavLM/modules.py b/src/models/components/WavLM/modules.py new file mode 100644 index 0000000..1dcfc6f --- /dev/null +++ b/src/models/components/WavLM/modules.py @@ -0,0 +1,827 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import warnings +from typing import Dict, Optional, Tuple +import torch +from torch import Tensor, nn +from torch.nn import Parameter +import torch.nn.functional as F + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None): + super().__init__() + self.deconstruct_idx = deconstruct_idx + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(-2, -1) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.group_norm( + input.float(), + self.num_groups, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class Swish(nn.Module): + """Swish function + """ + + def __init__(self): + """Construct an MultiHeadedAttention object.""" + super(Swish, self).__init__() + self.act = torch.nn.Sigmoid() + + def forward(self, x): + return x * self.act(x) + + +class GLU_Linear(nn.Module): + def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): + super(GLU_Linear, self).__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + + if glu_type == "sigmoid": + self.glu_act = torch.nn.Sigmoid() + elif glu_type == "swish": + self.glu_act = Swish() + elif glu_type == "relu": + self.glu_act = torch.nn.ReLU() + elif glu_type == "gelu": + self.glu_act = torch.nn.GELU() + + if bias_in_glu: + self.linear = nn.Linear(input_dim, output_dim * 2, True) + else: + self.linear = nn.Linear(input_dim, output_dim * 2, False) + + def forward(self, x): + # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + x = self.linear(x) + + if self.glu_type == "bilinear": + x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2]) + else: + x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) + + return x + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + + +def get_activation_fn(activation: str): + """Returns the activation function corresponding to `activation`""" + + if activation == "relu": + return F.relu + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + warnings.warn( + "--activation-fn=gelu_fast has been renamed to gelu_accurate" + ) + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "glu": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_( + data.cpu().normal_(mean=0.0, std=0.02).to(data.device) + ) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + + +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + has_relative_attention_bias=False, + num_buckets=32, + max_distance=128, + gru_rel_pos=False, + rescale_init=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = nn.Dropout(dropout) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + + self.head_dim = embed_dim // num_heads + self.q_head_dim = self.head_dim + self.k_head_dim = self.head_dim + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + k_bias = True + if rescale_init: + k_bias = False + + k_embed_dim = embed_dim + q_embed_dim = embed_dim + + self.k_proj = quant_noise( + nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.grep_linear = nn.Linear(self.q_head_dim, 8) + self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) + + self.reset_parameters() + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + if self.has_relative_attention_bias: + nn.init.xavier_normal_(self.relative_attention_bias.weight) + + def _relative_positions_bucket(self, relative_positions, bidirectional=True): + num_buckets = self.num_buckets + max_distance = self.max_distance + relative_buckets = 0 + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + relative_position_bucket = self._relative_positions_bucket( + relative_position, + bidirectional=True + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) + values = values.permute([2, 0, 1]) + return values + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + position_bias: Optional[Tensor] = None + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if self.has_relative_attention_bias and position_bias is None: + position_bias = self.compute_bias(tgt_len, src_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) + + if ( + not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + and self.q_head_dim == self.head_dim + ): + assert key is not None and value is not None + assert attn_mask is None + + attn_mask_rel_pos = None + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos: + query_layer = query.transpose(0, 1) + new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1) + query_layer = query_layer.view(*new_x_shape) + query_layer = query_layer.permute(0, 2, 1, 3) + _B, _H, _L, __ = query_layer.size() + + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len)) + k_proj_bias = self.k_proj.bias + if k_proj_bias is None: + k_proj_bias = torch.zeros_like(self.q_proj.bias) + + x, attn = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training, + # self.training or self.dropout_module.apply_during_inference, + key_padding_mask, + need_weights, + attn_mask_rel_pos, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + return x, attn, position_bias + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.q_head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads, self.k_head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as( + key_padding_mask + ), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v, position_bias + + if position_bias is not None: + if self.gru_rel_pos == 1: + query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) + _B, _H, _L, __ = query_layer.size() + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + position_bias = position_bias.view(attn_weights.size()) + + attn_weights = attn_weights + position_bias + + attn_weights_float = F.softmax( + attn_weights, dim=-1 + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights, position_bias + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights diff --git a/src/models/components/conformer_tcm/__init__.py b/src/models/components/conformer_tcm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/components/conformer_tcm/conformer.py b/src/models/components/conformer_tcm/conformer.py new file mode 100644 index 0000000..4c41bd7 --- /dev/null +++ b/src/models/components/conformer_tcm/conformer.py @@ -0,0 +1,331 @@ +import math +import torch +from torch import nn, einsum +import torch.nn.functional as F + +from einops import rearrange +from einops.layers.torch import Rearrange +from speechbrain.lobes.models.ECAPA_TDNN import Res2NetBlock, SEBlock +# helper functions + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +def calc_same_padding(kernel_size): + pad = kernel_size // 2 + return (pad, pad - (kernel_size + 1) % 2) + +# helper classes + +class Swish(nn.Module): + def forward(self, x): + return x * x.sigmoid() + +class GLU(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + out, gate = x.chunk(2, dim=self.dim) + return out * gate.sigmoid() + +class DepthWiseConv1d(nn.Module): + def __init__(self, chan_in, chan_out, kernel_size, padding): + super().__init__() + self.padding = padding + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in) + + def forward(self, x): + x = F.pad(x, self.padding) + return self.conv(x) + +# attention, feedforward, and conv module + +class Scale(nn.Module): + def __init__(self, scale, fn): + super().__init__() + self.fn = fn + self.scale = scale + + def forward(self, x, **kwargs): + return self.fn(x, **kwargs) * self.scale + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = nn.LayerNorm(dim) + + def forward(self, x, **kwargs): + x = self.norm(x) + return self.fn(x, **kwargs) + +class Attention(nn.Module): + # Head Token attention: https://arxiv.org/pdf/2210.05958.pdf + def __init__(self, dim, heads=8, dim_head=64, qkv_bias=False, dropout=0., proj_drop=0.): + super().__init__() + self.num_heads = heads + inner_dim = dim_head * heads + self.scale = dim_head ** -0.5 + + self.qkv = nn.Linear(dim, inner_dim * 3, bias=qkv_bias) + + self.attn_drop = nn.Dropout(dropout) + self.proj = nn.Linear(inner_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.act = nn.GELU() + self.ht_proj = nn.Linear(dim_head, dim,bias=True) + self.ht_norm = nn.LayerNorm(dim_head) + self.pos_embed = nn.Parameter(torch.zeros(1, self.num_heads, dim)) + + def forward(self, x, mask=None): + B, N, C = x.shape + + # head token + head_pos = self.pos_embed.expand(x.shape[0], -1, -1) + x_ = x.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + x_ = x_.mean(dim=2) # now the shape is [B, h, 1, d//h] + x_ = self.ht_proj(x_).reshape(B, -1, self.num_heads, C // self.num_heads) + x_ = self.act(self.ht_norm(x_)).flatten(2) + x_ = x_ + head_pos + x = torch.cat([x, x_], dim=1) + + # normal mhsa + qkv = self.qkv(x).reshape(B, N+self.num_heads, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + # attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N+self.num_heads, C) + x = self.proj(x) + + # merge head tokens into cls token + cls, patch, ht = torch.split(x, [1, N-1, self.num_heads], dim=1) + cls = cls + torch.mean(ht, dim=1, keepdim=True) + torch.mean(patch, dim=1, keepdim=True) + x = torch.cat([cls, patch], dim=1) + + x = self.proj_drop(x) + + return x, attn + + +class FeedForward(nn.Module): + def __init__( + self, + dim, + mult = 4, + dropout = 0. + ): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, dim * mult), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim * mult, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.net(x) + +class ConformerConvModule(nn.Module): + def __init__( + self, + dim, + causal = False, + expansion_factor = 2, + kernel_size = 31, + dropout = 0. + ): + super().__init__() + + inner_dim = dim * expansion_factor + padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0) + + self.net = nn.Sequential( + nn.LayerNorm(dim), + Rearrange('b n c -> b c n'), + nn.Conv1d(dim, inner_dim * 2, 1), + GLU(dim=1), + DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding), + nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(), + Swish(), + nn.Conv1d(inner_dim, dim, 1), + Rearrange('b c n -> b n c'), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.net(x) + +class ConformerConvModule(nn.Module): + def __init__( + self, + dim, + causal=False, + expansion_factor=2, + kernel_size=31, + dropout=0., + type='conv' + ): + super().__init__() + + inner_dim = dim * expansion_factor + padding = calc_same_padding( + kernel_size) if not causal else (kernel_size - 1, 0) + + if type == 'conv': + + self.net = nn.Sequential( + nn.LayerNorm(dim), + Rearrange('b n c -> b c n'), + nn.Conv1d(dim, inner_dim * 2, 1), + GLU(dim=1), + DepthWiseConv1d(inner_dim, inner_dim, + kernel_size=kernel_size, padding=padding), + nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(), + Swish(), + nn.Conv1d(inner_dim, dim, 1), + Rearrange('b c n -> b n c'), + nn.Dropout(dropout) + ) + + elif type == 'conv_res2net': + self.net = nn.Sequential( + nn.LayerNorm(dim), + Rearrange('b n c -> b c n'), + nn.Conv1d(dim, inner_dim * 2, 1), + GLU(dim=1), + DepthWiseConv1d(inner_dim, inner_dim, + kernel_size=kernel_size, padding=padding), + nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(), + Swish(), + nn.Conv1d(inner_dim, dim, 1), + Res2NetBlock(dim, dim, scale=4, dilation=3), + Rearrange('b c n -> b n c'), + nn.Dropout(dropout)) + + elif type == 'conv_seblock': + self.net = nn.Sequential( + nn.LayerNorm(dim), + Rearrange('b n c -> b c n'), + nn.Conv1d(dim, inner_dim * 2, 1), + GLU(dim=1), + DepthWiseConv1d(inner_dim, inner_dim, + kernel_size=kernel_size, padding=padding), + nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(), + Swish(), + nn.Conv1d(inner_dim, dim, 1), + SEBlock(dim, 16, dim), + Rearrange('b c n -> b n c'), + nn.Dropout(dropout)) + + elif type == 'conv_res2net_seblock': + self.net = nn.Sequential( + nn.LayerNorm(dim), + Rearrange('b n c -> b c n'), + nn.Conv1d(dim, inner_dim * 2, 1), + GLU(dim=1), + DepthWiseConv1d(inner_dim, inner_dim, + kernel_size=kernel_size, padding=padding), + nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(), + Swish(), + nn.Conv1d(inner_dim, dim, 1), + Res2NetBlock(dim, dim, scale=4, dilation=3), + SEBlock(dim, 16, dim), + Rearrange('b c n -> b n c'), + nn.Dropout(dropout)) + + else: + raise ValueError('Invalid type') + + def forward(self, x): + return self.net(x) + +# Conformer Block + + +class ConformerBlock(nn.Module): + def __init__( + self, + *, + dim, + dim_head=64, + heads=8, + ff_mult=4, + conv_expansion_factor=2, + conv_kernel_size=31, + attn_dropout=0., + ff_dropout=0., + conv_dropout=0., + conv_causal=False, + type='conv' + ): + super().__init__() + self.ff1 = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) + self.attn = Attention(dim=dim, dim_head=dim_head, + heads=heads, dropout=attn_dropout) + self.conv = ConformerConvModule( + dim=dim, causal=conv_causal, expansion_factor=conv_expansion_factor, kernel_size=conv_kernel_size, dropout=conv_dropout, type=type) + self.ff2 = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) + + self.attn = PreNorm(dim, self.attn) + self.ff1 = Scale(0.5, PreNorm(dim, self.ff1)) + self.ff2 = Scale(0.5, PreNorm(dim, self.ff2)) + + self.post_norm = nn.LayerNorm(dim) + + def forward(self, x, mask=None): + x = self.ff1(x) + x + attn_x, attn_weight = self.attn(x, mask = mask) + x = attn_x + x + x = self.conv(x) + x + x = self.ff2(x) + x + x = self.post_norm(x) + return x, attn_weight + +class Conformer(nn.Module): + def __init__( + self, + dim, + *, + depth, + dim_head = 64, + heads = 8, + ff_mult = 4, + conv_expansion_factor = 2, + conv_kernel_size = 31, + attn_dropout = 0., + ff_dropout = 0., + conv_dropout = 0., + conv_causal = False + ): + super().__init__() + self.dim = dim + self.layers = nn.ModuleList([]) + + for _ in range(depth): + self.layers.append(ConformerBlock( + dim = dim, + dim_head = dim_head, + heads = heads, + ff_mult = ff_mult, + conv_expansion_factor = conv_expansion_factor, + conv_kernel_size = conv_kernel_size, + conv_causal = conv_causal + + )) + + def forward(self, x): + + for block in self.layers: + x = block(x) + + return x \ No newline at end of file diff --git a/src/models/components/conformer_tcm/model.py b/src/models/components/conformer_tcm/model.py new file mode 100644 index 0000000..ae06cb3 --- /dev/null +++ b/src/models/components/conformer_tcm/model.py @@ -0,0 +1,56 @@ +import torch +import torch.nn as nn +from torch.nn.modules.transformer import _get_clones +from torch import Tensor + +try: + from model.conformer_tcm.conformer import ConformerBlock +except: + from .conformer import ConformerBlock + +def sinusoidal_embedding(n_channels, dim): + pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)] + for p in range(n_channels)]) + pe[:, 0::2] = torch.sin(pe[:, 0::2]) + pe[:, 1::2] = torch.cos(pe[:, 1::2]) + return pe.unsqueeze(0) + +class MyConformer(nn.Module): + def __init__(self, emb_size=128, heads=4, ffmult=4, exp_fac=2, kernel_size=16, n_encoders=1, pooling='mean', type='conv', **kwargs): + super(MyConformer, self).__init__() + self.pooling=pooling + self.dim_head=int(emb_size/heads) + self.dim=emb_size + self.heads=heads + self.kernel_size=kernel_size + self.n_encoders=n_encoders + self.positional_emb = nn.Parameter(sinusoidal_embedding(10000, emb_size), requires_grad=False) + self.conv_dropout = kwargs.get('conv_dropout', 0.0) + self.ff_dropout = kwargs.get('ff_dropout', 0.0) + self.attn_dropout = kwargs.get('attn_dropout', 0.0) + self.encoder_blocks=_get_clones(ConformerBlock(dim = emb_size, dim_head=self.dim_head, heads= heads, + ff_mult = ffmult, conv_expansion_factor = exp_fac, conv_kernel_size = kernel_size, type=type, + conv_dropout = self.conv_dropout, ff_dropout = self.ff_dropout, attn_dropout = self.attn_dropout + ), + n_encoders) + self.class_token = nn.Parameter(torch.rand(1, emb_size)) + self.fc5 = nn.Linear(emb_size, 2) + + def forward(self, x): # x shape [bs, tiempo, frecuencia] + x = x + self.positional_emb[:, :x.size(1), :] + x = torch.stack([torch.vstack((self.class_token, x[i])) for i in range(len(x))])#[bs,1+tiempo,emb_size] + list_attn_weight = [] + for layer in self.encoder_blocks: + x, attn_weight = layer(x) #[bs,1+tiempo,emb_size] + list_attn_weight.append(attn_weight) + if self.pooling=='mean': + embedding = x.mean(dim=1) + elif self.pooling=='max': + embedding = x.max(dim=1)[0] + else: + # first token + embedding=x[:,0,:] #[bs, emb_size] + out=self.fc5(embedding) #[bs,2] + return out, embedding + + diff --git a/src/models/components/loss_metrics.py b/src/models/components/loss_metrics.py new file mode 100644 index 0000000..7e854fb --- /dev/null +++ b/src/models/components/loss_metrics.py @@ -0,0 +1,598 @@ +#!/usr/bin/env python +""" +util_loss_metric + +Loss functions or metrics + + +References +[1] Weitang Liu, Xiaoyun Wang, John Owens, and Yixuan Li. + Energy-Based Out-of-Distribution Detection. + In Proc. NIPS, 33:21464–21475. 2020. +[2] Prannay Khosla, Piotr Teterwak, Chen Wang, Aaron Sarna, + Yonglong Tian, Phillip Isola, Aaron Maschinot, Ce Liu, and Dilip Krishnan + Supervised Contrastive Learning. Proc.NIPS. 2020. +[3] Hongyi Zhang, Moustapha Cisse, Yann N Dauphin, and David Lopez-Paz. + Mixup: Beyond Empirical Risk Minimization. In Proc. ICLR. 2018. +""" + +from __future__ import absolute_import +from __future__ import print_function + +import os +import sys +import numpy as np + +import torch +import torch.nn as torch_nn +import torch.nn.functional as torch_nn_func + +__author__ = "Xin Wang" +__email__ = "wangxin@nii.ac.jp" +__copyright__ = "Copyright 2022, Xin Wang" + + +##################### +# negative energy [1] +##################### + +def neg_energy(logits, temperature=1): + """ neg_eng = neg_energy(logits, temperature=1) + + neg_eng[x] = -T \log \sum_y \exp (logits[x, y] / T) + + See [1] + + input + ----- + logits: tensor, shape (batch, dim) + temperature: float, temperature hyperparameter + + output + ------ + neg_eng: tensor, shape (batch,) + """ + eng = - temperature * torch.logsumexp(logits / temperature, dim=1) + return eng + + +def neg_energy_reg_loss(energy, margin_in, margin_out, flag_in): + """ loss = neg_energy_reg_loss(energy, margin_in, margin_out, flag_in) + + See [1] eqs.(8-9) + + input + ----- + energy: tensor, any shape is OK + margin_in: float, margin for the in-dist. data + margin_out: float, margin for the out-dist. data + flag_in: bool, if the input data is in-dist. data + + output + ------ + loss: scalar + """ + if flag_in: + loss = torch.pow(torch_nn_func.relu(energy - margin_in), 2).mean() + else: + loss = torch.pow(torch_nn_func.relu(margin_out - energy), 2).mean() + return loss + + +##################### +# supervised contrastive loss [2] +##################### +def sim_metric_seq(mat1, mat2): + return torch.bmm(mat1.permute(1, 0, 2), mat2.permute(1, 2, 0)).mean(0) +def supcon_loss(input_feat, + labels = None, mask = None, sim_metric = sim_metric_seq, + t=0.07, contra_mode='all', length_norm=False, + weight = None): + """ + loss = SupConLoss(feat, + labels = None, mask = None, sim_metric = None, + t=0.07, contra_mode='all') + input + ----- + feat: tensor, feature vectors z [bsz, n_views, ...]. + labels: ground truth of shape [bsz]. + mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j + has the same class as sample i. Can be asymmetric. + sim_metric: func, function to measure the similarity between two + feature vectors + t: float, temperature + contra_mode: str, default 'all' + 'all': use all data in class i as anchors + 'one': use 1st data in class i as anchors + length_norm: bool, default False + if True, l2 normalize feat along the last dimension + + output + ------ + A loss scalar. + + Based on https://github.com/HobbitLong/SupContrast/blob/master/losses.py + Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. + + Example: + feature = torch.rand([16, 2, 1000], dtype=torch.float32) + feature = torch_nn_func.normalize(feature, dim=-1) + label = torch.tensor([1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 5, 1, 1, 1, 1, 1], + dtype=torch.long) + loss = supcon_loss(feature, labels=label) + """ + if length_norm: + feat = torch_nn_func.normalize(input_feat, dim=-1) + else: + feat = input_feat + + # batch size + bs = feat.shape[0] + # device + dc = feat.device + # dtype + dt = feat.dtype + # number of view + nv = feat.shape[1] + + # get the mask + # mask[i][:] indicates the data that has the same class label as data i + if labels is not None and mask is not None: + raise ValueError('Cannot define both `labels` and `mask`') + elif labels is None and mask is None: + mask = torch.eye(bs, dtype=dt, device=dc) + elif labels is not None: + labels = labels.view(-1, 1) + if labels.shape[0] != bs: + raise ValueError('Num of labels does not match num of features') + mask = torch.eq(labels, labels.T).type(dt).to(dc) + else: + mask = mask.type(dt).to(dc) + + # prepare feature matrix + # -> (num_view * batch, feature_dim, ...) + contrast_feature = torch.cat(torch.unbind(feat, dim=1), dim=0) + # + if contra_mode == 'one': + # (batch, feat_dim, ...) + anchor_feature = feat[:, 0] + anchor_count = 1 + elif contra_mode == 'all': + anchor_feature = contrast_feature + anchor_count = nv + else: + raise ValueError('Unknown mode: {}'.format(contra_mode)) + + # compute logits + # logits_mat is a matrix of size [num_view * batch, num_view * batch] + # or [batch, num_view * batch] + + if sim_metric is not None: + logits_mat = torch.div( + sim_metric(anchor_feature, contrast_feature), t) + else: + logits_mat = torch.div( + torch.matmul(anchor_feature, contrast_feature.T), t) + + # print(anchor_feature.shape) + # mask based on the label + # -> same shape as logits_mat + mask_ = mask.repeat(anchor_count, nv) + # mask on each data itself ( + self_mask = torch.scatter( + torch.ones_like(mask_), 1, + torch.arange(bs * anchor_count).view(-1, 1).to(dc), + 0) + # print(self_mask) + # + mask_ = mask_ * self_mask + # print(mask_) + + # for numerical stability, remove the max from logits + # see https://en.wikipedia.org/wiki/LogSumExp trick + # for numerical stability + logits_max, _ = torch.max(logits_mat * self_mask, dim=1, keepdim=True) + logits_mat_ = logits_mat - logits_max.detach() + # compute log_prob + exp_logits = torch.exp(logits_mat_ * self_mask) * self_mask + log_prob = logits_mat_ - torch.log(exp_logits.sum(1, keepdim=True)) + + # print("log_prob.shape", log_prob.shape) + # compute mean of log-likelihood over positive + # print(mask_ * log_prob) + mean_log_prob_pos = (mask_ * log_prob).sum(1) / mask_.sum(1) + # print("mean_log_prob_pos.shape", mean_log_prob_pos.shape) + # print(mean_log_prob_pos) + # loss + loss = - mean_log_prob_pos + loss = loss.view(anchor_count, bs).mean() + + return loss + + +################ +# Mixup +################ + +class MixUpCE(torch_nn.Module): + def __init__(self, weight = None): + super(MixUpCE, self).__init__() + self.m_loss1 = torch_nn.CrossEntropyLoss(weight=weight,reduction='none') + self.m_loss2 = torch_nn.CrossEntropyLoss(weight=weight,reduction='none') + return + + def forward(self, logits, y1, y2=None, gammas=None): + """ loss = MixUpCE.forward(logits, y1, y2, gammas) + + This API computes the mixup cross-entropy. + Logits is assumed to be f( gammas * x1 + (1-gammas) * x2). + Thus, this API only compute the CE: + gammas * Loss(logits, y1) + (1 - gammas) * Loss(logits, y2) + + Note that if y2 and gammas are None, it uses common CE + + input + ----- + logits: tensor, (batch, dim) + y1: tensor, (batch, ) + y2: tensor, (batch, ) + gammas: tensor, (batch, ) + + + output + ------ + loss: scalar + """ + if y2 is None and gammas is None: + loss_val = self.m_loss1(logits, y1) + else: + loss_val = gammas * self.m_loss1(logits, y1) + loss_val += (1-gammas) * self.m_loss2(logits, y2) + return loss_val.mean() + + + +##################### +# Distillation related +##################### + +def kld_distill(logits_s, logits_t, temp=20): + """ KLD-based distillation loss + + input + ----- + logits_s: tensor, (batch, ..., dim), student output logits + where dim is #. output categories + logits_t: tensor, (batch, ..., dim), teacher output logits + temp: float, temperature, default=20 + + output + ------ + loss: scalar + """ + + KD_loss = torch_nn_func.kl_div( + torch_nn_func.log_softmax(logits_s / temp, dim = -1), + torch_nn_func.log_softmax(logits_t / temp, dim = -1), + reduction = 'batchmean', + log_target = True) * temp * temp + + return KD_loss + +##################### +# Rank consistency +##################### +def rank_consistency(x, metric = None, anchor = None, diff_mat = None): + """loss = rank_consistency(x, metric) + + input + ----- + x: tensor, (batch, D1, D2 ...) + metric: a function or None + + This function must be f(x1, x2) -> scalar + where x1 and x2 are in shape (D1, D2 ...) + + if None, negative cosine similarity for + x1 and x2 of shape (D1, ) + anchor: tensor, (batch, D1, D2, ...), as anchor + or None + + If None, one of difference vector in the + matrix will be selected as anchor + + diff_mat: tensor, (batch, batch, D1, D2 ...) + of None + if diff_mat is provided, x will be ignored + output + ------ + loss: scalar, loss value + + Example + ------- + >> x = torch.randn(4, 10) + >> x[1] = x[0] + 1.0 + >> x[2] = x[0] + 2.0 + >> x[3] = x[0] + 3.0 + >> rank_consistency(x) + tensor(-1.) + """ + + if diff_mat is None: + # (batch, batch, dim) + # diff_mat[i, j] = x[j] - x[i] + diff_mat = x - x.unsqueeze(1) + + # batch size + bs = diff_mat.shape[0] + + # loss to be accumulated + loss = 0.0 + + # metric + if metric is None: + # default use negative cosine_similarity + metric = lambda x1, x2: -torch_nn_func.cosine_similarity(x1, x2, dim=0) + + # + if bs < 3: + return loss + + # get anchor + if anchor is None: + # choose the diff_mat[1, 0] as the anchor + anchor_row_idx = 1 + anchor_col_idx = 0 + anchor = diff_mat[anchor_row_idx, anchor_col_idx] + else: + # anchor is provided externally + anchor_row_idx = -1 + anchor_col_idx = -1 + + + # loop over the matrix, compare the off-diagnoal elements + # with the anchor + count = 0.0 + for col_idx in np.arange(bs-1): + for row_idx in np.arange(col_idx+1, bs): + if col_idx == anchor_col_idx and anchor_row_idx == row_idx: + continue + loss += metric(anchor, diff_mat[row_idx, col_idx]) + count += 1 + + loss = loss / count + return loss + + +def rank_consistency_v2(x, metric = None, diff_mat = None): + """loss = rank_consistency_v2(x, metric) + + input + ----- + x: tensor, (batch, D1, D2 ...) + metric: a function or None + + This function must be f(x1, x2) -> scalar + where x1 and x2 are in shape (D1, D2 ...) + + if None, negative cosine similarity for + x1 and x2 of shape (D1, ) + + diff_mat: tensor, (batch, batch, D1, D2 ...) + of None + if diff_mat is provided, x will be ignored + output + ------ + loss: scalar, loss value + + Example + ------- + >> x = torch.randn(4, 10) + >> x[1] = x[0] + 1.0 + >> x[2] = x[0] + 2.0 + >> x[3] = x[0] + 3.0 + >> metric = lambda x1, x2: \ + torch_nn_func.margin_ranking_loss(x1, x2, torch.ones_like(x1), 0.1) + >> rank_consistencyV2(x, metric) + tensor(.0) + """ + + if diff_mat is None: + # (batch, batch, dim) + # diff_mat[i, j] = x[j] - x[i] + diff_mat = x - x.unsqueeze(1) + + # batch size + bs = diff_mat.shape[0] + + # loss to be accumulated + loss = 0.0 + + # metric + if metric is None: + # default use margin_ranking_loss + metric = lambda x1, x2: torch_nn_func.margin_ranking_loss( + x1, x2, torch.ones_like(x1), 0.1) + + # + if bs < 3: + return loss + + count = 0.0 + # loop over the matrix, column first + for col_idx in np.arange(bs-2): + for row_idx in np.arange(col_idx+2, bs): + # assume diff[i, j] should be diff[i-1, j] + loss += metric(diff_mat[row_idx-1, col_idx], + diff_mat[row_idx, col_idx]) + count += 1 + + # loop over the matrix, column first + for row_idx in np.arange(2, bs): + for col_idx in np.arange(1, row_idx): + # assume diff[i, j] should be diff[i, j-1] + loss += metric(diff_mat[row_idx, col_idx], + diff_mat[row_idx, col_idx-1]) + count += 1 + + loss = loss / count + return loss + + +def rank_consistency_v3(x, metric = None): + """loss = rank_consistency_v3(x, metric) + + input + ----- + x: tensor, (batch, D1, D2 ...) + metric: a function or None + + This function must be f(x1, x2) -> scalar + where x1 and x2 are in shape (D1, D2 ...) + + if None, negative cosine similarity for + x1 and x2 of shape (D1, ) + + output + ------ + loss: scalar, loss value + + Example + ------- + >> x = torch.randn(4, 10) + >> x[1] = x[0] + 1.0 + >> x[2] = x[0] + 2.0 + >> x[3] = x[0] + 3.0 + >> metric = lambda x1, x2: \ + torch_nn_func.margin_ranking_loss(x1, x2, torch.ones_like(x1), 0.1) + >> rank_consistency_v3(x, metric) + tensor(.0) + """ + # batch size + bs = x.shape[0] + + # loss to be accumulated + loss = 0.0 + + # metric + if metric is None: + # default use margin_ranking_loss + # x1 should be ranked higher + metric = lambda x1, x2: torch_nn_func.margin_ranking_loss( + x1, x2, torch.ones_like(x1), 0.1) + + # + if bs < 2: + return loss + + count = 0.0 + # loop over the rows + for row_idx1 in np.arange(1, bs): + for row_idx2 in np.arange(0, row_idx1): + loss += metric(x[row_idx1], + x[row_idx2]) + count += 1 + + loss = loss / count + return loss + +################ +# Custom loss by Phucdt +# For Vocosig +################ + +def loss_custom(output, feats, emb, labels, config): + + real_bzs = output.shape[0] + n_views = 1.0 + loss_CE = torch_nn.CrossEntropyLoss() + + sim_metric_seq = lambda mat1, mat2: torch.bmm( + mat1.permute(1, 0, 2), mat2.permute(1, 2, 0)).mean(0) + + # print("output.shape", output.shape) + # print("labels.shape", labels.shape) + L_CE = 1/real_bzs *loss_CE(output, labels) + + # reshape the feats to match the supcon loss format + feats = feats.unsqueeze(1) + # print("feats.shape", feats.shape) + L_CF1 = 1/real_bzs * supcon_loss(feats, labels=labels, contra_mode=config['model']['contra_mode'], sim_metric=sim_metric_seq) + + # reshape the emb to match the supcon loss format + emb = emb.unsqueeze(1) + emb = emb.unsqueeze(-1) + # print("emb.shape", emb.shape) + L_CF2 = 1/real_bzs *supcon_loss(emb, labels=labels, contra_mode=config['model']['contra_mode'], sim_metric=sim_metric_seq) + + if config['model']['loss_type'] == 1: + return {'L_CE':L_CE, 'L_CF1':L_CF1, 'L_CF2':L_CF2} + elif config['model']['loss_type'] == 2: + return {'L_CE':L_CE, 'L_CF1':L_CF1} + elif config['model']['loss_type'] == 3: + return {'L_CE':L_CE, 'L_CF2':L_CF2} + # ablation study + elif config['model']['loss_type'] == 4: + return {'L_CE':L_CE} + elif config['model']['loss_type'] == 5: + return {'L_CF1':L_CF1, 'L_CF2':L_CF2} + + +################ +# AAM loss +################ +import time, pdb, numpy, math +import torch.nn as nn +import torch.nn.functional as F +class AAMLoss(nn.Module): + def __init__(self, nOut, nClasses, margin=0.3, scale=15, easy_margin=False, **kwargs): + super(AAMLoss, self).__init__() + + self.test_normalize = True + + self.m = margin + self.s = scale + self.in_feats = nOut + self.weight = torch.nn.Parameter(torch.FloatTensor(nClasses, nOut), requires_grad=True) + self.ce = nn.CrossEntropyLoss() + nn.init.xavier_normal_(self.weight, gain=1) + + self.easy_margin = easy_margin + self.cos_m = math.cos(self.m) + self.sin_m = math.sin(self.m) + + # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°] + self.th = math.cos(math.pi - self.m) + self.mm = math.sin(math.pi - self.m) * self.m + + print('Initialised AAMSoftmax margin %.3f scale %.3f'%(self.m,self.s)) + + def forward(self, x, label=None): + + assert x.size()[0] == label.size()[0] + assert x.size()[1] == self.in_feats + + # cos(theta) + cosine = F.linear(F.normalize(x), F.normalize(self.weight)) + # cos(theta + m) + sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1)) + phi = cosine * self.cos_m - sine * self.sin_m + + if self.easy_margin: + phi = torch.where(cosine > 0, phi, cosine) + else: + phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm) + + #one_hot = torch.zeros(cosine.size(), device='cuda' if torch.cuda.is_available() else 'cpu') + one_hot = torch.zeros_like(cosine) + one_hot.scatter_(1, label.view(-1, 1), 1) + output = (one_hot * phi) + ((1.0 - one_hot) * cosine) + output = output * self.s + + loss = self.ce(output, label) + return loss + +if __name__ == "__main__": + print("loss and metric") + + diff --git a/src/models/components/wavlmbase_conformertcm.py b/src/models/components/wavlmbase_conformertcm.py new file mode 100644 index 0000000..7539bba --- /dev/null +++ b/src/models/components/wavlmbase_conformertcm.py @@ -0,0 +1,147 @@ +import random +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +import os +from torch.nn.modules.transformer import _get_clones +from .WavLM.fe import WavLMFe +try: + from model.loss_metrics import supcon_loss + from model.RawNet3.model import RawNet3 + from model.RawNet3.RawNetBasicBlock import Bottle2neck + from model.conformer_tcm.model import MyConformer +except: + from .loss_metrics import supcon_loss + from .RawNet3.model import RawNet3 + from .RawNet3.RawNetBasicBlock import Bottle2neck + from .conformer_tcm.model import MyConformer + + +class Model(nn.Module): + def __init__(self, args, device, is_train = True): + super().__init__() + self.device = device + self.is_train = is_train + self.flag_fix_ssl = args['flag_fix_ssl'] + self.contra_mode = args['contra_mode'] + self.loss_type = args['loss_type'] + + self.loss_CE = nn.CrossEntropyLoss(weight = torch.FloatTensor([float(1-args['ce_loss_weight']), float(args['ce_loss_weight'])]).to(device)) + self.sim_metric_seq = lambda mat1, mat2: torch.bmm( + mat1.permute(1, 0, 2), mat2.permute(1, 2, 0)).mean(0) + # front-end kwargs + fe_kwargs = args.get('wavlm_kwargs', {}) + self.front_end = WavLMFe(**fe_kwargs) + + self.LL = nn.Linear(self.front_end.out_dim, args['conformer']['emb_size']) + self.first_bn = nn.BatchNorm2d(num_features=1) + self.drop = nn.Dropout(0.5, inplace=True) + self.selu = nn.SELU(inplace=True) + + self.backend=MyConformer(**args['conformer']) + self.loss_CE = nn.CrossEntropyLoss() + + self.sim_metric_seq = lambda mat1, mat2: torch.bmm( + mat1.permute(1, 0, 2), mat2.permute(1, 2, 0)).mean(0) + # Post-processing + + def _forward(self, x): + # Set gradient to True for x + x.requires_grad = True + #-----------------RawNet3-----------------# + x = self.front_end(x) #(bs,frame_number,frontend_out_dim) + x = self.LL(x) #(bs,frame_number,feat_out_dim) + + feats = x + x = x.unsqueeze(dim=1) # add channel #(bs, 1, frame_number, 256) + x = self.first_bn(x) + x = self.selu(x) + x = x.squeeze(dim=1) + + # output [batch, 2] + # emb [batch, emb_size] + output, emb = self.backend(x) + output = F.log_softmax(output, dim=1) + if (self.is_train): + return output, feats, emb + return output + + def forward(self, x_big): + # make labels to be a tensor of [bz] + # labels = labels.squeeze(0) + if (x_big.dim() == 3): + x_big = x_big.transpose(0,1) + batch, length, sample_per_batch = x_big.shape + # x_big is a tensor of [length, batch, sample per batch] + # transform to [length, batch*sample per batch] by concat last dim + x_big = x_big.transpose(1,2) + x_big = x_big.reshape(batch * sample_per_batch, length) + if (self.is_train): + # x_big is a tensor of [1, length, bz] + # convert to [bz, length] + # x_big = x_big.squeeze(0).transpose(0,1) + output, feats, emb = self._forward(x_big) + # calculate the loss + return output, feats, emb + else: + # in inference mode, we don't need the emb + # the x_big now is a tensor of [bz, length] + return self._forward(x_big) + + + def loss(self, output, feats, emb, labels, config, info=None): + real_bzs = output.shape[0] + # print("real_bzs", real_bzs) + # print("labels", labels) + L_CE = 1/real_bzs *self.loss_CE(output, labels) + + # reshape the feats to match the supcon loss format + feats = feats.unsqueeze(1) + # print("feats.shape", feats.shape) + L_CF1 = 1/real_bzs *supcon_loss(feats, labels=labels, contra_mode=self.contra_mode, sim_metric=self.sim_metric_seq) + # reshape the emb to match the supcon loss format + emb = emb.unsqueeze(1) + emb = emb.unsqueeze(-1) + # print("emb.shape", emb.shape) + L_CF2 = 1/real_bzs *supcon_loss(emb, labels=labels, contra_mode=self.contra_mode, sim_metric=self.sim_metric_seq) + if self.loss_type == 1: + return {'L_CE':L_CE, 'L_CF1':L_CF1, 'L_CF2':L_CF2} + elif self.loss_type == 2: + return {'L_CE':L_CE, 'L_CF1':L_CF1} + elif self.loss_type == 3: + return {'L_CE':L_CE, 'L_CF2':L_CF2} + # ablation study + elif self.loss_type == 4: + return {'L_CE':L_CE} + elif self.loss_type == 5: + return {'L_CF1':L_CF1, 'L_CF2':L_CF2} + +if __name__ == '__main__': + import yaml + # LoRA + from peft import LoraConfig, TaskType, PeftModel, get_peft_model + + config = yaml.load(open("/data/hungdx/asvspoof5/configs/3_augall_wavlm_conformertcm_res2net_seblock_sclnormal.yaml", 'r'), Loader=yaml.FullLoader) + device = "cpu" + model = Model(config['model'], device) + # print("Hello") + # print(model) + + # Trying apply lora + # Using the best config from this paper + # https://arxiv.org/pdf/2306.05617 + target_modules = ["q_proj", "v_proj"] + r = 4 + + lora_config = LoraConfig( + r=r, + target_modules=target_modules, + #modules_to_save=["qkv"] + ) + peft_model = get_peft_model(model, lora_config) + print(peft_model) + peft_model.print_trainable_parameters() diff --git a/src/models/components/wavlmbase_vib.py b/src/models/components/wavlmbase_vib.py new file mode 100644 index 0000000..b57bfa8 --- /dev/null +++ b/src/models/components/wavlmbase_vib.py @@ -0,0 +1,281 @@ +import random +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +import os +from torch.nn.modules.transformer import _get_clones +from .WavLM.fe import WavLMFe +try: + from model.loss_metrics import supcon_loss + from model.RawNet3.model import RawNet3 + from model.RawNet3.RawNetBasicBlock import Bottle2neck + from model.conformer_tcm.model import MyConformer + +except: + from .loss_metrics import supcon_loss + from .RawNet3.model import RawNet3 + from .RawNet3.RawNetBasicBlock import Bottle2neck + from .conformer_tcm.model import MyConformer + + +___author__ = "Hemlata Tak" +__email__ = "tak@eurecom.fr" +class DropoutForMC(nn.Module): + """Dropout layer for Bayesian model + THe difference is that we do dropout even in eval stage + """ + def __init__(self, p, dropout_flag=True): + super(DropoutForMC, self).__init__() + self.p = p + self.flag = dropout_flag + return + + def forward(self, x): + return torch.nn.functional.dropout(x, self.p, training=self.flag) + +class BackEnd(nn.Module): + """Back End Wrapper + """ + def __init__(self, input_dim, out_dim, num_classes, + dropout_rate, dropout_flag=True): + super(BackEnd, self).__init__() + + # input feature dimension + self.in_dim = input_dim + # output embedding dimension + self.out_dim = out_dim + # number of output classes + self.num_class = num_classes + + # dropout rate + self.m_mcdp_rate = dropout_rate + self.m_mcdp_flag = dropout_flag + + # a simple full-connected network for frame-level feature processing + self.m_frame_level = nn.Sequential( + nn.Linear(self.in_dim, self.in_dim), + nn.GELU(), + torch.nn.Dropout(self.m_mcdp_rate), + # DropoutForMC(self.m_mcdp_rate,self.m_mcdp_flag), + + nn.Linear(self.in_dim, self.in_dim), + nn.GELU(), + torch.nn.Dropout(self.m_mcdp_rate), + # DropoutForMC(self.m_mcdp_rate,self.m_mcdp_flag), + + nn.Linear(self.in_dim, self.out_dim), + nn.GELU(), + torch.nn.Dropout(self.m_mcdp_rate) + ) + # DropoutForMC(self.m_mcdp_rate,self.m_mcdp_flag)) + + # linear layer to produce output logits + self.m_utt_level = nn.Linear(self.out_dim, self.num_class) + + return + + def forward(self, feat): + """ logits, emb_vec = back_end_emb(feat) + + input: + ------ + feat: tensor, (batch, frame_num, feat_feat_dim) + + output: + ------- + logits: tensor, (batch, num_output_class) + emb_vec: tensor, (batch, emb_dim) + + """ + # through the frame-level network + # (batch, frame_num, self.out_dim) + feat_ = self.m_frame_level(feat) + + # average pooling -> (batch, self.out_dim) + feat_utt = feat_.mean(1) + + # output linear + logits = self.m_utt_level(feat_utt) + return logits, feat_utt + + +class VIB(nn.Module): + def __init__(self, input_dim, hidden_dim, latent_dim, m_mcdp_rate=0.5, mcdp_flag=True): + super(VIB, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.latent_dim = latent_dim + self.m_mcdp_rate = m_mcdp_rate + self.m_mcdp_flag = mcdp_flag + + # Encoder + self.encoder = nn.Sequential( + nn.Linear(self.input_dim, self.hidden_dim), + nn.GELU(), + torch.nn.Dropout(self.m_mcdp_rate), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.GELU(), + torch.nn.Dropout(self.m_mcdp_rate), + ) + + # Latent space + self.fc_mu = nn.Linear(self.hidden_dim, latent_dim) + self.fc_var = nn.Linear(self.hidden_dim, latent_dim) + + # Decoder + self.decoder = nn.Sequential( + nn.Linear(latent_dim, self.hidden_dim), + nn.GELU(), + torch.nn.Dropout(self.m_mcdp_rate), + nn.Linear(self.hidden_dim, self.input_dim), + ) + + def reparameterize(self, mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + + def forward(self, x): + encoded = self.encoder(x) + mu, logvar = self.fc_mu(encoded), self.fc_var(encoded) + z = self.reparameterize(mu, logvar) + decoded = self.decoder(z) + return z, decoded, mu, logvar + + +class Model(nn.Module): + def __init__(self, args, device, is_train = True): + super().__init__() + self.device = device + self.is_train = is_train + self.flag_fix_ssl = args['flag_fix_ssl'] + self.contra_mode = args['contra_mode'] + self.loss_type = args['loss_type'] + + self.loss_CE = nn.CrossEntropyLoss(weight = torch.FloatTensor([float(1-args['ce_loss_weight']), float(args['ce_loss_weight'])]).to(device)) + self.sim_metric_seq = lambda mat1, mat2: torch.bmm( + mat1.permute(1, 0, 2), mat2.permute(1, 2, 0)).mean(0) + # front-end kwargs + fe_kwargs = args.get('wavlm_kwargs', {}) + self.front_end = WavLMFe(**fe_kwargs) + + self.LL = nn.Linear(self.front_end.out_dim, 128) + self.first_bn = nn.BatchNorm2d(num_features=1) + self.drop = nn.Dropout(0.5, inplace=True) + self.gelu = nn.GELU() + self.VIB = VIB(128, 128, 64) + self.backend = BackEnd(64, 64, 2, 0.5, False) + self.loss_CE = nn.CrossEntropyLoss() + + self.sim_metric_seq = lambda mat1, mat2: torch.bmm( + mat1.permute(1, 0, 2), mat2.permute(1, 2, 0)).mean(0) + # Post-processing + + def _forward(self, x): + # Set gradient to True for x + #x.requires_grad = True + #-----------------RawNet3-----------------# + x = self.front_end(x) #(bs,frame_number,frontend_out_dim) + x = self.LL(x) #(bs,frame_number,feat_out_dim) + feats = x + + x = self.gelu(x) + + # VIB + # x [batch, frame_number, 64] + x, decoded, mu, logvar = self.VIB(x) + output, emb = self.backend(x) + output = F.log_softmax(output, dim=1) + + if (self.is_train): + return output, (decoded, mu, logvar, feats), emb + return output + + def forward(self, x_big): + # make labels to be a tensor of [bz] + # labels = labels.squeeze(0) + if (x_big.dim() == 3): + x_big = x_big.transpose(0,1) + batch, length, sample_per_batch = x_big.shape + # x_big is a tensor of [length, batch, sample per batch] + # transform to [length, batch*sample per batch] by concat last dim + x_big = x_big.transpose(1,2) + x_big = x_big.reshape(batch * sample_per_batch, length) + if (self.is_train): + # x_big is a tensor of [1, length, bz] + # convert to [bz, length] + # x_big = x_big.squeeze(0).transpose(0,1) + output, feats, emb = self._forward(x_big) + # calculate the loss + return output, feats, emb + else: + # in inference mode, we don't need the emb + # the x_big now is a tensor of [bz, length] + + return self._forward(x_big) + + + def loss(self, output, feats, emb, labels, config, info=None): + ''' + output: tensor, (batch, num_output_class) + feats: tuple: feats[0] decoded (batch, frame_num, feat_feat_dim) + feats[1] mu (batch, frame_num, feat_feat_dim) + feats[2] logvar (batch, frame_num, feat_feat_dim) + feats[3] wav2vec feats (batch, frame_num, feat_feat_dim) + emb: tensor, (batch, emb_dim) + ''' + + real_bzs = output.shape[0] + n_views = 1.0 + loss_CE = torch.nn.CrossEntropyLoss() + + decoded, mu, logvar, feats_w2v = feats + + sim_metric_seq = lambda mat1, mat2: torch.bmm( + mat1.permute(1, 0, 2), mat2.permute(1, 2, 0)).mean(0) + + # print("output.shape", output.shape) + # print("labels.shape", labels.shape) + L_CE = 1/real_bzs *loss_CE(output, labels) + + # Recon loss + # print("decoded: ", decoded.shape) + # print("feats_w2v: ", feats_w2v.shape) + # print("mu: ", mu.shape) + # print("logvar: ", logvar.shape) + + BCE = F.binary_cross_entropy(torch.sigmoid(decoded), torch.sigmoid(feats_w2v), reduction='sum') + # print("BCE: ", BCE) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + # print("KLD: ", KLD) + # Recon_loss = 0.000001*(BCE + 0.05*KLD) / real_bzs + Recon_loss = config['model']['recon_weight_l']*(BCE + config['model']['recon_weight_b']*KLD) / real_bzs + # reshape the feats_w2v to match the supcon loss format + feats_w2v = feats_w2v.unsqueeze(1) + # print("feats_w2v.shape", feats_w2v.shape) + L_CF1 = 1/real_bzs * supcon_loss(feats_w2v, labels=labels, contra_mode=config['model']['contra_mode'], sim_metric=sim_metric_seq) + + # reshape the emb to match the supcon loss format + emb = emb.unsqueeze(1) + emb = emb.unsqueeze(-1) + # print("emb.shape", emb.shape) + L_CF2 = 1/real_bzs *supcon_loss(emb, labels=labels, contra_mode=config['model']['contra_mode'], sim_metric=sim_metric_seq) + + if config['model']['loss_type'] == 1: + return {'L_CE':L_CE, 'L_CF1':L_CF1, 'L_CF2':L_CF2, 'Recon_loss':Recon_loss} + elif config['model']['loss_type'] == 2: + return {'L_CE':L_CE, 'L_CF1':L_CF1} + elif config['model']['loss_type'] == 3: + return {'L_CE':L_CE, 'L_CF2':L_CF2} + # ablation study + elif config['model']['loss_type'] == 4: + return {'L_CE':L_CE} + elif config['model']['loss_type'] == 5: + return {'L_CF1':L_CF1, 'L_CF2':L_CF2} + elif config['model']['loss_type'] == 6: + return {'L_CE':L_CE, 'Recon_loss':Recon_loss} + \ No newline at end of file diff --git a/src/models/components/xlsr_conformertcm.py b/src/models/components/xlsr_conformertcm.py new file mode 100644 index 0000000..6a5942d --- /dev/null +++ b/src/models/components/xlsr_conformertcm.py @@ -0,0 +1,206 @@ +import random +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn.modules.transformer import _get_clones +import fairseq +try: + from model.loss_metrics import supcon_loss + from model.RawNet3.model import RawNet3 + from model.RawNet3.RawNetBasicBlock import Bottle2neck + from model.conformer_tcm.model import MyConformer +except: + from .loss_metrics import supcon_loss + from .RawNet3.model import RawNet3 + from .RawNet3.RawNetBasicBlock import Bottle2neck + from .conformer_tcm.model import MyConformer + + +############################ +## FOR fine-tuned SSL MODEL +############################ + + +class SSLModel(nn.Module): + def __init__(self): + super(SSLModel, self).__init__() + + cp_path = '/data/hungdx/asvspoof5/model/pretrained/xlsr2_300m.pt' # Change the pre-trained XLSR model path. + model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path]) + self.model = model[0] + + self.out_dim = 1024 + return + + def extract_feat(self, input_data, is_train=True): + + # put the model to GPU if it not there + # if next(self.model.parameters()).device != input_data.device \ + # or next(self.model.parameters()).dtype != input_data.dtype: + # self.model.to(input_data.device, dtype=input_data.dtype) + # self.model.train() + + # input should be in shape (batch, length) + if input_data.ndim == 3: + input_tmp = input_data[:, :, 0] + else: + input_tmp = input_data + + # [batch, length, dim] + emb = self.model(input_tmp, mask=False, features_only=True)['x'] + # print(emb.shape) + return emb + + def forward(self, x): + # x is a tensor of [batch, length] + return self.extract_feat(x) + + +class Model(nn.Module): + def __init__(self, args, device, is_train = True): + super().__init__() + self.device = device + self.is_train = is_train + self.flag_fix_ssl = args['flag_fix_ssl'] + self.contra_mode = args['contra_mode'] + self.loss_type = args['loss_type'] + + self.loss_CE = nn.CrossEntropyLoss(weight = torch.FloatTensor([float(1-args['ce_loss_weight']), float(args['ce_loss_weight'])]).to(device)) + + #### + # create network wav2vec 2.0 + #### + self.front_end = SSLModel() + self.LL = nn.Linear(self.front_end.out_dim, args['conformer']['emb_size']) + self.first_bn = nn.BatchNorm2d(num_features=1) + self.drop = nn.Dropout(0.5, inplace=True) + self.selu = nn.SELU(inplace=True) + + self.backend=MyConformer(**args['conformer']) + self.loss_CE = nn.CrossEntropyLoss() + + self.sim_metric_seq = lambda mat1, mat2: torch.bmm( + mat1.permute(1, 0, 2), mat2.permute(1, 2, 0)).mean(0) + + def _forward(self, x): + x.requires_grad = True + #-----------------RawNet3-----------------# + x = self.front_end(x) #(bs,frame_number,frontend_out_dim) + x = self.LL(x) #(bs,frame_number,feat_out_dim) + + feats = x + x = x.unsqueeze(dim=1) # add channel #(bs, 1, frame_number, 256) + x = self.first_bn(x) + x = self.selu(x) + x = x.squeeze(dim=1) + + # output [batch, 2] + # emb [batch, emb_size] + output, emb = self.backend(x) + output = F.log_softmax(output, dim=1) + if (self.is_train): + return output, feats, emb + return output + + def forward(self, x_big): + # make labels to be a tensor of [bz] + # labels = labels.squeeze(0) + if (x_big.dim() == 3): + x_big = x_big.transpose(0,1) + batch, length, sample_per_batch = x_big.shape + # x_big is a tensor of [length, batch, sample per batch] + # transform to [length, batch*sample per batch] by concat last dim + x_big = x_big.transpose(1,2) + x_big = x_big.reshape(batch * sample_per_batch, length) + if (self.is_train): + # x_big is a tensor of [1, length, bz] + # convert to [bz, length] + # x_big = x_big.squeeze(0).transpose(0,1) + output, feats, emb = self._forward(x_big) + # calculate the loss + return output, feats, emb + else: + # in inference mode, we don't need the emb + # the x_big now is a tensor of [bz, length] + return self._forward(x_big) + + + def loss(self, output, feats, emb, labels, config, info=None): + real_bzs = output.shape[0] + # print("real_bzs", real_bzs) + # print("labels", labels) + L_CE = 1/real_bzs *self.loss_CE(output, labels) + + # reshape the feats to match the supcon loss format + feats = feats.unsqueeze(1) + # print("feats.shape", feats.shape) + L_CF1 = 1/real_bzs *supcon_loss(feats, labels=labels, contra_mode=self.contra_mode, sim_metric=self.sim_metric_seq) + # reshape the emb to match the supcon loss format + emb = emb.unsqueeze(1) + emb = emb.unsqueeze(-1) + # print("emb.shape", emb.shape) + L_CF2 = 1/real_bzs *supcon_loss(emb, labels=labels, contra_mode=self.contra_mode, sim_metric=self.sim_metric_seq) + if self.loss_type == 1: + return {'L_CE':L_CE, 'L_CF1':L_CF1, 'L_CF2':L_CF2} + elif self.loss_type == 2: + return {'L_CE':L_CE, 'L_CF1':L_CF1} + elif self.loss_type == 3: + return {'L_CE':L_CE, 'L_CF2':L_CF2} + # ablation study + elif self.loss_type == 4: + return {'L_CE':L_CE} + elif self.loss_type == 5: + return {'L_CF1':L_CF1, 'L_CF2':L_CF2} + +# if __name__ == '__main__': +# import yaml +# # LoRA +# from peft import LoraConfig, TaskType, PeftModel, get_peft_model + + +# config = yaml.load(open("/data/hungdx/asvspoof5/configs_21/xlsr_aasist_baseline.yaml", 'r'), Loader=yaml.FullLoader) +# device = "cpu" +# model = Model(config['model'], device) +# # Ensure model is moved to the correct device +# model = model.to(device) + +# # print("Hello") +# # print(model) + +# # Trying apply lora +# # Using the best config from this paper +# # https://arxiv.org/pdf/2306.05617 +# target_modules = ["q_proj", "v_proj"] +# r = 4 + +# lora_config = LoraConfig( +# r=r, +# target_modules=target_modules, +# #modules_to_save=config['lora']['modules_to_save'], # Modules to continue fine-tuning +# ) +# peft_model = get_peft_model(model, lora_config) +# #print(peft_model) +# peft_model.print_trainable_parameters() + +# x = torch.randn(1, 64000) +# x_ori = x.detach().clone() + +# x = x.to(device) + +# output, feats, emb = peft_model(x) + +# cross_entropy = nn.CrossEntropyLoss() +# labels = torch.tensor([1]) +# labels = labels.to(device) + +# loss = cross_entropy(output, labels) + +# loss.backward() + +# # Ensure that after backward pass, the x_ori and x still have the same value +# assert torch.allclose(x_ori, x.detach().clone()), "The input x has been changed after backward pass" +# print("Done") \ No newline at end of file diff --git a/src/models/components/xlsr_vib.py b/src/models/components/xlsr_vib.py new file mode 100644 index 0000000..5cb31c1 --- /dev/null +++ b/src/models/components/xlsr_vib.py @@ -0,0 +1,224 @@ +import random +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +import fairseq +import os +import torch.nn.functional as F + +class SSLModel(nn.Module): + def __init__(self): + super(SSLModel, self).__init__() + + cp_path = '/data/hungdx/asvspoof5/model/pretrained/xlsr2_300m.pt' # Change the pre-trained XLSR model path. + model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path]) + self.model = model[0] + + self.out_dim = 1024 + return + + def extract_feat(self, input_data): + + # put the model to GPU if it not there + # if next(self.model.parameters()).device != input_data.device \ + # or next(self.model.parameters()).dtype != input_data.dtype: + # self.model.to(input_data.device, dtype=input_data.dtype) + # self.model.train() + + + if True: + # input should be in shape (batch, length) + if input_data.ndim == 3: + input_tmp = input_data[:, :, 0] + else: + input_tmp = input_data + + # [batch, length, dim] + emb = self.model(input_tmp, mask=False, features_only=True)['x'] + return emb + +class DropoutForMC(nn.Module): + """Dropout layer for Bayesian model + THe difference is that we do dropout even in eval stage + """ + + def __init__(self, p, dropout_flag=True): + super(DropoutForMC, self).__init__() + self.p = p + self.flag = dropout_flag + return + + def forward(self, x): + return torch.nn.functional.dropout(x, self.p, training=self.flag) + + +class BackEnd(nn.Module): + """Back End Wrapper + """ + + def __init__(self, input_dim, out_dim, num_classes, + dropout_rate, dropout_flag=True): + super(BackEnd, self).__init__() + + # input feature dimension + self.in_dim = input_dim + # output embedding dimension + self.out_dim = out_dim + # number of output classes + self.num_class = num_classes + + # dropout rate + self.m_mcdp_rate = dropout_rate + self.m_mcdp_flag = dropout_flag + + # a simple full-connected network for frame-level feature processing + self.m_frame_level = nn.Sequential( + nn.Linear(self.in_dim, self.in_dim), + nn.GELU(), + torch.nn.Dropout(self.m_mcdp_rate), + # DropoutForMC(self.m_mcdp_rate,self.m_mcdp_flag), + + nn.Linear(self.in_dim, self.in_dim), + nn.GELU(), + torch.nn.Dropout(self.m_mcdp_rate), + # DropoutForMC(self.m_mcdp_rate,self.m_mcdp_flag), + + nn.Linear(self.in_dim, self.out_dim), + nn.GELU(), + torch.nn.Dropout(self.m_mcdp_rate) + ) + # DropoutForMC(self.m_mcdp_rate,self.m_mcdp_flag)) + + # linear layer to produce output logits + self.m_utt_level = nn.Linear(self.out_dim, self.num_class) + + return + + def forward(self, feat): + """ logits, emb_vec = back_end_emb(feat) + + input: + ------ + feat: tensor, (batch, frame_num, feat_feat_dim) + + output: + ------- + logits: tensor, (batch, num_output_class) + emb_vec: tensor, (batch, emb_dim) + + """ + # through the frame-level network + # (batch, frame_num, self.out_dim) + feat_ = self.m_frame_level(feat) + + # average pooling -> (batch, self.out_dim) + feat_utt = feat_.mean(1) + + # output linear + logits = self.m_utt_level(feat_utt) + return logits + + +class VIB(nn.Module): + def __init__(self, input_dim, hidden_dim, latent_dim, m_mcdp_rate=0.5, mcdp_flag=True): + super(VIB, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.latent_dim = latent_dim + self.m_mcdp_rate = m_mcdp_rate + self.m_mcdp_flag = mcdp_flag + + # Encoder + self.encoder = nn.Sequential( + nn.Linear(self.input_dim, self.hidden_dim), + nn.GELU(), + torch.nn.Dropout(self.m_mcdp_rate), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.GELU(), + torch.nn.Dropout(self.m_mcdp_rate), + ) + + # Latent space + self.fc_mu = nn.Linear(self.hidden_dim, latent_dim) + self.fc_var = nn.Linear(self.hidden_dim, latent_dim) + + # Decoder + self.decoder = nn.Sequential( + nn.Linear(latent_dim, self.hidden_dim), + nn.GELU(), + torch.nn.Dropout(self.m_mcdp_rate), + nn.Linear(self.hidden_dim, self.input_dim), + ) + + def reparameterize(self, mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + + def forward(self, x): + encoded = self.encoder(x) + mu, logvar = self.fc_mu(encoded), self.fc_var(encoded) + z = self.reparameterize(mu, logvar) + decoded = self.decoder(z) + return z, decoded, mu, logvar + + +class Model(nn.Module): + def __init__(self, args, device, is_train = False): + super().__init__() + self.device = device + self.is_train = is_train + self.flag_fix_ssl = args['flag_fix_ssl'] + self.contra_mode = args['contra_mode'] + self.loss_type = args['loss_type'] + #### + # create network wav2vec 2.0 + #### + self.ssl_model = SSLModel() + self.LL = nn.Linear(self.ssl_model.out_dim, 128) + self.first_bn = nn.BatchNorm2d(num_features=1) + self.first_bn1 = nn.BatchNorm2d(num_features=64) + self.drop = nn.Dropout(0.5, inplace=True) + self.selu = nn.SELU(inplace=True) + + self.loss_CE = nn.CrossEntropyLoss() + self.VIB = VIB(128, 128, 64) + self.backend = BackEnd(64, 64, 2, 0.5, False) + self.gelu = nn.GELU() + + self.sim_metric_seq = lambda mat1, mat2: torch.bmm( + mat1.permute(1, 0, 2), mat2.permute(1, 2, 0)).mean(0) + # Post-processing + + def _forward(self, x): + # -------pre-trained Wav2vec model fine tunning ------------------------## + # if self.flag_fix_ssl: + # with torch.no_grad(): + # x_ssl_feat = self.ssl_model.extract_feat( + # x.squeeze(-1), is_train=False) + # else: + # x_ssl_feat = self.ssl_model.extract_feat( + # x.squeeze(-1), is_train=self.is_train) # (bs,frame_number,feat_dim) + x_ssl_feat = self.ssl_model.extract_feat(x.squeeze(-1)) + + x = self.LL(x_ssl_feat) # (bs,frame_number,feat_out_dim) + feats = x + x = self.gelu(x) + + # VIB + # x [batch, frame_number, 64] + x, decoded, mu, logvar = self.VIB(x) + + # output [batch, 2] + # emb [batch, 128] + output = self.backend(x) + + return output + + def forward(self, x_big): + return self._forward(x_big) + From 61018ed52be835f5047567c209d6f161228ed335 Mon Sep 17 00:00:00 2001 From: Hung Dinh Xuan Date: Sun, 29 Sep 2024 00:01:44 +0900 Subject: [PATCH 3/8] Add multiple augmentation, support balance mini-batch --- .../normal_largecorpus_for_asvspoof5.yaml | 37 ++ .../scl_normal_largecorpus_for_asvspoof5.yaml | 39 ++ configs/experiment/test.yaml | 34 +- configs/model/wavlmbase_vib.yaml | 24 ++ requirements.txt | 3 +- .../components/audio_augmentor/RIRS_NOISES | 1 + .../components/audio_augmentor/__init__.py | 61 +++ .../components/audio_augmentor/__version__.py | 7 + .../audio_augmentor/background_noise.py | 60 +++ .../components/audio_augmentor/bandpass.py | 67 ++++ src/data/components/audio_augmentor/base.py | 63 +++ .../components/audio_augmentor/copy_paste.py | 67 ++++ .../audio_augmentor/freq_masking.py | 45 +++ .../components/audio_augmentor/freq_swap.py | 53 +++ .../components/audio_augmentor/gaussian.py | 42 ++ .../audio_augmentor/linear_filter.py | 60 +++ .../components/audio_augmentor/masking.py | 69 ++++ src/data/components/audio_augmentor/musan | 1 + src/data/components/audio_augmentor/pitch.py | 39 ++ src/data/components/audio_augmentor/reverb.py | 46 +++ src/data/components/audio_augmentor/speed.py | 33 ++ .../components/audio_augmentor/swapping.py | 67 ++++ .../components/audio_augmentor/telephone.py | 73 ++++ .../audio_augmentor/time_masking.py | 51 +++ .../components/audio_augmentor/time_swap.py | 54 +++ src/data/components/audio_augmentor/utils.py | 80 ++++ src/data/components/audio_augmentor/volume.py | 32 ++ src/data/components/augwrapper.py | 8 +- src/data/components/baseloader.py | 4 +- src/data/cyberaicup_datamodule.py | 96 ++--- src/data/normal_datamodule.py | 359 ++++++++++++++++++ src/data/scl_datamodule.py | 214 ++++++----- src/models/components/loss_metrics.py | 2 +- .../components/wavlmbase_conformertcm.py | 41 +- src/models/components/wavlmbase_vib.py | 48 +-- src/models/components/xlsr_conformertcm.py | 65 +--- src/models/wavlmvib_module.py | 277 ++++++++++++++ src/utils/debug.py | 62 +++ 38 files changed, 2104 insertions(+), 280 deletions(-) create mode 100644 configs/data/normal_largecorpus_for_asvspoof5.yaml create mode 100644 configs/data/scl_normal_largecorpus_for_asvspoof5.yaml create mode 100644 configs/model/wavlmbase_vib.yaml create mode 120000 src/data/components/audio_augmentor/RIRS_NOISES create mode 100644 src/data/components/audio_augmentor/__init__.py create mode 100644 src/data/components/audio_augmentor/__version__.py create mode 100644 src/data/components/audio_augmentor/background_noise.py create mode 100644 src/data/components/audio_augmentor/bandpass.py create mode 100644 src/data/components/audio_augmentor/base.py create mode 100644 src/data/components/audio_augmentor/copy_paste.py create mode 100644 src/data/components/audio_augmentor/freq_masking.py create mode 100644 src/data/components/audio_augmentor/freq_swap.py create mode 100644 src/data/components/audio_augmentor/gaussian.py create mode 100644 src/data/components/audio_augmentor/linear_filter.py create mode 100644 src/data/components/audio_augmentor/masking.py create mode 120000 src/data/components/audio_augmentor/musan create mode 100644 src/data/components/audio_augmentor/pitch.py create mode 100644 src/data/components/audio_augmentor/reverb.py create mode 100644 src/data/components/audio_augmentor/speed.py create mode 100644 src/data/components/audio_augmentor/swapping.py create mode 100644 src/data/components/audio_augmentor/telephone.py create mode 100644 src/data/components/audio_augmentor/time_masking.py create mode 100644 src/data/components/audio_augmentor/time_swap.py create mode 100644 src/data/components/audio_augmentor/utils.py create mode 100644 src/data/components/audio_augmentor/volume.py create mode 100644 src/data/normal_datamodule.py create mode 100644 src/models/wavlmvib_module.py create mode 100644 src/utils/debug.py diff --git a/configs/data/normal_largecorpus_for_asvspoof5.yaml b/configs/data/normal_largecorpus_for_asvspoof5.yaml new file mode 100644 index 0000000..f189724 --- /dev/null +++ b/configs/data/normal_largecorpus_for_asvspoof5.yaml @@ -0,0 +1,37 @@ +_target_: src.data.normal_datamodule.NormalDataModule +data_dir: ${oc.env:LARGE_CORPUS_FOR_ASVSPOOF5} +batch_size: 2 # Because of scl datamodule will be re-organized mini-batch size +num_workers: 4 +pin_memory: False +args: + # The sampling rate of the audio files + portion: 0.2 # 20% of the data + nBands: 5 + minF: 20 + maxF: 8000 + minBW: 100 + maxBW: 1000 + minCoeff: 10 + maxCoeff: 100 + minG: 0 + maxG: 0 + minBiasLinNonLin: 5 + maxBiasLinNonLin: 20 + N_f: 5 + P: 10 + g_sd: 2 + SNRmin: 10 + SNRmax: 40 + + data: + + augmentation_methods: + ["RawBoost12", "background_noise_5_15", "reverb_1", "telephone_g722"] + trim_length: 100000 # 6.25s + wav_samp_rate: 16000 + online_aug: true + aug_dir: ${oc.env:LARGE_CORPUS_FOR_ASVSPOOF5}/aug + noise_path: ${oc.env:NOISE_PATH} + rir_path: ${oc.env:RIR_PATH} + repeat_pad: false # If true, repeat the audio to the trim_length + random_start: false # If true, randomly pick a start point for the audio \ No newline at end of file diff --git a/configs/data/scl_normal_largecorpus_for_asvspoof5.yaml b/configs/data/scl_normal_largecorpus_for_asvspoof5.yaml new file mode 100644 index 0000000..d203011 --- /dev/null +++ b/configs/data/scl_normal_largecorpus_for_asvspoof5.yaml @@ -0,0 +1,39 @@ +_target_: src.data.scl_datamodule.SclNormalDataModule +data_dir: ${oc.env:LARGE_CORPUS_FOR_ASVSPOOF5} +batch_size: 2 # Because of scl datamodule will be re-organized mini-batch size +num_workers: 4 +pin_memory: False +args: + # The sampling rate of the audio files + portion: 0.2 # 20% of the data + nBands: 5 + minF: 20 + maxF: 8000 + minBW: 100 + maxBW: 1000 + minCoeff: 10 + maxCoeff: 100 + minG: 0 + maxG: 0 + minBiasLinNonLin: 5 + maxBiasLinNonLin: 20 + N_f: 5 + P: 10 + g_sd: 2 + SNRmin: 10 + SNRmax: 40 + + data: + vocoders: ["hifigan", "hn-sinc-nsf-hifi", "waveglow"] + augmentation_methods: + ["RawBoost12", "background_noise_5_15", "reverb_1", "telephone_g722"] + num_additional_real: 2 + num_additional_spoof: 3 + trim_length: 100000 # 6.25s + wav_samp_rate: 16000 + online_aug: true + aug_dir: "/data/Datasets/0_large-corpus/aug" + noise_path: "/data/Datasets/musan/asvspoof5" + rir_path: "/data/Datasets/RIRS_NOISES" + repeat_pad: false # If true, repeat the audio to the trim_length + random_start: false # If true, randomly pick a start point for the audio \ No newline at end of file diff --git a/configs/experiment/test.yaml b/configs/experiment/test.yaml index 728becd..f31e8db 100644 --- a/configs/experiment/test.yaml +++ b/configs/experiment/test.yaml @@ -4,15 +4,15 @@ # python train.py experiment=example defaults: - - override /data: asvspoof - - override /model: aasist + - override /data: scl_normal_largecorpus_for_asvspoof5 + - override /model: wavlmbase_vib - override /callbacks: default - override /trainer: default # all parameters below will be merged with parameters from default configurations set above # this allows you to overwrite only specified parameters -tags: ["asvspoof", "aasist"] +tags: ["scl_normal_largecorpus_for_asvspoof5", "wavlmbase_vib"] seed: 12345 @@ -25,6 +25,8 @@ trainer: model: optimizer: lr: 0.000001 + args: + loss_type: 1 net: d_args: first_conv: 128 @@ -34,13 +36,31 @@ model: temperatures: [2.0, 2.0, 100.0, 100.0] compile: false + scheduler: + _target_: torch.optim.lr_scheduler.CyclicLR + _partial_: true + cycle_momentum: false + base_lr: 0.000001 + max_lr: 0.00001 + mode: "exp_range" + gamma: 0.85 + + + data: - batch_size: 16 - num_workers: 4 + batch_size: 1 + num_workers: 8 + args: + portion: 0.1 # 10% of the data + data: + trim_length: 100000 # 6s + repeat_pad: true # If true, repeat the audio to the trim_length + random_start: true # If true, randomly pick a start point for the audio + logger: wandb: tags: ${tags} - group: "asvspoof" + group: "scl_normal_largecorpus_for_asvspoof5" aim: - experiment: "asvspoof" + experiment: "scl_normal_largecorpus_for_asvspoof5" diff --git a/configs/model/wavlmbase_vib.yaml b/configs/model/wavlmbase_vib.yaml new file mode 100644 index 0000000..7193bcc --- /dev/null +++ b/configs/model/wavlmbase_vib.yaml @@ -0,0 +1,24 @@ +_target_: src.models.wavlmvib_module.WAVLMVIBLLitModule + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.000001 + weight_decay: 0.00001 + +scheduler: null + +args: + contra_mode: "all" # 'one' or 'all' + loss_type: 1 # 1: 2loss, 2: loss emb only, 3: loss final hidden state only, 4: CE loss + ce_loss_weight: 0.5 # weight of the cross entropy loss for bona fide class + recon_weight_l: 0.000001 + recon_weight_b: 0.05 + wavlm_kwargs: + checkpoint_path: ${oc.env:WAVLMBASE_PRETRAINED_MODEL_PATH} + extract_mode: "weighted_sum" + +is_train: true + +# compile model for faster training with pytorch 2.0 +compile: false diff --git a/requirements.txt b/requirements.txt index 84a785d..aeb3911 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,4 +24,5 @@ pytest # tests # sh # for running bash commands in some tests (linux/macos only) # --------- audio --------- # -librosa==0.10.0 \ No newline at end of file +librosa==0.10.0 +pydub==0.25.1 \ No newline at end of file diff --git a/src/data/components/audio_augmentor/RIRS_NOISES b/src/data/components/audio_augmentor/RIRS_NOISES new file mode 120000 index 0000000..d8ebc1c --- /dev/null +++ b/src/data/components/audio_augmentor/RIRS_NOISES @@ -0,0 +1 @@ +/datab/Dataset/Noise/RIRS_NOISES \ No newline at end of file diff --git a/src/data/components/audio_augmentor/__init__.py b/src/data/components/audio_augmentor/__init__.py new file mode 100644 index 0000000..162d812 --- /dev/null +++ b/src/data/components/audio_augmentor/__init__.py @@ -0,0 +1,61 @@ +from .background_noise import BackgroundNoiseAugmentor +from .pitch import PitchAugmentor +from .reverb import ReverbAugmentor +from .speed import SpeedAugmentor +from .volume import VolumeAugmentor +from .telephone import TelephoneEncodingAugmentor +from .gaussian import GaussianAugmentor +from .copy_paste import CopyPasteAugmentor +from .base import BaseAugmentor +from .time_masking import TimeMaskingAugmentor +from .freq_masking import FrequencyMaskingAugmentor +from .masking import MaskingAugmentor +from .time_swap import TimeSwapAugmentor +from .freq_swap import FrequencySwapAugmentor +from .swapping import SwappingAugmentor +from .linear_filter import LinearFilterAugmentor +from .bandpass import BandpassAugmentor + + +# from . import utils + +from .__version__ import ( + + __author__, + __author_email__, + __description__, + __license__, + __title__, + __url__, + __version__, +) + +SUPPORTED_AUGMENTORS = ['background_noise', 'pitch', 'speed', 'volume', 'reverb', 'telephone', 'gaussian_noise', 'time_masking', 'freq_masking', 'masking', 'time_swap', 'freq_swap', 'swapping', 'linear_filter', 'bandpass'] + +import logging.config +LOGGING = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "std": { + "format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s", + "datefmt": "%Y-%m-%d %H:%M", + } + }, + "handlers": { + "default": { + "class": "logging.NullHandler", + }, + "test": { + "class": "logging.StreamHandler", + "formatter": "std", + "level": logging.INFO, + }, + }, + "loggers": { + "art": {"handlers": ["default"]}, + "tests": {"handlers": ["test"], "level": "INFO", "propagate": True}, + }, +} +logging.config.dictConfig(LOGGING) +logger = logging.getLogger(__name__) \ No newline at end of file diff --git a/src/data/components/audio_augmentor/__version__.py b/src/data/components/audio_augmentor/__version__.py new file mode 100644 index 0000000..4655078 --- /dev/null +++ b/src/data/components/audio_augmentor/__version__.py @@ -0,0 +1,7 @@ +__title__ = "audio_augmentor" +__description__ = "Audio Augmentor for human speech data" +__url__ = "" +__version__ = "0.3.0" +__author__ = "Thien-Phuc Doan" +__author_email__ = "josebeo2016@gmail.com" +__license__ = "MIT" \ No newline at end of file diff --git a/src/data/components/audio_augmentor/background_noise.py b/src/data/components/audio_augmentor/background_noise.py new file mode 100644 index 0000000..ba4cc81 --- /dev/null +++ b/src/data/components/audio_augmentor/background_noise.py @@ -0,0 +1,60 @@ +from .base import BaseAugmentor +from .utils import recursive_list_files +from pydub import AudioSegment +import random +import numpy as np +from .utils import librosa_to_pydub + +import logging +logger = logging.getLogger(__name__) +class BackgroundNoiseAugmentor(BaseAugmentor): + def __init__(self, config: dict): + """ + Background noise augmentation method. + + Config: + noise_path: str, path to the folder containing noise files + min_SNR_dB: int, min SNR in dB + max_SNR_dB: int, max SNR in dB + """ + super().__init__(config) + self.noise_path = config["noise_path"] + self.noise_list = self.select_noise(self.noise_path) + self.min_SNR_dB = config["min_SNR_dB"] + self.max_SNR_dB = config["max_SNR_dB"] + + def select_noise(self, noise_path: str) -> list: + noise_list = recursive_list_files(noise_path) + return noise_list + + def load(self, input_path: str): + """ + :param input_path: path to the input audio file + """ + # load with librosa + super().load(input_path) + + # Convert to pydub audio segment + self.audio_data = librosa_to_pydub(self.data, sr=self.sr) + + def transform(self): + # Load audio files + noise_file = AudioSegment.from_file(random.choice(self.noise_list)) + + # Set the desired SNR (signal-to-noise ratio) level in decibels + SNR_dB = random.randint(self.min_SNR_dB, self.max_SNR_dB) + + # Calculate the power of the signal and noise + signal_power = self.audio_data.dBFS + noise_power = noise_file.dBFS + + # Calculate the scaling factor for the noise + scaling_factor = SNR_dB * noise_power / signal_power + + # Apply the noise to the audio file + scaled_audio = self.audio_data.apply_gain(scaling_factor) + self.augmented_audio = scaled_audio.overlay(noise_file) + + + + \ No newline at end of file diff --git a/src/data/components/audio_augmentor/bandpass.py b/src/data/components/audio_augmentor/bandpass.py new file mode 100644 index 0000000..516c53f --- /dev/null +++ b/src/data/components/audio_augmentor/bandpass.py @@ -0,0 +1,67 @@ +from .base import BaseAugmentor +from .utils import recursive_list_files, librosa_to_pydub +import scipy.signal as ss +import librosa +import random +import numpy as np +import torchaudio.functional as F +from torchaudio.io import AudioEffector +import logging +import torch +from torchaudio.sox_effects import apply_effects_tensor as sox_fx + +logger = logging.getLogger(__name__) + +def apply_codec(waveform, sample_rate, format, encoder=None): + encoder = AudioEffector(format=format, encoder=encoder) + return encoder.apply(waveform, sample_rate) + +def apply_effect(waveform, sample_rate, effect): + effector = AudioEffector(effect=effect) + return effector.apply(waveform, sample_rate) +class BandpassAugmentor(BaseAugmentor): + """ + About + ----- + + This class makes audio sound like it's over telephone, by apply one or two things: + * codec: choose between ALAW, ULAW, g722 + * bandpass: Optional, can be None. This limits the frequency within common telephone range (300, 3400) be default. + Note: lowpass value should be higher than that of highpass. + + Example + ------- + + CONFIG = { + "aug_type": "telephone", + "output_path": os.path.join(BASE_DIR,"data/augmented"), + "out_format": "wav", + "encoding": "ALAW", + "bandpass": { + "lowpass": "4000", + "highpass": "0" + } + } + """ + def __init__(self, config: dict): + super().__init__(config) + self.bandpass = { + "lowpass": config.get("lowpass", 4000), + "highpass": config.get("highpass", 0) + } + + self.effects = ",".join( + [ + "lowpass=frequency={}:poles=1".format(self.bandpass['lowpass']), + "highpass=frequency={}:poles=1".format(self.bandpass['highpass']), + ] + ) + + def transform(self): + """ + """ + torch_audio = torch.tensor(self.data).reshape(1, -1) + aug_audio = apply_effect(torch_audio.T, self.sr, self.effects) + # convert to numpy array + aug_audio = aug_audio.numpy().flatten() + self.augmented_audio = librosa_to_pydub(aug_audio) \ No newline at end of file diff --git a/src/data/components/audio_augmentor/base.py b/src/data/components/audio_augmentor/base.py new file mode 100644 index 0000000..0685ef1 --- /dev/null +++ b/src/data/components/audio_augmentor/base.py @@ -0,0 +1,63 @@ +import librosa +import os +import typing + +import logging + +logger = logging.getLogger(__name__) + + +class BaseAugmentor: + """ + Basic augmentor class requires these config: + aug_type: str, augmentation type + output_path: str, output path + out_format: str, output format + """ + + def __init__(self, config: dict): + """ + This method initialize the `BaseAugmentor` object. + """ + self.config = config + self.aug_type = config["aug_type"] + + self.output_path = config["output_path"] + self.out_format = config["out_format"] + self.augmented_audio = None + self.data = None + self.sr = 16000 + + def load(self, input_path: str): + """ + Load audio file and normalize the data + Librosa done this part + self.data: audio data in numpy array (librosa load) + :param input_path: path to the input audio file + """ + self.input_path = input_path + self.file_name = self.input_path.split("/")[-1].split(".")[0] + # load with librosa and auto resample to 16kHz + self.data, self.sr = librosa.load(self.input_path, sr=self.sr) + + # Convert to mono channel + self.data = librosa.to_mono(self.data) + + def transform(self): + """ + Transform audio data (librosa load) to augmented audio data (pydub audio segment) + Note that self.augmented_audio is pydub audio segment + """ + raise NotImplementedError + + def save(self): + """ + Save augmented audio data (pydub audio segment) to file + self.out_format: output format + This done the codec transform by pydub + """ + self.augmented_audio.export( + os.path.join(self.output_path, self.file_name + "." + self.out_format), + format=self.out_format, + ) + diff --git a/src/data/components/audio_augmentor/copy_paste.py b/src/data/components/audio_augmentor/copy_paste.py new file mode 100644 index 0000000..ae27c2f --- /dev/null +++ b/src/data/components/audio_augmentor/copy_paste.py @@ -0,0 +1,67 @@ +from .base import BaseAugmentor +from .utils import librosa_to_pydub +import random +import soundfile as sf +import numpy as np +import logging + +logger = logging.getLogger(__name__) +class CopyPasteAugmentor(BaseAugmentor): + """ + Copy and Patse augmentation + Config: + shuffle_ratio (float): The ratio of frames to shuffle (between 0 and 1). + frame_size (int): The size of each frame. If zero, the frame size is randomly selected from [800:3200] samples. + + """ + def __init__(self, config: dict): + """ + This method initialize the `GaussianAugmentor` object. + + :param config: dict, configuration dictionary + """ + + super().__init__(config) + self.shuffle_ratio = config['shuffle_ratio'] + self.frame_size = config['frame_size'] + assert self.shuffle_ratio > 0.0 and self.shuffle_ratio < 1.0 + assert self.frame_size >= 0 + # Determine frame size + if self.frame_size == 0: + self.frame_size = np.random.randint(800, 3201) # Randomly select frame size from [800:3200] + + def transform(self): + """ + Transform the audio by adding gausian noise. + """ + + # Check if the frame_size is greater than the audio length + if self.frame_size > len(self.data): + logger.warning(f"Frame size {self.frame_size} is greater than audio length {len(self.data)}. Skipping augmentation.") + # no transformation, just return the original audio + self.augmented_audio = librosa_to_pydub(self.data, sr=self.sr) + return + + # Split audio into frames + num_frames = len(self.data) // self.frame_size + frames = [self.data[i*self.frame_size:(i+1)*self.frame_size] for i in range(num_frames)] + + # Handle the last chunk of audio if it's not a full frame + if len(self.data) % self.frame_size != 0: + frames.append(self.data[num_frames*self.frame_size:]) + + # Determine the number of frames to shuffle + num_frames_to_shuffle = int(len(frames) * self.shuffle_ratio) + + # Randomly shuffle a ratio of frames + if num_frames_to_shuffle > 0: + shuffle_indices = np.random.choice(len(frames), num_frames_to_shuffle, replace=False) + shuffled_frames = np.array(frames, dtype=object)[shuffle_indices] + np.random.shuffle(shuffled_frames) + for i, idx in enumerate(shuffle_indices): + frames[idx] = shuffled_frames[i] + + # Concatenate frames back together + transformed_audio = np.concatenate(frames) + # transform to pydub audio segment + self.augmented_audio = librosa_to_pydub(transformed_audio, sr=self.sr) diff --git a/src/data/components/audio_augmentor/freq_masking.py b/src/data/components/audio_augmentor/freq_masking.py new file mode 100644 index 0000000..0141eb5 --- /dev/null +++ b/src/data/components/audio_augmentor/freq_masking.py @@ -0,0 +1,45 @@ +from .base import BaseAugmentor +from .utils import librosa_to_pydub +import numpy as np +import random +import librosa +import soundfile as sf +import os + +import logging +logger = logging.getLogger(__name__) + +class FrequencyMaskingAugmentor(BaseAugmentor): + def __init__(self, config: dict): + """ + Frequency Masking augmentation + Config: + F: int, maximum frequency mask parameter + num_masks: int, number of frequency masks to apply + """ + super().__init__(config) + self.F = config.get('F', 27) + self.num_masks = config.get('num_masks', 2) + + def apply_frequency_masking(self, mel_spectrogram): + """ + Apply frequency masking to the mel spectrogram. + """ + num_mel_channels = mel_spectrogram.shape[0] + + for _ in range(self.num_masks): + f = random.randint(0, self.F) + f0 = random.randint(0, num_mel_channels - f) + mel_spectrogram[f0:f0 + f, :] = 0 + + return mel_spectrogram + + def transform(self): + """ + Transform the audio by applying frequency masking. + """ + data = self.data.copy() + mel_spectrogram = librosa.feature.melspectrogram(y=data, sr=self.sr) + masked_mel_spectrogram = self.apply_frequency_masking(mel_spectrogram) + augmented_data = librosa.feature.inverse.mel_to_audio(masked_mel_spectrogram, sr=self.sr) + self.augmented_audio = librosa_to_pydub(augmented_data, sr=self.sr) \ No newline at end of file diff --git a/src/data/components/audio_augmentor/freq_swap.py b/src/data/components/audio_augmentor/freq_swap.py new file mode 100644 index 0000000..b08fac3 --- /dev/null +++ b/src/data/components/audio_augmentor/freq_swap.py @@ -0,0 +1,53 @@ +import torch +import torch.nn.functional as F +from .base import BaseAugmentor +from .utils import librosa_to_pydub +import numpy as np +import random +import librosa +import soundfile as sf +import os + +import logging +logger = logging.getLogger(__name__) + +class FrequencySwapAugmentor(BaseAugmentor): + def __init__(self, config: dict): + """ + Frequency Swap augmentation + Config: + F: int, maximum frequency swap parameter + num_swaps: int, number of frequency swaps to apply + """ + super().__init__(config) + self.F = config.get('F', 7) + self.num_swaps = config.get('num_swaps', 1) + + def apply_frequency_swapping(self, mel_spectrogram): + """ + Apply frequency swapping to the mel spectrogram. + """ + num_mel_channels = mel_spectrogram.shape[0] + + for _ in range(self.num_swaps): + f = random.randint(0, self.F) + f0 = random.randint(0, num_mel_channels - 2 * f) + f1 = random.randint(f0 + f, num_mel_channels - f) + + mel_spectrogram[f0:f0 + f], mel_spectrogram[f1:f1 + f] = \ + mel_spectrogram[f1:f1 + f].clone(), mel_spectrogram[f0:f0 + f].clone() + + return mel_spectrogram + + def transform(self): + """ + Transform the audio by applying frequency swapping. + """ + data = self.data.copy() + mel_spectrogram = librosa.feature.melspectrogram(y=data, sr=self.sr) + + swapped_mel_spectrogram = self.apply_frequency_swapping(torch.tensor(mel_spectrogram)) + swapped_mel_spectrogram = swapped_mel_spectrogram.numpy() + + augmented_data = librosa.feature.inverse.mel_to_audio(swapped_mel_spectrogram, sr=self.sr) + self.augmented_audio = librosa_to_pydub(augmented_data, sr=self.sr) \ No newline at end of file diff --git a/src/data/components/audio_augmentor/gaussian.py b/src/data/components/audio_augmentor/gaussian.py new file mode 100644 index 0000000..df0ea31 --- /dev/null +++ b/src/data/components/audio_augmentor/gaussian.py @@ -0,0 +1,42 @@ +from .base import BaseAugmentor +from .utils import librosa_to_pydub +import random +import librosa +import soundfile as sf +import numpy as np +from pydub import AudioSegment +import wave +import logging + +logger = logging.getLogger(__name__) +class GaussianAugmentor(BaseAugmentor): + """ + Gaussian augmentation + Config: + mean: int, gaussian mean factor + std_dev: int, gaussian standard deviation factor + """ + def __init__(self, config: dict): + """ + This method initialize the `GaussianAugmentor` object. + + :param config: dict, configuration dictionary + """ + + super().__init__(config) + self.min_amplitude = config['min_amplitude'] + self.max_amplitude = config['max_amplitude'] + assert self.min_amplitude > 0.0 + assert self.max_amplitude > 0.0 + assert self.max_amplitude >= self.min_amplitude + self.amplitude = random.uniform(self.min_amplitude, self.max_amplitude) + + + def transform(self): + """ + Transform the audio by adding gausian noise. + """ + noise = np.random.randn(self.data.shape[0]).astype(np.float32) + data = self.data + self.amplitude * noise + # transform to pydub audio segment + self.augmented_audio = librosa_to_pydub(data, sr=self.sr) diff --git a/src/data/components/audio_augmentor/linear_filter.py b/src/data/components/audio_augmentor/linear_filter.py new file mode 100644 index 0000000..380564b --- /dev/null +++ b/src/data/components/audio_augmentor/linear_filter.py @@ -0,0 +1,60 @@ +import torch +import numpy as np +import random +import librosa +from .base import BaseAugmentor +from .utils import librosa_to_pydub + +import logging +logger = logging.getLogger(__name__) + +class LinearFilterAugmentor(BaseAugmentor): + def __init__(self, config: dict): + """ + Linear Filter augmentation + Config: + db_range: list, range of dB values for filter weights + n_band: list, range of number of frequency bands + min_bw: int, minimum bandwidth + """ + super().__init__(config) + self.db_range = config.get('db_range', [-6, 6]) + self.n_band = config.get('n_band', [3, 6]) + self.min_bw = config.get('min_bw', 6) + + def apply_linear_filter(self, mel_spectrogram): + """ + Apply linear filter to the mel spectrogram. + """ + mel_spectrogram = torch.tensor(mel_spectrogram) # Convert numpy array to torch tensor + n_freq_bin, time_steps = mel_spectrogram.shape + n_freq_band = random.randint(self.n_band[0], self.n_band[1]) + + if n_freq_band > 1: + while n_freq_bin - n_freq_band * self.min_bw + 1 < 0: + self.min_bw -= 1 + band_bndry_freqs = torch.sort(torch.randint(0, n_freq_bin - n_freq_band * self.min_bw + 1, + (n_freq_band - 1,)))[0] + \ + torch.arange(1, n_freq_band) * self.min_bw + band_bndry_freqs = torch.cat((torch.tensor([0]), band_bndry_freqs, torch.tensor([n_freq_bin]))) + + band_factors = torch.rand(n_freq_band + 1) * (self.db_range[1] - self.db_range[0]) + self.db_range[0] + freq_filt = torch.ones((n_freq_bin, 1)) + for i in range(n_freq_band): + freq_filt[band_bndry_freqs[i]:band_bndry_freqs[i+1], :] = \ + torch.linspace(band_factors[i], band_factors[i+1], + band_bndry_freqs[i+1] - band_bndry_freqs[i]).unsqueeze(-1) + freq_filt = 10 ** (freq_filt / 20) + return mel_spectrogram * freq_filt + else: + return mel_spectrogram + + def transform(self): + """ + Transform the audio by applying linear filter. + """ + data = self.data.copy() + mel_spectrogram = librosa.feature.melspectrogram(y=data, sr=self.sr) + filtered_mel_spectrogram = self.apply_linear_filter(mel_spectrogram) + filtered_audio = librosa.feature.inverse.mel_to_audio(filtered_mel_spectrogram.numpy(), sr=self.sr) # Convert back to numpy array + self.augmented_audio = librosa_to_pydub(filtered_audio, sr=self.sr) \ No newline at end of file diff --git a/src/data/components/audio_augmentor/masking.py b/src/data/components/audio_augmentor/masking.py new file mode 100644 index 0000000..c8f982c --- /dev/null +++ b/src/data/components/audio_augmentor/masking.py @@ -0,0 +1,69 @@ +from .base import BaseAugmentor +from .utils import librosa_to_pydub +import numpy as np +import random +import librosa +import soundfile as sf +import os + +import logging +logger = logging.getLogger(__name__) + +class MaskingAugmentor(BaseAugmentor): + def __init__(self, config: dict): + """ + Masking augmentation + Config: + F: int, maximum frequency mask parameter + num_frequency_masks: int, number of frequency masks to apply + T: int, maximum time mask parameter (number of frames to mask) + p: float, upper bound parameter for time mask as a fraction of total time steps + num_time_masks: int, number of time masks to apply + """ + super().__init__(config) + self.F = config.get('F', 27) + self.num_frequency_masks = config.get('num_frequency_masks', 2) + self.T = config.get('T', 100) + self.p = config.get('p', 1.0) + self.num_time_masks = config.get('num_time_masks', 2) + + def apply_frequency_masking(self, mel_spectrogram): + """ + Apply frequency masking to the mel spectrogram. + """ + num_mel_channels = mel_spectrogram.shape[0] + + for _ in range(self.num_frequency_masks): + f = random.randint(0, self.F) + f0 = random.randint(0, num_mel_channels - f) + mel_spectrogram[f0:f0 + f, :] = 0 + + return mel_spectrogram + + def apply_time_masking(self, mel_spectrogram): + """ + Apply time masking to the mel spectrogram. + """ + num_time_steps = mel_spectrogram.shape[1] + max_mask_length = int(self.p * num_time_steps) + + for _ in range(self.num_time_masks): + if num_time_steps > 1: + t = random.randint(1, min(self.T, max_mask_length)) + t0 = random.randint(0, num_time_steps - t) + mel_spectrogram[:, t0:t0 + t] = 0 + + return mel_spectrogram + + def transform(self): + """ + Transform the audio by applying frequency and time masking. + """ + data = self.data.copy() + mel_spectrogram = librosa.feature.melspectrogram(y=data, sr=self.sr) + + masked_mel_spectrogram = self.apply_frequency_masking(mel_spectrogram) + masked_mel_spectrogram = self.apply_time_masking(masked_mel_spectrogram) + + masked_audio = librosa.feature.inverse.mel_to_audio(masked_mel_spectrogram, sr=self.sr) + self.augmented_audio = librosa_to_pydub(masked_audio, sr=self.sr) diff --git a/src/data/components/audio_augmentor/musan b/src/data/components/audio_augmentor/musan new file mode 120000 index 0000000..71d0fc0 --- /dev/null +++ b/src/data/components/audio_augmentor/musan @@ -0,0 +1 @@ +/datab/Dataset/Noise/musan \ No newline at end of file diff --git a/src/data/components/audio_augmentor/pitch.py b/src/data/components/audio_augmentor/pitch.py new file mode 100644 index 0000000..628bb9b --- /dev/null +++ b/src/data/components/audio_augmentor/pitch.py @@ -0,0 +1,39 @@ +from .base import BaseAugmentor +from .utils import librosa_to_pydub +import random +import librosa +import soundfile as sf +import numpy as np +from pydub import AudioSegment + +import logging +logger = logging.getLogger(__name__) +class PitchAugmentor(BaseAugmentor): + """ + Pitch augmentation + Config: + min_pitch_shift: int, min pitch shift factor + max_pitch_shift: int, max pitch shift factor + """ + def __init__(self, config: dict): + """ + This method initialize the `PitchAugmentor` object. + + :param config: dict, configuration dictionary + """ + + super().__init__(config) + self.min_pitch_shift = config["min_pitch_shift"] + self.max_pitch_shift = config["max_pitch_shift"] + self.pitch_shift = random.randint(self.min_pitch_shift, self.max_pitch_shift) + + + def transform(self): + """ + Transform the audio by pitch shifting based on `librosa.effects.pitch_shift` + The pitch shift factor is randomly selected between min_pitch_shift and max_pitch_shift + """ + data = librosa.effects.pitch_shift(self.data, sr=self.sr, n_steps=self.pitch_shift) + # transform to pydub audio segment + self.augmented_audio = librosa_to_pydub(data, sr=self.sr) + \ No newline at end of file diff --git a/src/data/components/audio_augmentor/reverb.py b/src/data/components/audio_augmentor/reverb.py new file mode 100644 index 0000000..5a7283c --- /dev/null +++ b/src/data/components/audio_augmentor/reverb.py @@ -0,0 +1,46 @@ +from .base import BaseAugmentor +from .utils import recursive_list_files, librosa_to_pydub +import scipy.signal as ss +import librosa +import random +import numpy as np + +import logging +logger = logging.getLogger(__name__) +class ReverbAugmentor(BaseAugmentor): + def __init__(self, config: dict): + """ + Reverb augmentation + Config: + rir_path: str, path to the folder containing RIR files + (RIR dataset example https://www.openslr.org/28/) + """ + super().__init__(config) + self.rir_path = config["rir_path"] + self.rir_file = self.select_rir(self.rir_path) + + def select_rir(self, rir_path): + """ + Randomly select the RIR file from the `rir_path` + + :param rir_path: path to the folder containing RIR files + + :return: path to the selected RIR file + """ + rir_list = recursive_list_files(rir_path) + return random.choice(rir_list) + + def transform(self): + """ + Reverb the audio by convolving with the RIR file selected from `rir_path` + """ + rir_data, _ = librosa.load(self.rir_file, sr=self.sr) + # Compute convolution + reverberate = np.convolve(self.data, rir_data) + # Normalize output signal to avoid clipping + # reverberate /= (np.max(np.abs(reverberate))) + + # transform to pydub audio segment + self.augmented_audio = librosa_to_pydub(reverberate, sr=self.sr) + + \ No newline at end of file diff --git a/src/data/components/audio_augmentor/speed.py b/src/data/components/audio_augmentor/speed.py new file mode 100644 index 0000000..57244b9 --- /dev/null +++ b/src/data/components/audio_augmentor/speed.py @@ -0,0 +1,33 @@ +from .base import BaseAugmentor +from .utils import librosa_to_pydub +from pydub import AudioSegment +import random + +import logging +logger = logging.getLogger(__name__) +class SpeedAugmentor(BaseAugmentor): + def __init__(self, config: dict): + """ + Speed augmentor class requires these config: + min_speed_factor: float, min speed factor + max_speed_factor: float, max speed factor + """ + super().__init__(config) + self.speed_factor = random.uniform(config["min_speed_factor"], config["max_speed_factor"]) + + self.audio_data = None + + def load(self, input_path: str): + """ + :param input_path: path to the input audio file + """ + # load with librosa + super().load(input_path) + # transform to pydub audio segment + self.audio_data = librosa_to_pydub(self.data, sr=self.sr) + + def transform(self): + """ + Speed up or down the audio using pydub `AudioSegment.speedup` method + """ + self.augmented_audio = self.audio_data.speedup(self.speed_factor) \ No newline at end of file diff --git a/src/data/components/audio_augmentor/swapping.py b/src/data/components/audio_augmentor/swapping.py new file mode 100644 index 0000000..142e572 --- /dev/null +++ b/src/data/components/audio_augmentor/swapping.py @@ -0,0 +1,67 @@ +import torch +import torch.nn.functional as F +from .base import BaseAugmentor +from .utils import librosa_to_pydub +import numpy as np +import random +import librosa +import soundfile as sf +import os + +import logging +logger = logging.getLogger(__name__) + +class SwappingAugmentor(BaseAugmentor): + def __init__(self, config: dict): + """ + Swapping augmentation + Config: + T: int, maximum time swap parameter + F: int, maximum frequency swap parameter + """ + super().__init__(config) + self.T = config.get('T', 40) + self.F = config.get('F', 7) + + def apply_time_swapping(self, mel_spectrogram): + """ + Apply time swapping to the mel spectrogram. + """ + num_time_steps = mel_spectrogram.shape[1] + + t = random.randint(0, self.T) + t0 = random.randint(0, num_time_steps - 2 * t) + t1 = random.randint(t0 + t, num_time_steps - t) + + mel_spectrogram[:, t0:t0 + t], mel_spectrogram[:, t1:t1 + t] = \ + mel_spectrogram[:, t1:t1 + t].copy(), mel_spectrogram[:, t0:t0 + t].copy() + + return mel_spectrogram + + def apply_frequency_swapping(self, mel_spectrogram): + """ + Apply frequency swapping to the mel spectrogram. + """ + num_mel_channels = mel_spectrogram.shape[0] + + f = random.randint(0, self.F) + f0 = random.randint(0, num_mel_channels - 2 * f) + f1 = random.randint(f0 + f, num_mel_channels - f) + + mel_spectrogram[f0:f0 + f, :], mel_spectrogram[f1:f1 + f, :] = \ + mel_spectrogram[f1:f1 + f, :].copy(), mel_spectrogram[f0:f0 + f, :].copy() + + return mel_spectrogram + + def transform(self): + """ + Transform the audio by applying time and frequency swapping. + """ + data = self.data.copy() + mel_spectrogram = librosa.feature.melspectrogram(y=data, sr=self.sr) + + mel_spectrogram = self.apply_time_swapping(mel_spectrogram) + swapped_mel_spectrogram = self.apply_frequency_swapping(mel_spectrogram) + + augmented_data = librosa.feature.inverse.mel_to_audio(swapped_mel_spectrogram, sr=self.sr) + self.augmented_audio = librosa_to_pydub(augmented_data, sr=self.sr) \ No newline at end of file diff --git a/src/data/components/audio_augmentor/telephone.py b/src/data/components/audio_augmentor/telephone.py new file mode 100644 index 0000000..5341b1e --- /dev/null +++ b/src/data/components/audio_augmentor/telephone.py @@ -0,0 +1,73 @@ +from .base import BaseAugmentor +from .utils import recursive_list_files, librosa_to_pydub +import scipy.signal as ss +import librosa +import random +import numpy as np +import torchaudio.functional as F +from torchaudio.io import AudioEffector +import logging +import torch +from torchaudio.sox_effects import apply_effects_tensor as sox_fx + +logger = logging.getLogger(__name__) + +def apply_codec(waveform, sample_rate, format, encoder=None): + encoder = AudioEffector(format=format, encoder=encoder) + return encoder.apply(waveform, sample_rate) + +def apply_effect(waveform, sample_rate, effect): + effector = AudioEffector(effect=effect) + return effector.apply(waveform, sample_rate) +class TelephoneEncodingAugmentor(BaseAugmentor): + """ + About + ----- + + This class makes audio sound like it's over telephone, by apply one or two things: + * codec: choose between ALAW, ULAW, g722 + * bandpass: Optional, can be None. This limits the frequency within common telephone range (300, 3400) be default. + Note: lowpass value should be higher than that of highpass. + + Example + ------- + + CONFIG = { + "aug_type": "telephone", + "output_path": os.path.join(BASE_DIR,"data/augmented"), + "out_format": "wav", + "encoding": "ALAW", + "bandpass": { + "lowpass": "3400", + "highpass": "400" + } + } + """ + def __init__(self, config: dict): + super().__init__(config) + self.encoding = config.get("encoding", "g722") + self.bandpass = config.get("bandpass", None) + if self.bandpass: + self.effects = ",".join( + [ + "lowpass=frequency={}:poles=1".format(self.bandpass.get("lowpass", 3400)), + "highpass=frequency={}:poles=1".format(self.bandpass.get("highpass", 300)), + "compand=attacks=0.02:decays=0.05:points=-60/-60|-30/-10|-20/-8|-5/-8|-2/-8:gain=-8:volume=-7:delay=0.05", + ] + ) + + def transform(self): + """ + """ + torch_audio = torch.tensor(self.data).reshape(1, -1) + + if self.effects: + # aug_audio, _ = sox_fx(torch.tensor(self.data).reshape(1, -1), self.sr, self.effects) + aug_audio = apply_effect(torch_audio.T, self.sr, self.effects) + else: + aug_audio = torch_audio + codec_applied = apply_codec(aug_audio, self.sr, self.encoding).T + # convert to numpy array + aug_audio = codec_applied.numpy().flatten() + + self.augmented_audio = librosa_to_pydub(aug_audio) \ No newline at end of file diff --git a/src/data/components/audio_augmentor/time_masking.py b/src/data/components/audio_augmentor/time_masking.py new file mode 100644 index 0000000..d740cc1 --- /dev/null +++ b/src/data/components/audio_augmentor/time_masking.py @@ -0,0 +1,51 @@ +from .base import BaseAugmentor +from .utils import librosa_to_pydub +import numpy as np +import random +import librosa +import soundfile as sf +import os + +import logging +logger = logging.getLogger(__name__) + +class TimeMaskingAugmentor(BaseAugmentor): + def __init__(self, config: dict): + """ + Time Masking augmentation + Config: + T: int, maximum time mask parameter (number of frames to mask) + p: float, upper bound parameter for time mask as a fraction of total time steps + num_masks: int, number of time masks to apply + """ + super().__init__(config) + self.T = config.get('T', 100) + self.p = config.get('p', 1.0) + self.num_masks = config.get('num_masks', 2) + + def apply_time_masking(self, mel_spectrogram): + """ + Apply time masking to the mel spectrogram. + """ + num_time_steps = mel_spectrogram.shape[1] + max_mask_length = int(self.p * num_time_steps) + + for _ in range(self.num_masks): + if num_time_steps > 1: + t = random.randint(1, min(self.T, max_mask_length)) + t0 = random.randint(0, num_time_steps - t) + mel_spectrogram[:, t0:t0 + t] = 0 + + return mel_spectrogram + + def transform(self): + """ + Transform the audio by applying time masking. + """ + data = self.data.copy() + mel_spectrogram = librosa.feature.melspectrogram(y=data, sr=self.sr) + + # Apply time masking + masked_mel_spectrogram = self.apply_time_masking(mel_spectrogram) + masked_audio = librosa.feature.inverse.mel_to_audio(masked_mel_spectrogram, sr=self.sr) + self.augmented_audio = librosa_to_pydub(masked_audio, sr=self.sr) diff --git a/src/data/components/audio_augmentor/time_swap.py b/src/data/components/audio_augmentor/time_swap.py new file mode 100644 index 0000000..0604dba --- /dev/null +++ b/src/data/components/audio_augmentor/time_swap.py @@ -0,0 +1,54 @@ +import torch +import torch.nn.functional as F +from .base import BaseAugmentor +from .utils import librosa_to_pydub +import numpy as np +import random +import librosa +import soundfile as sf +import os + +import logging +logger = logging.getLogger(__name__) + +class TimeSwapAugmentor(BaseAugmentor): + def __init__(self, config: dict): + """ + Time Swap augmentation + Config: + T: int, maximum time swap parameter + num_swaps: int, number of time swaps to apply + """ + super().__init__(config) + self.T = config.get('T', 40) + self.num_swaps = config.get('num_swaps', 1) + + def apply_time_swapping(self, mel_spectrogram): + """ + Apply time swapping to the mel spectrogram. + """ + num_time_steps = mel_spectrogram.shape[1] + + for _ in range(self.num_swaps): + t = random.randint(1, self.T) + t0 = random.randint(0, num_time_steps - 2 * t) + t1 = random.randint(t0 + t, num_time_steps - t) + + # Swap the segments + segment1 = mel_spectrogram[:, t0:t0 + t].copy() + segment2 = mel_spectrogram[:, t1:t1 + t].copy() + mel_spectrogram[:, t0:t0 + t] = segment2 + mel_spectrogram[:, t1:t1 + t] = segment1 + + return mel_spectrogram + + def transform(self): + """ + Transform the audio by applying time swapping. + """ + data = self.data.copy() + mel_spectrogram = librosa.feature.melspectrogram(y=data, sr=self.sr) + + swapped_mel_spectrogram = self.apply_time_swapping(mel_spectrogram) + augmented_data = librosa.feature.inverse.mel_to_audio(swapped_mel_spectrogram, sr=self.sr) + self.augmented_audio = librosa_to_pydub(augmented_data, sr=self.sr) \ No newline at end of file diff --git a/src/data/components/audio_augmentor/utils.py b/src/data/components/audio_augmentor/utils.py new file mode 100644 index 0000000..8e2d516 --- /dev/null +++ b/src/data/components/audio_augmentor/utils.py @@ -0,0 +1,80 @@ +import os +import numpy as np +import librosa +from pydub import AudioSegment +import subprocess + +import logging +logger = logging.getLogger(__name__) + +def recursive_list_files(path: str, file_type: str =["wav", "mp3", "flac"]) -> list: + """Recursively lists all files in a directory and its subdirectories""" + files = [] + for dirpath, dirnames, filenames in os.walk(path): + for filename in filenames: + real_file_type = filename.split(".")[-1] + if (real_file_type in file_type): + files.append(os.path.join(dirpath, filename)) + return files + +def pydub_to_librosa(audio_segment: AudioSegment) -> np.ndarray: + """Convert pydub audio segment to librosa audio data""" + return np.array(audio_segment.get_array_of_samples(), dtype=np.int16) + +def librosa_to_pydub(audio_data: np.ndarray, sr: int =16000) -> AudioSegment: + """Convert librosa audio data to pydub audio segment""" + audio_data = np.array(audio_data * (1<<15), dtype=np.int16) + return AudioSegment(audio_data.tobytes(), + frame_rate=sr, + sample_width=audio_data.dtype.itemsize, + channels=1) + +def run_cmd(cmd: str) -> str: + try: + output = subprocess.call(cmd, shell=True, stderr=subprocess.STDOUT, stdout=subprocess.DEVNULL) + except subprocess.CalledProcessError as e: + output = e.output + # return output.decode() + +def down_load_model(model_name: str, save_path: str): + """ + This function help to download the pretrained model of famous ASV spoofing models + + :param model_name: the name of the model. Support ["rawnet2", "aasistssl", "xlsr2_300m", "lcnn", "btse"] + :param save_path: the path to save the pretrained model + """ + if model_name == "rawnet2": + if os.path.exists(os.path.join(save_path,"pre_trained_DF_RawNet2.pth")): + return + logger.info("Downloading pretrained model for Rawnet2") + run_cmd("wget https://www.asvspoof.org/asvspoof2021/pre_trained_DF_RawNet2.zip") + run_cmd("unzip pre_trained_DF_RawNet2.zip") + run_cmd("rm pre_trained_DF_RawNet2.zip") + run_cmd("mv pre_trained_DF_RawNet2.pth {}".format(os.path.join(save_path,"pre_trained_DF_RawNet2.pth"))) + if model_name == "aasistssl": + if os.path.exists(os.path.join(save_path,"LA_model.pth")): + return + logger.info("Downloading pretrained model for AASIST-SSL") + run_cmd("gdown 11vFBNKzYUtWqdK358_JEygawFzmmaZjO") + run_cmd("mv LA_model.pth {}".format(os.path.join(save_path,"LA_model.pth"))) + + if model_name == "xlsr2_300m": + if os.path.exists(os.path.join(save_path,"xlsr2_300m.pth")): + return + logger.info("Downloading pretrained model for Wav2Vec2.0 (XLSR-2.0-300M)") + run_cmd("wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr2_300m.pt") + run_cmd("mv xlsr2_300m.pt {}".format(os.path.join(save_path,"xlsr2_300m.pth"))) + + if model_name == "lcnn": + if os.path.exists(os.path.join(save_path,"lcnn_full_230209.pth")): + return + logger.info("Downloading pretrained model for LCNN") + run_cmd("gdown 1T2xWKuaFbtJjXdwUVPL7_hDPoKBE-MC8") + run_cmd("mv lcnn_full_230209.pth {}".format(os.path.join(save_path,"lcnn_full_230209.pth"))) + + if model_name == "btse": + if os.path.exists(os.path.join(save_path,"tts_vc_trans_64_concat.pth")): + return + logger.info("Downloading pretrained model for BTSE") + run_cmd("gdown 174L262CCMnvp4YQKG2t07Mrf_mDIf0ul") + run_cmd("mv tts_vc_trans_64_concat.pth {}".format(os.path.join(save_path,"tts_vc_trans_64_concat.pth"))) \ No newline at end of file diff --git a/src/data/components/audio_augmentor/volume.py b/src/data/components/audio_augmentor/volume.py new file mode 100644 index 0000000..b644a7e --- /dev/null +++ b/src/data/components/audio_augmentor/volume.py @@ -0,0 +1,32 @@ +from .base import BaseAugmentor +from .utils import librosa_to_pydub +import random + +import logging +logger = logging.getLogger(__name__) +class VolumeAugmentor(BaseAugmentor): + def __init__(self, config: dict): + """ + Volume augmentor class requires these config: + min_volume_dBFS: float, min volume dBFS + max_volume_dBFS: float, max volume dBFS + """ + super().__init__(config) + self.volume_dBFS = random.uniform(config["min_volume_dBFS"], config["max_volume_dBFS"]) + + self.audio_data = None + + def load(self, input_path: str): + """ + :param input_path: path to the input audio file + """ + # load with librosa + super().load(input_path) + # transform to pydub audio segment + self.audio_data = librosa_to_pydub(self.data, sr=self.sr) + + def transform(self): + """ + Volume up or down the audio using pydub `AudioSegment` method + """ + self.augmented_audio = self.audio_data + self.volume_dBFS \ No newline at end of file diff --git a/src/data/components/augwrapper.py b/src/data/components/augwrapper.py index 0110014..42916ef 100644 --- a/src/data/components/augwrapper.py +++ b/src/data/components/augwrapper.py @@ -4,9 +4,11 @@ import torchaudio from torch import Tensor import librosa -from datautils.RawBoost import ISD_additive_noise, LnL_convolutive_noise, SSI_additive_noise, normWav -from datautils.audio_augmentor import BackgroundNoiseAugmentor, PitchAugmentor, ReverbAugmentor, SpeedAugmentor, VolumeAugmentor, TelephoneEncodingAugmentor, GaussianAugmentor, CopyPasteAugmentor, BaseAugmentor, TimeMaskingAugmentor, FrequencyMaskingAugmentor, MaskingAugmentor, TimeSwapAugmentor, FrequencySwapAugmentor, SwappingAugmentor, LinearFilterAugmentor, BandpassAugmentor -from datautils.audio_augmentor.utils import pydub_to_librosa, librosa_to_pydub + +from src.data.components.RawBoost import ISD_additive_noise, LnL_convolutive_noise, SSI_additive_noise, normWav +from src.data.components.audio_augmentor import BackgroundNoiseAugmentor, PitchAugmentor, ReverbAugmentor, SpeedAugmentor, VolumeAugmentor, TelephoneEncodingAugmentor, GaussianAugmentor, CopyPasteAugmentor, BaseAugmentor, TimeMaskingAugmentor, FrequencyMaskingAugmentor, MaskingAugmentor, TimeSwapAugmentor, FrequencySwapAugmentor, SwappingAugmentor, LinearFilterAugmentor, BandpassAugmentor +from src.data.components.audio_augmentor.utils import pydub_to_librosa, librosa_to_pydub + import soundfile as sf import random diff --git a/src/data/components/baseloader.py b/src/data/components/baseloader.py index de96c38..d8809fd 100644 --- a/src/data/components/baseloader.py +++ b/src/data/components/baseloader.py @@ -6,7 +6,8 @@ class Dataset_base(Dataset): def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[], augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2, trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None, - aug_dir=None, online_aug=False, repeat_pad=True, is_train=True): + aug_dir=None, online_aug=False, repeat_pad=True, is_train=True, + random_start=False): """ Args: list_IDs (string): Path to the .lst file with real audio filenames. @@ -39,6 +40,7 @@ def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[], self.args.aug_dir = aug_dir self.args.online_aug = online_aug self.repeat_pad = repeat_pad + self.random_start = random_start self.is_train = is_train def __len__(self): diff --git a/src/data/cyberaicup_datamodule.py b/src/data/cyberaicup_datamodule.py index a1b817f..24f8c32 100644 --- a/src/data/cyberaicup_datamodule.py +++ b/src/data/cyberaicup_datamodule.py @@ -10,7 +10,8 @@ from scipy import signal import copy import random -from src.data.components.RawBoost import process_Rawboost_feature +from src.core_scripts.data_io import wav_augmentation as nii_wav_aug +from src.core_scripts.data_io import wav_tools as nii_wav_tools from src.data.components.dataio import load_audio, pad from src.data.components.baseloader import Dataset_base @@ -25,13 +26,12 @@ class Dataset_for(Dataset_base): def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[], augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2, trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None, - aug_dir=None, online_aug=False, repeat_pad=True, is_train=True): + aug_dir=None, online_aug=False, repeat_pad=True, is_train=True, random_start=False): super(Dataset_for, self).__init__(args, list_IDs, labels, base_dir, algo, vocoders, augmentation_methods, eval_augment, num_additional_real, num_additional_spoof, trim_length, wav_samp_rate, noise_path, rir_path, - aug_dir, online_aug, repeat_pad, is_train) - print("train repeat_pad:", repeat_pad) - #self.repeat_pad = False # temporary fix for the padding issue + aug_dir, online_aug, repeat_pad, is_train, random_start) + def __getitem__(self, idx): utt_id = self.list_IDs[idx] filepath = os.path.join(self.base_dir, utt_id) @@ -48,20 +48,21 @@ def __getitem__(self, idx): X_pad= pad(X,padding_type="repeat" if self.repeat_pad else "zero", max_len=self.trim_length, random_start=True) x_inp= Tensor(X_pad) target = self.labels[utt_id] - return idx, x_inp, target + return idx, x_inp, target + class Dataset_for_dev(Dataset_base): def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[], augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2, trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None, - aug_dir=None, online_aug=False, repeat_pad=False, is_train=True): + aug_dir=None, online_aug=False, repeat_pad=False, is_train=True, + random_start=False + ): super(Dataset_for_dev, self).__init__(args, list_IDs, labels, base_dir, algo, vocoders, augmentation_methods, eval_augment, num_additional_real, num_additional_spoof, trim_length, wav_samp_rate, noise_path, rir_path, - aug_dir, online_aug, repeat_pad, is_train) - #repeat_pad = False - print("dev repeat_pad:", repeat_pad) - + aug_dir, online_aug, repeat_pad, is_train, random_start) + if repeat_pad: self.padding_type = "repeat" else: @@ -70,7 +71,7 @@ def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[], def __getitem__(self, index): utt_id = self.list_IDs[index] X, fs = librosa.load(self.base_dir + "/" + utt_id, sr=16000) - X_pad = pad(X,self.padding_type,self.trim_length, random_start=True) + X_pad = pad(X,self.padding_type,self.trim_length, random_start=self.random_start) x_inp = Tensor(X_pad) target = self.labels[utt_id] return index, x_inp, target @@ -79,12 +80,12 @@ class Dataset_for_eval(Dataset_base): def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[], augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2, trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None, - aug_dir=None, online_aug=False, repeat_pad=True, is_train=True, enable_chunking=False + aug_dir=None, online_aug=False, repeat_pad=True, is_train=True, enable_chunking=False, random_start=False ): super(Dataset_for_eval, self).__init__(args, list_IDs, labels, base_dir, algo, vocoders, augmentation_methods, eval_augment, num_additional_real, num_additional_spoof, trim_length, wav_samp_rate, noise_path, rir_path, - aug_dir, online_aug, repeat_pad, is_train) + aug_dir, online_aug, repeat_pad, is_train, random_start) self.enable_chunking = enable_chunking if repeat_pad: self.padding_type = "repeat" @@ -100,11 +101,11 @@ def __getitem__(self, idx): # print("eval_augment:", self.eval_augment) X = globals()[self.eval_augment](X, self.args, self.sample_rate, audio_path = filepath) if not self.enable_chunking: - X= pad(X,padding_type=self.padding_type,max_len=self.trim_length, random_start=True) + X= pad(X,padding_type=self.padding_type,max_len=self.trim_length, random_start=self.random_start) x_inp = Tensor(X) return x_inp, utt_id -class CyberAiCupDataModule(LightningDataModule): +class CyberpcupNormalDataModule(LightningDataModule): """`LightningDataModule` for the ASVSpoof dataset. The ASVspoof 2019 database for logical access is based upon a standard multi-speaker speech synthesis database called VCTK2. @@ -155,10 +156,10 @@ def __init__( num_workers: int = 0, pin_memory: bool = False, args: Optional[Dict[str, Any]] = None, - list_IDs: list = [], labels: list = [], base_dir: str = '', algo: int = 5, vocoders: list = [], - augmentation_methods: list = [], eval_augment: Optional[str] = None, num_additional_real: int = 2, num_additional_spoof: int = 2, - trim_length: int = 64000, wav_samp_rate: int = 16000, noise_path: Optional[str] = None, rir_path: Optional[str] = None, - aug_dir: Optional[str] = None, online_aug: bool = False, repeat_pad: bool = True, is_train: bool = True, enable_chunking: bool = False + # list_IDs: list = [], labels: list = [], base_dir: str = '', algo: int = 5, vocoders: list = [], + # augmentation_methods: list = [], eval_augment: Optional[str] = None, num_additional_real: int = 2, num_additional_spoof: int = 2, + # trim_length: int = 64000, wav_samp_rate: int = 16000, noise_path: Optional[str] = None, rir_path: Optional[str] = None, + # aug_dir: Optional[str] = None, online_aug: bool = False, repeat_pad: bool = True, is_train: bool = True, enable_chunking: bool = False ) -> None: """Initialize a `ASVSpoofDataModule`. @@ -179,6 +180,7 @@ def __init__( self.batch_size_per_device = batch_size self.data_dir = data_dir + self.args = args @property @@ -218,35 +220,36 @@ def setup(self, stage: Optional[str] = None) -> None: d_label_trn, file_train = self.genList(dir_meta=os.path.join( self.data_dir), is_train=True, is_eval=False, is_dev=False) - # if 'portion' in config['data']: - # idx = range(len(file_train)) - # idx = np.random.choice( - # idx, int(len(file_train)*config['data']['portion']), replace=False) - # file_train = [file_train[i] for i in idx] - # if len(d_label_trn) > 0: # the case of train without label - # d_label_trn = {k: d_label_trn[k] for k in file_train} - + if 'portion' in self.args: + idx = range(len(file_train)) + idx = np.random.choice( + idx, int(len(file_train)*self.args['portion']), replace=False) + file_train = [file_train[i] for i in idx] + if len(d_label_trn) > 0: # the case of train without label + d_label_trn = {k: d_label_trn[k] for k in file_train} + print('no. of training trials', len(file_train)) + d_label_dev, file_dev = self.genList(dir_meta=os.path.join(self.data_dir), is_train=False, is_eval=False, is_dev=True) - # if 'portion' in config['data']: - # idx = range(len(file_dev)) - # idx = np.random.choice( - # idx, int(len(file_dev)*config['data']['portion']), replace=False) - # file_dev = [file_dev[i] for i in idx] - # if len(d_label_dev) > 0: # the case of train without label - # d_label_dev = {k: d_label_dev[k] for k in file_dev} + if 'portion' in self.args: + idx = range(len(file_dev)) + idx = np.random.choice( + idx, int(len(file_dev)*self.args['portion']), replace=False) + file_dev = [file_dev[i] for i in idx] + if len(d_label_dev) > 0: # the case of train without label + d_label_dev = {k: d_label_dev[k] for k in file_dev} + + print('no. of validation trials', len(file_dev)) + file_eval = self.genList(dir_meta=os.path.join(self.data_dir), is_train=False, is_eval=True, is_dev=False) - self.data_train = Dataset_for(args, list_IDs=file_train, labels=d_label_trn, - base_dir=args.database_path+'/', algo=args.algo, repeat_pad=is_repeat_pad, **config['data']['kwargs']) + self.data_train = Dataset_for(self.args, list_IDs=file_train, labels=d_label_trn, + base_dir=self.data_dir+'/', **self.args['data']) - self.data_val = Dataset_for_dev(args, list_IDs=file_dev, labels=d_label_dev, - base_dir=args.database_path+'/', algo=args.algo, repeat_pad=is_repeat_pad, is_train=False, **config['data']['kwargs']) + self.data_val = Dataset_for_dev(self.args, list_IDs=file_dev, labels=d_label_dev, + base_dir=self.data_dir+'/', is_train=False, **self.args['data']) - self.data_test = Dataset_for_eval(args, list_IDs=file_eval, labels=None, - base_dir=args.database_path+'/', algo=args.algo, repeat_pad=is_repeat_pad, **config['data']['kwargs'], - enable_chunking=config.get( - 'eval_chunking', False) - ) + self.data_test = Dataset_for_eval(self.args, list_IDs=file_eval, labels=None, + base_dir=self.data_dir+'/', **self.args['data']) def train_dataloader(self) -> DataLoader[Any]: @@ -260,6 +263,7 @@ def train_dataloader(self) -> DataLoader[Any]: num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, shuffle=True, + drop_last=True ) def val_dataloader(self) -> DataLoader[Any]: @@ -344,7 +348,7 @@ def genList(self, dir_meta, is_train=False, is_eval=False, is_dev=False): l_meta = f.readlines() for line in l_meta: utt, subset, label = line.strip().split() - if subset == 'dev': + if subset == 'eval': file_list.append(utt) file_list.append(utt) d_meta[utt] = 1 if label == 'bonafide' else 0 @@ -352,4 +356,4 @@ def genList(self, dir_meta, is_train=False, is_eval=False, is_dev=False): return d_meta, file_list if __name__ == "__main__": - _ = CyberAiCupDataModule() + _ = CyberpcupNormalDataModule() \ No newline at end of file diff --git a/src/data/normal_datamodule.py b/src/data/normal_datamodule.py new file mode 100644 index 0000000..dec2c5b --- /dev/null +++ b/src/data/normal_datamodule.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from typing import Any, Dict, Optional, Tuple +from lightning import LightningDataModule +from torch.utils.data import DataLoader, Dataset +from torch import Tensor +import librosa +import numpy as np +import os +from scipy import signal +import copy +import random +from src.core_scripts.data_io import wav_augmentation as nii_wav_aug +from src.core_scripts.data_io import wav_tools as nii_wav_tools +from src.data.components.dataio import load_audio, pad +from src.data.components.baseloader import Dataset_base + +# augwrapper +from src.data.components.augwrapper import SUPPORTED_AUGMENTATION + +# dynamic import of augmentation methods +for aug in SUPPORTED_AUGMENTATION: + exec(f"from src.data.components.augwrapper import {aug}") + +class Dataset_for(Dataset_base): + def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[], + augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2, + trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None, + aug_dir=None, online_aug=False, repeat_pad=True, is_train=True, random_start=False): + super(Dataset_for, self).__init__(args, list_IDs, labels, base_dir, algo, vocoders, + augmentation_methods, eval_augment, num_additional_real, num_additional_spoof, + trim_length, wav_samp_rate, noise_path, rir_path, + aug_dir, online_aug, repeat_pad, is_train, random_start) + + def __getitem__(self, idx): + utt_id = self.list_IDs[idx] + filepath = os.path.join(self.base_dir, utt_id) + X = load_audio(filepath, self.sample_rate) + + # apply augmentation + # randomly choose an augmentation method + if self.is_train: + augmethod_index = random.choice(range(len(self.augmentation_methods))) if len(self.augmentation_methods) > 0 else -1 + if augmethod_index >= 0: + X = globals()[self.augmentation_methods[augmethod_index]](X, self.args, self.sample_rate, + audio_path = filepath) + + X_pad= pad(X,padding_type="repeat" if self.repeat_pad else "zero", max_len=self.trim_length, random_start=True) + x_inp= Tensor(X_pad) + target = self.labels[utt_id] + return idx, x_inp, target + + +class Dataset_for_dev(Dataset_base): + def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[], + augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2, + trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None, + aug_dir=None, online_aug=False, repeat_pad=False, is_train=True, + random_start=False + ): + super(Dataset_for_dev, self).__init__(args, list_IDs, labels, base_dir, algo, vocoders, + augmentation_methods, eval_augment, num_additional_real, num_additional_spoof, + trim_length, wav_samp_rate, noise_path, rir_path, + aug_dir, online_aug, repeat_pad, is_train, random_start) + + if repeat_pad: + self.padding_type = "repeat" + else: + self.padding_type = "zero" + + def __getitem__(self, index): + utt_id = self.list_IDs[index] + X, fs = librosa.load(self.base_dir + "/" + utt_id, sr=16000) + X_pad = pad(X,self.padding_type,self.trim_length, random_start=self.random_start) + x_inp = Tensor(X_pad) + target = self.labels[utt_id] + return index, x_inp, target + +class Dataset_for_eval(Dataset_base): + def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[], + augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2, + trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None, + aug_dir=None, online_aug=False, repeat_pad=True, is_train=True, enable_chunking=False, random_start=False + ): + super(Dataset_for_eval, self).__init__(args, list_IDs, labels, base_dir, algo, vocoders, + augmentation_methods, eval_augment, num_additional_real, num_additional_spoof, + trim_length, wav_samp_rate, noise_path, rir_path, + aug_dir, online_aug, repeat_pad, is_train, random_start) + self.enable_chunking = enable_chunking + if repeat_pad: + self.padding_type = "repeat" + else: + self.padding_type = "zero" + + def __getitem__(self, idx): + utt_id = self.list_IDs[idx] + filepath = os.path.join(self.base_dir, utt_id) + X, _ = librosa.load(filepath, sr=16000) + # apply augmentation at inference time + if self.eval_augment is not None: + # print("eval_augment:", self.eval_augment) + X = globals()[self.eval_augment](X, self.args, self.sample_rate, audio_path = filepath) + if not self.enable_chunking: + X= pad(X,padding_type=self.padding_type,max_len=self.trim_length, random_start=self.random_start) + x_inp = Tensor(X) + return x_inp, utt_id + +class NormalDataModule(LightningDataModule): + """`LightningDataModule` for the ASVSpoof dataset. + + The ASVspoof 2019 database for logical access is based upon a standard multi-speaker speech synthesis database called VCTK2. + Genuine speech is collected from 107 speakers (46 male, 61 female) and with no significant channel or background noise effects. + Spoofed speech is generated from the genuine data using a number of different spoofing algorithms. + The full dataset is partitioned into three subsets, the first for training, the second for development and the third for evaluation. + There is no speaker overlap across the three subsets regarding target speakers used in voice conversion or TTS adaptation. + + A `LightningDataModule` implements 7 key methods: + + ```python + def prepare_data(self): + # Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP). + # Download data, pre-process, split, save to disk, etc... + + def setup(self, stage): + # Things to do on every process in DDP. + # Load data, set variables, etc... + + def train_dataloader(self): + # return train dataloader + + def val_dataloader(self): + # return validation dataloader + + def test_dataloader(self): + # return test dataloader + + def predict_dataloader(self): + # return predict dataloader + + def teardown(self, stage): + # Called on every process in DDP. + # Clean up after fit or test. + ``` + + This allows you to share a full dataset without explaining how to download, + split, transform and process the data. + + Read the docs: + https://lightning.ai/docs/pytorch/latest/data/datamodule.html + """ + + def __init__( + self, + data_dir: str = "data/", + batch_size: int = 64, + num_workers: int = 0, + pin_memory: bool = False, + args: Optional[Dict[str, Any]] = None, + # list_IDs: list = [], labels: list = [], base_dir: str = '', algo: int = 5, vocoders: list = [], + # augmentation_methods: list = [], eval_augment: Optional[str] = None, num_additional_real: int = 2, num_additional_spoof: int = 2, + # trim_length: int = 64000, wav_samp_rate: int = 16000, noise_path: Optional[str] = None, rir_path: Optional[str] = None, + # aug_dir: Optional[str] = None, online_aug: bool = False, repeat_pad: bool = True, is_train: bool = True, enable_chunking: bool = False + ) -> None: + """Initialize a `ASVSpoofDataModule`. + + :param data_dir: The data directory. Defaults to `"data/"`. + :param batch_size: The batch size. Defaults to `64`. + :param num_workers: The number of workers. Defaults to `0`. + :param pin_memory: Whether to pin memory. Defaults to `False`. + """ + super().__init__() + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False) + + self.data_train: Optional[Dataset] = None + self.data_val: Optional[Dataset] = None + self.data_test: Optional[Dataset] = None + + self.batch_size_per_device = batch_size + self.data_dir = data_dir + + self.args = args + + @property + def num_classes(self) -> int: + """Get the number of classes. + + :return: The number of ASVSpoof classes (2). + """ + return 2 + + def prepare_data(self) -> None: + pass + + def setup(self, stage: Optional[str] = None) -> None: + """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. + + This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and + `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after + `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to + `self.setup()` once the data is prepared and available for use. + + :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``. + """ + # Divide batch size by the number of devices. + if self.trainer is not None: + if self.hparams.batch_size % self.trainer.world_size != 0: + raise RuntimeError( + f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})." + ) + self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size + + # load and split datasets only if not loaded already + if not self.data_train and not self.data_val and not self.data_test: + + # define train dataloader + + d_label_trn, file_train = self.genList(dir_meta=os.path.join( + self.data_dir), is_train=True, is_eval=False, is_dev=False) + + if 'portion' in self.args: + idx = range(len(file_train)) + idx = np.random.choice( + idx, int(len(file_train)*self.args['portion']), replace=False) + file_train = [file_train[i] for i in idx] + if len(d_label_trn) > 0: # the case of train without label + d_label_trn = {k: d_label_trn[k] for k in file_train} + print('no. of training trials', len(file_train)) + + d_label_dev, file_dev = self.genList(dir_meta=os.path.join(self.data_dir), is_train=False, is_eval=False, is_dev=True) + + if 'portion' in self.args: + idx = range(len(file_dev)) + idx = np.random.choice( + idx, int(len(file_dev)*self.args['portion']), replace=False) + file_dev = [file_dev[i] for i in idx] + if len(d_label_dev) > 0: # the case of train without label + d_label_dev = {k: d_label_dev[k] for k in file_dev} + + print('no. of validation trials', len(file_dev)) + file_eval = self.genList(dir_meta=os.path.join(self.data_dir), is_train=False, is_eval=True, is_dev=False) + + self.data_train = Dataset_for(self.args, list_IDs=file_train, labels=d_label_trn, + base_dir=self.data_dir+'/', **self.args['data']) + + self.data_val = Dataset_for_dev(self.args, list_IDs=file_dev, labels=d_label_dev, + base_dir=self.data_dir+'/', is_train=False, **self.args['data']) + + self.data_test = Dataset_for_eval(self.args, list_IDs=file_eval, labels=None, + base_dir=self.data_dir+'/', **self.args['data']) + + + def train_dataloader(self) -> DataLoader[Any]: + """Create and return the train dataloader. + + :return: The train dataloader. + """ + return DataLoader( + dataset=self.data_train, + batch_size=self.batch_size_per_device, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=True, + drop_last=True + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Create and return the validation dataloader. + + :return: The validation dataloader. + """ + return DataLoader( + dataset=self.data_val, + batch_size=self.batch_size_per_device, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Create and return the test dataloader. + + :return: The test dataloader. + """ + return DataLoader( + dataset=self.data_test, + batch_size=self.batch_size_per_device, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=False, + ) + + def teardown(self, stage: Optional[str] = None) -> None: + """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`, + `trainer.test()`, and `trainer.predict()`. + + :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. + Defaults to ``None``. + """ + pass + + def state_dict(self) -> Dict[Any, Any]: + """Called when saving a checkpoint. Implement to generate and save the datamodule state. + + :return: A dictionary containing the datamodule state that you want to save. + """ + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Called when loading a checkpoint. Implement to reload datamodule state given datamodule + `state_dict()`. + + :param state_dict: The datamodule state returned by `self.state_dict()`. + """ + pass + + def genList(self, dir_meta, is_train=False, is_eval=False, is_dev=False): + # bonafide: 1, spoof: 0 + d_meta = {} + file_list=[] + protocol = os.path.join(dir_meta, "protocol.txt") + + if (is_train): + with open(protocol, 'r') as f: + l_meta = f.readlines() + for line in l_meta: + utt, subset, label = line.strip().split() + if subset == 'train': + file_list.append(utt) + d_meta[utt] = 1 if label == 'bonafide' else 0 + + return d_meta, file_list + if (is_dev): + with open(protocol, 'r') as f: + l_meta = f.readlines() + for line in l_meta: + utt, subset, label = line.strip().split() + if subset == 'dev': + file_list.append(utt) + d_meta[utt] = 1 if label == 'bonafide' else 0 + return d_meta, file_list + + if (is_eval): + # no eval protocol yet + with open(protocol, 'r') as f: + l_meta = f.readlines() + for line in l_meta: + utt, subset, label = line.strip().split() + if subset == 'eval': + file_list.append(utt) + file_list.append(utt) + d_meta[utt] = 1 if label == 'bonafide' else 0 + # return d_meta, file_list + return d_meta, file_list + +if __name__ == "__main__": + _ = NormalDataModule() \ No newline at end of file diff --git a/src/data/scl_datamodule.py b/src/data/scl_datamodule.py index 76e1137..8351d2a 100644 --- a/src/data/scl_datamodule.py +++ b/src/data/scl_datamodule.py @@ -15,18 +15,23 @@ from src.data.components.dataio import load_audio, pad from src.data.components.baseloader import Dataset_base +# augwrapper +from src.data.components.augwrapper import SUPPORTED_AUGMENTATION + +# dynamic import of augmentation methods +for aug in SUPPORTED_AUGMENTATION: + exec(f"from src.data.components.augwrapper import {aug}") + class Dataset_for(Dataset_base): def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[], augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2, trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None, - aug_dir=None, online_aug=False, repeat_pad=True, is_train=True): + aug_dir=None, online_aug=False, repeat_pad=True, is_train=True, random_start=False): super(Dataset_for, self).__init__(args, list_IDs, labels, base_dir, algo, vocoders, augmentation_methods, eval_augment, num_additional_real, num_additional_spoof, trim_length, wav_samp_rate, noise_path, rir_path, - aug_dir, online_aug, repeat_pad, is_train) - - self.vocoded_dir = os.path.join(base_dir, 'vocoded') + aug_dir, online_aug, repeat_pad, is_train, random_start) # read spoof_train and spoof_dev list from scp if self.is_train: self.spoof_list = [] @@ -93,44 +98,68 @@ def __getitem__(self, idx): # label is 1 for anchor and positive, 0 for vocoded label = [1] * (len(augmented_audios) +len(additional_audios) + len(augmented_additional_audios) + 1) + [0] * (len(additional_spoofs) + len(augmented_additional_spoofs)) # print("label", label) + #print("label", len(label)) return self.list_IDs[idx], batch_data, Tensor(label) -class Args: - def __init__(self): - self.database_path = '/your/path/to/data/ASVspoof_database/DF/' - self.protocols_path = '/data/hungdx/Datasets/protocols/database/' - self.batch_size = 14 - self.num_epochs = 100 - self.lr = 0.000001 - self.weight_decay = 0.0001 - self.loss = 'weighted_CCE' - self.seed = 1234 - self.model_path = None - self.comment = None - self.track = 'DF' - self.eval_output = None - self.eval = False - self.is_eval = False - self.eval_part = 0 - self.cudnn_deterministic_toggle = True - self.cudnn_benchmark_toggle = False - self.algo = 3 - self.nBands = 5 - self.minF = 20 - self.maxF = 8000 - self.minBW = 100 - self.maxBW = 1000 - self.minCoeff = 10 - self.maxCoeff = 100 - self.minG = 0 - self.maxG = 0 - self.minBiasLinNonLin = 5 - self.maxBiasLinNonLin = 20 - self.N_f = 5 - self.P = 10 - self.g_sd = 2 - self.SNRmin = 10 - self.SNRmax = 40 +# def collate_fn(batch): +# return { +# 'pixel_values': torch.stack([x['pixel_values'] for x in batch]), +# 'labels': torch.tensor([x['labels'] for x in batch]) +# } + +class Dataset_for_dev(Dataset_base): + def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[], + augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2, + trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None, + aug_dir=None, online_aug=False, repeat_pad=False, is_train=True, + random_start=False + ): + super(Dataset_for_dev, self).__init__(args, list_IDs, labels, base_dir, algo, vocoders, + augmentation_methods, eval_augment, num_additional_real, num_additional_spoof, + trim_length, wav_samp_rate, noise_path, rir_path, + aug_dir, online_aug, repeat_pad, is_train, random_start) + + if repeat_pad: + self.padding_type = "repeat" + else: + self.padding_type = "zero" + + def __getitem__(self, index): + utt_id = self.list_IDs[index] + X, fs = librosa.load(self.base_dir + "/" + utt_id, sr=16000) + X_pad = pad(X,self.padding_type,self.trim_length, random_start=self.random_start) + x_inp = Tensor(X_pad) + target = self.labels[utt_id] + return index, x_inp, target + +class Dataset_for_eval(Dataset_base): + def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[], + augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2, + trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None, + aug_dir=None, online_aug=False, repeat_pad=True, is_train=True, enable_chunking=False, random_start=False + ): + super(Dataset_for_eval, self).__init__(args, list_IDs, labels, base_dir, algo, vocoders, + augmentation_methods, eval_augment, num_additional_real, num_additional_spoof, + trim_length, wav_samp_rate, noise_path, rir_path, + aug_dir, online_aug, repeat_pad, is_train, random_start) + self.enable_chunking = enable_chunking + if repeat_pad: + self.padding_type = "repeat" + else: + self.padding_type = "zero" + + def __getitem__(self, idx): + utt_id = self.list_IDs[idx] + filepath = os.path.join(self.base_dir, utt_id) + X, _ = librosa.load(filepath, sr=16000) + # apply augmentation at inference time + if self.eval_augment is not None: + # print("eval_augment:", self.eval_augment) + X = globals()[self.eval_augment](X, self.args, self.sample_rate, audio_path = filepath) + if not self.enable_chunking: + X= pad(X,padding_type=self.padding_type,max_len=self.trim_length, random_start=self.random_start) + x_inp = Tensor(X) + return x_inp, utt_id class SclNormalDataModule(LightningDataModule): """`LightningDataModule` for the ASVSpoof dataset. @@ -183,10 +212,10 @@ def __init__( num_workers: int = 0, pin_memory: bool = False, args: Optional[Dict[str, Any]] = None, - list_IDs: list = [], labels: list = [], base_dir: str = '', algo: int = 5, vocoders: list = [], - augmentation_methods: list = [], eval_augment: Optional[str] = None, num_additional_real: int = 2, num_additional_spoof: int = 2, - trim_length: int = 64000, wav_samp_rate: int = 16000, noise_path: Optional[str] = None, rir_path: Optional[str] = None, - aug_dir: Optional[str] = None, online_aug: bool = False, repeat_pad: bool = True, is_train: bool = True, enable_chunking: bool = False + # list_IDs: list = [], labels: list = [], base_dir: str = '', algo: int = 5, vocoders: list = [], + # augmentation_methods: list = [], eval_augment: Optional[str] = None, num_additional_real: int = 2, num_additional_spoof: int = 2, + # trim_length: int = 64000, wav_samp_rate: int = 16000, noise_path: Optional[str] = None, rir_path: Optional[str] = None, + # aug_dir: Optional[str] = None, online_aug: bool = False, repeat_pad: bool = True, is_train: bool = True, enable_chunking: bool = False ) -> None: """Initialize a `ASVSpoofDataModule`. @@ -208,6 +237,8 @@ def __init__( self.batch_size_per_device = batch_size self.data_dir = data_dir + self.args = args + @property def num_classes(self) -> int: """Get the number of classes. @@ -245,34 +276,36 @@ def setup(self, stage: Optional[str] = None) -> None: d_label_trn, file_train = self.genList(dir_meta=os.path.join( self.data_dir), is_train=True, is_eval=False, is_dev=False) - # if 'portion' in config['data']: - # idx = range(len(file_train)) - # idx = np.random.choice( - # idx, int(len(file_train)*config['data']['portion']), replace=False) - # file_train = [file_train[i] for i in idx] - # if len(d_label_trn) > 0: # the case of train without label - # d_label_trn = {k: d_label_trn[k] for k in file_train} - + if 'portion' in self.args: + idx = range(len(file_train)) + idx = np.random.choice( + idx, int(len(file_train)*self.args['portion']), replace=False) + file_train = [file_train[i] for i in idx] + if len(d_label_trn) > 0: # the case of train without label + d_label_trn = {k: d_label_trn[k] for k in file_train} + print('no. of training trials', len(file_train)) + d_label_dev, file_dev = self.genList(dir_meta=os.path.join(self.data_dir), is_train=False, is_eval=False, is_dev=True) - # if 'portion' in config['data']: - # idx = range(len(file_dev)) - # idx = np.random.choice( - # idx, int(len(file_dev)*config['data']['portion']), replace=False) - # file_dev = [file_dev[i] for i in idx] - # if len(d_label_dev) > 0: # the case of train without label - # d_label_dev = {k: d_label_dev[k] for k in file_dev} + if 'portion' in self.args: + idx = range(len(file_dev)) + idx = np.random.choice( + idx, int(len(file_dev)*self.args['portion']), replace=False) + file_dev = [file_dev[i] for i in idx] + if len(d_label_dev) > 0: # the case of train without label + d_label_dev = {k: d_label_dev[k] for k in file_dev} + + print('no. of validation trials', len(file_dev)) + file_eval = self.genList(dir_meta=os.path.join(self.data_dir), is_train=False, is_eval=True, is_dev=False) - self.data_train = Dataset_for(args, list_IDs=file_train, labels=d_label_trn, - base_dir=args.database_path+'/', algo=args.algo, repeat_pad=is_repeat_pad, **config['data']['kwargs']) + self.data_train = Dataset_for(self.args, list_IDs=file_train, labels=d_label_trn, + base_dir=self.data_dir+'/', **self.args['data']) - self.data_val = Dataset_for_dev(args, list_IDs=file_dev, labels=d_label_dev, - base_dir=args.database_path+'/', algo=args.algo, repeat_pad=is_repeat_pad, is_train=False, **config['data']['kwargs']) + self.data_val = Dataset_for_dev(self.args, list_IDs=file_dev, labels=d_label_dev, + base_dir=self.data_dir+'/', is_train=False, **self.args['data']) - self.data_test = Dataset_for_eval(args, list_IDs=file_eval, labels=None, - base_dir=args.database_path+'/', algo=args.algo, repeat_pad=is_repeat_pad, **config['data']['kwargs'], - enable_chunking=config.get( - 'eval_chunking', False)) + self.data_test = Dataset_for_eval(self.args, list_IDs=file_eval, labels=None, + base_dir=self.data_dir+'/', **self.args['data']) def train_dataloader(self) -> DataLoader[Any]: @@ -286,6 +319,7 @@ def train_dataloader(self) -> DataLoader[Any]: num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, shuffle=True, + drop_last=True ) def val_dataloader(self) -> DataLoader[Any]: @@ -338,36 +372,40 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """ pass - def genSpoof_list(self, dir_meta, is_train=False, is_eval=False): - """ - This function is from the following source: https://github.com/TakHemlata/SSL_Anti-spoofing/blob/main/data_utils_SSL.py#L17 - Official source: https://arxiv.org/abs/2202.12233 - Automatic speaker verification spoofing and deepfake detection using wav2vec 2.0 and data augmentation - """ + def genList(self, dir_meta, is_train=False, is_eval=False, is_dev=False): + # bonafide: 1, spoof: 0 d_meta = {} file_list=[] - with open(dir_meta, 'r') as f: + + if is_train: + metafile = os.path.join(dir_meta, 'scp/bonafide_train.lst') + elif is_dev: + metafile = os.path.join(dir_meta, 'scp/bonafide_dev.lst') + elif is_eval: + metafile = os.path.join(dir_meta, 'scp/eval.lst') + + with open(metafile, 'r') as f: l_meta = f.readlines() - + if (is_train): for line in l_meta: - _,key,_,_,label = line.strip().split() - - file_list.append(key) - d_meta[key] = 1 if label == 'bonafide' else 0 + utt = line.strip().split() + file_list.append(utt[0]) + d_meta[utt[0]] = 1 return d_meta,file_list - elif(is_eval): + if (is_dev): for line in l_meta: - key= line.strip() - file_list.append(key) - return file_list - else: + utt = line.strip().split() + file_list.append(utt[0]) + d_meta[utt[0]] = 1 + return d_meta,file_list + + elif(is_eval): for line in l_meta: - _,key,_,_,label = line.strip().split() - - file_list.append(key) - d_meta[key] = 1 if label == 'bonafide' else 0 + utt = line.strip().split() + file_list.append(utt[0]) + return d_meta,file_list if __name__ == "__main__": diff --git a/src/models/components/loss_metrics.py b/src/models/components/loss_metrics.py index 7e854fb..3fc959f 100644 --- a/src/models/components/loss_metrics.py +++ b/src/models/components/loss_metrics.py @@ -200,7 +200,7 @@ def supcon_loss(input_feat, # print("log_prob.shape", log_prob.shape) # compute mean of log-likelihood over positive # print(mask_ * log_prob) - mean_log_prob_pos = (mask_ * log_prob).sum(1) / mask_.sum(1) + mean_log_prob_pos = (mask_ * log_prob).sum(1) / (mask_.sum(1) + 1e-8) # avoid zero division # print("mean_log_prob_pos.shape", mean_log_prob_pos.shape) # print(mean_log_prob_pos) # loss diff --git a/src/models/components/wavlmbase_conformertcm.py b/src/models/components/wavlmbase_conformertcm.py index 7539bba..c6fe635 100644 --- a/src/models/components/wavlmbase_conformertcm.py +++ b/src/models/components/wavlmbase_conformertcm.py @@ -8,17 +8,9 @@ from torch import Tensor import os from torch.nn.modules.transformer import _get_clones -from .WavLM.fe import WavLMFe -try: - from model.loss_metrics import supcon_loss - from model.RawNet3.model import RawNet3 - from model.RawNet3.RawNetBasicBlock import Bottle2neck - from model.conformer_tcm.model import MyConformer -except: - from .loss_metrics import supcon_loss - from .RawNet3.model import RawNet3 - from .RawNet3.RawNetBasicBlock import Bottle2neck - from .conformer_tcm.model import MyConformer +from src.models.components.WavLM.fe import WavLMFe +from src.models.components.loss_metrics import supcon_loss +from src.models.components.conformer_tcm.model import MyConformer class Model(nn.Module): @@ -119,29 +111,4 @@ def loss(self, output, feats, emb, labels, config, info=None): return {'L_CE':L_CE} elif self.loss_type == 5: return {'L_CF1':L_CF1, 'L_CF2':L_CF2} - -if __name__ == '__main__': - import yaml - # LoRA - from peft import LoraConfig, TaskType, PeftModel, get_peft_model - - config = yaml.load(open("/data/hungdx/asvspoof5/configs/3_augall_wavlm_conformertcm_res2net_seblock_sclnormal.yaml", 'r'), Loader=yaml.FullLoader) - device = "cpu" - model = Model(config['model'], device) - # print("Hello") - # print(model) - - # Trying apply lora - # Using the best config from this paper - # https://arxiv.org/pdf/2306.05617 - target_modules = ["q_proj", "v_proj"] - r = 4 - - lora_config = LoraConfig( - r=r, - target_modules=target_modules, - #modules_to_save=["qkv"] - ) - peft_model = get_peft_model(model, lora_config) - print(peft_model) - peft_model.print_trainable_parameters() + \ No newline at end of file diff --git a/src/models/components/wavlmbase_vib.py b/src/models/components/wavlmbase_vib.py index b57bfa8..2bd9d18 100644 --- a/src/models/components/wavlmbase_vib.py +++ b/src/models/components/wavlmbase_vib.py @@ -1,29 +1,12 @@ -import random -from typing import Union -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from torch import Tensor -import os -from torch.nn.modules.transformer import _get_clones -from .WavLM.fe import WavLMFe -try: - from model.loss_metrics import supcon_loss - from model.RawNet3.model import RawNet3 - from model.RawNet3.RawNetBasicBlock import Bottle2neck - from model.conformer_tcm.model import MyConformer - -except: - from .loss_metrics import supcon_loss - from .RawNet3.model import RawNet3 - from .RawNet3.RawNetBasicBlock import Bottle2neck - from .conformer_tcm.model import MyConformer +from src.models.components.WavLM.fe import WavLMFe +from src.models.components.loss_metrics import supcon_loss -___author__ = "Hemlata Tak" -__email__ = "tak@eurecom.fr" + class DropoutForMC(nn.Module): """Dropout layer for Bayesian model THe difference is that we do dropout even in eval stage @@ -148,15 +131,14 @@ def forward(self, x): class Model(nn.Module): - def __init__(self, args, device, is_train = True): + def __init__(self, args, is_train = True): super().__init__() - self.device = device self.is_train = is_train - self.flag_fix_ssl = args['flag_fix_ssl'] + self.contra_mode = args['contra_mode'] self.loss_type = args['loss_type'] - self.loss_CE = nn.CrossEntropyLoss(weight = torch.FloatTensor([float(1-args['ce_loss_weight']), float(args['ce_loss_weight'])]).to(device)) + self.loss_CE = nn.CrossEntropyLoss(weight = torch.FloatTensor([float(1-args['ce_loss_weight']), float(args['ce_loss_weight'])])) self.sim_metric_seq = lambda mat1, mat2: torch.bmm( mat1.permute(1, 0, 2), mat2.permute(1, 2, 0)).mean(0) # front-end kwargs @@ -253,29 +235,29 @@ def loss(self, output, feats, emb, labels, config, info=None): KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) # print("KLD: ", KLD) # Recon_loss = 0.000001*(BCE + 0.05*KLD) / real_bzs - Recon_loss = config['model']['recon_weight_l']*(BCE + config['model']['recon_weight_b']*KLD) / real_bzs + Recon_loss = config['recon_weight_l']*(BCE + config['recon_weight_b']*KLD) / real_bzs # reshape the feats_w2v to match the supcon loss format feats_w2v = feats_w2v.unsqueeze(1) # print("feats_w2v.shape", feats_w2v.shape) - L_CF1 = 1/real_bzs * supcon_loss(feats_w2v, labels=labels, contra_mode=config['model']['contra_mode'], sim_metric=sim_metric_seq) + L_CF1 = 1/real_bzs * supcon_loss(feats_w2v, labels=labels, contra_mode=config['contra_mode'], sim_metric=sim_metric_seq) # reshape the emb to match the supcon loss format emb = emb.unsqueeze(1) emb = emb.unsqueeze(-1) # print("emb.shape", emb.shape) - L_CF2 = 1/real_bzs *supcon_loss(emb, labels=labels, contra_mode=config['model']['contra_mode'], sim_metric=sim_metric_seq) + L_CF2 = 1/real_bzs *supcon_loss(emb, labels=labels, contra_mode=config['contra_mode'], sim_metric=sim_metric_seq) - if config['model']['loss_type'] == 1: + if config['loss_type'] == 1: return {'L_CE':L_CE, 'L_CF1':L_CF1, 'L_CF2':L_CF2, 'Recon_loss':Recon_loss} - elif config['model']['loss_type'] == 2: + elif config['loss_type'] == 2: return {'L_CE':L_CE, 'L_CF1':L_CF1} - elif config['model']['loss_type'] == 3: + elif config['loss_type'] == 3: return {'L_CE':L_CE, 'L_CF2':L_CF2} # ablation study - elif config['model']['loss_type'] == 4: + elif config['loss_type'] == 4: return {'L_CE':L_CE} - elif config['model']['loss_type'] == 5: + elif config['loss_type'] == 5: return {'L_CF1':L_CF1, 'L_CF2':L_CF2} - elif config['model']['loss_type'] == 6: + elif config['loss_type'] == 6: return {'L_CE':L_CE, 'Recon_loss':Recon_loss} \ No newline at end of file diff --git a/src/models/components/xlsr_conformertcm.py b/src/models/components/xlsr_conformertcm.py index 6a5942d..db23bf4 100644 --- a/src/models/components/xlsr_conformertcm.py +++ b/src/models/components/xlsr_conformertcm.py @@ -8,21 +8,9 @@ from torch import Tensor from torch.nn.modules.transformer import _get_clones import fairseq -try: - from model.loss_metrics import supcon_loss - from model.RawNet3.model import RawNet3 - from model.RawNet3.RawNetBasicBlock import Bottle2neck - from model.conformer_tcm.model import MyConformer -except: - from .loss_metrics import supcon_loss - from .RawNet3.model import RawNet3 - from .RawNet3.RawNetBasicBlock import Bottle2neck - from .conformer_tcm.model import MyConformer - -############################ -## FOR fine-tuned SSL MODEL -############################ +from src.models.components.loss_metrics import supcon_loss +from src.models.components.conformer_tcm.model import MyConformer class SSLModel(nn.Module): @@ -155,52 +143,3 @@ def loss(self, output, feats, emb, labels, config, info=None): return {'L_CE':L_CE} elif self.loss_type == 5: return {'L_CF1':L_CF1, 'L_CF2':L_CF2} - -# if __name__ == '__main__': -# import yaml -# # LoRA -# from peft import LoraConfig, TaskType, PeftModel, get_peft_model - - -# config = yaml.load(open("/data/hungdx/asvspoof5/configs_21/xlsr_aasist_baseline.yaml", 'r'), Loader=yaml.FullLoader) -# device = "cpu" -# model = Model(config['model'], device) -# # Ensure model is moved to the correct device -# model = model.to(device) - -# # print("Hello") -# # print(model) - -# # Trying apply lora -# # Using the best config from this paper -# # https://arxiv.org/pdf/2306.05617 -# target_modules = ["q_proj", "v_proj"] -# r = 4 - -# lora_config = LoraConfig( -# r=r, -# target_modules=target_modules, -# #modules_to_save=config['lora']['modules_to_save'], # Modules to continue fine-tuning -# ) -# peft_model = get_peft_model(model, lora_config) -# #print(peft_model) -# peft_model.print_trainable_parameters() - -# x = torch.randn(1, 64000) -# x_ori = x.detach().clone() - -# x = x.to(device) - -# output, feats, emb = peft_model(x) - -# cross_entropy = nn.CrossEntropyLoss() -# labels = torch.tensor([1]) -# labels = labels.to(device) - -# loss = cross_entropy(output, labels) - -# loss.backward() - -# # Ensure that after backward pass, the x_ori and x still have the same value -# assert torch.allclose(x_ori, x.detach().clone()), "The input x has been changed after backward pass" -# print("Done") \ No newline at end of file diff --git a/src/models/wavlmvib_module.py b/src/models/wavlmvib_module.py new file mode 100644 index 0000000..776b1b8 --- /dev/null +++ b/src/models/wavlmvib_module.py @@ -0,0 +1,277 @@ +from typing import Any, Dict, Tuple + +import torch +from lightning import LightningModule +from torchmetrics import MaxMetric, MeanMetric +from torchmetrics.classification.accuracy import Accuracy + +import random +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +import fairseq +from src.models.components.wavlmbase_vib import Model as WavlmBaseVIB +from src.utils.debug import NaNErrorMode + + +class WAVLMVIBLLitModule(LightningModule): + """Example of a `LightningModule` for MNIST classification. + + A `LightningModule` implements 8 key methods: + + ```python + def __init__(self): + # Define initialization code here. + + def setup(self, stage): + # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'. + # This hook is called on every process when using DDP. + + def training_step(self, batch, batch_idx): + # The complete training step. + + def validation_step(self, batch, batch_idx): + # The complete validation step. + + def test_step(self, batch, batch_idx): + # The complete test step. + + def predict_step(self, batch, batch_idx): + # The complete predict step. + + def configure_optimizers(self): + # Define and configure optimizers and LR schedulers. + ``` + + Docs: + https://lightning.ai/docs/pytorch/latest/common/lightning_module.html + """ + + def __init__( + self, + net: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler, + compile: bool, + args: Union[Dict[str, Any], None] = None, + is_train: bool = True + ) -> None: + """Initialize a `MNISTLitModule`. + + :param net: The model to train. + :param optimizer: The optimizer to use for training. + :param scheduler: The learning rate scheduler to use for training. + """ + super().__init__() + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False) + + self.net = WavlmBaseVIB(args, is_train) + + # metric objects for calculating and averaging accuracy across batches + self.train_acc = Accuracy(task="multiclass", num_classes=2) + self.val_acc = Accuracy(task="multiclass", num_classes=2) + self.test_acc = Accuracy(task="multiclass", num_classes=2) + + # for averaging loss across batches + self.train_loss = MeanMetric() + self.val_loss = MeanMetric() + self.test_loss = MeanMetric() + + # for tracking best so far validation accuracy + self.val_acc_best = MaxMetric() + self.args = args + self.is_train = is_train + + # for optimizer and scheduler + print("We are in the WAVLMVIBLLitModule") + self.running_loss = 0.0 + self.num_total = 0.0 + self.train_loss_detail = {} + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Perform a forward pass through the model `self.net`. + + :param x: A tensor of images. + :return: A tensor of logits. + """ + return self.net(x) + + def on_train_start(self) -> None: + """Lightning hook that is called when training begins.""" + # by default lightning executes validation step sanity checks before training starts, + # so it's worth to make sure validation metrics don't store results from these checks + self.val_loss.reset() + self.val_acc.reset() + self.val_acc_best.reset() + + def model_step( + self, batch: Tuple[torch.Tensor, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Perform a single model step on a batch of data. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. + + :return: A tuple containing (in order): + - A tensor of losses. + - A tensor of predictions. + - A tensor of target labels. + - A dictionary of detailed losses. + """ + + train_loss = 0.0 + OC_info = { + 'previous_C': None, + 'previous_num_bona': None + } + + info, batch_x, batch_y = batch + + if len(batch_x.shape) == 3: + batch_x = batch_x.squeeze(0).transpose(0, 1) + + batch_y = batch_y.view(-1).type(torch.int64) + # print('batch_y', batch_y.shape) + self.num_total += batch_y.shape[0] + + with NaNErrorMode( + enabled=True, raise_error=False, print_stats=True, print_nan_index=False + ): + batch_out, batch_feat, batch_emb = self.net(batch_x) + # losses = loss_custom(batch_out, batch_feat, batch_emb, batch_y, config) + # OC for the OC loss + losses = self.net.loss(batch_out, batch_feat, + batch_emb, batch_y, self.args, OC_info) + + for key, value in losses.items(): + train_loss += value + self.train_loss_detail[key] = self.train_loss_detail.get( + key, 0) + value.item() + + self.running_loss += train_loss.item() + preds = torch.argmax(batch_out, dim=1) + #return loss, preds, y + return train_loss, preds, batch_y, self.train_loss_detail + + def training_step( + self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int + ) -> torch.Tensor: + """Perform a single training step on a batch of data from the training set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + :return: A tensor of losses between model predictions and targets. + """ + loss, preds, targets, train_loss_detail = self.model_step(batch) + + # update and log metrics + self.train_loss(self.running_loss) + self.train_acc(preds, targets) + self.log("train/loss", self.train_loss, on_step=True, on_epoch=True, prog_bar=True) + self.log("train/acc", self.train_acc, on_step=True, on_epoch=True, prog_bar=True) + + for loss_name, _loss in train_loss_detail.items(): + self.log(f"train/{loss_name}", _loss, on_step=True, on_epoch=True, prog_bar=True) + + return loss + + def on_train_epoch_end(self) -> None: + "Lightning hook that is called when a training epoch ends." + + # Reset metrics at the end of the epoch + self.running_loss = 0.0 + self.num_total = 0.0 + self.train_loss_detail = {} + + + def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + """Perform a single validation step on a batch of data from the validation set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + """ + loss, preds, targets, val_loss_detail = self.model_step(batch) + + # update and log metrics + self.val_loss(loss) + self.val_acc(preds, targets) + self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) + + for loss_name, _loss in val_loss_detail.items(): + self.log(f"train/{loss_name}", _loss, on_step=False, on_epoch=True, prog_bar=True) + + def on_validation_epoch_end(self) -> None: + "Lightning hook that is called when a validation epoch ends." + acc = self.val_acc.compute() # get current val acc + self.val_acc_best(acc) # update best so far val acc + # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object + # otherwise metric would be reset by lightning after each epoch + self.log("val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True) + + def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + """Perform a single test step on a batch of data from the test set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + """ + loss, preds, targets = self.model_step(batch) + + # update and log metrics + self.test_loss(loss) + self.test_acc(preds, targets) + self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True) + + def on_test_epoch_end(self) -> None: + """Lightning hook that is called when a test epoch ends.""" + pass + + def setup(self, stage: str) -> None: + """Lightning hook that is called at the beginning of fit (train + validate), validate, + test, or predict. + + This is a good hook when you need to build models dynamically or adjust something about + them. This hook is called on every process when using DDP. + + :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. + """ + # if self.hparams.compile and stage == "fit": + # self.net = torch.compile(self.net) + pass + + def configure_optimizers(self) -> Dict[str, Any]: + """Choose what optimizers and learning-rate schedulers to use in your optimization. + Normally you'd need one. But in the case of GANs or similar you might have multiple. + + Examples: + https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers + + :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. + """ + optimizer = self.hparams.optimizer(params=self.trainer.model.parameters()) + if self.hparams.scheduler is not None: + scheduler = self.hparams.scheduler(optimizer=optimizer) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val/loss", + "interval": "epoch", + "frequency": 1, + }, + } + return {"optimizer": optimizer} + + +# if __name__ == "__main__": +# _ = WAVLMVIBLLitModule(None, None, None, None) diff --git a/src/utils/debug.py b/src/utils/debug.py new file mode 100644 index 0000000..5879781 --- /dev/null +++ b/src/utils/debug.py @@ -0,0 +1,62 @@ +import torch +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_map +import itertools +import warnings +import random + +# adapted from https://github.com/albanD/subclass_zoo/blob/main/logging_mode.py + +class Lit: + def __init__(self, s): + self.s = s + + def __repr__(self): + return self.s + +def fmt(t: object, print_stats=False) -> str: + if isinstance(t, torch.Tensor): + s = f"torch.tensor(..., size={tuple(t.shape)}, dtype={t.dtype}, device='{t.device}')" + if print_stats: + s += f" [with stats min={t.min()}, max={t.max()}, mean={t.mean()}]" + return Lit(s) + else: + return t + +class NaNErrorMode(TorchDispatchMode): + def __init__( + self, enabled=True, raise_error=False, print_stats=True, print_nan_index=False + ): + self.enabled = enabled + # warning or error + self.raise_error = raise_error + # print min/max/mean stats + self.print_stats = print_stats + # print indices of invalid values in output + self.print_nan_index = print_nan_index + + def __torch_dispatch__(self, func, types, args, kwargs): + out = func(*args, **kwargs) + if self.enabled: + if isinstance(out, torch.Tensor): + if not torch.isfinite(out).all(): + # fmt_partial = partial(fmt, self.print_stats) + fmt_lambda = lambda t: fmt(t, self.print_stats) + fmt_args = ", ".join( + itertools.chain( + (repr(tree_map(fmt_lambda, a)) for a in args), + ( + f"{k}={tree_map(fmt_lambda, v)}" + for k, v in kwargs.items() + ), + ) + ) + msg = f"NaN outputs in out = {func}({fmt_args})" + if self.print_nan_index: + msg += f"\nInvalid values detected at:\n{(~out.isfinite()).nonzero()}" + if self.raise_error: + raise RuntimeError(msg) + else: + warnings.warn(msg) + + return out \ No newline at end of file From 3070a05d1bc9dce6b56537f5f4e4575744552b1b Mon Sep 17 00:00:00 2001 From: Hung Dinh Xuan Date: Sun, 29 Sep 2024 01:35:36 +0900 Subject: [PATCH 4/8] Add XLSR+Conformertcm for CyberpAICup Challenge --- configs/data/cyberaicup.yaml | 10 - configs/data/cyberaicup_track2.yaml | 35 +++ ...scl_normal_largecorpus_for_asvspoof5.yaml} | 21 +- .../xlsr_conformertcm_normal_cyberaicup.yaml | 58 ++++ configs/model/xlsr_conformertcm.yaml | 28 ++ configs/paths/default.yaml | 2 +- requirements.txt | 3 +- src/data/scl_datamodule.py | 6 +- src/models/components/loss_metrics.py | 2 +- src/models/components/xlsr_conformertcm.py | 19 +- src/models/wavlmvib_module.py | 19 +- src/models/xlsr_conformertcm_module.py | 278 ++++++++++++++++++ 12 files changed, 431 insertions(+), 50 deletions(-) delete mode 100644 configs/data/cyberaicup.yaml create mode 100644 configs/data/cyberaicup_track2.yaml rename configs/experiment/{test.yaml => wavlmbase_vib_scl_normal_largecorpus_for_asvspoof5.yaml} (79%) create mode 100644 configs/experiment/xlsr_conformertcm_normal_cyberaicup.yaml create mode 100644 configs/model/xlsr_conformertcm.yaml create mode 100644 src/models/xlsr_conformertcm_module.py diff --git a/configs/data/cyberaicup.yaml b/configs/data/cyberaicup.yaml deleted file mode 100644 index 28aea53..0000000 --- a/configs/data/cyberaicup.yaml +++ /dev/null @@ -1,10 +0,0 @@ -_target_: src.data.asvspoof_aasistssl_reproduce_datamodule.ASVSpoofDataModule -data_dir: ${oc.env:ASVSPOOF_PATH} -batch_size: 16 # Needs to be divisible by the number of devices (e.g., if in a distributed setup) -num_workers: 4 -pin_memory: False -args: - # The number of seconds to consider for each audio file - duration: 3 - # The sampling rate of the audio files - sampling_rate: 16000 \ No newline at end of file diff --git a/configs/data/cyberaicup_track2.yaml b/configs/data/cyberaicup_track2.yaml new file mode 100644 index 0000000..c10c155 --- /dev/null +++ b/configs/data/cyberaicup_track2.yaml @@ -0,0 +1,35 @@ +_target_: src.data.normal_datamodule.NormalDataModule +data_dir: ${oc.env:CYBERCUP2_PATH} +batch_size: 16 # Needs to be divisible by the number of devices (e.g., if in a distributed setup) +num_workers: 4 +pin_memory: False +args: + # The sampling rate of the audio files + portion: 0.2 # 20% of the data + nBands: 5 + minF: 20 + maxF: 8000 + minBW: 100 + maxBW: 1000 + minCoeff: 10 + maxCoeff: 100 + minG: 0 + maxG: 0 + minBiasLinNonLin: 5 + maxBiasLinNonLin: 20 + N_f: 5 + P: 10 + g_sd: 2 + SNRmin: 10 + SNRmax: 40 + + data: + augmentation_methods: [] + trim_length: 100000 # 6.25s + wav_samp_rate: 16000 + online_aug: true + aug_dir: ${oc.env:LARGE_CORPUS_FOR_ASVSPOOF5}/aug + noise_path: ${oc.env:NOISE_PATH} + rir_path: ${oc.env:RIR_PATH} + repeat_pad: false # If true, repeat the audio to the trim_length + random_start: false # If true, randomly pick a start point for the audio \ No newline at end of file diff --git a/configs/experiment/test.yaml b/configs/experiment/wavlmbase_vib_scl_normal_largecorpus_for_asvspoof5.yaml similarity index 79% rename from configs/experiment/test.yaml rename to configs/experiment/wavlmbase_vib_scl_normal_largecorpus_for_asvspoof5.yaml index f31e8db..572f175 100644 --- a/configs/experiment/test.yaml +++ b/configs/experiment/wavlmbase_vib_scl_normal_largecorpus_for_asvspoof5.yaml @@ -17,23 +17,18 @@ tags: ["scl_normal_largecorpus_for_asvspoof5", "wavlmbase_vib"] seed: 12345 trainer: - min_epochs: 10 - max_epochs: 10 + min_epochs: 30 + max_epochs: 100 gradient_clip_val: 0.5 accelerator: cuda + detect_anomaly: true model: optimizer: lr: 0.000001 args: loss_type: 1 - net: - d_args: - first_conv: 128 - filts: [70, [1, 32], [32, 32], [32, 64], [64, 64]] - gat_dims: [64, 32] - pool_ratios: [0.5, 0.7, 0.5, 0.5] - temperatures: [2.0, 2.0, 100.0, 100.0] + net: null compile: false scheduler: @@ -45,19 +40,17 @@ model: mode: "exp_range" gamma: 0.85 - - data: - batch_size: 1 + batch_size: 2 + batch_size_eval: 32 num_workers: 8 args: - portion: 0.1 # 10% of the data + portion: 1 # 100% of the data data: trim_length: 100000 # 6s repeat_pad: true # If true, repeat the audio to the trim_length random_start: true # If true, randomly pick a start point for the audio - logger: wandb: tags: ${tags} diff --git a/configs/experiment/xlsr_conformertcm_normal_cyberaicup.yaml b/configs/experiment/xlsr_conformertcm_normal_cyberaicup.yaml new file mode 100644 index 0000000..1f22d66 --- /dev/null +++ b/configs/experiment/xlsr_conformertcm_normal_cyberaicup.yaml @@ -0,0 +1,58 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /data: cyberaicup_track2 + - override /model: xlsr_conformertcm + - override /callbacks: default + - override /trainer: default + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["cyberaicup_track2", "xlsr_conformertcm"] + +seed: 12345 + +trainer: + min_epochs: 30 + max_epochs: 100 + gradient_clip_val: 0.5 + accelerator: cuda + detect_anomaly: true + +model: + optimizer: + lr: 0.000001 + args: + loss_type: 4 + net: null + compile: false + + scheduler: + _target_: torch.optim.lr_scheduler.CyclicLR + _partial_: true + cycle_momentum: false + base_lr: 0.000001 + max_lr: 0.00001 + mode: "exp_range" + gamma: 0.85 + +data: + batch_size: 10 + num_workers: 8 + args: + portion: 1 # 100% of the data + data: + trim_length: 160000 # 10s + repeat_pad: false # If true, repeat the audio to the trim_length + random_start: true # If true, randomly pick a start point for the audio + +logger: + wandb: + tags: ${tags} + group: "cyberaicup_track2" + aim: + experiment: "cyberaicup_track2" diff --git a/configs/model/xlsr_conformertcm.yaml b/configs/model/xlsr_conformertcm.yaml new file mode 100644 index 0000000..c18f927 --- /dev/null +++ b/configs/model/xlsr_conformertcm.yaml @@ -0,0 +1,28 @@ +_target_: src.models.xlsr_conformertcm_module.XLSRConformerTCMLitModule + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.000001 + weight_decay: 0.00001 + +scheduler: null + +args: + contra_mode: "all" # 'one' or 'all' + loss_type: 1 # 1: 2loss, 2: loss emb only, 3: loss final hidden state only, 4: CE loss + ce_loss_weight: 0.5 # weight of the cross entropy loss for bona fide class + conformer: + emb_size: 128 + heads: 4 + kernel_size: 32 + n_encoders: 4 + type: "conv_res2net_seblock" # conv_res2net, conv_seblock, conv_res2net_seblock + pooling: "first" + +cp_path: ${oc.env:XLSR_PRETRAINED_MODEL_PATH} + +is_train: true + +# compile model for faster training with pytorch 2.0 +compile: false diff --git a/configs/paths/default.yaml b/configs/paths/default.yaml index ec81db2..ca534cd 100644 --- a/configs/paths/default.yaml +++ b/configs/paths/default.yaml @@ -12,7 +12,7 @@ log_dir: ${paths.root_dir}/logs/ # path to output directory, created dynamically by hydra # path generation pattern is specified in `configs/hydra/default.yaml` # use it to store all files generated during the run, like ckpts and metrics -output_dir: ${hydra:runtime.output_dir} +output_dir: ${oc.env:OUTPUT_DIR} # path to working directory work_dir: ${hydra:runtime.cwd} diff --git a/requirements.txt b/requirements.txt index aeb3911..14b85a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,4 +25,5 @@ pytest # tests # --------- audio --------- # librosa==0.10.0 -pydub==0.25.1 \ No newline at end of file +pydub==0.25.1 +speechbrain==1.0.1 \ No newline at end of file diff --git a/src/data/scl_datamodule.py b/src/data/scl_datamodule.py index 8351d2a..1994d47 100644 --- a/src/data/scl_datamodule.py +++ b/src/data/scl_datamodule.py @@ -212,6 +212,7 @@ def __init__( num_workers: int = 0, pin_memory: bool = False, args: Optional[Dict[str, Any]] = None, + batch_size_eval: int = 32, # list_IDs: list = [], labels: list = [], base_dir: str = '', algo: int = 5, vocoders: list = [], # augmentation_methods: list = [], eval_augment: Optional[str] = None, num_additional_real: int = 2, num_additional_spoof: int = 2, # trim_length: int = 64000, wav_samp_rate: int = 16000, noise_path: Optional[str] = None, rir_path: Optional[str] = None, @@ -235,6 +236,7 @@ def __init__( self.data_test: Optional[Dataset] = None self.batch_size_per_device = batch_size + self.batch_size_eval = batch_size_eval self.data_dir = data_dir self.args = args @@ -329,7 +331,7 @@ def val_dataloader(self) -> DataLoader[Any]: """ return DataLoader( dataset=self.data_val, - batch_size=self.batch_size_per_device, + batch_size=self.batch_size_eval, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, shuffle=False, @@ -342,7 +344,7 @@ def test_dataloader(self) -> DataLoader[Any]: """ return DataLoader( dataset=self.data_test, - batch_size=self.batch_size_per_device, + batch_size=self.batch_size_eval, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, shuffle=False, diff --git a/src/models/components/loss_metrics.py b/src/models/components/loss_metrics.py index 3fc959f..7e854fb 100644 --- a/src/models/components/loss_metrics.py +++ b/src/models/components/loss_metrics.py @@ -200,7 +200,7 @@ def supcon_loss(input_feat, # print("log_prob.shape", log_prob.shape) # compute mean of log-likelihood over positive # print(mask_ * log_prob) - mean_log_prob_pos = (mask_ * log_prob).sum(1) / (mask_.sum(1) + 1e-8) # avoid zero division + mean_log_prob_pos = (mask_ * log_prob).sum(1) / mask_.sum(1) # print("mean_log_prob_pos.shape", mean_log_prob_pos.shape) # print(mean_log_prob_pos) # loss diff --git a/src/models/components/xlsr_conformertcm.py b/src/models/components/xlsr_conformertcm.py index db23bf4..d221655 100644 --- a/src/models/components/xlsr_conformertcm.py +++ b/src/models/components/xlsr_conformertcm.py @@ -14,10 +14,10 @@ class SSLModel(nn.Module): - def __init__(self): + def __init__(self, cp_path): super(SSLModel, self).__init__() - cp_path = '/data/hungdx/asvspoof5/model/pretrained/xlsr2_300m.pt' # Change the pre-trained XLSR model path. + #cp_path = '/data/hungdx/asvspoof5/model/pretrained/xlsr2_300m.pt' # Change the pre-trained XLSR model path. model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path]) self.model = model[0] @@ -26,11 +26,7 @@ def __init__(self): def extract_feat(self, input_data, is_train=True): - # put the model to GPU if it not there - # if next(self.model.parameters()).device != input_data.device \ - # or next(self.model.parameters()).dtype != input_data.dtype: - # self.model.to(input_data.device, dtype=input_data.dtype) - # self.model.train() + # input should be in shape (batch, length) if input_data.ndim == 3: @@ -49,20 +45,19 @@ def forward(self, x): class Model(nn.Module): - def __init__(self, args, device, is_train = True): + def __init__(self, args, cp_path, is_train = True): super().__init__() - self.device = device + self.is_train = is_train - self.flag_fix_ssl = args['flag_fix_ssl'] self.contra_mode = args['contra_mode'] self.loss_type = args['loss_type'] - self.loss_CE = nn.CrossEntropyLoss(weight = torch.FloatTensor([float(1-args['ce_loss_weight']), float(args['ce_loss_weight'])]).to(device)) + self.loss_CE = nn.CrossEntropyLoss(weight = torch.FloatTensor([float(1-args['ce_loss_weight']), float(args['ce_loss_weight'])])) #### # create network wav2vec 2.0 #### - self.front_end = SSLModel() + self.front_end = SSLModel(cp_path) self.LL = nn.Linear(self.front_end.out_dim, args['conformer']['emb_size']) self.first_bn = nn.BatchNorm2d(num_features=1) self.drop = nn.Dropout(0.5, inplace=True) diff --git a/src/models/wavlmvib_module.py b/src/models/wavlmvib_module.py index 776b1b8..ea09a60 100644 --- a/src/models/wavlmvib_module.py +++ b/src/models/wavlmvib_module.py @@ -3,7 +3,7 @@ import torch from lightning import LightningModule from torchmetrics import MaxMetric, MeanMetric -from torchmetrics.classification.accuracy import Accuracy +from torchmetrics.classification.accuracy import BinaryAccuracy import random from typing import Union @@ -75,9 +75,9 @@ def __init__( self.net = WavlmBaseVIB(args, is_train) # metric objects for calculating and averaging accuracy across batches - self.train_acc = Accuracy(task="multiclass", num_classes=2) - self.val_acc = Accuracy(task="multiclass", num_classes=2) - self.test_acc = Accuracy(task="multiclass", num_classes=2) + self.train_acc = BinaryAccuracy() + self.val_acc = BinaryAccuracy() + self.test_acc = BinaryAccuracy() # for averaging loss across batches self.train_loss = MeanMetric() @@ -141,7 +141,7 @@ def model_step( self.num_total += batch_y.shape[0] with NaNErrorMode( - enabled=True, raise_error=False, print_stats=True, print_nan_index=False + enabled=False, raise_error=False, print_stats=True, print_nan_index=False ): batch_out, batch_feat, batch_emb = self.net(batch_x) # losses = loss_custom(batch_out, batch_feat, batch_emb, batch_y, config) @@ -155,7 +155,8 @@ def model_step( key, 0) + value.item() self.running_loss += train_loss.item() - preds = torch.argmax(batch_out, dim=1) + _, preds = batch_out.max(dim=1) + #preds = torch.argmax(batch_out, dim=1) #return loss, preds, y return train_loss, preds, batch_y, self.train_loss_detail @@ -174,11 +175,11 @@ def training_step( # update and log metrics self.train_loss(self.running_loss) self.train_acc(preds, targets) - self.log("train/loss", self.train_loss, on_step=True, on_epoch=True, prog_bar=True) - self.log("train/acc", self.train_acc, on_step=True, on_epoch=True, prog_bar=True) + self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True) for loss_name, _loss in train_loss_detail.items(): - self.log(f"train/{loss_name}", _loss, on_step=True, on_epoch=True, prog_bar=True) + self.log(f"train/{loss_name}", _loss, on_step=False, on_epoch=True, prog_bar=True) return loss diff --git a/src/models/xlsr_conformertcm_module.py b/src/models/xlsr_conformertcm_module.py new file mode 100644 index 0000000..aa0eac3 --- /dev/null +++ b/src/models/xlsr_conformertcm_module.py @@ -0,0 +1,278 @@ +from typing import Any, Dict, Tuple + +import torch +from lightning import LightningModule +from torchmetrics import MaxMetric, MeanMetric +from torchmetrics.classification.accuracy import BinaryAccuracy + +import random +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +import fairseq +from src.models.components.wavlmbase_vib import Model as WavlmBaseVIB +from src.utils.debug import NaNErrorMode +from src.models.components.xlsr_conformertcm import Model as XLSRConformerTCM + +class XLSRConformerTCMLitModule(LightningModule): + """Example of a `LightningModule` for MNIST classification. + + A `LightningModule` implements 8 key methods: + + ```python + def __init__(self): + # Define initialization code here. + + def setup(self, stage): + # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'. + # This hook is called on every process when using DDP. + + def training_step(self, batch, batch_idx): + # The complete training step. + + def validation_step(self, batch, batch_idx): + # The complete validation step. + + def test_step(self, batch, batch_idx): + # The complete test step. + + def predict_step(self, batch, batch_idx): + # The complete predict step. + + def configure_optimizers(self): + # Define and configure optimizers and LR schedulers. + ``` + + Docs: + https://lightning.ai/docs/pytorch/latest/common/lightning_module.html + """ + + def __init__( + self, + net: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler, + compile: bool, + args: Union[Dict[str, Any], None] = None, + cp_path: str = None, + is_train: bool = True + ) -> None: + """Initialize a `MNISTLitModule`. + + :param net: The model to train. + :param optimizer: The optimizer to use for training. + :param scheduler: The learning rate scheduler to use for training. + """ + super().__init__() + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False) + + self.net = XLSRConformerTCM(args, cp_path, is_train) + + # metric objects for calculating and averaging accuracy across batches + self.train_acc = BinaryAccuracy() + self.val_acc = BinaryAccuracy() + self.test_acc = BinaryAccuracy() + + # for averaging loss across batches + self.train_loss = MeanMetric() + self.val_loss = MeanMetric() + self.test_loss = MeanMetric() + + # for tracking best so far validation accuracy + self.val_acc_best = MaxMetric() + self.args = args + self.is_train = is_train + + # for optimizer and scheduler + self.running_loss = 0.0 + self.num_total = 0.0 + self.train_loss_detail = {} + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Perform a forward pass through the model `self.net`. + + :param x: A tensor of images. + :return: A tensor of logits. + """ + return self.net(x) + + def on_train_start(self) -> None: + """Lightning hook that is called when training begins.""" + # by default lightning executes validation step sanity checks before training starts, + # so it's worth to make sure validation metrics don't store results from these checks + self.val_loss.reset() + self.val_acc.reset() + self.val_acc_best.reset() + + def model_step( + self, batch: Tuple[torch.Tensor, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Perform a single model step on a batch of data. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. + + :return: A tuple containing (in order): + - A tensor of losses. + - A tensor of predictions. + - A tensor of target labels. + - A dictionary of detailed losses. + """ + + train_loss = 0.0 + OC_info = { + 'previous_C': None, + 'previous_num_bona': None + } + + info, batch_x, batch_y = batch + + if len(batch_x.shape) == 3: + batch_x = batch_x.squeeze(0).transpose(0, 1) + + batch_y = batch_y.view(-1).type(torch.int64) + # print('batch_y', batch_y.shape) + self.num_total += batch_y.shape[0] + + with NaNErrorMode( + enabled=False, raise_error=False, print_stats=True, print_nan_index=False + ): + batch_out, batch_feat, batch_emb = self.net(batch_x) + # losses = loss_custom(batch_out, batch_feat, batch_emb, batch_y, config) + # OC for the OC loss + losses = self.net.loss(batch_out, batch_feat, + batch_emb, batch_y, self.args, OC_info) + + for key, value in losses.items(): + train_loss += value + self.train_loss_detail[key] = self.train_loss_detail.get( + key, 0) + value.item() + + self.running_loss += train_loss.item() + _, preds = batch_out.max(dim=1) + #preds = torch.argmax(batch_out, dim=1) + #return loss, preds, y + return train_loss, preds, batch_y, self.train_loss_detail + + def training_step( + self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int + ) -> torch.Tensor: + """Perform a single training step on a batch of data from the training set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + :return: A tensor of losses between model predictions and targets. + """ + loss, preds, targets, train_loss_detail = self.model_step(batch) + + # update and log metrics + self.train_loss(self.running_loss) + self.train_acc(preds, targets) + self.log("train/loss", self.train_loss, on_step=True, on_epoch=True, prog_bar=True) + self.log("train/acc", self.train_acc, on_step=True, on_epoch=True, prog_bar=True) + + for loss_name, _loss in train_loss_detail.items(): + self.log(f"train/{loss_name}", _loss, on_step=True, on_epoch=True, prog_bar=True) + + return loss + + def on_train_epoch_end(self) -> None: + "Lightning hook that is called when a training epoch ends." + + # Reset metrics at the end of the epoch + self.running_loss = 0.0 + self.num_total = 0.0 + self.train_loss_detail = {} + + + def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + """Perform a single validation step on a batch of data from the validation set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + """ + loss, preds, targets, val_loss_detail = self.model_step(batch) + + # update and log metrics + self.val_loss(loss) + self.val_acc(preds, targets) + self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) + + for loss_name, _loss in val_loss_detail.items(): + self.log(f"train/{loss_name}", _loss, on_step=False, on_epoch=True, prog_bar=True) + + def on_validation_epoch_end(self) -> None: + "Lightning hook that is called when a validation epoch ends." + acc = self.val_acc.compute() # get current val acc + self.val_acc_best(acc) # update best so far val acc + # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object + # otherwise metric would be reset by lightning after each epoch + self.log("val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True) + + def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + """Perform a single test step on a batch of data from the test set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + """ + loss, preds, targets = self.model_step(batch) + + # update and log metrics + self.test_loss(loss) + self.test_acc(preds, targets) + self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True) + + def on_test_epoch_end(self) -> None: + """Lightning hook that is called when a test epoch ends.""" + pass + + def setup(self, stage: str) -> None: + """Lightning hook that is called at the beginning of fit (train + validate), validate, + test, or predict. + + This is a good hook when you need to build models dynamically or adjust something about + them. This hook is called on every process when using DDP. + + :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. + """ + # if self.hparams.compile and stage == "fit": + # self.net = torch.compile(self.net) + pass + + def configure_optimizers(self) -> Dict[str, Any]: + """Choose what optimizers and learning-rate schedulers to use in your optimization. + Normally you'd need one. But in the case of GANs or similar you might have multiple. + + Examples: + https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers + + :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. + """ + optimizer = self.hparams.optimizer(params=self.trainer.model.parameters()) + if self.hparams.scheduler is not None: + scheduler = self.hparams.scheduler(optimizer=optimizer) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val/loss", + "interval": "epoch", + "frequency": 1, + }, + } + return {"optimizer": optimizer} + + +# if __name__ == "__main__": +# _ = WAVLMVIBLLitModule(None, None, None, None) From 4a08956e0a1e7edf727d11b62e459103e5f4b40c Mon Sep 17 00:00:00 2001 From: Hung Dinh Xuan Date: Mon, 30 Sep 2024 10:17:06 +0900 Subject: [PATCH 5/8] Add code for mixed training datamodule --- configs/data/cyberaicup_track2_mixed.yaml | 35 ++ ...conformertcm_normal_cyberaicup2_mixed.yaml | 58 +++ src/data/cyber2_mixed_normal_datamodule.py | 374 ++++++++++++++++++ 3 files changed, 467 insertions(+) create mode 100644 configs/data/cyberaicup_track2_mixed.yaml create mode 100644 configs/experiment/xlsr_conformertcm_normal_cyberaicup2_mixed.yaml create mode 100644 src/data/cyber2_mixed_normal_datamodule.py diff --git a/configs/data/cyberaicup_track2_mixed.yaml b/configs/data/cyberaicup_track2_mixed.yaml new file mode 100644 index 0000000..d43245e --- /dev/null +++ b/configs/data/cyberaicup_track2_mixed.yaml @@ -0,0 +1,35 @@ +_target_: src.data.cyber2_mixed_normal_datamodule.NormalDataModule +data_dir: ${oc.env:CYBERCUP2_PATH} +batch_size: 16 # Needs to be divisible by the number of devices (e.g., if in a distributed setup) +num_workers: 4 +pin_memory: False +args: + # The sampling rate of the audio files + portion: 0.2 # 20% of the data + nBands: 5 + minF: 20 + maxF: 8000 + minBW: 100 + maxBW: 1000 + minCoeff: 10 + maxCoeff: 100 + minG: 0 + maxG: 0 + minBiasLinNonLin: 5 + maxBiasLinNonLin: 20 + N_f: 5 + P: 10 + g_sd: 2 + SNRmin: 10 + SNRmax: 40 + + data: + augmentation_methods: [] + trim_length: 100000 # 6.25s + wav_samp_rate: 16000 + online_aug: true + aug_dir: ${oc.env:LARGE_CORPUS_FOR_ASVSPOOF5}/aug + noise_path: ${oc.env:NOISE_PATH} + rir_path: ${oc.env:RIR_PATH} + repeat_pad: false # If true, repeat the audio to the trim_length + random_start: false # If true, randomly pick a start point for the audio \ No newline at end of file diff --git a/configs/experiment/xlsr_conformertcm_normal_cyberaicup2_mixed.yaml b/configs/experiment/xlsr_conformertcm_normal_cyberaicup2_mixed.yaml new file mode 100644 index 0000000..28346f3 --- /dev/null +++ b/configs/experiment/xlsr_conformertcm_normal_cyberaicup2_mixed.yaml @@ -0,0 +1,58 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /data: cyberaicup_track2_mixed + - override /model: xlsr_conformertcm + - override /callbacks: default + - override /trainer: default + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["cyberaicup_track2_mixed", "xlsr_conformertcm"] + +seed: 12345 + +trainer: + min_epochs: 30 + max_epochs: 100 + gradient_clip_val: 0.5 + accelerator: cuda + detect_anomaly: true + +model: + optimizer: + lr: 0.000001 + args: + loss_type: 4 + net: null + compile: false + + scheduler: + _target_: torch.optim.lr_scheduler.CyclicLR + _partial_: true + cycle_momentum: false + base_lr: 0.000001 + max_lr: 0.00001 + mode: "exp_range" + gamma: 0.85 + +data: + batch_size: 10 + num_workers: 8 + args: + portion: 1 # 100% of the data + data: + trim_length: 160000 # 10s + repeat_pad: false # If true, repeat the audio to the trim_length + random_start: true # If true, randomly pick a start point for the audio + +logger: + wandb: + tags: ${tags} + group: "cyberaicup_track2_mixed" + aim: + experiment: "cyberaicup_track2_mixed" diff --git a/src/data/cyber2_mixed_normal_datamodule.py b/src/data/cyber2_mixed_normal_datamodule.py new file mode 100644 index 0000000..30127f6 --- /dev/null +++ b/src/data/cyber2_mixed_normal_datamodule.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from typing import Any, Dict, Optional, Tuple +from lightning import LightningDataModule +from torch.utils.data import DataLoader, Dataset +from torch import Tensor +import librosa +import numpy as np +import os +from scipy import signal +import copy +import random +from src.core_scripts.data_io import wav_augmentation as nii_wav_aug +from src.core_scripts.data_io import wav_tools as nii_wav_tools +from src.data.components.dataio import load_audio, pad +from src.data.components.baseloader import Dataset_base + +# augwrapper +from src.data.components.augwrapper import SUPPORTED_AUGMENTATION + +# dynamic import of augmentation methods +for aug in SUPPORTED_AUGMENTATION: + exec(f"from src.data.components.augwrapper import {aug}") + +class Dataset_for(Dataset_base): + def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[], + augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2, + trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None, + aug_dir=None, online_aug=False, repeat_pad=True, is_train=True, random_start=False): + super(Dataset_for, self).__init__(args, list_IDs, labels, base_dir, algo, vocoders, + augmentation_methods, eval_augment, num_additional_real, num_additional_spoof, + trim_length, wav_samp_rate, noise_path, rir_path, + aug_dir, online_aug, repeat_pad, is_train, random_start) + + def __getitem__(self, idx): + utt_id = self.list_IDs[idx] + filepath = os.path.join(self.base_dir, utt_id) + X = load_audio(filepath, self.sample_rate) + + # apply augmentation + # randomly choose an augmentation method + if self.is_train: + augmethod_index = random.choice(range(len(self.augmentation_methods))) if len(self.augmentation_methods) > 0 else -1 + if augmethod_index >= 0: + X = globals()[self.augmentation_methods[augmethod_index]](X, self.args, self.sample_rate, + audio_path = filepath) + + X_pad= pad(X,padding_type="repeat" if self.repeat_pad else "zero", max_len=self.trim_length, random_start=True) + x_inp= Tensor(X_pad) + target = self.labels[utt_id] + return idx, x_inp, target + + +class Dataset_for_dev(Dataset_base): + def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[], + augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2, + trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None, + aug_dir=None, online_aug=False, repeat_pad=False, is_train=True, + random_start=False + ): + super(Dataset_for_dev, self).__init__(args, list_IDs, labels, base_dir, algo, vocoders, + augmentation_methods, eval_augment, num_additional_real, num_additional_spoof, + trim_length, wav_samp_rate, noise_path, rir_path, + aug_dir, online_aug, repeat_pad, is_train, random_start) + + if repeat_pad: + self.padding_type = "repeat" + else: + self.padding_type = "zero" + + def __getitem__(self, index): + utt_id = self.list_IDs[index] + X, fs = librosa.load(self.base_dir + "/" + utt_id, sr=16000) + X_pad = pad(X,self.padding_type,self.trim_length, random_start=self.random_start) + x_inp = Tensor(X_pad) + target = self.labels[utt_id] + return index, x_inp, target + +class Dataset_for_eval(Dataset_base): + def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[], + augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2, + trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None, + aug_dir=None, online_aug=False, repeat_pad=True, is_train=True, enable_chunking=False, random_start=False + ): + super(Dataset_for_eval, self).__init__(args, list_IDs, labels, base_dir, algo, vocoders, + augmentation_methods, eval_augment, num_additional_real, num_additional_spoof, + trim_length, wav_samp_rate, noise_path, rir_path, + aug_dir, online_aug, repeat_pad, is_train, random_start) + self.enable_chunking = enable_chunking + if repeat_pad: + self.padding_type = "repeat" + else: + self.padding_type = "zero" + + def __getitem__(self, idx): + utt_id = self.list_IDs[idx] + filepath = os.path.join(self.base_dir, utt_id) + X, _ = librosa.load(filepath, sr=16000) + # apply augmentation at inference time + if self.eval_augment is not None: + # print("eval_augment:", self.eval_augment) + X = globals()[self.eval_augment](X, self.args, self.sample_rate, audio_path = filepath) + if not self.enable_chunking: + X= pad(X,padding_type=self.padding_type,max_len=self.trim_length, random_start=self.random_start) + x_inp = Tensor(X) + return x_inp, utt_id + +class NormalDataModule(LightningDataModule): + """`LightningDataModule` for the ASVSpoof dataset. + + The ASVspoof 2019 database for logical access is based upon a standard multi-speaker speech synthesis database called VCTK2. + Genuine speech is collected from 107 speakers (46 male, 61 female) and with no significant channel or background noise effects. + Spoofed speech is generated from the genuine data using a number of different spoofing algorithms. + The full dataset is partitioned into three subsets, the first for training, the second for development and the third for evaluation. + There is no speaker overlap across the three subsets regarding target speakers used in voice conversion or TTS adaptation. + + A `LightningDataModule` implements 7 key methods: + + ```python + def prepare_data(self): + # Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP). + # Download data, pre-process, split, save to disk, etc... + + def setup(self, stage): + # Things to do on every process in DDP. + # Load data, set variables, etc... + + def train_dataloader(self): + # return train dataloader + + def val_dataloader(self): + # return validation dataloader + + def test_dataloader(self): + # return test dataloader + + def predict_dataloader(self): + # return predict dataloader + + def teardown(self, stage): + # Called on every process in DDP. + # Clean up after fit or test. + ``` + + This allows you to share a full dataset without explaining how to download, + split, transform and process the data. + + Read the docs: + https://lightning.ai/docs/pytorch/latest/data/datamodule.html + """ + + def __init__( + self, + data_dir: str = "data/", + batch_size: int = 64, + num_workers: int = 0, + pin_memory: bool = False, + args: Optional[Dict[str, Any]] = None, + # list_IDs: list = [], labels: list = [], base_dir: str = '', algo: int = 5, vocoders: list = [], + # augmentation_methods: list = [], eval_augment: Optional[str] = None, num_additional_real: int = 2, num_additional_spoof: int = 2, + # trim_length: int = 64000, wav_samp_rate: int = 16000, noise_path: Optional[str] = None, rir_path: Optional[str] = None, + # aug_dir: Optional[str] = None, online_aug: bool = False, repeat_pad: bool = True, is_train: bool = True, enable_chunking: bool = False + ) -> None: + """Initialize a `ASVSpoofDataModule`. + + :param data_dir: The data directory. Defaults to `"data/"`. + :param batch_size: The batch size. Defaults to `64`. + :param num_workers: The number of workers. Defaults to `0`. + :param pin_memory: Whether to pin memory. Defaults to `False`. + """ + super().__init__() + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False) + + self.data_train: Optional[Dataset] = None + self.data_val: Optional[Dataset] = None + self.data_test: Optional[Dataset] = None + + self.batch_size_per_device = batch_size + self.data_dir = data_dir + + self.args = args + + @property + def num_classes(self) -> int: + """Get the number of classes. + + :return: The number of ASVSpoof classes (2). + """ + return 2 + + def prepare_data(self) -> None: + pass + + def setup(self, stage: Optional[str] = None) -> None: + """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. + + This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and + `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after + `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to + `self.setup()` once the data is prepared and available for use. + + :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``. + """ + # Divide batch size by the number of devices. + if self.trainer is not None: + if self.hparams.batch_size % self.trainer.world_size != 0: + raise RuntimeError( + f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})." + ) + self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size + + # load and split datasets only if not loaded already + if not self.data_train and not self.data_val and not self.data_test: + + # define train dataloader + + d_label_trn, file_train = self.genList(dir_meta=os.path.join( + self.data_dir), is_train=True, is_eval=False, is_dev=False) + + if 'portion' in self.args: + idx = range(len(file_train)) + idx = np.random.choice( + idx, int(len(file_train)*self.args['portion']), replace=False) + file_train = [file_train[i] for i in idx] + if len(d_label_trn) > 0: # the case of train without label + d_label_trn = {k: d_label_trn[k] for k in file_train} + print('no. of training trials', len(file_train)) + + d_label_dev, file_dev = self.genList(dir_meta=os.path.join(self.data_dir), is_train=False, is_eval=False, is_dev=True) + + if 'portion' in self.args: + idx = range(len(file_dev)) + idx = np.random.choice( + idx, int(len(file_dev)*self.args['portion']), replace=False) + file_dev = [file_dev[i] for i in idx] + if len(d_label_dev) > 0: # the case of train without label + d_label_dev = {k: d_label_dev[k] for k in file_dev} + + print('no. of validation trials', len(file_dev)) + file_eval = self.genList(dir_meta=os.path.join(self.data_dir), is_train=False, is_eval=True, is_dev=False) + + self.data_train = Dataset_for(self.args, list_IDs=file_train, labels=d_label_trn, + base_dir=self.data_dir+'/', **self.args['data']) + + self.data_val = Dataset_for_dev(self.args, list_IDs=file_dev, labels=d_label_dev, + base_dir=self.data_dir+'/', is_train=False, **self.args['data']) + + self.data_test = Dataset_for_eval(self.args, list_IDs=file_eval, labels=None, + base_dir=self.data_dir+'/', **self.args['data']) + + + def train_dataloader(self) -> DataLoader[Any]: + """Create and return the train dataloader. + + :return: The train dataloader. + """ + return DataLoader( + dataset=self.data_train, + batch_size=self.batch_size_per_device, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=True, + drop_last=True + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Create and return the validation dataloader. + + :return: The validation dataloader. + """ + return DataLoader( + dataset=self.data_val, + batch_size=self.batch_size_per_device, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Create and return the test dataloader. + + :return: The test dataloader. + """ + return DataLoader( + dataset=self.data_test, + batch_size=self.batch_size_per_device, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=False, + ) + + def teardown(self, stage: Optional[str] = None) -> None: + """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`, + `trainer.test()`, and `trainer.predict()`. + + :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. + Defaults to ``None``. + """ + pass + + def state_dict(self) -> Dict[Any, Any]: + """Called when saving a checkpoint. Implement to generate and save the datamodule state. + + :return: A dictionary containing the datamodule state that you want to save. + """ + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Called when loading a checkpoint. Implement to reload datamodule state given datamodule + `state_dict()`. + + :param state_dict: The datamodule state returned by `self.state_dict()`. + """ + pass + + def genList(self, dir_meta, is_train=False, is_eval=False, is_dev=False): + # bonafide: 1, spoof: 0 + d_meta = {} + file_list=[] + protocol = os.path.join(dir_meta, "protocol.txt") + + # Mixed protocol + mixed_protocol = "/home/hungdx/code/Lightning-hydra/data/CyberAICup2024/mix_rTTSD_vocoded.txt" + + if (is_train): + with open(protocol, 'r') as f: + l_meta = f.readlines() + for line in l_meta: + utt, subset, label = line.strip().split() + if subset == 'train': + file_list.append(utt) + d_meta[utt] = 1 if label == 'bonafide' else 0 + + # Add all the vocoded files + print("Adding all vocoded files to training set") + with open(mixed_protocol, 'r') as f: + l_meta = f.readlines() + + print("Number of vocoded files: ", len(l_meta)) + + for line in l_meta: + utt = line.strip() + d_meta[utt] = 0 # all vocoded files are spoof + file_list.append(utt) + + return d_meta, file_list + if (is_dev): + with open(protocol, 'r') as f: + l_meta = f.readlines() + for line in l_meta: + utt, subset, label = line.strip().split() + if subset == 'dev': + file_list.append(utt) + d_meta[utt] = 1 if label == 'bonafide' else 0 + return d_meta, file_list + + if (is_eval): + # no eval protocol yet + with open(protocol, 'r') as f: + l_meta = f.readlines() + for line in l_meta: + utt, subset, label = line.strip().split() + if subset == 'eval': + file_list.append(utt) + file_list.append(utt) + d_meta[utt] = 1 if label == 'bonafide' else 0 + # return d_meta, file_list + return d_meta, file_list + +if __name__ == "__main__": + _ = NormalDataModule() \ No newline at end of file From 366879755f0caf136353b4b43e3aefd38c3b6b27 Mon Sep 17 00:00:00 2001 From: Hung Dinh Xuan Date: Mon, 30 Sep 2024 13:55:56 +0900 Subject: [PATCH 6/8] Log epoch only --- src/models/xlsr_conformertcm_module.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/models/xlsr_conformertcm_module.py b/src/models/xlsr_conformertcm_module.py index aa0eac3..de14841 100644 --- a/src/models/xlsr_conformertcm_module.py +++ b/src/models/xlsr_conformertcm_module.py @@ -175,11 +175,11 @@ def training_step( # update and log metrics self.train_loss(self.running_loss) self.train_acc(preds, targets) - self.log("train/loss", self.train_loss, on_step=True, on_epoch=True, prog_bar=True) - self.log("train/acc", self.train_acc, on_step=True, on_epoch=True, prog_bar=True) + self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True) for loss_name, _loss in train_loss_detail.items(): - self.log(f"train/{loss_name}", _loss, on_step=True, on_epoch=True, prog_bar=True) + self.log(f"train/{loss_name}", _loss, on_step=False, on_epoch=True, prog_bar=True) return loss From 865c9f8afb7b258a4dc8a2f2df443b979630bab7 Mon Sep 17 00:00:00 2001 From: Hung Dinh Xuan Date: Thu, 3 Oct 2024 01:45:42 +0900 Subject: [PATCH 7/8] Add EER metrics and update default configs --- configs/callbacks/default.yaml | 8 +- configs/callbacks/learning_rate_logger.yaml | 3 + configs/eval_cyberaicup2.yaml | 77 +++++ ...conformertcm_normal_cyberaicup2_mixed.yaml | 29 +- configs/paths/default.yaml | 2 +- src/data/cyber2_mixed_normal_datamodule.py | 8 +- src/metrics/__init__.py | 0 src/metrics/eer.py | 49 +++ src/models/xlsr_conformertcm_module.py | 38 ++- src/models/xlsr_vib_module.py | 298 ++++++++++++++++++ src/train.py | 4 +- 11 files changed, 488 insertions(+), 28 deletions(-) create mode 100644 configs/callbacks/learning_rate_logger.yaml create mode 100644 configs/eval_cyberaicup2.yaml create mode 100644 src/metrics/__init__.py create mode 100644 src/metrics/eer.py create mode 100644 src/models/xlsr_vib_module.py diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml index c9bf2fb..e561c62 100644 --- a/configs/callbacks/default.yaml +++ b/configs/callbacks/default.yaml @@ -1,10 +1,14 @@ defaults: - model_checkpoint - early_stopping - - model_summary +# - model_summary - rich_progress_bar + - learning_rate_logger - _self_ +learning_rate_logger: + logging_interval: 'epoch' + model_checkpoint: dirpath: ${paths.output_dir}/checkpoints filename: "epoch_{epoch:03d}" @@ -15,7 +19,7 @@ model_checkpoint: early_stopping: monitor: "val/acc" - patience: 100 + patience: 20 mode: "max" model_summary: diff --git a/configs/callbacks/learning_rate_logger.yaml b/configs/callbacks/learning_rate_logger.yaml new file mode 100644 index 0000000..2b0e302 --- /dev/null +++ b/configs/callbacks/learning_rate_logger.yaml @@ -0,0 +1,3 @@ +learning_rate_logger: + _target_: lightning.pytorch.callbacks.LearningRateMonitor + logging_interval: 'step' \ No newline at end of file diff --git a/configs/eval_cyberaicup2.yaml b/configs/eval_cyberaicup2.yaml new file mode 100644 index 0000000..a6f9864 --- /dev/null +++ b/configs/eval_cyberaicup2.yaml @@ -0,0 +1,77 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /data: cyberaicup_track2_mixed + - override /model: xlsr_conformertcm + - override /callbacks: none + - override /trainer: default + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +#tags: ["cyberaicup_track2_mixed", "xlsr_conformertcm"] + +seed: 12345 + +trainer: + min_epochs: 30 + max_epochs: 100 + gradient_clip_val: 0.0 # 0.0 means don't clip + accelerator: cuda + +model: + optimizer: + lr: 0.00001 + args: + loss_type: 4 + net: null + compile: false + + scheduler: + _target_: torch.optim.lr_scheduler.CyclicLR + _partial_: true + cycle_momentum: false + base_lr: 0.000001 + max_lr: 0.00001 + mode: "exp_range" + gamma: 0.85 + + score_save_path: logs/eval/cyberaicup_track2_mixed_xlsr_conformertcm_epoch_25.txt + + +data: + batch_size: 10 + num_workers: 8 + args: + portion: 1 # 100% of the data + data: + trim_length: 160000 # 10s + repeat_pad: false # If true, repeat the audio to the trim_length + random_start: true # If true, randomly pick a start point for the audio + augmentation_methods: [] + +logger: null + + +# task name, determines output directory path +task_name: "eval" + +# tags to help you identify your experiments +# you can overwrite this in experiment configs +# overwrite from command line with `python train.py tags="[first_tag, second_tag]"` +# appending lists from command line is currently not supported :( +# https://github.com/facebookresearch/hydra/issues/1547 +tags: ["dev"] + +# set False to skip model training +train: False + +# evaluate on test set, using best model weights achieved during training +# lightning chooses best weights based on the metric specified in checkpoint callback +test: True + +# simply provide checkpoint path to resume training +ckpt_path: logs/train/checkpoints/last-v3.ckpt diff --git a/configs/experiment/xlsr_conformertcm_normal_cyberaicup2_mixed.yaml b/configs/experiment/xlsr_conformertcm_normal_cyberaicup2_mixed.yaml index 28346f3..20ae772 100644 --- a/configs/experiment/xlsr_conformertcm_normal_cyberaicup2_mixed.yaml +++ b/configs/experiment/xlsr_conformertcm_normal_cyberaicup2_mixed.yaml @@ -19,26 +19,34 @@ seed: 12345 trainer: min_epochs: 30 max_epochs: 100 - gradient_clip_val: 0.5 + gradient_clip_val: 0.0 # 0.0 means don't clip accelerator: cuda - detect_anomaly: true + #check_val_every_n_epoch: 0.5 # 0.5 means every 0.5 epoch model: optimizer: - lr: 0.000001 + lr: 0.000005 args: loss_type: 4 net: null - compile: false + compile: true + # scheduler: + # _target_: torch.optim.lr_scheduler.CyclicLR + # _partial_: true + # cycle_momentum: false + # base_lr: 0.000005 + # max_lr: 0.00001 + # mode: "exp_range" + # gamma: 0.85 + # verbose: true scheduler: - _target_: torch.optim.lr_scheduler.CyclicLR + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau _partial_: true - cycle_momentum: false - base_lr: 0.000001 - max_lr: 0.00001 - mode: "exp_range" - gamma: 0.85 + mode: min + factor: 0.5 + patience: 5 + verbose: true data: batch_size: 10 @@ -49,6 +57,7 @@ data: trim_length: 160000 # 10s repeat_pad: false # If true, repeat the audio to the trim_length random_start: true # If true, randomly pick a start point for the audio + augmentation_methods: [] logger: wandb: diff --git a/configs/paths/default.yaml b/configs/paths/default.yaml index ca534cd..ec81db2 100644 --- a/configs/paths/default.yaml +++ b/configs/paths/default.yaml @@ -12,7 +12,7 @@ log_dir: ${paths.root_dir}/logs/ # path to output directory, created dynamically by hydra # path generation pattern is specified in `configs/hydra/default.yaml` # use it to store all files generated during the run, like ckpts and metrics -output_dir: ${oc.env:OUTPUT_DIR} +output_dir: ${hydra:runtime.output_dir} # path to working directory work_dir: ${hydra:runtime.cwd} diff --git a/src/data/cyber2_mixed_normal_datamodule.py b/src/data/cyber2_mixed_normal_datamodule.py index 30127f6..71fe99e 100644 --- a/src/data/cyber2_mixed_normal_datamodule.py +++ b/src/data/cyber2_mixed_normal_datamodule.py @@ -363,12 +363,12 @@ def genList(self, dir_meta, is_train=False, is_eval=False, is_dev=False): l_meta = f.readlines() for line in l_meta: utt, subset, label = line.strip().split() - if subset == 'eval': - file_list.append(utt) + #if subset == 'dev': + # file_list.append(utt) file_list.append(utt) - d_meta[utt] = 1 if label == 'bonafide' else 0 + #d_meta[utt] = 1 if label == 'bonafide' else 0 # return d_meta, file_list - return d_meta, file_list + return file_list if __name__ == "__main__": _ = NormalDataModule() \ No newline at end of file diff --git a/src/metrics/__init__.py b/src/metrics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/metrics/eer.py b/src/metrics/eer.py new file mode 100644 index 0000000..5eed473 --- /dev/null +++ b/src/metrics/eer.py @@ -0,0 +1,49 @@ +import torch +import numpy as np +from torch import Tensor +from torchmetrics import Metric + +def compute_det_curve(target_scores, nontarget_scores): + n_scores = target_scores.size + nontarget_scores.size + all_scores = np.concatenate((target_scores, nontarget_scores)) + labels = np.concatenate((np.ones(target_scores.size), np.zeros(nontarget_scores.size))) + + indices = np.argsort(all_scores, kind='mergesort') + labels = labels[indices] + + tar_trial_sums = np.cumsum(labels) + nontarget_trial_sums = nontarget_scores.size - (np.arange(1, n_scores + 1) - tar_trial_sums) + + frr = np.concatenate((np.atleast_1d(0), tar_trial_sums / target_scores.size)) + far = np.concatenate((np.atleast_1d(1), nontarget_trial_sums / nontarget_scores.size)) + thresholds = np.concatenate((np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices])) + + return frr, far, thresholds + +def compute_eer(pred, label): + if isinstance(pred, torch.Tensor): + pred, label = pred.cpu().numpy(), label.cpu().numpy() + pred, label = pred.flatten(), label.flatten() + target_scores = pred[label == 1.] + nontarget_scores = pred[label == 0.] + frr, far, thresholds = compute_det_curve(target_scores, nontarget_scores) + abs_diffs = np.abs(frr - far) + min_index = np.argmin(abs_diffs) + eer = np.mean((frr[min_index], far[min_index])) + return eer, thresholds[min_index] + +class EERMetric(Metric): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("targets", default=[], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: + self.preds.append(preds) + self.targets.append(target) + + def compute(self) -> Tensor: + preds = torch.cat(self.preds, dim=0) + targets = torch.cat(self.targets, dim=0) + eer, _ = compute_eer(preds, targets) + return torch.tensor(eer) diff --git a/src/models/xlsr_conformertcm_module.py b/src/models/xlsr_conformertcm_module.py index de14841..aba5083 100644 --- a/src/models/xlsr_conformertcm_module.py +++ b/src/models/xlsr_conformertcm_module.py @@ -17,6 +17,7 @@ from src.models.components.wavlmbase_vib import Model as WavlmBaseVIB from src.utils.debug import NaNErrorMode from src.models.components.xlsr_conformertcm import Model as XLSRConformerTCM +from src.metrics.eer import EERMetric class XLSRConformerTCMLitModule(LightningModule): """Example of a `LightningModule` for MNIST classification. @@ -59,7 +60,8 @@ def __init__( compile: bool, args: Union[Dict[str, Any], None] = None, cp_path: str = None, - is_train: bool = True + is_train: bool = True, + score_save_path: str = None, ) -> None: """Initialize a `MNISTLitModule`. @@ -79,6 +81,8 @@ def __init__( self.train_acc = BinaryAccuracy() self.val_acc = BinaryAccuracy() self.test_acc = BinaryAccuracy() + self.score_save_path = score_save_path + #self.test_eer = EERMetric() # for averaging loss across batches self.train_loss = MeanMetric() @@ -204,11 +208,11 @@ def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: i # update and log metrics self.val_loss(loss) self.val_acc(preds, targets) - self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True) - self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) + self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) + self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) for loss_name, _loss in val_loss_detail.items(): - self.log(f"train/{loss_name}", _loss, on_step=False, on_epoch=True, prog_bar=True) + self.log(f"train/{loss_name}", _loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) def on_validation_epoch_end(self) -> None: "Lightning hook that is called when a validation epoch ends." @@ -225,16 +229,31 @@ def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> labels. :param batch_idx: The index of the current batch. """ - loss, preds, targets = self.model_step(batch) + batch_x, utt_id = batch + batch_out = self.net(batch_x) # update and log metrics - self.test_loss(loss) - self.test_acc(preds, targets) - self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True) - self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True) + # self.test_loss(loss) + # self.test_acc(preds, targets) + # self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True) + # self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True) + # Write scores to file + fname_list = list(utt_id) + score_list = batch_out.data.cpu().numpy().tolist() + + with open(self.score_save_path, 'a+') as fh: + for f, cm in zip(fname_list, score_list): + fh.write('{} {} {}\n'.format(f, cm[0], cm[1])) + + def on_test_epoch_start(self) -> None: + """Lightning hook that is called when a test epoch starts.""" + self.net.is_train = False def on_test_epoch_end(self) -> None: """Lightning hook that is called when a test epoch ends.""" + # eer = self.test_eer.compute() + # self.log("test/eer", eer, sync_dist=True, prog_bar=True) + # self.test_eer.reset() pass def setup(self, stage: str) -> None: @@ -261,6 +280,7 @@ def configure_optimizers(self) -> Dict[str, Any]: """ optimizer = self.hparams.optimizer(params=self.trainer.model.parameters()) if self.hparams.scheduler is not None: + print("Using LR Scheduler") scheduler = self.hparams.scheduler(optimizer=optimizer) return { "optimizer": optimizer, diff --git a/src/models/xlsr_vib_module.py b/src/models/xlsr_vib_module.py new file mode 100644 index 0000000..aba5083 --- /dev/null +++ b/src/models/xlsr_vib_module.py @@ -0,0 +1,298 @@ +from typing import Any, Dict, Tuple + +import torch +from lightning import LightningModule +from torchmetrics import MaxMetric, MeanMetric +from torchmetrics.classification.accuracy import BinaryAccuracy + +import random +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +import fairseq +from src.models.components.wavlmbase_vib import Model as WavlmBaseVIB +from src.utils.debug import NaNErrorMode +from src.models.components.xlsr_conformertcm import Model as XLSRConformerTCM +from src.metrics.eer import EERMetric + +class XLSRConformerTCMLitModule(LightningModule): + """Example of a `LightningModule` for MNIST classification. + + A `LightningModule` implements 8 key methods: + + ```python + def __init__(self): + # Define initialization code here. + + def setup(self, stage): + # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'. + # This hook is called on every process when using DDP. + + def training_step(self, batch, batch_idx): + # The complete training step. + + def validation_step(self, batch, batch_idx): + # The complete validation step. + + def test_step(self, batch, batch_idx): + # The complete test step. + + def predict_step(self, batch, batch_idx): + # The complete predict step. + + def configure_optimizers(self): + # Define and configure optimizers and LR schedulers. + ``` + + Docs: + https://lightning.ai/docs/pytorch/latest/common/lightning_module.html + """ + + def __init__( + self, + net: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler, + compile: bool, + args: Union[Dict[str, Any], None] = None, + cp_path: str = None, + is_train: bool = True, + score_save_path: str = None, + ) -> None: + """Initialize a `MNISTLitModule`. + + :param net: The model to train. + :param optimizer: The optimizer to use for training. + :param scheduler: The learning rate scheduler to use for training. + """ + super().__init__() + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False) + + self.net = XLSRConformerTCM(args, cp_path, is_train) + + # metric objects for calculating and averaging accuracy across batches + self.train_acc = BinaryAccuracy() + self.val_acc = BinaryAccuracy() + self.test_acc = BinaryAccuracy() + self.score_save_path = score_save_path + #self.test_eer = EERMetric() + + # for averaging loss across batches + self.train_loss = MeanMetric() + self.val_loss = MeanMetric() + self.test_loss = MeanMetric() + + # for tracking best so far validation accuracy + self.val_acc_best = MaxMetric() + self.args = args + self.is_train = is_train + + # for optimizer and scheduler + self.running_loss = 0.0 + self.num_total = 0.0 + self.train_loss_detail = {} + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Perform a forward pass through the model `self.net`. + + :param x: A tensor of images. + :return: A tensor of logits. + """ + return self.net(x) + + def on_train_start(self) -> None: + """Lightning hook that is called when training begins.""" + # by default lightning executes validation step sanity checks before training starts, + # so it's worth to make sure validation metrics don't store results from these checks + self.val_loss.reset() + self.val_acc.reset() + self.val_acc_best.reset() + + def model_step( + self, batch: Tuple[torch.Tensor, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Perform a single model step on a batch of data. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. + + :return: A tuple containing (in order): + - A tensor of losses. + - A tensor of predictions. + - A tensor of target labels. + - A dictionary of detailed losses. + """ + + train_loss = 0.0 + OC_info = { + 'previous_C': None, + 'previous_num_bona': None + } + + info, batch_x, batch_y = batch + + if len(batch_x.shape) == 3: + batch_x = batch_x.squeeze(0).transpose(0, 1) + + batch_y = batch_y.view(-1).type(torch.int64) + # print('batch_y', batch_y.shape) + self.num_total += batch_y.shape[0] + + with NaNErrorMode( + enabled=False, raise_error=False, print_stats=True, print_nan_index=False + ): + batch_out, batch_feat, batch_emb = self.net(batch_x) + # losses = loss_custom(batch_out, batch_feat, batch_emb, batch_y, config) + # OC for the OC loss + losses = self.net.loss(batch_out, batch_feat, + batch_emb, batch_y, self.args, OC_info) + + for key, value in losses.items(): + train_loss += value + self.train_loss_detail[key] = self.train_loss_detail.get( + key, 0) + value.item() + + self.running_loss += train_loss.item() + _, preds = batch_out.max(dim=1) + #preds = torch.argmax(batch_out, dim=1) + #return loss, preds, y + return train_loss, preds, batch_y, self.train_loss_detail + + def training_step( + self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int + ) -> torch.Tensor: + """Perform a single training step on a batch of data from the training set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + :return: A tensor of losses between model predictions and targets. + """ + loss, preds, targets, train_loss_detail = self.model_step(batch) + + # update and log metrics + self.train_loss(self.running_loss) + self.train_acc(preds, targets) + self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True) + + for loss_name, _loss in train_loss_detail.items(): + self.log(f"train/{loss_name}", _loss, on_step=False, on_epoch=True, prog_bar=True) + + return loss + + def on_train_epoch_end(self) -> None: + "Lightning hook that is called when a training epoch ends." + + # Reset metrics at the end of the epoch + self.running_loss = 0.0 + self.num_total = 0.0 + self.train_loss_detail = {} + + + def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + """Perform a single validation step on a batch of data from the validation set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + """ + loss, preds, targets, val_loss_detail = self.model_step(batch) + + # update and log metrics + self.val_loss(loss) + self.val_acc(preds, targets) + self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) + self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) + + for loss_name, _loss in val_loss_detail.items(): + self.log(f"train/{loss_name}", _loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) + + def on_validation_epoch_end(self) -> None: + "Lightning hook that is called when a validation epoch ends." + acc = self.val_acc.compute() # get current val acc + self.val_acc_best(acc) # update best so far val acc + # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object + # otherwise metric would be reset by lightning after each epoch + self.log("val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True) + + def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + """Perform a single test step on a batch of data from the test set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + """ + batch_x, utt_id = batch + batch_out = self.net(batch_x) + + # update and log metrics + # self.test_loss(loss) + # self.test_acc(preds, targets) + # self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True) + # self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True) + # Write scores to file + fname_list = list(utt_id) + score_list = batch_out.data.cpu().numpy().tolist() + + with open(self.score_save_path, 'a+') as fh: + for f, cm in zip(fname_list, score_list): + fh.write('{} {} {}\n'.format(f, cm[0], cm[1])) + + def on_test_epoch_start(self) -> None: + """Lightning hook that is called when a test epoch starts.""" + self.net.is_train = False + + def on_test_epoch_end(self) -> None: + """Lightning hook that is called when a test epoch ends.""" + # eer = self.test_eer.compute() + # self.log("test/eer", eer, sync_dist=True, prog_bar=True) + # self.test_eer.reset() + pass + + def setup(self, stage: str) -> None: + """Lightning hook that is called at the beginning of fit (train + validate), validate, + test, or predict. + + This is a good hook when you need to build models dynamically or adjust something about + them. This hook is called on every process when using DDP. + + :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. + """ + # if self.hparams.compile and stage == "fit": + # self.net = torch.compile(self.net) + pass + + def configure_optimizers(self) -> Dict[str, Any]: + """Choose what optimizers and learning-rate schedulers to use in your optimization. + Normally you'd need one. But in the case of GANs or similar you might have multiple. + + Examples: + https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers + + :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. + """ + optimizer = self.hparams.optimizer(params=self.trainer.model.parameters()) + if self.hparams.scheduler is not None: + print("Using LR Scheduler") + scheduler = self.hparams.scheduler(optimizer=optimizer) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val/loss", + "interval": "epoch", + "frequency": 1, + }, + } + return {"optimizer": optimizer} + + +# if __name__ == "__main__": +# _ = WAVLMVIBLLitModule(None, None, None, None) diff --git a/src/train.py b/src/train.py index 4adbcf4..cec00ca 100644 --- a/src/train.py +++ b/src/train.py @@ -92,8 +92,8 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: log.info("Starting testing!") ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "": - log.warning("Best ckpt not found! Using current weights for testing...") - ckpt_path = None + ckpt_path = cfg.get("ckpt_path") + log.warning(f"Using weights {ckpt_path} for testing...") trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) log.info(f"Best ckpt path: {ckpt_path}") From 0ed2cf231deed928786f3fc260731b1f03435a94 Mon Sep 17 00:00:00 2001 From: Hung Dinh Xuan Date: Thu, 3 Oct 2024 21:42:24 +0900 Subject: [PATCH 8/8] Finish code for cyberaicup challenge, --- configs/callbacks/default.yaml | 1 + ...conformertcm_normal_cyberaicup2_mixed.yaml | 12 +-- src/data/cyber2_mixed_normal_datamodule.py | 90 +++++++++---------- src/models/xlsr_conformertcm_module.py | 6 +- 4 files changed, 55 insertions(+), 54 deletions(-) diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml index e561c62..4c78cc0 100644 --- a/configs/callbacks/default.yaml +++ b/configs/callbacks/default.yaml @@ -16,6 +16,7 @@ model_checkpoint: mode: "max" save_last: True auto_insert_metric_name: False + save_top_k: 5 # save k best models (determined by above metric) early_stopping: monitor: "val/acc" diff --git a/configs/experiment/xlsr_conformertcm_normal_cyberaicup2_mixed.yaml b/configs/experiment/xlsr_conformertcm_normal_cyberaicup2_mixed.yaml index 20ae772..7f2a2a0 100644 --- a/configs/experiment/xlsr_conformertcm_normal_cyberaicup2_mixed.yaml +++ b/configs/experiment/xlsr_conformertcm_normal_cyberaicup2_mixed.yaml @@ -4,7 +4,7 @@ # python train.py experiment=example defaults: - - override /data: cyberaicup_track2_mixed + - override /data: cyberaicup_track2 - override /model: xlsr_conformertcm - override /callbacks: default - override /trainer: default @@ -12,7 +12,7 @@ defaults: # all parameters below will be merged with parameters from default configurations set above # this allows you to overwrite only specified parameters -tags: ["cyberaicup_track2_mixed", "xlsr_conformertcm"] +tags: ["cyberaicup_track2", "xlsr_conformertcm"] seed: 12345 @@ -22,10 +22,11 @@ trainer: gradient_clip_val: 0.0 # 0.0 means don't clip accelerator: cuda #check_val_every_n_epoch: 0.5 # 0.5 means every 0.5 epoch + #strategy: ddp_find_unused_parameters_false model: optimizer: - lr: 0.000005 + lr: 0.00001 args: loss_type: 4 net: null @@ -51,6 +52,7 @@ model: data: batch_size: 10 num_workers: 8 + pin_memory: true args: portion: 1 # 100% of the data data: @@ -62,6 +64,6 @@ data: logger: wandb: tags: ${tags} - group: "cyberaicup_track2_mixed" + group: "cyberaicup_track2" aim: - experiment: "cyberaicup_track2_mixed" + experiment: "cyberaicup_track2" diff --git a/src/data/cyber2_mixed_normal_datamodule.py b/src/data/cyber2_mixed_normal_datamodule.py index 71fe99e..8598303 100644 --- a/src/data/cyber2_mixed_normal_datamodule.py +++ b/src/data/cyber2_mixed_normal_datamodule.py @@ -156,10 +156,7 @@ def __init__( num_workers: int = 0, pin_memory: bool = False, args: Optional[Dict[str, Any]] = None, - # list_IDs: list = [], labels: list = [], base_dir: str = '', algo: int = 5, vocoders: list = [], - # augmentation_methods: list = [], eval_augment: Optional[str] = None, num_additional_real: int = 2, num_additional_spoof: int = 2, - # trim_length: int = 64000, wav_samp_rate: int = 16000, noise_path: Optional[str] = None, rir_path: Optional[str] = None, - # aug_dir: Optional[str] = None, online_aug: bool = False, repeat_pad: bool = True, is_train: bool = True, enable_chunking: bool = False + train_mode: Optional[bool] = True, # Optional[bool] = True ) -> None: """Initialize a `ASVSpoofDataModule`. @@ -180,7 +177,7 @@ def __init__( self.batch_size_per_device = batch_size self.data_dir = data_dir - + self.train_mode = train_mode self.args = args @property @@ -216,40 +213,40 @@ def setup(self, stage: Optional[str] = None) -> None: if not self.data_train and not self.data_val and not self.data_test: # define train dataloader - - d_label_trn, file_train = self.genList(dir_meta=os.path.join( - self.data_dir), is_train=True, is_eval=False, is_dev=False) - - if 'portion' in self.args: - idx = range(len(file_train)) - idx = np.random.choice( - idx, int(len(file_train)*self.args['portion']), replace=False) - file_train = [file_train[i] for i in idx] - if len(d_label_trn) > 0: # the case of train without label - d_label_trn = {k: d_label_trn[k] for k in file_train} - print('no. of training trials', len(file_train)) - - d_label_dev, file_dev = self.genList(dir_meta=os.path.join(self.data_dir), is_train=False, is_eval=False, is_dev=True) - - if 'portion' in self.args: - idx = range(len(file_dev)) - idx = np.random.choice( - idx, int(len(file_dev)*self.args['portion']), replace=False) - file_dev = [file_dev[i] for i in idx] - if len(d_label_dev) > 0: # the case of train without label - d_label_dev = {k: d_label_dev[k] for k in file_dev} - - print('no. of validation trials', len(file_dev)) - file_eval = self.genList(dir_meta=os.path.join(self.data_dir), is_train=False, is_eval=True, is_dev=False) - - self.data_train = Dataset_for(self.args, list_IDs=file_train, labels=d_label_trn, - base_dir=self.data_dir+'/', **self.args['data']) - - self.data_val = Dataset_for_dev(self.args, list_IDs=file_dev, labels=d_label_dev, - base_dir=self.data_dir+'/', is_train=False, **self.args['data']) - - self.data_test = Dataset_for_eval(self.args, list_IDs=file_eval, labels=None, - base_dir=self.data_dir+'/', **self.args['data']) + if self.self.train_mode: + d_label_trn, file_train = self.genList(dir_meta=os.path.join( + self.data_dir), is_train=True, is_eval=False, is_dev=False) + + if 'portion' in self.args: + idx = range(len(file_train)) + idx = np.random.choice( + idx, int(len(file_train)*self.args['portion']), replace=False) + file_train = [file_train[i] for i in idx] + if len(d_label_trn) > 0: # the case of train without label + d_label_trn = {k: d_label_trn[k] for k in file_train} + print('no. of training trials', len(file_train)) + + d_label_dev, file_dev = self.genList(dir_meta=os.path.join(self.data_dir), is_train=False, is_eval=False, is_dev=True) + + if 'portion' in self.args: + idx = range(len(file_dev)) + idx = np.random.choice( + idx, int(len(file_dev)*self.args['portion']), replace=False) + file_dev = [file_dev[i] for i in idx] + if len(d_label_dev) > 0: # the case of train without label + d_label_dev = {k: d_label_dev[k] for k in file_dev} + + print('no. of validation trials', len(file_dev)) + + self.data_train = Dataset_for(self.args, list_IDs=file_train, labels=d_label_trn, + base_dir=self.data_dir+'/', **self.args['data']) + + self.data_val = Dataset_for_dev(self.args, list_IDs=file_dev, labels=d_label_dev, + base_dir=self.data_dir+'/', is_train=False, **self.args['data']) + else: + file_eval = self.genList(dir_meta=os.path.join(self.data_dir), is_train=False, is_eval=True, is_dev=False) + self.data_test = Dataset_for_eval(self.args, list_IDs=file_eval, labels=None, + base_dir=self.data_dir+'/', **self.args['data']) def train_dataloader(self) -> DataLoader[Any]: @@ -359,15 +356,16 @@ def genList(self, dir_meta, is_train=False, is_eval=False, is_dev=False): if (is_eval): # no eval protocol yet - with open(protocol, 'r') as f: + # with open(protocol, 'r') as f: + # l_meta = f.readlines() + # for line in l_meta: + # utt, subset, label = line.strip().split() + # file_list.append(utt) + with open(os.path.join(dir_meta, "eval.txt"), 'r') as f: l_meta = f.readlines() for line in l_meta: - utt, subset, label = line.strip().split() - #if subset == 'dev': - # file_list.append(utt) - file_list.append(utt) - #d_meta[utt] = 1 if label == 'bonafide' else 0 - # return d_meta, file_list + file_list.append(line.strip()) + print("Number of eval files: ", len(file_list)) return file_list if __name__ == "__main__": diff --git a/src/models/xlsr_conformertcm_module.py b/src/models/xlsr_conformertcm_module.py index aba5083..ab661f6 100644 --- a/src/models/xlsr_conformertcm_module.py +++ b/src/models/xlsr_conformertcm_module.py @@ -179,11 +179,11 @@ def training_step( # update and log metrics self.train_loss(self.running_loss) self.train_acc(preds, targets) - self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True) - self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True) + self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) + self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) for loss_name, _loss in train_loss_detail.items(): - self.log(f"train/{loss_name}", _loss, on_step=False, on_epoch=True, prog_bar=True) + self.log(f"train/{loss_name}", _loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) return loss