Skip to content

Commit

Permalink
Add multiview-conf
Browse files Browse the repository at this point in the history
  • Loading branch information
hungdinhxuan committed Oct 21, 2024
1 parent 293f057 commit 8194db6
Show file tree
Hide file tree
Showing 20 changed files with 995 additions and 49 deletions.
3 changes: 2 additions & 1 deletion configs/callbacks/default_loss.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defaults:
- model_checkpoint
- early_stopping
#- early_stopping
- rich_progress_bar
- _self_

Expand All @@ -12,6 +12,7 @@ model_checkpoint:
save_last: True
auto_insert_metric_name: False
save_top_k: 5 # save k best models (determined by above metric)
#save_weights_only: True

early_stopping:
monitor: "val/loss"
Expand Down
2 changes: 1 addition & 1 deletion configs/experiment/aasistssl_multiview_adaptive.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
defaults:
- override /data: asvspoof_multiview
- override /model: xlsr_aasist_multiview
- override /callbacks: default
- override /callbacks: default_loss
- override /trainer: default

# all parameters below will be merged with parameters from default configurations set above
Expand Down
46 changes: 46 additions & 0 deletions configs/experiment/aasistssl_multiview_conf-2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
- override /data: asvspoof_multiview
- override /model: xlsr_aasist_multiview
- override /callbacks: default_loss
- override /trainer: default

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["asvspoof_multiview", "xlsr_aasist_multiview"]

seed: 1234

trainer:
max_epochs: 100
gradient_clip_val: 0.0
accelerator: cuda

model:
optimizer:
lr: 0.000001
weight_decay: 0.0001
net: null
scheduler: null
compile: true

data:
batch_size: 14
num_workers: 8
pin_memory: true
args:
padding_type: repeat
random_start: False


logger:
wandb:
tags: ${tags}
group: "asvspoof_multiview"
aim:
experiment: "asvspoof_multiview"
51 changes: 51 additions & 0 deletions configs/experiment/aasistssl_multiview_conf-3-2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
- override /data: asvspoof_multiview
- override /model: xlsr_aasist_multiview
- override /callbacks: default_loss
- override /trainer: default

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["asvspoof_multiview", "xlsr_aasist_multiview"]

seed: 1234

trainer:
max_epochs: 100
gradient_clip_val: 0.0
accelerator: cuda

model:
optimizer:
lr: 0.000001
weight_decay: 0.0001
net: null
scheduler: null
compile: true
weighted_views:
'1': 0.4
'2': 0.3
'3': 0.2
'4': 0.1

data:
batch_size: 14
num_workers: 8
pin_memory: true
args:
padding_type: repeat
random_start: False


logger:
wandb:
tags: ${tags}
group: "asvspoof_multiview"
aim:
experiment: "asvspoof_multiview"
51 changes: 51 additions & 0 deletions configs/experiment/aasistssl_multiview_conf-3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
- override /data: asvspoof_multiview
- override /model: xlsr_aasist_multiview
- override /callbacks: default_loss
- override /trainer: default

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["asvspoof_multiview", "xlsr_aasist_multiview"]

seed: 1234

trainer:
max_epochs: 100
gradient_clip_val: 0.0
accelerator: cuda

model:
optimizer:
lr: 0.000001
weight_decay: 0.0001
net: null
scheduler: null
compile: true
weighted_views:
'1': 0.1
'2': 0.2
'3': 0.3
'4': 0.4

data:
batch_size: 14
num_workers: 8
pin_memory: true
args:
padding_type: repeat
random_start: False


logger:
wandb:
tags: ${tags}
group: "asvspoof_multiview"
aim:
experiment: "asvspoof_multiview"
47 changes: 47 additions & 0 deletions configs/experiment/aasistssl_multiview_conf-4.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
- override /data: asvspoof_multiview
- override /model: xlsr_aasist_multiview
- override /callbacks: default_loss
- override /trainer: default

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["asvspoof_multiview", "xlsr_aasist_multiview"]

seed: 1234

trainer:
max_epochs: 100
gradient_clip_val: 0.0
accelerator: cuda

model:
optimizer:
lr: 0.000001
weight_decay: 0.0001
net: null
scheduler: null
compile: true
adaptive_weights: true

data:
batch_size: 14
num_workers: 8
pin_memory: true
args:
padding_type: repeat
random_start: False


