Skip to content

Commit

Permalink
Add env example, adaptive weighted view, improve training performance
Browse files Browse the repository at this point in the history
  • Loading branch information
hungdinhxuan committed Oct 9, 2024
1 parent fce3f22 commit 9265dd7
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 17 deletions.
18 changes: 17 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,20 @@
# .env is loaded by train.py automatically
# hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR}

MY_VAR="/home/user/my/system/path"
ASVSPOOF_PATH=""
ASVSPOOF_PATH_PROTOCOLS=""
CYBERCUP2_PATH=""
CYBERCUP2_TEST_PATH=""
CYBERCUP1_PATH=""
CYBERCUP1_TEST_PATH=""
NEPTUNE_PROJECT=""
NEPTUNE_API_TOKEN=""
WANDB_PROJECT=""
COMET_API_TOKEN=""
COMET_PROJECT_NAME=""
LARGE_CORPUS_FOR_ASVSPOOF5=""
WAVLMBASE_PRETRAINED_MODEL_PATH=""
NOISE_PATH=""
RIR_PATH=""
OUTPUT_DIR=""
XLSR_PRETRAINED_MODEL_PATH=""
22 changes: 22 additions & 0 deletions configs/callbacks/default_loss.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
defaults:
- model_checkpoint
- early_stopping
- rich_progress_bar
- _self_

model_checkpoint:
dirpath: ${paths.output_dir}/checkpoints
filename: "epoch_{epoch:03d}"
monitor: "val/loss"
mode: "min"
save_last: True
auto_insert_metric_name: False
save_top_k: 5 # save k best models (determined by above metric)

early_stopping:
monitor: "val/loss"
patience: 20
mode: "min"

model_summary:
max_depth: -1
1 change: 1 addition & 0 deletions configs/data/asvspoof_multiview.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pin_memory: True

args:
# The sampling rate of the audio files
protocols_path: ${oc.env:ASVSPOOF_PATH_PROTOCOLS}
sample_rate: 16000
cut: 64000
padding_type: zero
Expand Down
42 changes: 42 additions & 0 deletions configs/experiment/aasistssl_multiview.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
- override /data: asvspoof_multiview
- override /model: xlsr_aasist_multiview
- override /callbacks: default_loss
- override /trainer: default

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["asvspoof_multiview", "xlsr_aasist_multiview"]

seed: 1234

trainer:
max_epochs: 100
gradient_clip_val: 0.0
accelerator: cuda

model:
optimizer:
lr: 0.000001
weight_decay: 0.0001
net: null
scheduler: null
compile: true

data:
batch_size: 14
num_workers: 8
pin_memory: true

logger:
wandb:
tags: ${tags}
group: "asvspoof_multiview"
aim:
experiment: "asvspoof_multiview"
42 changes: 42 additions & 0 deletions configs/experiment/aasistssl_multiview_adaptive.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
- override /data: asvspoof_multiview
- override /model: xlsr_aasist_multiview
- override /callbacks: default
- override /trainer: default

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["asvspoof_multiview", "xlsr_aasist_multiview"]

seed: 1234

trainer:
max_epochs: 100
gradient_clip_val: 0.0
accelerator: cuda

model:
optimizer:
lr: 0.000001
weight_decay: 0.0001
net: null
scheduler: null
compile: false
adaptive_weights: true

data:
batch_size: 8
num_workers: 8

logger:
wandb:
tags: ${tags}
group: "asvspoof_multiview"
aim:
experiment: "asvspoof_multiview"
2 changes: 1 addition & 1 deletion configs/experiment/test_aasist_multiview.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ tags: ["asvspoof_multiview", "xlsr_aasist_multiview"]
seed: 1234

trainer:
min_epochs: 50
#min_epochs: 50
max_epochs: 100
gradient_clip_val: 0.0
accelerator: cuda
Expand Down
1 change: 1 addition & 0 deletions src/data/asvspoof_aasistssl_reproduce_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def train_dataloader(self) -> DataLoader[Any]:
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=True,
drop_last=True,
)

