From b585d10535c964480dcc3bef02c1bcfc888f2faa Mon Sep 17 00:00:00 2001 From: Hung Dinh Xuan Date: Wed, 30 Oct 2024 15:30:27 +0900 Subject: [PATCH] Add code for variable-length --- .gitignore | 6 +- .../aasistssl_multiview_conf-2.1.yaml | 47 +++ configs/model/xlsr_conformer_reproduce.yaml | 21 ++ notebooks/analyze.ipynb | 121 +++++++ notebooks/eval.ipynb | 79 +++++ notebooks/eval_metrics_DF.py | 318 ++++++++++++++++++ notebooks/test.ipynb | 144 ++++++++ src/data/asvspoof_multiview_datamodule.py | 88 ++++- src/data/normal_datamodule.py | 3 +- src/models/xlsr_conformer_reproduce_module.py | 244 ++++++++++++++ 10 files changed, 1051 insertions(+), 20 deletions(-) create mode 100644 configs/experiment/aasistssl_multiview_conf-2.1.yaml create mode 100644 configs/model/xlsr_conformer_reproduce.yaml create mode 100644 notebooks/analyze.ipynb create mode 100644 notebooks/eval.ipynb create mode 100644 notebooks/eval_metrics_DF.py create mode 100644 src/models/xlsr_conformer_reproduce_module.py diff --git a/.gitignore b/.gitignore index 29ec039..607232d 100644 --- a/.gitignore +++ b/.gitignore @@ -154,4 +154,8 @@ configs/local/default.yaml .aim # Neptune logging -.neptune \ No newline at end of file +.neptune + +# Notebooks +notebooks/**/*.txt +notebooks/**/*.csv \ No newline at end of file diff --git a/configs/experiment/aasistssl_multiview_conf-2.1.yaml b/configs/experiment/aasistssl_multiview_conf-2.1.yaml new file mode 100644 index 0000000..afa5e63 --- /dev/null +++ b/configs/experiment/aasistssl_multiview_conf-2.1.yaml @@ -0,0 +1,47 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /data: asvspoof_multiview + - override /model: xlsr_aasist_multiview + - override /callbacks: default_loss + - override /trainer: default + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["asvspoof_multiview", "xlsr_aasist_multiview"] + +seed: 1234 + +trainer: + max_epochs: 100 + gradient_clip_val: 0.0 + accelerator: cuda + +model: + optimizer: + lr: 0.000001 + weight_decay: 0.0001 + net: null + scheduler: null + compile: true + +data: + batch_size: 14 + num_workers: 8 + pin_memory: true + args: + padding_type: repeat + random_start: False + algo: 5 # Optimised for 5th algorithm LA + + +logger: + wandb: + tags: ${tags} + group: "asvspoof_multiview" + aim: + experiment: "asvspoof_multiview" diff --git a/configs/model/xlsr_conformer_reproduce.yaml b/configs/model/xlsr_conformer_reproduce.yaml new file mode 100644 index 0000000..eafa5cc --- /dev/null +++ b/configs/model/xlsr_conformer_reproduce.yaml @@ -0,0 +1,21 @@ +_target_: src.models.xlsr_conformer_reproduce_module.XLSRConformerTCMLitModule + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.000001 + weight_decay: 0.0001 + +scheduler: null + +args: + conformer: + emb_size: 144 + heads: 4 + kernel_size: 31 + n_encoders: 4 + +ssl_pretrained_path: ${oc.env:XLSR_PRETRAINED_MODEL_PATH} +cross_entropy_weight: [0.1, 0.9] # weight for cross entropy loss 0.1 for spoof and 0.9 for bonafide +# compile model for faster training with pytorch 2.0 +compile: false diff --git a/notebooks/analyze.ipynb b/notebooks/analyze.ipynb new file mode 100644 index 0000000..7c9c65e --- /dev/null +++ b/notebooks/analyze.ipynb @@ -0,0 +1,121 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processing audio files: 0%| | 0/31779 [00:00= asv_threshold) / non_asv.size + Pmiss_asv = sum(tar_asv < asv_threshold) / tar_asv.size + + # Rate of rejecting spoofs in ASV + if spoof_asv.size == 0: + Pmiss_spoof_asv = None + Pfa_spoof_asv = None + else: + Pmiss_spoof_asv = np.sum(spoof_asv < asv_threshold) / spoof_asv.size + Pfa_spoof_asv = np.sum(spoof_asv >= asv_threshold) / spoof_asv.size + + return Pfa_asv, Pmiss_asv, Pmiss_spoof_asv, Pfa_spoof_asv + + +def compute_det_curve(target_scores, nontarget_scores): + + n_scores = target_scores.size + nontarget_scores.size + all_scores = np.concatenate((target_scores, nontarget_scores)) + labels = np.concatenate((np.ones(target_scores.size), np.zeros(nontarget_scores.size))) + + # Sort labels based on scores + indices = np.argsort(all_scores, kind='mergesort') + labels = labels[indices] + + # Compute false rejection and false acceptance rates + tar_trial_sums = np.cumsum(labels) + nontarget_trial_sums = nontarget_scores.size - (np.arange(1, n_scores + 1) - tar_trial_sums) + + frr = np.concatenate((np.atleast_1d(0), tar_trial_sums / target_scores.size)) # false rejection rates + far = np.concatenate((np.atleast_1d(1), nontarget_trial_sums / nontarget_scores.size)) # false acceptance rates + thresholds = np.concatenate((np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices])) # Thresholds are the sorted scores + + return frr, far, thresholds + + +def compute_eer(target_scores, nontarget_scores): + """ Returns equal error rate (EER) and the corresponding threshold. """ + frr, far, thresholds = compute_det_curve(target_scores, nontarget_scores) + abs_diffs = np.abs(frr - far) + min_index = np.argmin(abs_diffs) + eer = np.mean((frr[min_index], far[min_index])) + return eer, thresholds[min_index] + + +def compute_tDCF(bonafide_score_cm, spoof_score_cm, Pfa_asv, Pmiss_asv, Pfa_spoof_asv, cost_model, print_cost): + """ + Compute Tandem Detection Cost Function (t-DCF) [1] for a fixed ASV system. + In brief, t-DCF returns a detection cost of a cascaded system of this form, + Speech waveform -> [CM] -> [ASV] -> decision + where CM stands for countermeasure and ASV for automatic speaker + verification. The CM is therefore used as a 'gate' to decided whether or + not the input speech sample should be passed onwards to the ASV system. + Generally, both CM and ASV can do detection errors. Not all those errors + are necessarily equally cost, and not all types of users are necessarily + equally likely. The tandem t-DCF gives a principled with to compare + different spoofing countermeasures under a detection cost function + framework that takes that information into account. + INPUTS: + bonafide_score_cm A vector of POSITIVE CLASS (bona fide or human) + detection scores obtained by executing a spoofing + countermeasure (CM) on some positive evaluation trials. + trial represents a bona fide case. + spoof_score_cm A vector of NEGATIVE CLASS (spoofing attack) + detection scores obtained by executing a spoofing + CM on some negative evaluation trials. + Pfa_asv False alarm (false acceptance) rate of the ASV + system that is evaluated in tandem with the CM. + Assumed to be in fractions, not percentages. + Pmiss_asv Miss (false rejection) rate of the ASV system that + is evaluated in tandem with the spoofing CM. + Assumed to be in fractions, not percentages. + Pmiss_spoof_asv Miss rate of spoof samples of the ASV system that + is evaluated in tandem with the spoofing CM. That + is, the fraction of spoof samples that were + rejected by the ASV system. + cost_model A struct that contains the parameters of t-DCF, + with the following fields. + Ptar Prior probability of target speaker. + Pnon Prior probability of nontarget speaker (zero-effort impostor) + Psoof Prior probability of spoofing attack. + Cmiss Cost of tandem system falsely rejecting target speaker. + Cfa Cost of tandem system falsely accepting nontarget speaker. + Cfa_spoof Cost of tandem system falsely accepting spoof. + print_cost Print a summary of the cost parameters and the + implied t-DCF cost function? + OUTPUTS: + tDCF_norm Normalized t-DCF curve across the different CM + system operating points; see [2] for more details. + Normalized t-DCF > 1 indicates a useless + countermeasure (as the tandem system would do + better without it). min(tDCF_norm) will be the + minimum t-DCF used in ASVspoof 2019 [2]. + CM_thresholds Vector of same size as tDCF_norm corresponding to + the CM threshold (operating point). + NOTE: + o In relative terms, higher detection scores values are assumed to + indicate stronger support for the bona fide hypothesis. + o You should provide real-valued soft scores, NOT hard decisions. The + recommendation is that the scores are log-likelihood ratios (LLRs) + from a bonafide-vs-spoof hypothesis based on some statistical model. + This, however, is NOT required. The scores can have arbitrary range + and scaling. + o Pfa_asv, Pmiss_asv, Pmiss_spoof_asv are in fractions, not percentages. + References: + [1] T. Kinnunen, H. Delgado, N. Evans,K.-A. Lee, V. Vestman, + A. Nautsch, M. Todisco, X. Wang, M. Sahidullah, J. Yamagishi, + and D.-A. Reynolds, "Tandem Assessment of Spoofing Countermeasures + and Automatic Speaker Verification: Fundamentals," IEEE/ACM Transaction on + Audio, Speech and Language Processing (TASLP). + [2] ASVspoof 2019 challenge evaluation plan + https://www.asvspoof.org/asvspoof2019/asvspoof2019_evaluation_plan.pdf + """ + + + # Sanity check of cost parameters + if cost_model['Cfa'] < 0 or cost_model['Cmiss'] < 0 or \ + cost_model['Cfa'] < 0 or cost_model['Cmiss'] < 0: + print('WARNING: Usually the cost values should be positive!') + + if cost_model['Ptar'] < 0 or cost_model['Pnon'] < 0 or cost_model['Pspoof'] < 0 or \ + np.abs(cost_model['Ptar'] + cost_model['Pnon'] + cost_model['Pspoof'] - 1) > 1e-10: + sys.exit('ERROR: Your prior probabilities should be positive and sum up to one.') + + # Unless we evaluate worst-case model, we need to have some spoof tests against asv + if Pfa_spoof_asv is None: + sys.exit('ERROR: you should provide false alarm rate of spoof tests against your ASV system.') + + # Sanity check of scores + combined_scores = np.concatenate((bonafide_score_cm, spoof_score_cm)) + if np.isnan(combined_scores).any() or np.isinf(combined_scores).any(): + sys.exit('ERROR: Your scores contain nan or inf.') + + # Sanity check that inputs are scores and not decisions + n_uniq = np.unique(combined_scores).size + if n_uniq < 3: + sys.exit('ERROR: You should provide soft CM scores - not binary decisions') + + # Obtain miss and false alarm rates of CM + Pmiss_cm, Pfa_cm, CM_thresholds = compute_det_curve(bonafide_score_cm, spoof_score_cm) + + # Constants - see ASVspoof 2019 evaluation plan + + C0 = cost_model['Ptar'] * cost_model['Cmiss'] * Pmiss_asv + cost_model['Pnon']*cost_model['Cfa']*Pfa_asv + C1 = cost_model['Ptar'] * cost_model['Cmiss'] - (cost_model['Ptar'] * cost_model['Cmiss'] * Pmiss_asv + cost_model['Pnon'] * cost_model['Cfa'] * Pfa_asv) + C2 = cost_model['Pspoof'] * cost_model['Cfa_spoof'] * Pfa_spoof_asv; + + + # Sanity check of the weights + if C0 < 0 or C1 < 0 or C2 < 0: + sys.exit('You should never see this error but I cannot evalute tDCF with negative weights - please check whether your ASV error rates are correctly computed?') + + # Obtain t-DCF curve for all thresholds + tDCF = C0 + C1 * Pmiss_cm + C2 * Pfa_cm + + # Obtain default t-DCF + tDCF_default = C0 + np.minimum(C1, C2) + + # Normalized t-DCF + tDCF_norm = tDCF / tDCF_default + + # Everything should be fine if reaching here. + if print_cost: + + print('t-DCF evaluation from [Nbona={}, Nspoof={}] trials\n'.format(bonafide_score_cm.size, spoof_score_cm.size)) + print('t-DCF MODEL') + print(' Ptar = {:8.5f} (Prior probability of target user)'.format(cost_model['Ptar'])) + print(' Pnon = {:8.5f} (Prior probability of nontarget user)'.format(cost_model['Pnon'])) + print(' Pspoof = {:8.5f} (Prior probability of spoofing attack)'.format(cost_model['Pspoof'])) + print(' Cfa = {:8.5f} (Cost of tandem system falsely accepting a nontarget)'.format(cost_model['Cfa'])) + print(' Cmiss = {:8.5f} (Cost of tandem system falsely rejecting target speaker)'.format(cost_model['Cmiss'])) + print(' Cfa_spoof = {:8.5f} (Cost of tandem sysmte falsely accepting spoof)'.format(cost_model['Cfa_spoof'])) + print('\n Implied normalized t-DCF function (depends on t-DCF parameters and ASV errors), t_CM=CM threshold)') + print(' tDCF_norm(t_CM) = {:8.5f} + {:8.5f} x Pmiss_cm(t_CM) + {:8.5f} x Pfa_cm(t_CM)\n'.format(C0/tDCF_default, C1/tDCF_default, C2/tDCF_default)) + print(' * The optimum value is given by the first term (0.06273). This is the normalized t-DCF obtained with an error-free CM system.') + print(' * The minimum normalized cost (minimum over all possible thresholds) is always <= 1.00.') + print('') + + return tDCF_norm, CM_thresholds + +def compute_tDCF_legacy(bonafide_score_cm, spoof_score_cm, Pfa_asv, Pmiss_asv, Pmiss_spoof_asv, cost_model, print_cost): + """ + Compute Tandem Detection Cost Function (t-DCF) [1] for a fixed ASV system. + In brief, t-DCF returns a detection cost of a cascaded system of this form, + Speech waveform -> [CM] -> [ASV] -> decision + where CM stands for countermeasure and ASV for automatic speaker + verification. The CM is therefore used as a 'gate' to decided whether or + not the input speech sample should be passed onwards to the ASV system. + Generally, both CM and ASV can do detection errors. Not all those errors + are necessarily equally cost, and not all types of users are necessarily + equally likely. The tandem t-DCF gives a principled with to compare + different spoofing countermeasures under a detection cost function + framework that takes that information into account. + INPUTS: + bonafide_score_cm A vector of POSITIVE CLASS (bona fide or human) + detection scores obtained by executing a spoofing + countermeasure (CM) on some positive evaluation trials. + trial represents a bona fide case. + spoof_score_cm A vector of NEGATIVE CLASS (spoofing attack) + detection scores obtained by executing a spoofing + CM on some negative evaluation trials. + Pfa_asv False alarm (false acceptance) rate of the ASV + system that is evaluated in tandem with the CM. + Assumed to be in fractions, not percentages. + Pmiss_asv Miss (false rejection) rate of the ASV system that + is evaluated in tandem with the spoofing CM. + Assumed to be in fractions, not percentages. + Pmiss_spoof_asv Miss rate of spoof samples of the ASV system that + is evaluated in tandem with the spoofing CM. That + is, the fraction of spoof samples that were + rejected by the ASV system. + cost_model A struct that contains the parameters of t-DCF, + with the following fields. + Ptar Prior probability of target speaker. + Pnon Prior probability of nontarget speaker (zero-effort impostor) + Psoof Prior probability of spoofing attack. + Cmiss_asv Cost of ASV falsely rejecting target. + Cfa_asv Cost of ASV falsely accepting nontarget. + Cmiss_cm Cost of CM falsely rejecting target. + Cfa_cm Cost of CM falsely accepting spoof. + print_cost Print a summary of the cost parameters and the + implied t-DCF cost function? + OUTPUTS: + tDCF_norm Normalized t-DCF curve across the different CM + system operating points; see [2] for more details. + Normalized t-DCF > 1 indicates a useless + countermeasure (as the tandem system would do + better without it). min(tDCF_norm) will be the + minimum t-DCF used in ASVspoof 2019 [2]. + CM_thresholds Vector of same size as tDCF_norm corresponding to + the CM threshold (operating point). + NOTE: + o In relative terms, higher detection scores values are assumed to + indicate stronger support for the bona fide hypothesis. + o You should provide real-valued soft scores, NOT hard decisions. The + recommendation is that the scores are log-likelihood ratios (LLRs) + from a bonafide-vs-spoof hypothesis based on some statistical model. + This, however, is NOT required. The scores can have arbitrary range + and scaling. + o Pfa_asv, Pmiss_asv, Pmiss_spoof_asv are in fractions, not percentages. + References: + [1] T. Kinnunen, K.-A. Lee, H. Delgado, N. Evans, M. Todisco, + M. Sahidullah, J. Yamagishi, D.A. Reynolds: "t-DCF: a Detection + Cost Function for the Tandem Assessment of Spoofing Countermeasures + and Automatic Speaker Verification", Proc. Odyssey 2018: the + Speaker and Language Recognition Workshop, pp. 312--319, Les Sables d'Olonne, + France, June 2018 (https://www.isca-speech.org/archive/Odyssey_2018/pdfs/68.pdf) + [2] ASVspoof 2019 challenge evaluation plan + https://www.asvspoof.org/asvspoof2019/asvspoof2019_evaluation_plan.pdf + """ + + + # Sanity check of cost parameters + if cost_model['Cfa_asv'] < 0 or cost_model['Cmiss_asv'] < 0 or \ + cost_model['Cfa_cm'] < 0 or cost_model['Cmiss_cm'] < 0: + print('WARNING: Usually the cost values should be positive!') + + if cost_model['Ptar'] < 0 or cost_model['Pnon'] < 0 or cost_model['Pspoof'] < 0 or \ + np.abs(cost_model['Ptar'] + cost_model['Pnon'] + cost_model['Pspoof'] - 1) > 1e-10: + sys.exit('ERROR: Your prior probabilities should be positive and sum up to one.') + + # Unless we evaluate worst-case model, we need to have some spoof tests against asv + if Pmiss_spoof_asv is None: + sys.exit('ERROR: you should provide miss rate of spoof tests against your ASV system.') + + # Sanity check of scores + combined_scores = np.concatenate((bonafide_score_cm, spoof_score_cm)) + if np.isnan(combined_scores).any() or np.isinf(combined_scores).any(): + sys.exit('ERROR: Your scores contain nan or inf.') + + # Sanity check that inputs are scores and not decisions + n_uniq = np.unique(combined_scores).size + if n_uniq < 3: + sys.exit('ERROR: You should provide soft CM scores - not binary decisions') + + # Obtain miss and false alarm rates of CM + Pmiss_cm, Pfa_cm, CM_thresholds = compute_det_curve(bonafide_score_cm, spoof_score_cm) + + # Constants - see ASVspoof 2019 evaluation plan + C1 = cost_model['Ptar'] * (cost_model['Cmiss_cm'] - cost_model['Cmiss_asv'] * Pmiss_asv) - \ + cost_model['Pnon'] * cost_model['Cfa_asv'] * Pfa_asv + C2 = cost_model['Cfa_cm'] * cost_model['Pspoof'] * (1 - Pmiss_spoof_asv) + + # Sanity check of the weights + if C1 < 0 or C2 < 0: + sys.exit('You should never see this error but I cannot evalute tDCF with negative weights - please check whether your ASV error rates are correctly computed?') + + # Obtain t-DCF curve for all thresholds + tDCF = C1 * Pmiss_cm + C2 * Pfa_cm + + # Normalized t-DCF + tDCF_norm = tDCF / np.minimum(C1, C2) + + # Everything should be fine if reaching here. + if print_cost: + + print('t-DCF evaluation from [Nbona={}, Nspoof={}] trials\n'.format(bonafide_score_cm.size, spoof_score_cm.size)) + print('t-DCF MODEL') + print(' Ptar = {:8.5f} (Prior probability of target user)'.format(cost_model['Ptar'])) + print(' Pnon = {:8.5f} (Prior probability of nontarget user)'.format(cost_model['Pnon'])) + print(' Pspoof = {:8.5f} (Prior probability of spoofing attack)'.format(cost_model['Pspoof'])) + print(' Cfa_asv = {:8.5f} (Cost of ASV falsely accepting a nontarget)'.format(cost_model['Cfa_asv'])) + print(' Cmiss_asv = {:8.5f} (Cost of ASV falsely rejecting target speaker)'.format(cost_model['Cmiss_asv'])) + print(' Cfa_cm = {:8.5f} (Cost of CM falsely passing a spoof to ASV system)'.format(cost_model['Cfa_cm'])) + print(' Cmiss_cm = {:8.5f} (Cost of CM falsely blocking target utterance which never reaches ASV)'.format(cost_model['Cmiss_cm'])) + print('\n Implied normalized t-DCF function (depends on t-DCF parameters and ASV errors), s=CM threshold)') + + if C2 == np.minimum(C1, C2): + print(' tDCF_norm(s) = {:8.5f} x Pmiss_cm(s) + Pfa_cm(s)\n'.format(C1 / C2)) + else: + print(' tDCF_norm(s) = Pmiss_cm(s) + {:8.5f} x Pfa_cm(s)\n'.format(C2 / C1)) + + return tDCF_norm, CM_thresholds \ No newline at end of file diff --git a/notebooks/test.ipynb b/notebooks/test.ipynb index e3d7c55..30707fd 100644 --- a/notebooks/test.ipynb +++ b/notebooks/test.ipynb @@ -290,6 +290,150 @@ " break\n", "\n" ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "('xlsr_conformertcm/2s.txt',)\n", + "611829\n", + "('xlsr_conformertcm/1s.txt',)\n", + "611829\n", + "('xlsr_conformertcm/3s.txt',)\n", + "611829\n", + "('xlsr_conformertcm/4s.txt',)\n", + "611829\n", + "('xlsr_conformertcm/2s.txt', 'xlsr_conformertcm/1s.txt')\n", + "611829\n", + "('xlsr_conformertcm/2s.txt', 'xlsr_conformertcm/3s.txt')\n", + "611829\n", + "('xlsr_conformertcm/2s.txt', 'xlsr_conformertcm/4s.txt')\n", + "611829\n", + "('xlsr_conformertcm/1s.txt', 'xlsr_conformertcm/3s.txt')\n", + "611829\n", + "('xlsr_conformertcm/1s.txt', 'xlsr_conformertcm/4s.txt')\n", + "611829\n", + "('xlsr_conformertcm/3s.txt', 'xlsr_conformertcm/4s.txt')\n", + "611829\n", + "('xlsr_conformertcm/2s.txt', 'xlsr_conformertcm/1s.txt', 'xlsr_conformertcm/3s.txt')\n", + "611829\n", + "('xlsr_conformertcm/2s.txt', 'xlsr_conformertcm/1s.txt', 'xlsr_conformertcm/4s.txt')\n", + "611829\n", + "('xlsr_conformertcm/2s.txt', 'xlsr_conformertcm/3s.txt', 'xlsr_conformertcm/4s.txt')\n", + "611829\n", + "('xlsr_conformertcm/1s.txt', 'xlsr_conformertcm/3s.txt', 'xlsr_conformertcm/4s.txt')\n", + "611829\n", + "('xlsr_conformertcm/2s.txt', 'xlsr_conformertcm/1s.txt', 'xlsr_conformertcm/3s.txt', 'xlsr_conformertcm/4s.txt')\n", + "611829\n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "import os\n", + "from itertools import combinations\n", + "\n", + "# Score fusion\n", + "def fusion_score(score_files: list):\n", + " \"\"\"Fuse the score of all files in the list, calculate EER, ACC, etc. after fusion\n", + "\n", + " Args:\n", + " score_files (list): list of score file\n", + "\n", + " \"\"\"\n", + "\n", + " all_scores = pd.DataFrame()\n", + " for score_file in score_files:\n", + " if not os.path.exists(score_file):\n", + " print(f\"File not found: {score_file}\")\n", + " continue\n", + " df = pd.read_csv(score_file, sep=' ', header=None, names=['path', 'score'])\n", + " df['utt'] = df['path'].apply(lambda x: x.split('/')[-1].split('.')[0])\n", + " df.sort_values(by='utt', inplace=True)\n", + " all_scores = pd.concat([all_scores, df])\n", + " \n", + " if all_scores.empty:\n", + " print(\"No valid score files found.\")\n", + " return pd.DataFrame()\n", + "\n", + " grouped_scores = all_scores.groupby('utt').mean().reset_index()\n", + " res_df = grouped_scores\n", + " print(len(res_df))\n", + " return res_df\n", + "\n", + "\n", + "list_score_files = [\n", + " os.path.join(\"xlsr_conformertcm\", file) for file in os.listdir(\"xlsr_conformertcm\") if file.endswith(\".txt\")\n", + "]\n", + "\n", + "# Write fusion scores as a combination of all scores\n", + "# For example if we have 3 scores, we will have 3! = 6 combinations\n", + "\n", + "output_dir = \"fusion_scores_xlsr_conformertcm\"\n", + "os.makedirs(output_dir, exist_ok=True)\n", + "\n", + "for i in range(1, len(list_score_files) + 1):\n", + " for comb in combinations(list_score_files, i):\n", + " print(comb)\n", + " fusion_result = fusion_score(list(comb))\n", + " if not fusion_result.empty:\n", + " output_file = os.path.join(output_dir, f\"fusion_score_{'_'.join([os.path.basename(file).split('.')[0] for file in comb])}.txt\")\n", + " fusion_result.to_csv(output_file, sep=' ', index=False, header=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import os\n", + "from itertools import combinations\n", + "\n", + "# Score fusion\n", + "def fusion_score(score_files: list):\n", + " \"\"\"Fuse the score of all files in the list, calculate EER, ACC, etc. after fusion\n", + "\n", + " Args:\n", + " score_files (list): list of score file\n", + "\n", + " \"\"\"\n", + "\n", + " all_scores = pd.DataFrame()\n", + " for score_file in score_files:\n", + " if not os.path.exists(score_file):\n", + " print(f\"File not found: {score_file}\")\n", + " continue\n", + " df = pd.read_csv(score_file, sep=' ', header=None, names=['path', 'score'])\n", + " df['utt'] = df['path'].apply(lambda x: x.split('/')[-1].split('.')[0])\n", + " df.sort_values(by='utt', inplace=True)\n", + " all_scores = pd.concat([all_scores, df])\n", + " \n", + " if all_scores.empty:\n", + " print(\"No valid score files found.\")\n", + " return pd.DataFrame()\n", + "\n", + " grouped_scores = all_scores.groupby('utt').mean().reset_index()\n", + " res_df = grouped_scores\n", + " print(len(res_df))\n", + " return res_df\n", + "\n", + "\n", + "list_score_files = [\n", + " \"/data/hungdx/Lightning-hydra/logs/eval/xlsr_aasist_multiview_conf-2_epoch15_var.txt\",\n", + " \"/data/hungdx/Lightning-hydra/logs/eval/xlsr_aasist_multiview_conf-2_epoch15.txt\"\n", + "]\n", + "\n", + "# Export fusion\n", + "\n", + "fusion_score(list_score_files).to_csv(\"best_fusion_xlsr_aasist_multiview_conf-2.txt\", sep=' ', index=False, header=False)\n" + ] } ], "metadata": { diff --git a/src/data/asvspoof_multiview_datamodule.py b/src/data/asvspoof_multiview_datamodule.py index c4fb939..c6f522a 100644 --- a/src/data/asvspoof_multiview_datamodule.py +++ b/src/data/asvspoof_multiview_datamodule.py @@ -48,13 +48,17 @@ def __init__(self, args, list_IDs, base_dir): self.base_dir = base_dir # Sampling rate and cut-off + self.no_pad = args.get('no_pad', False) if args is not None else False self.fs = args.get('sampling_rate', 16000) if args is not None else 16000 - self.cut = args.get('cut', 64600) if args is not None else 64600 - self.padding_type = args.get('padding_type', 'zero') if args is not None else 'zero' - self.random_start = args.get('random_start', False) if args is not None else False - print('padding_type:',self.padding_type) - print('cut:',self.cut) - print('random_start:',self.random_start) + if self.no_pad: + print('No padding') + else: + self.cut = args.get('cut', 64600) if args is not None else 64600 + self.padding_type = args.get('padding_type', 'zero') if args is not None else 'zero' + self.random_start = args.get('random_start', False) if args is not None else False + print('padding_type:',self.padding_type) + print('cut:',self.cut) + print('random_start:',self.random_start) def __len__(self): return len(self.list_IDs) @@ -62,10 +66,41 @@ def __len__(self): def __getitem__(self, index): utt_id = self.list_IDs[index] X, fs = librosa.load(self.base_dir+utt_id+'.flac', sr=self.fs) - X_pad = pad(X,self.padding_type, self.cut, self.random_start) - x_inp = Tensor(X_pad) + if not self.no_pad: + X_pad = pad(X,self.padding_type, self.cut, self.random_start) + x_inp = Tensor(X_pad) if not self.no_pad else Tensor(X) return x_inp,utt_id +class Dataset_Normal_eval(Dataset): + def __init__(self, args, list_IDs, base_dir): + self.list_IDs = list_IDs + self.base_dir = base_dir + + # Sampling rate and cut-off + self.no_pad = args.get('no_pad', False) if args is not None else False + self.fs = args.get('sampling_rate', 16000) if args is not None else 16000 + if self.no_pad: + print('No padding') + else: + self.cut = args.get('cut', 64600) if args is not None else 64600 + self.padding_type = args.get('padding_type', 'zero') if args is not None else 'zero' + self.random_start = args.get('random_start', False) if args is not None else False + print('padding_type:',self.padding_type) + print('cut:',self.cut) + print('random_start:',self.random_start) + + def __len__(self): + return len(self.list_IDs) + + def __getitem__(self, index): + utt_id = self.list_IDs[index] + X, fs = librosa.load(os.path.join( + self.base_dir, utt_id + ), sr=self.fs) + if not self.no_pad: + X_pad = pad(X,self.padding_type, self.cut, self.random_start) + x_inp = Tensor(X_pad) if not self.no_pad else Tensor(X) + return x_inp,utt_id class ASVSpoofDataModule(LightningDataModule): """`LightningDataModule` for the ASVSpoof dataset. @@ -181,14 +216,21 @@ def setup(self, stage: Optional[str] = None) -> None: prefix_2021 = 'ASVspoof2021.{}'.format(track) self.algo = self.args.get('algo', -1) if self.args is not None else -1 - d_label_trn,file_train = self.genSpoof_list( dir_meta = os.path.join(self.protocols_path+'ASVspoof_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt'),is_train=True,is_eval=False) - d_label_dev,file_dev = self.genSpoof_list( dir_meta = os.path.join(self.protocols_path+'ASVspoof_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt'),is_train=False,is_eval=False) - file_eval = self.genSpoof_list( dir_meta = os.path.join(self.protocols_path+'ASVspoof_{}_cm_protocols/{}.cm.eval.trl.txt'.format(track,prefix_2021)),is_train=False,is_eval=True) - - self.data_train = Dataset_ASVspoof2019_train(self.args,list_IDs = file_train,labels = d_label_trn,base_dir = os.path.join(self.database_path+'ASVspoof2019_LA_train/'),algo=self.algo) - self.data_val = Dataset_ASVspoof2019_train(self.args,list_IDs = file_dev,labels = d_label_dev,base_dir = os.path.join(self.database_path+'ASVspoof2019_LA_dev/'),algo=self.algo) - self.data_test = Dataset_ASVspoof2021_eval(self.args, list_IDs = file_eval,base_dir = os.path.join(self.database_path+'ASVspoof2021_{}_eval/'.format(track))) - + if not self.args.get('eval', False): + d_label_trn,file_train = self.genSpoof_list( dir_meta = os.path.join(self.protocols_path+'ASVspoof_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt'),is_train=True,is_eval=False) + d_label_dev,file_dev = self.genSpoof_list( dir_meta = os.path.join(self.protocols_path+'ASVspoof_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt'),is_train=False,is_eval=False) + + self.data_train = Dataset_ASVspoof2019_train(self.args,list_IDs = file_train,labels = d_label_trn,base_dir = os.path.join(self.database_path+'ASVspoof2019_LA_train/'),algo=self.algo) + self.data_val = Dataset_ASVspoof2019_train(self.args,list_IDs = file_dev,labels = d_label_dev,base_dir = os.path.join(self.database_path+'ASVspoof2019_LA_dev/'),algo=self.algo) + + if self.args.get('eval_set', 'DF21') == 'DF21': + print('Using ASVspoof2021 evaluation set') + file_eval = self.genSpoof_list( dir_meta = os.path.join(self.protocols_path+'ASVspoof_{}_cm_protocols/{}.cm.eval.trl.txt'.format(track,prefix_2021)),is_train=False,is_eval=True) + self.data_test = Dataset_ASVspoof2021_eval(self.args, list_IDs = file_eval,base_dir = os.path.join(self.database_path+'ASVspoof2021_{}_eval/'.format(track))) + else: # Using in-the-wild evaluation set + print('Using in-the-wild evaluation set') + file_eval = self.genInTheWild_list( dir_meta = self.protocols_path) + self.data_test = Dataset_Normal_eval(self.args,list_IDs = file_eval,base_dir = self.database_path) def train_dataloader(self) -> DataLoader[Any]: """Create and return the train dataloader. @@ -256,7 +298,19 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: :param state_dict: The datamodule state returned by `self.state_dict()`. """ pass - + + def genInTheWild_list(self, dir_meta): + """ + This function is from the following source: + """ + file_list=[] + with open(dir_meta, 'r') as f: + l_meta = f.readlines() + for line in l_meta: + key,label = line.strip().split() + file_list.append(key) + return file_list + def genSpoof_list(self, dir_meta, is_train=False, is_eval=False): """ This function is from the following source: https://github.com/TakHemlata/SSL_Anti-spoofing/blob/main/data_utils_SSL.py#L17 diff --git a/src/data/normal_datamodule.py b/src/data/normal_datamodule.py index e6bf350..91e107a 100644 --- a/src/data/normal_datamodule.py +++ b/src/data/normal_datamodule.py @@ -334,8 +334,7 @@ def genList(self, dir_meta, is_train=False, is_eval=False, is_dev=False): utt, subset, label = line.strip().split() if subset == 'eval': file_list.append(utt) - file_list.append(utt) - d_meta[utt] = 1 if label == 'bonafide' else 0 + d_meta[utt] = 1 if label == 'bonafide' else 0 # return d_meta, file_list return d_meta, file_list diff --git a/src/models/xlsr_conformer_reproduce_module.py b/src/models/xlsr_conformer_reproduce_module.py new file mode 100644 index 0000000..b33f908 --- /dev/null +++ b/src/models/xlsr_conformer_reproduce_module.py @@ -0,0 +1,244 @@ +from typing import Any, Dict, Tuple + +import torch +from lightning import LightningModule +from torchmetrics import MaxMetric, MeanMetric +from torchmetrics.classification.accuracy import BinaryAccuracy + +from typing import Union + +import torch +from src.models.components.xlsr_conformer_reproduce import Model as XLSRConformer + +class XLSRConformerLitModule(LightningModule): + """Example of a `LightningModule` for MNIST classification. + + A `LightningModule` implements 8 key methods: + + ```python + def __init__(self): + # Define initialization code here. + + def setup(self, stage): + # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'. + # This hook is called on every process when using DDP. + + def training_step(self, batch, batch_idx): + # The complete training step. + + def validation_step(self, batch, batch_idx): + # The complete validation step. + + def test_step(self, batch, batch_idx): + # The complete test step. + + def predict_step(self, batch, batch_idx): + # The complete predict step. + + def configure_optimizers(self): + # Define and configure optimizers and LR schedulers. + ``` + + Docs: + https://lightning.ai/docs/pytorch/latest/common/lightning_module.html + """ + + def __init__( + self, + net: torch.nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler, + compile: bool, + args: Union[Dict[str, Any], None] = None, + ssl_pretrained_path: str = None, + score_save_path: str = None, + cross_entropy_weight: list[float] = [0.1, 0.9], + ) -> None: + """Initialize a `MNISTLitModule`. + + :param net: The model to train. + :param optimizer: The optimizer to use for training. + :param scheduler: The learning rate scheduler to use for training. + """ + super().__init__() + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False) + + self.net = XLSRConformer(args['conformer'], ssl_pretrained_path) + # loss function + cross_entropy_weight = torch.tensor(cross_entropy_weight) + self.criterion = torch.nn.CrossEntropyLoss(cross_entropy_weight) + + # metric objects for calculating and averaging accuracy across batches + self.train_acc = BinaryAccuracy() + self.val_acc = BinaryAccuracy() + self.test_acc = BinaryAccuracy() + self.score_save_path = score_save_path + #self.test_eer = EERMetric() + + # for averaging loss across batches + self.train_loss = MeanMetric() + self.val_loss = MeanMetric() + self.test_loss = MeanMetric() + + # for tracking best so far validation accuracy + self.val_acc_best = MaxMetric() + + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Perform a forward pass through the model `self.net`. + + :param x: A tensor of images. + :return: A tensor of logits. + """ + return self.net(x) + + def on_train_start(self) -> None: + """Lightning hook that is called when training begins.""" + # by default lightning executes validation step sanity checks before training starts, + # so it's worth to make sure validation metrics don't store results from these checks + self.val_loss.reset() + self.val_acc.reset() + self.val_acc_best.reset() + + def model_step( + self, batch: Tuple[torch.Tensor, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Perform a single model step on a batch of data. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. + + :return: A tuple containing (in order): + - A tensor of losses. + - A tensor of predictions. + - A tensor of target labels. + - A dictionary of detailed losses. + """ + + x, y = batch + logits = self.forward(x) + loss = self.criterion(logits, y) + preds = torch.argmax(logits, dim=1) + return loss, preds, y + + def training_step( + self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int + ) -> torch.Tensor: + """Perform a single training step on a batch of data from the training set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + :return: A tensor of losses between model predictions and targets. + """ + loss, preds, targets = self.model_step(batch) + + # update and log metrics + self.train_loss(loss) + self.train_acc(preds, targets) + self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True) + + # return loss or backpropagation will fail + return loss + + def on_train_epoch_end(self) -> None: + "Lightning hook that is called when a training epoch ends." + pass + + def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + """Perform a single validation step on a batch of data from the validation set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + """ + loss, preds, targets = self.model_step(batch) + + # update and log metrics + self.val_loss(loss) + self.val_acc(preds, targets) + self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) + + def on_validation_epoch_end(self) -> None: + "Lightning hook that is called when a validation epoch ends." + acc = self.val_acc.compute() # get current val acc + self.val_acc_best(acc) # update best so far val acc + # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object + # otherwise metric would be reset by lightning after each epoch + self.log("val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True) + + def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + """Perform a single test step on a batch of data from the test set. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + :param batch_idx: The index of the current batch. + """ + if self.score_save_path is not None: + self._export_score_file(batch) + else: + raise ValueError("score_save_path is not provided") + + def _export_score_file(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> None: + """Get the score file for the batch of data. + + :param batch: A batch of data (a tuple) containing the input tensor of images and target + labels. + """ + batch_x, utt_id = batch + batch_out = self.net(batch_x) + + fname_list = list(utt_id) + score_list = batch_out.data.cpu().numpy().tolist() + + with open(self.score_save_path, 'a+') as fh: + for f, cm in zip(fname_list, score_list): + fh.write('{} {}\n'.format(f, cm[1])) + + def on_test_epoch_end(self) -> None: + """Lightning hook that is called when a test epoch ends.""" + pass + + def setup(self, stage: str) -> None: + """Lightning hook that is called at the beginning of fit (train + validate), validate, + test, or predict. + + This is a good hook when you need to build models dynamically or adjust something about + them. This hook is called on every process when using DDP. + + :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. + """ + # if self.hparams.compile and stage == "fit": + # self.net = torch.compile(self.net) + pass + + def configure_optimizers(self) -> Dict[str, Any]: + """Choose what optimizers and learning-rate schedulers to use in your optimization. + Normally you'd need one. But in the case of GANs or similar you might have multiple. + + Examples: + https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers + + :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. + """ + optimizer = self.hparams.optimizer(params=self.trainer.model.parameters()) + if self.hparams.scheduler is not None: + scheduler = self.hparams.scheduler(optimizer=optimizer) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val/loss", + "interval": "epoch", + "frequency": 1, + }, + } + return {"optimizer": optimizer} + + +# if __name__ == "__main__": +# _ = WAVLMVIBLLitModule(None, None, None, None)