generated from ashleve/lightning-hydra-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7b96ffa
commit b585d10
Showing
10 changed files
with
1,051 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# @package _global_ | ||
|
||
# to execute this experiment run: | ||
# python train.py experiment=example | ||
|
||
defaults: | ||
- override /data: asvspoof_multiview | ||
- override /model: xlsr_aasist_multiview | ||
- override /callbacks: default_loss | ||
- override /trainer: default | ||
|
||
# all parameters below will be merged with parameters from default configurations set above | ||
# this allows you to overwrite only specified parameters | ||
|
||
tags: ["asvspoof_multiview", "xlsr_aasist_multiview"] | ||
|
||
seed: 1234 | ||
|
||
trainer: | ||
max_epochs: 100 | ||
gradient_clip_val: 0.0 | ||
accelerator: cuda | ||
|
||
model: | ||
optimizer: | ||
lr: 0.000001 | ||
weight_decay: 0.0001 | ||
net: null | ||
scheduler: null | ||
compile: true | ||
|
||
data: | ||
batch_size: 14 | ||
num_workers: 8 | ||
pin_memory: true | ||
args: | ||
padding_type: repeat | ||
random_start: False | ||
algo: 5 # Optimised for 5th algorithm LA | ||
|
||
|
||
logger: | ||
wandb: | ||
tags: ${tags} | ||
group: "asvspoof_multiview" | ||
aim: | ||
experiment: "asvspoof_multiview" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
_target_: src.models.xlsr_conformer_reproduce_module.XLSRConformerTCMLitModule | ||
|
||
optimizer: | ||
_target_: torch.optim.Adam | ||
_partial_: true | ||
lr: 0.000001 | ||
weight_decay: 0.0001 | ||
|
||
scheduler: null | ||
|
||
args: | ||
conformer: | ||
emb_size: 144 | ||
heads: 4 | ||
kernel_size: 31 | ||
n_encoders: 4 | ||
|
||
ssl_pretrained_path: ${oc.env:XLSR_PRETRAINED_MODEL_PATH} | ||
cross_entropy_weight: [0.1, 0.9] # weight for cross entropy loss 0.1 for spoof and 0.9 for bonafide | ||
# compile model for faster training with pytorch 2.0 | ||
compile: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Processing audio files: 0%| | 0/31779 [00:00<?, ?it/s]/tmp/ipykernel_3393222/3850030267.py:28: FutureWarning: get_duration() keyword argument 'filename' has been renamed to 'path' in version 0.10.0.\n", | ||
"\tThis alias will be removed in version 1.0.\n", | ||
" duration = librosa.get_duration(filename=audio_path)\n", | ||
"Processing audio files: 100%|██████████| 31779/31779 [02:46<00:00, 190.63it/s]" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Results written to in_the_wild_durations.csv\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import os\n", | ||
"import librosa\n", | ||
"import csv\n", | ||
"from tqdm import tqdm\n", | ||
"\n", | ||
"def get_audio_durations(audio_folder, output_csv):\n", | ||
" \"\"\"\n", | ||
" Calculate duration of each audio file in the folder and write results to CSV\n", | ||
" \n", | ||
" Parameters:\n", | ||
" audio_folder (str): Path to folder containing audio files\n", | ||
" output_csv (str): Path to output CSV file\n", | ||
" \"\"\"\n", | ||
" \n", | ||
" # Get list of audio files\n", | ||
" audio_files = [f for f in os.listdir(audio_folder) if f.endswith(('.wav', '.mp3', '.flac'))]\n", | ||
" \n", | ||
" # Open CSV file to write results\n", | ||
" with open(output_csv, 'w', newline='') as csvfile:\n", | ||
" writer = csv.writer(csvfile)\n", | ||
" writer.writerow(['filename', 'duration']) # Write header\n", | ||
" \n", | ||
" # Process each audio file with progress bar\n", | ||
" for audio_file in tqdm(audio_files, desc=\"Processing audio files\"):\n", | ||
" try:\n", | ||
" # Load audio file and get duration\n", | ||
" audio_path = os.path.join(audio_folder, audio_file)\n", | ||
" duration = librosa.get_duration(filename=audio_path)\n", | ||
" \n", | ||
" # Write result to CSV\n", | ||
" writer.writerow([audio_file, f\"{duration:.2f}\"])\n", | ||
" \n", | ||
" except Exception as e:\n", | ||
" print(f\"Error processing {audio_file}: {str(e)}\")\n", | ||
"\n", | ||
"if __name__ == \"__main__\":\n", | ||
" # Example usage\n", | ||
" audio_folder = \"/data/hungdx/Lightning-hydra/data/in_the_wild\"\n", | ||
" output_csv = \"in_the_wild_durations.csv\"\n", | ||
" \n", | ||
" get_audio_durations(audio_folder, output_csv)\n", | ||
" print(f\"Results written to {output_csv}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Average duration: 4.287989552849366\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import pandas as pd\n", | ||
"\n", | ||
"df = pd.read_csv('in_the_wild_durations.csv')\n", | ||
"\n", | ||
"print(\"Average duration: \", df['duration'].mean())\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "asvspoof5", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.19" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"eer: 9.326796296235376\tthreshold: -2.986328125\n", | ||
"\n", | ||
"0.09326796296235375\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import sys\n", | ||
"import os.path\n", | ||
"import numpy as np\n", | ||
"import pandas\n", | ||
"import eval_metrics_DF as em\n", | ||
"\n", | ||
"\n", | ||
"def eval_to_score_file(score_file, cm_key_file):\n", | ||
" # CM key file is the metadata file that contains the ground truth labels for the eval set\n", | ||
" # score file is the output of the system that contains the scores for the eval set\n", | ||
" # phase is the phase of the eval set (dev or eval)\n", | ||
"\n", | ||
" cm_data = pandas.read_csv(cm_key_file, sep=' ', header=None)\n", | ||
" submission_scores = pandas.read_csv(\n", | ||
" score_file, sep=' ', header=None, skipinitialspace=True)\n", | ||
" # check here for progress vs eval set\n", | ||
" cm_scores = submission_scores.merge(cm_data, left_on=0, right_on=0, how='inner')\n", | ||
" # cm_scores.head()\n", | ||
" # 0 1_x 1_y 2 3\n", | ||
" # a.wav 1.234 eval Music spoof\n", | ||
" bona_cm = cm_scores[cm_scores['1_y'] == 'bonafide']['1_x'].values\n", | ||
" spoof_cm = cm_scores[cm_scores['1_y'] == 'spoof']['1_x'].values\n", | ||
"\n", | ||
" eer_cm, th = em.compute_eer(bona_cm, spoof_cm)\n", | ||
" out_data = \"eer: {}\\tthreshold: {}\\n\".format(100*eer_cm, th)\n", | ||
" print(out_data)\n", | ||
" return eer_cm\n", | ||
"\n", | ||
"print(eval_to_score_file(\"/data/hungdx/Lightning-hydra/logs/eval/itw_xlsr_aasist_multiview_conf-2_epoch15_3s.txt\", \"/dataa/Datasets/in_the_wild.txt\"))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "base", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.