diff --git a/configs/callbacks/default_loss_w_only.yaml b/configs/callbacks/default_loss_w_only.yaml new file mode 100644 index 0000000..45a6dfc --- /dev/null +++ b/configs/callbacks/default_loss_w_only.yaml @@ -0,0 +1,23 @@ +defaults: + - model_checkpoint + #- early_stopping + - rich_progress_bar + - _self_ + +model_checkpoint: + dirpath: ${paths.output_dir}/checkpoints + filename: "epoch_{epoch:03d}" + monitor: "val/loss" + mode: "min" + save_last: True + auto_insert_metric_name: False + save_top_k: 5 # save k best models (determined by above metric) + save_weights_only: True + +early_stopping: + monitor: "val/loss" + patience: 20 + mode: "min" + +model_summary: + max_depth: -1 diff --git a/configs/experiment/aasistssl_multiview_conf-2-var.yaml b/configs/experiment/aasistssl_multiview_conf-2-var.yaml new file mode 100644 index 0000000..10686a6 --- /dev/null +++ b/configs/experiment/aasistssl_multiview_conf-2-var.yaml @@ -0,0 +1,50 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /data: asvspoof_multiview + - override /model: xlsr_aasist_var_multiview + - override /callbacks: default_loss + - override /trainer: default + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["asvspoof_multiview", "xlsr_aasist_var_multiview"] + +seed: 1234 + +trainer: + max_epochs: 100 + gradient_clip_val: 0.0 + accelerator: cuda + +model: + optimizer: + lr: 0.000001 + weight_decay: 0.0001 + net: null + scheduler: null + compile: true + +data: + batch_size: 14 + num_workers: 8 + pin_memory: true + args: + padding_type: repeat + random_start: False + is_variable_multi_view: True + top_k: 4 + min_duration: 16000 + max_duration: 64600 + + +logger: + wandb: + tags: ${tags} + group: "asvspoof_multiview" + aim: + experiment: "asvspoof_multiview" diff --git a/configs/model/xlsr_aasist_var_multiview.yaml b/configs/model/xlsr_aasist_var_multiview.yaml new file mode 100644 index 0000000..513b0d3 --- /dev/null +++ b/configs/model/xlsr_aasist_var_multiview.yaml @@ -0,0 +1,16 @@ +_target_: src.models.aasistssl_var_multiview_module.AASISTSSLLitModule + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.0001 + weight_decay: 0.0001 + +scheduler: null + +net: null + +ssl_pretrained_path: ${oc.env:XLSR_PRETRAINED_MODEL_PATH} +cross_entropy_weight: [0.1, 0.9] # weight for cross entropy loss 0.1 for spoof and 0.9 for bonafide +# compile model for faster training with pytorch 2.0 +compile: true diff --git a/notebooks/eval.ipynb b/notebooks/eval.ipynb index c60356e..19a2fdd 100644 --- a/notebooks/eval.ipynb +++ b/notebooks/eval.ipynb @@ -9,16 +9,16 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "eer: 9.326796296235376\tthreshold: -2.986328125\n", + "eer: 8.955265192535776\tthreshold: -2.849609375\n", "\n", - "0.09326796296235375\n" + "0.08955265192535777\n" ] } ], @@ -51,7 +51,7 @@ " print(out_data)\n", " return eer_cm\n", "\n", - "print(eval_to_score_file(\"/data/hungdx/Lightning-hydra/logs/eval/itw_xlsr_aasist_multiview_conf-2_epoch15_3s.txt\", \"/dataa/Datasets/in_the_wild.txt\"))" + "print(eval_to_score_file(\"/data/hungdx/Lightning-hydra/logs/eval/in_the_wild_xlsr_aasist_multiview_conf-2_4s_var.txt\", \"/dataa/Datasets/in_the_wild.txt\"))" ] } ], diff --git a/notebooks/test.ipynb b/notebooks/test.ipynb index 30707fd..54bf5bd 100644 --- a/notebooks/test.ipynb +++ b/notebooks/test.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 10, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -14,25 +14,15 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 5, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/hungdx/miniconda3/envs/asvspoof5/lib/python3.9/site-packages/lazy_loader/__init__.py:202: RuntimeWarning: subpackages can technically be lazily loaded, but it causes the package to be eagerly loaded even if it is already lazily loaded.So, you probably shouldn't use subpackages with this lazy feature.\n", - " warnings.warn(msg, RuntimeWarning)\n", - "/home/hungdx/miniconda3/envs/asvspoof5/lib/python3.9/site-packages/lazy_loader/__init__.py:202: RuntimeWarning: subpackages can technically be lazily loaded, but it causes the package to be eagerly loaded even if it is already lazily loaded.So, you probably shouldn't use subpackages with this lazy feature.\n", - " warnings.warn(msg, RuntimeWarning)\n" - ] - } - ], + "outputs": [], "source": [ "from IPython.display import Audio, display\n", "import librosa.display\n", "import matplotlib.pyplot as plt\n", "import random\n", + "import numpy as np\n", "\n", "\n", "def play_and_show(file_path):\n", @@ -388,9 +378,17 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "611829\n" + ] + } + ], "source": [ "import pandas as pd\n", "import os\n", @@ -434,6 +432,403 @@ "\n", "fusion_score(list_score_files).to_csv(\"best_fusion_xlsr_aasist_multiview_conf-2.txt\", sep=' ', index=False, header=False)\n" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Views: [3.2545 2.62675 3.710375 3.814625]\n", + "Batch 1\n", + " View Duration: 3.2545s - Batch Size: 10, Sequence Length: 52072\n", + " Label: 0\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " View Duration: 2.62675s - Batch Size: 10, Sequence Length: 42028\n", + " Label: 0\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " View Duration: 3.710375s - Batch Size: 10, Sequence Length: 59366\n", + " Label: 0\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " View Duration: 3.814625s - Batch Size: 10, Sequence Length: 61034\n", + " Label: 0\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Views: [2.2813125 2.150125 1.235875 2.874375 ]\n", + "Batch 2\n", + " View Duration: 2.2813125s - Batch Size: 10, Sequence Length: 36501\n", + " Label: 1\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " View Duration: 2.150125s - Batch Size: 10, Sequence Length: 34402\n", + " Label: 0\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " View Duration: 1.235875s - Batch Size: 10, Sequence Length: 19774\n", + " Label: 0\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " View Duration: 2.874375s - Batch Size: 10, Sequence Length: 45990\n", + " Label: 0\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import torch\n", + "from torch.utils.data import DataLoader, Dataset\n", + "from torch.nn.utils.rnn import pad_sequence\n", + "import numpy as np\n", + "import librosa\n", + "import os\n", + "import random\n", + "\n", + "def genSpoof_list( dir_meta, is_train=False, is_eval=False):\n", + " \"\"\"\n", + " This function is from the following source: https://github.com/TakHemlata/SSL_Anti-spoofing/blob/main/data_utils_SSL.py#L17\n", + " Official source: https://arxiv.org/abs/2202.12233\n", + " Automatic speaker verification spoofing and deepfake detection using wav2vec 2.0 and data augmentation\n", + " \"\"\"\n", + " d_meta = {}\n", + " file_list=[]\n", + " with open(dir_meta, 'r') as f:\n", + " l_meta = f.readlines()\n", + "\n", + " if (is_train):\n", + " for line in l_meta:\n", + " _,key,_,_,label = line.strip().split()\n", + " \n", + " file_list.append(key)\n", + " d_meta[key] = 1 if label == 'bonafide' else 0\n", + " \n", + " return d_meta,file_list\n", + " \n", + " elif(is_eval):\n", + " for line in l_meta:\n", + " key= line.strip()\n", + " file_list.append(key)\n", + " return file_list\n", + " else:\n", + " for line in l_meta:\n", + " _,key,_,_,label = line.strip().split()\n", + " \n", + " file_list.append(key)\n", + " d_meta[key] = 1 if label == 'bonafide' else 0\n", + " return d_meta,file_list\n", + " \n", + "class Dataset_ASVspoof2019_train(Dataset):\n", + " def __init__(self,list_IDs, labels, base_dir):\n", + " '''self.list_IDs\t: list of strings (each string: utt key),\n", + " self.labels : dictionary (key: utt key, value: label integer)'''\n", + " \n", + " self.list_IDs = list_IDs\n", + " self.labels = labels\n", + " self.base_dir = base_dir\n", + " \n", + " def __len__(self):\n", + " return len(self.list_IDs)\n", + "\n", + " def __getitem__(self, index): \n", + " utt_id = self.list_IDs[index]\n", + " X,fs = librosa.load(self.base_dir+utt_id+'.flac', sr=16000) \n", + " x_inp= torch.from_numpy(X)\n", + " target = self.labels[utt_id]\n", + "\n", + " return x_inp, target\n", + "# Create a synthetic dataset\n", + "class SyntheticAudioDataset(Dataset):\n", + " def __init__(self, num_samples=100, sample_rate=16000):\n", + " self.num_samples = num_samples\n", + " self.sample_rate = sample_rate\n", + " # Generate random lengths between 0.5 seconds and 4 seconds\n", + " self.lengths = np.random.randint(low=sample_rate // 2, high=sample_rate * 4, size=self.num_samples)\n", + " self.labels = np.random.randint(low=0, high=2, size=self.num_samples) # Binary classification\n", + "\n", + " def __len__(self):\n", + " return self.num_samples\n", + "\n", + " def __getitem__(self, idx):\n", + " # Generate random audio data based on the length\n", + " length = self.lengths[idx]\n", + " audio_data = torch.randn(length)\n", + " label = self.labels[idx]\n", + " return audio_data, label\n", + "\n", + "\n", + "# Custom collate function\n", + "def collate_fn(batch, views=[1, 2, 3, 4], sample_rate=16000):\n", + " view_batches = {view: [] for view in views}\n", + "\n", + " # Process each sample in the batch\n", + " for x, label in batch:\n", + " # Pad each sample for each view\n", + " for view in views:\n", + " view_length = view * sample_rate\n", + " x_view = pad(x, padding_type='zero', max_len=view_length, random_start=True)\n", + " # Check if x_view is Tensor or numpy array and convert to Tensor if necessary\n", + " if not torch.is_tensor(x_view):\n", + " x_view = torch.from_numpy(x_view)\n", + " view_batches[view].append((x_view, label))\n", + "\n", + " # Convert lists to tensors\n", + " for view in views:\n", + " sequences, labels = zip(*view_batches[view])\n", + " padded_sequences = torch.stack(sequences)\n", + " labels = torch.tensor(labels, dtype=torch.long)\n", + " view_batches[view] = (padded_sequences, labels)\n", + "\n", + " return view_batches\n", + "\n", + "def variable_multi_view_collate_fn(batch, top_k=4, min_duration=16000, max_duration=64000, sample_rate=16000, padding_type='zero', random_start=True):\n", + " '''\n", + " Collate function to pad each sample in a batch to multiple views with variable duration\n", + " :param batch: list of tuples (x, label)\n", + " :param top_k: number of views to pad each sample to\n", + " :param min_duration: minimum duration of the audio\n", + " :param max_duration: maximum duration of the audio\n", + " :param sample_rate: sample rate of the audio\n", + " :param padding_type: padding type to use\n", + " :param random_start: whether to randomly start the sample\n", + " :return: dictionary with keys as views and values as tuples of padded sequences and labels\n", + "\n", + " Example:\n", + " batch = [([1, 2, 3], 0), ([1, 2, 3, 4], 1)]\n", + " variable_multi_view_collate_fn(batch, top_k=2, min_duration=16000, max_duration=32000, sample_rate=16000)\n", + " Output:\n", + " {\n", + " 1: (tensor([[1, 2, 3], [1, 2, 3, 4]]), tensor([0, 1])),\n", + " 2: (tensor([[1, 2, 3, 0], [1, 2, 3, 4]]), tensor([0, 1]))\n", + " }\n", + " '''\n", + " # Duration of each view should be picked from a range of min_duration to max_duration by a uniform distribution\n", + " views = np.random.randint(min_duration, max_duration, (top_k,)) / sample_rate\n", + "\n", + " print(f\"Views: {views}\")\n", + "\n", + " view_batches = {view: [] for view in views}\n", + " # Process each sample in the batch\n", + " for x, label in batch:\n", + " # Pad each sample for each view\n", + " for view in views:\n", + " view_length = view * sample_rate\n", + " x_view = pad(x, padding_type=padding_type, max_len=view_length, random_start=random_start)\n", + " # Check if x_view is Tensor or numpy array and convert to Tensor if necessary\n", + " if not torch.is_tensor(x_view):\n", + " x_view = torch.from_numpy(x_view)\n", + " view_batches[view].append((x_view, label))\n", + "\n", + " # Convert lists to tensors\n", + " for view in views:\n", + " sequences, labels = zip(*view_batches[view])\n", + " padded_sequences = torch.stack(sequences)\n", + " labels = torch.tensor(labels, dtype=torch.long)\n", + " view_batches[view] = (padded_sequences, labels)\n", + "\n", + " return view_batches\n", + "# Create the dataset and dataloader\n", + "protocols_path = \"/data/hungdx/Datasets/protocols/database/\"\n", + "database_path = \"/home/hungdx/code/Lightning-hydra/data/Datasets/\"\n", + "d_label_trn,file_train = genSpoof_list( dir_meta = os.path.join(protocols_path+'ASVspoof_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt'),is_train=True,is_eval=False)\n", + "\n", + "\n", + "d_label_trn = {k: d_label_trn[k] for k in list(d_label_trn)}\n", + "file_train = file_train\n", + "\n", + "\n", + "data_train = Dataset_ASVspoof2019_train(list_IDs = file_train,labels = d_label_trn,base_dir = os.path.join(database_path+'ASVspoof2019_LA_train/'))\n", + "#dataset = SyntheticAudioDataset()\n", + "dataloader = DataLoader(data_train, batch_size=10, collate_fn=lambda x: variable_multi_view_collate_fn(x), shuffle=True)\n", + "\n", + "# Iterate through the DataLoader\n", + "for batch_idx, batch in enumerate(dataloader):\n", + " print(f\"Batch {batch_idx + 1}\")\n", + " for view, (data, labels) in batch.items():\n", + " print(f\" View Duration: {view}s - Batch Size: {data.size(0)}, Sequence Length: {data.size(1)}\")\n", + " # Play a random audio file from the batch\n", + " random_idx = random.randint(0, data.size(0) - 1)\n", + " # Show the label\n", + " print(f\" Label: {labels[random_idx]}\")\n", + " display(Audio(data[random_idx].numpy(), rate=16000)) \n", + " if batch_idx == 1: # Print only for the 2 batch for brevity\n", + " break\n", + "\n" + ] } ], "metadata": { diff --git a/src/data/asvspoof_multiview_datamodule.py b/src/data/asvspoof_multiview_datamodule.py index c6f522a..3bcfde8 100644 --- a/src/data/asvspoof_multiview_datamodule.py +++ b/src/data/asvspoof_multiview_datamodule.py @@ -11,7 +11,7 @@ import copy from src.data.components.RawBoost import process_Rawboost_feature from src.data.components.dataio import load_audio, pad -from src.data.components.collate_fn import multi_view_collate_fn +from src.data.components.collate_fn import multi_view_collate_fn, variable_multi_view_collate_fn ''' Hemlata Tak, Madhu Kamble, Jose Patino, Massimiliano Todisco, Nicholas Evans. RawBoost: A Raw Data Boosting and Augmentation Method applied to Automatic Speaker Verification Anti-Spoofing. @@ -174,6 +174,29 @@ def __init__( self.batch_size_per_device = batch_size self.data_dir = data_dir self.args = args + self.is_variable_multi_view = args.get('is_variable_multi_view', False) if args is not None else False + if self.is_variable_multi_view: + print('Using variable multi-view collate function') + self.top_k = self.args.get('top_k', 4) + self.min_duration = self.args.get('min_duration', 16000) + self.max_duration = self.args.get('max_duration', 64000) + self.collate_fn = lambda x: variable_multi_view_collate_fn( + x, + self.top_k, + self.min_duration, + self.max_duration, + self.args.sample_rate, + self.args.padding_type, + self.args.random_start + ) + else: + self.collate_fn = lambda x: multi_view_collate_fn( + x, + self.args.views, + self.args.sample_rate, + self.args.padding_type, + self.args.random_start + ) self.protocols_path = self.args.get('protocols_path', '/data/hungdx/Datasets/protocols/database/') if self.args is not None else '/data/hungdx/Datasets/protocols/database/' @property @@ -243,10 +266,10 @@ def train_dataloader(self) -> DataLoader[Any]: num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, shuffle=True, - collate_fn=lambda x: multi_view_collate_fn(x, self.args.views, self.args.sample_rate, self.args.padding_type, self.args.random_start), + collate_fn=self.collate_fn, drop_last=True, persistent_workers=True - ) + ) def val_dataloader(self) -> DataLoader[Any]: """Create and return the validation dataloader. @@ -259,7 +282,7 @@ def val_dataloader(self) -> DataLoader[Any]: num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, shuffle=False, - collate_fn=lambda x: multi_view_collate_fn(x, self.args.views, self.args.sample_rate, self.args.padding_type, self.args.random_start), + collate_fn=self.collate_fn, ) def test_dataloader(self) -> DataLoader[Any]: diff --git a/src/data/components/collate_fn.py b/src/data/components/collate_fn.py index 4987b4c..f8b19fb 100644 --- a/src/data/components/collate_fn.py +++ b/src/data/components/collate_fn.py @@ -1,7 +1,26 @@ import torch from src.data.components.dataio import pad +import numpy as np def multi_view_collate_fn(batch, views=[1, 2, 3, 4], sample_rate=16000, padding_type='zero', random_start=True): + ''' + Collate function to pad each sample in a batch to multiple views + :param batch: list of tuples (x, label) + :param views: list of views to pad each sample to + :param sample_rate: sample rate of the audio + :param padding_type: padding type to use + :param random_start: whether to randomly start the sample + :return: dictionary with keys as views and values as tuples of padded sequences and labels + + Example: + batch = [([1, 2, 3], 0), ([1, 2, 3, 4], 1)] + multi_view_collate_fn(batch, views=[1, 2], sample_rate=16000) + Output: + { + 1: (tensor([[1, 2, 3], [1, 2, 3, 4]]), tensor([0, 1])), + 2: (tensor([[1, 2, 3, 0], [1, 2, 3, 4]]), tensor([0, 1])) + } + ''' view_batches = {view: [] for view in views} # Process each sample in the batch @@ -22,4 +41,51 @@ def multi_view_collate_fn(batch, views=[1, 2, 3, 4], sample_rate=16000, padding_ labels = torch.tensor(labels, dtype=torch.long) view_batches[view] = (padded_sequences, labels) + return view_batches + +def variable_multi_view_collate_fn(batch, top_k=4, min_duration=16000, max_duration=64000, sample_rate=16000, padding_type='zero', random_start=True): + ''' + Collate function to pad each sample in a batch to multiple views with variable duration + :param batch: list of tuples (x, label) + :param top_k: number of views to pad each sample to + :param min_duration: minimum duration of the audio + :param max_duration: maximum duration of the audio + :param sample_rate: sample rate of the audio + :param padding_type: padding type to use + :param random_start: whether to randomly start the sample + :return: dictionary with keys as views and values as tuples of padded sequences and labels + + Example: + batch = [([1, 2, 3], 0), ([1, 2, 3, 4], 1)] + variable_multi_view_collate_fn(batch, top_k=2, min_duration=16000, max_duration=32000, sample_rate=16000) + Output: + { + 1: (tensor([[1, 2, 3], [1, 2, 3, 4]]), tensor([0, 1])), + 2: (tensor([[1, 2, 3, 0], [1, 2, 3, 4]]), tensor([0, 1])) + } + ''' + # Duration of each view should be picked from a range of min_duration to max_duration by a uniform distribution + # Duration in seconds for each view + durations = np.random.uniform(min_duration, max_duration, top_k).astype(int) + # Ensure unique durations to avoid key collisions + views = np.unique(durations) + view_batches = {view: [] for view in views} + # Process each sample in the batch + for x, label in batch: + # Pad each sample for each view + for view in views: + view_length = view + x_view = pad(x, padding_type=padding_type, max_len=view_length, random_start=random_start) + # Check if x_view is Tensor or numpy array and convert to Tensor if necessary + if not torch.is_tensor(x_view): + x_view = torch.from_numpy(x_view) + view_batches[view].append((x_view, label)) + + # Convert lists to tensors + for view in views: + sequences, labels = zip(*view_batches[view]) + padded_sequences = torch.stack(sequences) + labels = torch.tensor(labels, dtype=torch.long) + view_batches[view] = (padded_sequences, labels) + return view_batches \ No newline at end of file diff --git a/src/models/aasistssl_var_multiview_module.py b/src/models/aasistssl_var_multiview_module.py new file mode 100644 index 0000000..ee0fd16 --- /dev/null +++ b/src/models/aasistssl_var_multiview_module.py @@ -0,0 +1,299 @@ +from typing import Any, Dict, List, Tuple + +import torch +from lightning import LightningModule +from torchmetrics import MaxMetric, MeanMetric +from torchmetrics.classification.accuracy import BinaryAccuracy + +import torch +from src.models.components.xlsr_aasist import XlsrAasist + + + +class AASISTSSLLitModule(LightningModule): + """Example of a `LightningModule` for MNIST classification. + + A `LightningModule` implements 8 key methods: + + ```python + def __init__(self): + # Define initialization code here. + + def setup(self, stage): + # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'. + # This hook is called on every process when using DDP. + + def training_step(self, batch, batch_idx): + # The complete training step. + + def validation_step(self, batch, batch_idx): + # The complete validation step. + + def test_step(self, batch, batch_idx): + # The complete test step. + + def predict_step(self, batch, batch_idx): + # The complete predict step. + + def configure_optimizers(self): + # Define and configure optimizers and LR schedulers. + ``` + + Docs: + https://lightning.ai/docs/pytorch/latest/common/lightning_module.html + """ + + def __init__( + self, + net: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler, + compile: bool, + ssl_pretrained_path: str, + cross_entropy_weight: list[float] = [0.5, 0.5], + score_save_path: str = None, + + ) -> None: + """Initialize a `MNISTLitModule`. + + :param net: The model to train. + :param optimizer: The optimizer to use for training. + :param scheduler: The learning rate scheduler to use for training. + """ + super().__init__() + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False) + self.score_save_path = score_save_path + + self.net = XlsrAasist(ssl_pretrained_path) + + self.save_hyperparameters(ignore=['net']) + + # loss function + cross_entropy_weight = torch.tensor(cross_entropy_weight) + self.criterion = torch.nn.CrossEntropyLoss(cross_entropy_weight) + + # metric objects for calculating and averaging accuracy across batches + self.train_acc = BinaryAccuracy() + self.val_acc = BinaryAccuracy() + self.test_acc = BinaryAccuracy() + + self.test_view_acc = {} + + # for averaging loss across batches + self.train_loss = MeanMetric() + self.val_loss = MeanMetric() + self.test_loss = MeanMetric() + self.view_loss = MeanMetric() + + # for tracking best so far validation accuracy + self.val_acc_best = MaxMetric() + # for tracking best so far validation accuracy for each view + + + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Perform a forward pass through the model `self.net`. + + :param x: A tensor of audio. + :return: A tensor of logits. + """ + return self.net(x) + + def on_train_start(self) -> None: + """Lightning hook that is called when training begins.""" + # by default lightning executes validation step sanity checks before training starts, + # so it's worth to make sure validation metrics don't store results from these checks + self.val_loss.reset() + self.val_acc.reset() + self.val_acc_best.reset() + + + + + + def model_step( + self, batch: Tuple[torch.Tensor, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Perform a single model step on a batch of data. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. + + :return: A tuple containing (in order): + - A tensor of losses. + - A tensor of predictions. + - A tensor of target labels. + - A dictionary of loss details for each view. + - A dictionary of accuracy metrics for each view. + """ + train_loss = 0.0 + all_preds = [] + all_labels = [] + + for view, (x, y) in batch.items(): + ''' + In the case of multi-view, the batch is a dictionary with the view as the key and the value + is a tuple of the input tensor and target tensor. + + + view: int + x: sub-minibatch of input tensor + y: sub-minibatch of target tensor + + Example: + view: 1 + x: shape (batch_size, view * sample_rate) + y: shape (batch_size, 1) + ''' + self.batch_size = view * x.size(0) # Update batch size for each view + + view = str(view) # Convert view to string for indexing + + # Ensure the input tensor is of type float + x = x.float() + + logits = self.forward(x) + loss = self.criterion(logits, y) # Calculate loss + train_loss += loss + preds = torch.argmax(logits, dim=1) # Get predictions + # Collect predictions and labels + all_preds.append(preds) + all_labels.append(y) + + + + # Concatenate all predictions and labels + all_preds = torch.cat(all_preds) + all_labels = torch.cat(all_labels) + + return train_loss, all_preds, all_labels + + def training_step( + self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int + ) -> torch.Tensor: + """Perform a single training step on a batch of data from the training set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + :return: A tensor of losses between model predictions and targets. + """ + loss, preds, targets = self.model_step(batch) + + + # Update and log train loss and accuracy + self.train_loss(loss) + self.train_acc(preds, targets) + self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=self.batch_size) + self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=self.batch_size) + + # Return loss for backpropagation + return loss + + def on_train_epoch_end(self) -> None: + "Lightning hook that is called when a training epoch ends." + pass + + + def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + """Perform a single validation step on a batch of data from the validation set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + """ + loss, preds, targets = self.model_step(batch) + + # Update and log val loss and accuracy + self.val_loss(loss) + self.val_acc(preds, targets) + self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=self.batch_size) + self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=self.batch_size) + + def on_validation_epoch_end(self) -> None: + "Lightning hook that is called when a validation epoch ends." + acc = self.val_acc.compute() # get current val acc + self.val_acc_best(acc) # update best so far val acc + # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object + # otherwise metric would be reset by lightning after each epoch + self.log("val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True, batch_size=self.batch_size) + + def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + """Perform a single test step on a batch of data from the test set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + """ + if self.score_save_path is not None: + self._export_score_file(batch) + else: + raise ValueError("score_save_path is not provided") + + + def _export_score_file(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> None: + """Get the score file for the batch of data. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + """ + batch_x, utt_id = batch + batch_out = self.net(batch_x) + + fname_list = list(utt_id) + score_list = batch_out.data.cpu().numpy().tolist() + + # with open(self.score_save_path, 'a+') as fh: + # for f, cm in zip(fname_list, score_list): + # fh.write('{} {} {}\n'.format(f, cm[0], cm[1])) + with open(self.score_save_path, 'a+') as fh: + for f, cm in zip(fname_list, score_list): + fh.write('{} {}\n'.format(f, cm[1])) + + def on_test_epoch_end(self) -> None: + """Lightning hook that is called when a test epoch ends.""" + pass + + def setup(self, stage: str) -> None: + """Lightning hook that is called at the beginning of fit (train + validate), validate, + test, or predict. + + This is a good hook when you need to build models dynamically or adjust something about + them. This hook is called on every process when using DDP. + + :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. + """ + # if self.hparams.compile and stage == "fit": + # self.net = torch.compile(self.net) + pass + + def configure_optimizers(self) -> Dict[str, Any]: + """Choose what optimizers and learning-rate schedulers to use in your optimization. + Normally you'd need one. But in the case of GANs or similar you might have multiple. + + Examples: + https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers + + :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. + """ + optimizer = self.hparams.optimizer(params=self.trainer.model.parameters()) + if self.hparams.scheduler is not None: + scheduler = self.hparams.scheduler(optimizer=optimizer) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val/loss", + "interval": "epoch", + "frequency": 1, + }, + } + return {"optimizer": optimizer} + + def optimizer_zero_grad(self, epoch, batch_idx, optimizer): + optimizer.zero_grad(set_to_none=True) + +if __name__ == "__main__": + _ = AASISTSSLLitModule(None, None, None, None)