Skip to content

Commit

Permalink
add conf for beta-2 and beta-3, update dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
hungdinhxuan committed Jan 3, 2025
1 parent 3b5399a commit 1ca86ab
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 81 deletions.
14 changes: 14 additions & 0 deletions configs/data/normal_largecorpus_multiview_cnsl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ args:
g_sd: 2
SNRmin: 10
SNRmax: 40

view_padding_configs:
'1':
padding_type: repeat
random_start: False
'2':
padding_type: repeat
random_start: False
'3':
padding_type: repeat
random_start: False
'4':
padding_type: repeat
random_start: False

data:
augmentation_methods:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
defaults:
- override /data: normal_largecorpus_multiview_cnsl
- override /model: xlsr_aasist_multiview
- override /callbacks: default_loss_earlystop
- override /callbacks: default_loss_earlystop_lts
- override /trainer: default

# all parameters below will be merged with parameters from default configurations set above
Expand Down Expand Up @@ -34,10 +34,14 @@ data:
batch_size: 14
num_workers: 8
pin_memory: true

args:
protocol_path: ${oc.env:MY_EXTENDED_ASVSPOOF_PROTOCOLS}
chunk_size: 16000 # 1 sec
overlap_size: 8000 # 0.5 sec
random_start: False
trim_length: 64600
padding_type: repeat
view_padding_configs:
'1':
padding_type: repeat
Expand All @@ -51,10 +55,11 @@ data:
'4':
padding_type: repeat
random_start: False

data:
data:
augmentation_methods: ["RawBoostdf"]
random_start: False
aug_dir: ${oc.env:MY_EXTENDED_ASVSPOOF_PATH}/aug
repeat_pad: True
online_aug: True

logger:
wandb:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
defaults:
- override /data: normal_largecorpus_multiview_cnsl
- override /model: xlsr_aasist_multiview
- override /callbacks: default_loss_earlystop
- override /callbacks: default_loss_earlystop_lts
- override /trainer: default

# all parameters below will be merged with parameters from default configurations set above
Expand Down Expand Up @@ -34,10 +34,14 @@ data:
batch_size: 14
num_workers: 8
pin_memory: true

args:
protocol_path: ${oc.env:MY_EXTENDED_ASVSPOOF_PROTOCOLS}
chunk_size: 16000 # 1 sec
overlap_size: 8000 # 0.5 sec
random_start: False
trim_length: 64600
padding_type: repeat
view_padding_configs:
'1':
padding_type: repeat
Expand All @@ -51,10 +55,11 @@ data:
'4':
padding_type: repeat
random_start: False

data:
data:
augmentation_methods: ["RawBoostdf"]
random_start: False
aug_dir: ${oc.env:MY_EXTENDED_ASVSPOOF_PATH}/aug
repeat_pad: True
online_aug: True

logger:
wandb:
Expand Down
3 changes: 2 additions & 1 deletion src/data/components/augwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,11 +1075,12 @@ def RawBoostFull(x, args, sr=16000, audio_path=None):
def RawBoostdf(x, args, sr=16000, audio_path=None):
aug_dir = args.aug_dir
utt_id = os.path.basename(audio_path).split('.')[0]
aug_audio_path = os.path.join(aug_dir, 'RawBoostdf', utt_id + '.wav')

if args.online_aug:
return process_Rawboost_feature(x, sr, args, algo=3)
else:
# check if the augmented file exists
aug_audio_path = os.path.join(aug_dir, 'RawBoostdf', utt_id + '.wav')
if (os.path.exists(aug_audio_path)):
waveform, _ = librosa.load(aug_audio_path, sr=sr, mono=True)
return waveform
Expand Down
15 changes: 7 additions & 8 deletions src/data/components/baseloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@