logger:
wandb:
tags: ${tags}
group: "asvspoof_multiview"
aim:
experiment: "asvspoof_multiview"
53 changes: 53 additions & 0 deletions configs/experiment/aasistssl_multiview_conf-5.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
- override /data: asvspoof_multiview
- override /model: xlsr_aasist_multiview
- override /callbacks: default_loss
- override /trainer: default

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["asvspoof_multiview", "xlsr_aasist_multiview"]

seed: 1234

trainer:
max_epochs: 100
gradient_clip_val: 0.0
accelerator: cuda

model:
optimizer:
lr: 0.000001
weight_decay: 0.0001
net: null
scheduler: null
compile: true
weighted_views:
'1': 1
'2': 1
'3': 1
'4': 1
'5': 1

data:
batch_size: 14
num_workers: 8
pin_memory: true
args:
padding_type: repeat
random_start: False
views: [1, 2, 3, 4, 5]


logger:
wandb:
tags: ${tags}
group: "asvspoof_multiview"
aim:
experiment: "asvspoof_multiview"
54 changes: 54 additions & 0 deletions configs/experiment/aasistssl_multiview_conf-6.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
- override /data: asvspoof_multiview
- override /model: xlsr_aasist_multiview
- override /callbacks: default_loss
- override /trainer: default

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["asvspoof_multiview", "xlsr_aasist_multiview"]

seed: 1234

trainer:
max_epochs: 100
gradient_clip_val: 0.0
accelerator: cuda

model:
optimizer:
lr: 0.000001
weight_decay: 0.0001
net: null
scheduler: null
compile: true
weighted_views:
'0.5': 1
'1': 1
'2': 1
'3': 1
'4': 1


data:
batch_size: 14
num_workers: 8
pin_memory: true
args:
padding_type: repeat
random_start: False
views: [0.5, 1, 2, 3, 4]


logger:
wandb:
tags: ${tags}
group: "asvspoof_multiview"
aim:
experiment: "asvspoof_multiview"
9 changes: 1 addition & 8 deletions configs/model/xlsr_aasist_multiview.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,7 @@ optimizer:

scheduler: null

net:
_target_: src.models.components.aasist.AASIST
d_args:
first_conv: 128
filts: [70, [1, 32], [32, 32], [32, 64], [64, 64]]
gat_dims: [64, 32]
pool_ratios: [0.5, 0.7, 0.5, 0.5]
temperatures: [2.0, 2.0, 100.0, 100.0]
net: null

ssl_pretrained_path: ${oc.env:XLSR_PRETRAINED_MODEL_PATH}
cross_entropy_weight: [0.1, 0.9] # weight for cross entropy loss 0.1 for spoof and 0.9 for bonafide
Expand Down
143 changes: 143 additions & 0 deletions notebooks/visualize.ipynb

Large diffs are not rendered by default.

25 changes: 25 additions & 0 deletions src/callbacks/finetuning_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from lightning.pytorch.callbacks import BaseFinetuning

class FeatureExtractorFreezeUnfreeze(BaseFinetuning):
def __init__(self, unfreeze_at_epoch=10, module_names=["feature_extractor"]):
super().__init__()
self._unfreeze_at_epoch = unfreeze_at_epoch
self._module_names = module_names

def freeze_before_training(self, pl_module):

# Freeze all modules in module_names
for module_name in self._module_names:
self.freeze(getattr(pl_module, module_name))

def finetune_function(self, pl_module, current_epoch, optimizer):
# When `current_epoch` is 10, feature_extractor will start training.
if current_epoch == self._unfreeze_at_epoch:

# Unfreeze all modules in module_names
for module_name in self._module_names:
self.unfreeze_and_add_param_group(
modules=getattr(pl_module, module_name),
optimizer=optimizer,
train_bn=True,
)
3 changes: 3 additions & 0 deletions src/data/asvspoof_multiview_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def __init__(self, args, list_IDs, base_dir):
self.cut = args.get('cut', 64600) if args is not None else 64600
self.padding_type = args.get('padding_type', 'zero') if args is not None else 'zero'
self.random_start = args.get('random_start', False) if args is not None else False
print('padding_type:',self.padding_type)
print('cut:',self.cut)
print('random_start:',self.random_start)

def __len__(self):
return len(self.list_IDs)
Expand Down
2 changes: 2 additions & 0 deletions src/data/components/dataio.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def pad(x:np.ndarray, padding_type:str='zero', max_len=64000, random_start=False
max_len: max length of the audio, default 64000
random_start: if True, randomly choose the start point of the audio
'''
# Ensure that max_len should be integer
max_len = int(max_len)
x_len = x.shape[0]
padded_x = None
if max_len == 0:
Expand Down
Loading

0 comments on commit 8194db6

Please sign in to comment.