def val_dataloader(self) -> DataLoader[Any]:
Expand Down
6 changes: 4 additions & 2 deletions src/data/asvspoof_multiview_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(
self.batch_size_per_device = batch_size
self.data_dir = data_dir
self.args = args
self.protocols_path = self.args.get('protocols_path', '/data/hungdx/Datasets/protocols/database/') if self.args is not None else '/data/hungdx/Datasets/protocols/database/'

@property
def num_classes(self) -> int:
Expand Down Expand Up @@ -175,7 +176,6 @@ def setup(self, stage: Optional[str] = None) -> None:
track = 'DF'

prefix_2021 = 'ASVspoof2021.{}'.format(track)
self.protocols_path = '/data/hungdx/Datasets/protocols/database/'
self.algo = self.args.get('algo', -1) if self.args is not None else -1

d_label_trn,file_train = self.genSpoof_list( dir_meta = os.path.join(self.protocols_path+'ASVspoof_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt'),is_train=True,is_eval=False)
Expand All @@ -199,7 +199,9 @@ def train_dataloader(self) -> DataLoader[Any]:
pin_memory=self.hparams.pin_memory,
shuffle=True,
collate_fn=lambda x: multi_view_collate_fn(x, self.args.views, self.args.sample_rate, self.args.padding_type, self.args.random_start),
)
drop_last=True,
persistent_workers=True
)

def val_dataloader(self) -> DataLoader[Any]:
"""Create and return the validation dataloader.
Expand Down
44 changes: 31 additions & 13 deletions src/models/aasistssl_multiview_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def __init__(
cross_entropy_weight: list[float] = [0.5, 0.5],
score_save_path: str = None,
#views: list[float] = [1, 2, 3, 4],
weighted_views: Dict[str, float] = {'1': 1, '2': 1, '3': 1, '4': 1},
adaptive_weight: bool = False,
weighted_views: Dict[str, float] = {'1': 1.0, '2': 1.0, '3': 1.0, '4': 1.0},
adaptive_weights: bool = False,
) -> None:
"""Initialize a `MNISTLitModule`.
Expand Down Expand Up @@ -122,9 +122,17 @@ def __init__(
self.running_loss = 0.0

self.weighted_views = weighted_views
self.adaptive_weight = adaptive_weight
self.adaptive_weights = adaptive_weights
print("We are in the AASISTSSLLitModule")

if self.adaptive_weights:
self.weighted_views = {}
for k, v in weighted_views.items():
param = torch.nn.Parameter(torch.tensor(float(v)), requires_grad=True)
self.register_parameter(f"adaptive_weight_{k}", param)
self.weighted_views[k] = param



def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Perform a forward pass through the model `self.net`.
Expand Down Expand Up @@ -153,6 +161,10 @@ def on_train_start(self) -> None:
for k, v in self.val_view_acc_best.items():
self.val_view_acc_best[k].reset()

# Log current adaptive_weights
if self.adaptive_weights:
print("Adaptive weights are enabled")

def model_step(
self, batch: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -258,6 +270,7 @@ def training_step(
self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=self.batch_size)
self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=self.batch_size)

self.log_dict( {f"adaptive_weight_{k}": v for k, v in self.weighted_views.items()}, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
# Return loss for backpropagation
return loss

Expand All @@ -270,16 +283,19 @@ def on_train_epoch_end(self) -> None:
self.log(f"train/view_{k}_loss", self.train_loss_detail[k].compute(), prog_bar=True, sync_dist=True)

# Check if adaptive weight is enabled
if self.adaptive_weight:
# Get the accuracy for each view
view_acc = {k: v.compute() for k, v in self.train_view_acc.items()}
# Adjust the weights based on the accuracy
# weighted_views is a dictionary of the views and their weights
# adjust_weights is a function that adjusts the weights based on the accuracy of the views and returns a list of normalized weights
self.weighted_views = {k: v for k, v in zip(view_acc.keys(), adjust_weights(list(view_acc.values())))}



# if self.adaptive_weights:
# # Get the accuracy for each view
# view_acc = {k: v.compute() for k, v in self.train_view_acc.items()}
# # Adjust the weights based on the accuracy
# # weighted_views is a dictionary of the views and their weights
# # adjust_weights is a function that adjusts the weights based on the accuracy of the views and returns a list of normalized weights
# self.weighted_views = {k: v for k, v in zip(view_acc.keys(), adjust_weights(list(view_acc.values())))}

# Log current adaptive_weights
if self.adaptive_weights:
print(self.weighted_views)
self.log_dict( {f"adaptive_weight_{k}": v for k, v in self.weighted_views.items()}, on_epoch=True, prog_bar=True, sync_dist=True)

def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
"""Perform a single validation step on a batch of data from the validation set.
Expand Down Expand Up @@ -397,6 +413,8 @@ def configure_optimizers(self) -> Dict[str, Any]:
}
return {"optimizer": optimizer}

def optimizer_zero_grad(self, epoch, batch_idx, optimizer):
optimizer.zero_grad(set_to_none=True)

if __name__ == "__main__":
_ = AASISTSSLLitModule(None, None, None, None)

0 comments on commit 9265dd7

Please sign in to comment.