class Dataset_base(Dataset):
def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[],
augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2,
trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None,
aug_dir=None, online_aug=False, repeat_pad=True, is_train=True,
def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[],
augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2,
trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None,
aug_dir=None, online_aug=True, repeat_pad=True, is_train=True,
random_start=False):
"""
Args:
Expand All @@ -23,8 +23,9 @@ def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[],
self.algo = algo
self.vocoders = vocoders
print("vocoders:", vocoders)

self.augmentation_methods = augmentation_methods
print("augmentation_methods:", augmentation_methods)
if len(augmentation_methods) < 1:
# using default augmentation method RawBoostWrapper12
# self.augmentation_methods = ["RawBoost12"]
Expand All @@ -34,8 +35,7 @@ def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[],
self.num_additional_spoof = num_additional_spoof
self.trim_length = trim_length
self.sample_rate = wav_samp_rate

self.args.noise_path = noise_path
self.args.q = noise_path
self.args.rir_path = rir_path
self.args.aug_dir = aug_dir
self.args.online_aug = online_aug
Expand All @@ -49,4 +49,3 @@ def __len__(self):
def __getitem__(self, idx):
# to be implemented in child classes
pass

116 changes: 61 additions & 55 deletions src/data/normal_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,89 +22,97 @@
for aug in SUPPORTED_AUGMENTATION:
exec(f"from src.data.components.augwrapper import {aug}")


class Dataset_for(Dataset_base):
def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[],
augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2,
trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None,
aug_dir=None, online_aug=False, repeat_pad=True, is_train=True, random_start=False):
super(Dataset_for, self).__init__(args, list_IDs, labels, base_dir, algo, vocoders,
augmentation_methods, eval_augment, num_additional_real, num_additional_spoof,
trim_length, wav_samp_rate, noise_path, rir_path,
aug_dir, online_aug, repeat_pad, is_train, random_start)
augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2,
trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None,
aug_dir=None, online_aug=False, repeat_pad=True, is_train=True, random_start=False):
super(Dataset_for, self).__init__(args, list_IDs, labels, base_dir, algo, vocoders,
augmentation_methods, eval_augment, num_additional_real, num_additional_spoof,
trim_length, wav_samp_rate, noise_path, rir_path,
aug_dir, online_aug, repeat_pad, is_train, random_start)

def __getitem__(self, idx):
utt_id = self.list_IDs[idx]
filepath = os.path.join(self.base_dir, utt_id)
X = load_audio(filepath, self.sample_rate)

# apply augmentation
# randomly choose an augmentation method
if self.is_train:
augmethod_index = random.choice(range(len(self.augmentation_methods))) if len(self.augmentation_methods) > 0 else -1
augmethod_index = random.choice(range(len(self.augmentation_methods))) if len(
self.augmentation_methods) > 0 else -1
if augmethod_index >= 0:
X = globals()[self.augmentation_methods[augmethod_index]](X, self.args, self.sample_rate,
audio_path = filepath)
X = globals()[self.augmentation_methods[augmethod_index]](X, self.args, self.sample_rate,
audio_path=filepath)

X_pad= pad(X,padding_type="repeat" if self.repeat_pad else "zero", max_len=self.trim_length, random_start=True)
x_inp= Tensor(X_pad)
X_pad = pad(X, padding_type="repeat" if self.repeat_pad else "zero",
max_len=self.trim_length, random_start=True)
x_inp = Tensor(X_pad)
target = self.labels[utt_id]
return idx, x_inp, target
return idx, x_inp, target


class Dataset_for_dev(Dataset_base):
def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[],
augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2,
trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None,
aug_dir=None, online_aug=False, repeat_pad=False, is_train=True,
random_start=False
):
super(Dataset_for_dev, self).__init__(args, list_IDs, labels, base_dir, algo, vocoders,
augmentation_methods, eval_augment, num_additional_real, num_additional_spoof,
trim_length, wav_samp_rate, noise_path, rir_path,
aug_dir, online_aug, repeat_pad, is_train, random_start)
augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2,
trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None,
aug_dir=None, online_aug=False, repeat_pad=False, is_train=True,
random_start=False
):
super(Dataset_for_dev, self).__init__(args, list_IDs, labels, base_dir, algo, vocoders,
augmentation_methods, eval_augment, num_additional_real, num_additional_spoof,
trim_length, wav_samp_rate, noise_path, rir_path,
aug_dir, online_aug, repeat_pad, is_train, random_start)

if repeat_pad:
self.padding_type = "repeat"
else:
self.padding_type = "zero"

def __getitem__(self, index):
utt_id = self.list_IDs[index]
X, fs = librosa.load(self.base_dir + "/" + utt_id, sr=16000)
X_pad = pad(X,self.padding_type,self.trim_length, random_start=self.random_start)
X_pad = pad(X, self.padding_type, self.trim_length,
random_start=self.random_start)
x_inp = Tensor(X_pad)
target = self.labels[utt_id]
return index, x_inp, target


class Dataset_for_eval(Dataset_base):
def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[],
augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2,
trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None,
aug_dir=None, online_aug=False, repeat_pad=True, is_train=True, enable_chunking=False, random_start=False
):
super(Dataset_for_eval, self).__init__(args, list_IDs, labels, base_dir, algo, vocoders,
augmentation_methods, eval_augment, num_additional_real, num_additional_spoof,
trim_length, wav_samp_rate, noise_path, rir_path,
aug_dir, online_aug, repeat_pad, is_train, random_start)
augmentation_methods=[], eval_augment=None, num_additional_real=2, num_additional_spoof=2,
trim_length=64000, wav_samp_rate=16000, noise_path=None, rir_path=None,
aug_dir=None, online_aug=False, repeat_pad=True, is_train=True, enable_chunking=False, random_start=False
):
super(Dataset_for_eval, self).__init__(args, list_IDs, labels, base_dir, algo, vocoders,
augmentation_methods, eval_augment, num_additional_real, num_additional_spoof,
trim_length, wav_samp_rate, noise_path, rir_path,
aug_dir, online_aug, repeat_pad, is_train, random_start)
self.enable_chunking = enable_chunking
if repeat_pad:
self.padding_type = "repeat"
else:
self.padding_type = "zero"

def __getitem__(self, idx):
utt_id = self.list_IDs[idx]
filepath = os.path.join(self.base_dir, utt_id)
X, _ = librosa.load(filepath, sr=16000)
# apply augmentation at inference time
if self.eval_augment is not None:
# print("eval_augment:", self.eval_augment)
X = globals()[self.eval_augment](X, self.args, self.sample_rate, audio_path = filepath)
X = globals()[self.eval_augment](
X, self.args, self.sample_rate, audio_path=filepath)
if not self.enable_chunking:
X= pad(X,padding_type=self.padding_type,max_len=self.trim_length, random_start=self.random_start)
X = pad(X, padding_type=self.padding_type,
max_len=self.trim_length, random_start=self.random_start)
x_inp = Tensor(X)
return x_inp, utt_id


class NormalDataModule(LightningDataModule):
"""`LightningDataModule` for the ASVSpoof dataset.
Expand Down Expand Up @@ -156,10 +164,6 @@ def __init__(
num_workers: int = 0,
pin_memory: bool = False,
args: Optional[Dict[str, Any]] = None,
# list_IDs: list = [], labels: list = [], base_dir: str = '', algo: int = 5, vocoders: list = [],
# augmentation_methods: list = [], eval_augment: Optional[str] = None, num_additional_real: int = 2, num_additional_spoof: int = 2,
# trim_length: int = 64000, wav_samp_rate: int = 16000, noise_path: Optional[str] = None, rir_path: Optional[str] = None,
# aug_dir: Optional[str] = None, online_aug: bool = False, repeat_pad: bool = True, is_train: bool = True, enable_chunking: bool = False
) -> None:
"""Initialize a `ASVSpoofDataModule`.
Expand Down Expand Up @@ -191,7 +195,7 @@ def num_classes(self) -> int:
"""
return 2

def prepare_data(self) -> None:
def prepare_data(self) -> None:
pass

def setup(self, stage: Optional[str] = None) -> None:
Expand All @@ -214,27 +218,28 @@ def setup(self, stage: Optional[str] = None) -> None:

# load and split datasets only if not loaded already
if not self.data_train and not self.data_val and not self.data_test:

# define train dataloader

d_label_trn, file_train = self.genList(dir_meta=os.path.join(
self.data_dir), is_train=True, is_eval=False, is_dev=False)
self.data_dir), is_train=True, is_eval=False, is_dev=False)

print('no. of training trials', len(file_train))

d_label_dev, file_dev = self.genList(dir_meta=os.path.join(self.data_dir), is_train=False, is_eval=False, is_dev=True)
d_label_dev, file_dev = self.genList(dir_meta=os.path.join(
self.data_dir), is_train=False, is_eval=False, is_dev=True)
print('no. of validation trials', len(file_dev))
file_eval = self.genList(dir_meta=os.path.join(self.data_dir), is_train=False, is_eval=True, is_dev=False)

file_eval = self.genList(dir_meta=os.path.join(
self.data_dir), is_train=False, is_eval=True, is_dev=False)

self.data_train = Dataset_for(self.args, list_IDs=file_train, labels=d_label_trn,
base_dir=self.data_dir+'/', **self.args['data'])
base_dir=self.data_dir+'/', **self.args['data'])

self.data_val = Dataset_for_dev(self.args, list_IDs=file_dev, labels=d_label_dev,
base_dir=self.data_dir+'/', is_train=False, **self.args['data'])
base_dir=self.data_dir+'/', is_train=False, **self.args['data'])

self.data_test = Dataset_for_eval(self.args, list_IDs=file_eval, labels=None,
base_dir=self.data_dir+'/', **self.args['data'])

base_dir=self.data_dir+'/', **self.args['data'])

def train_dataloader(self) -> DataLoader[Any]:
"""Create and return the train dataloader.
Expand Down Expand Up @@ -299,11 +304,11 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
:param state_dict: The datamodule state returned by `self.state_dict()`.
"""
pass

def genList(self, dir_meta, is_train=False, is_eval=False, is_dev=False):
# bonafide: 1, spoof: 0
d_meta = {}
file_list=[]
file_list = []
protocol = os.path.join(dir_meta, "protocol.txt")

if (is_train):
Expand All @@ -325,7 +330,7 @@ def genList(self, dir_meta, is_train=False, is_eval=False, is_dev=False):
file_list.append(utt)
d_meta[utt] = 1 if label == 'bonafide' else 0
return d_meta, file_list

if (is_eval):
# no eval protocol yet
with open(protocol, 'r') as f:
Expand All @@ -338,5 +343,6 @@ def genList(self, dir_meta, is_train=False, is_eval=False, is_dev=False):
# return d_meta, file_list
return d_meta, file_list


if __name__ == "__main__":
_ = NormalDataModule()
_ = NormalDataModule()
Loading

0 comments on commit 1ca86ab

Please sign in to comment.