From e43863e4a0c1b2c804c24b8402a4aadad357894e Mon Sep 17 00:00:00 2001 From: Hung Dinh Xuan Date: Wed, 11 Dec 2024 13:28:57 +0900 Subject: [PATCH] updare code for normal multiview da --- notebooks/prepare_dataset.ipynb | 139 ++++++++++++++++++++++++ src/data/normal_multiview_datamodule.py | 17 ++- 2 files changed, 146 insertions(+), 10 deletions(-) create mode 100644 notebooks/prepare_dataset.ipynb diff --git a/notebooks/prepare_dataset.ipynb b/notebooks/prepare_dataset.ipynb new file mode 100644 index 0000000..033ae53 --- /dev/null +++ b/notebooks/prepare_dataset.ipynb @@ -0,0 +1,139 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import os\n", + "import librosa\n", + "from tqdm import tqdm\n", + "from multiprocessing import Pool, cpu_count\n", + "from functools import partial\n", + "\n", + "\n", + "def get_audio_duration(row, base_dir):\n", + " \"\"\"Calculate duration for a single audio file\"\"\"\n", + " try:\n", + " file_path = os.path.join(base_dir, row['utt_id'])\n", + " duration = librosa.get_duration(path=file_path)\n", + " return {\n", + " 'utt_id': row['utt_id'],\n", + " 'subset': row['subset'],\n", + " 'label': row['label'],\n", + " 'duration': duration\n", + " }\n", + " except Exception as e:\n", + " print(f\"Error processing {row['utt_id']}: {str(e)}\")\n", + " return {\n", + " 'utt_id': row['utt_id'],\n", + " 'subset': row['subset'],\n", + " 'label': row['label'],\n", + " 'duration': -1 # Mark failed files with -1\n", + " }\n", + "\n", + "\n", + "def process_chunk(chunk, base_dir):\n", + " \"\"\"Process a chunk of the dataframe\"\"\"\n", + " return [get_audio_duration(row, base_dir) for row in chunk.to_dict('records')]\n", + "\n", + "\n", + "def calculate_durations(protocol_file, base_dir, output_file, n_workers=None):\n", + " \"\"\"\n", + " Calculate durations for all audio files in parallel\n", + " \n", + " Args:\n", + " protocol_file: Path to protocol file\n", + " base_dir: Base directory containing audio files\n", + " output_file: Path to output CSV file\n", + " n_workers: Number of worker processes (default: CPU count - 1)\n", + " \"\"\"\n", + " if n_workers is None:\n", + " n_workers = cpu_count() - 1\n", + "\n", + " print(\"Reading protocol file...\")\n", + " protocol = pd.read_csv(protocol_file, sep=\" \", header=None)\n", + " protocol.columns = [\"utt_id\", \"subset\", \"label\"]\n", + "\n", + " # Split dataframe into chunks for parallel processing\n", + " chunk_size = len(protocol) // n_workers + 1\n", + " chunks = np.array_split(protocol, n_workers)\n", + "\n", + " print(f\"Processing {len(protocol)} files using {n_workers} workers...\")\n", + "\n", + " # Process chunks in parallel\n", + " with Pool(n_workers) as pool:\n", + " partial_process = partial(process_chunk, base_dir=base_dir)\n", + " results = list(tqdm(\n", + " pool.imap(partial_process, chunks),\n", + " total=len(chunks),\n", + " desc=\"Calculating durations\"\n", + " ))\n", + "\n", + " # Flatten results and convert to dataframe\n", + " all_results = [item for sublist in results for item in sublist]\n", + " df_results = pd.DataFrame(all_results)\n", + "\n", + " # Calculate statistics\n", + " valid_durations = df_results[df_results['duration'] != -1]['duration']\n", + " stats = {\n", + " 'total_files': len(df_results),\n", + " 'failed_files': len(df_results[df_results['duration'] == -1]),\n", + " 'total_duration_hours': valid_durations.sum() / 3600,\n", + " 'mean_duration': valid_durations.mean(),\n", + " 'min_duration': valid_durations.min(),\n", + " 'max_duration': valid_durations.max()\n", + " }\n", + "\n", + " # Save results\n", + " print(\"\\nSaving results...\")\n", + " df_results.to_csv(output_file, index=False)\n", + "\n", + " # Print statistics\n", + " print(\"\\nProcessing Statistics:\")\n", + " print(f\"Total files processed: {stats['total_files']}\")\n", + " print(f\"Failed files: {stats['failed_files']}\")\n", + " print(f\"Total duration: {stats['total_duration_hours']:.2f} hours\")\n", + " print(f\"Mean duration: {stats['mean_duration']:.2f} seconds\")\n", + " print(f\"Min duration: {stats['min_duration']:.2f} seconds\")\n", + " print(f\"Max duration: {stats['max_duration']:.2f} seconds\")\n", + "\n", + " return df_results, stats\n", + "\n", + "\n", + "# Usage\n", + "BASE_DIR = \"/nvme2/hungdx/Lightning-hydra/data/0_large-corpus\"\n", + "protocol_file = \"/nvme2/hungdx/Lightning-hydra/notebooks/new_protocol_trim_vocoded.txt\"\n", + "output_file = \"audio_durations.csv\"\n", + "\n", + "# Run the processing\n", + "df_results, stats = calculate_durations(\n", + " protocol_file=protocol_file,\n", + " base_dir=BASE_DIR,\n", + " output_file=output_file,\n", + " n_workers=8 # Adjust based on your system\n", + ")\n", + "\n", + "# Display first few rows of results\n", + "print(\"\\nFirst few rows of results:\")\n", + "print(df_results.head())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/data/normal_multiview_datamodule.py b/src/data/normal_multiview_datamodule.py index 0639788..8a87a09 100644 --- a/src/data/normal_multiview_datamodule.py +++ b/src/data/normal_multiview_datamodule.py @@ -43,12 +43,13 @@ def __getitem__(self, idx): augmethod_index = random.choice(range(len(self.augmentation_methods))) if len( self.augmentation_methods) > 0 else -1 if augmethod_index >= 0: + # print("Augmenting with", self.augmentation_methods[augmethod_index]) X = globals()[self.augmentation_methods[augmethod_index]](X, self.args, self.sample_rate, audio_path=filepath) x_inp = Tensor(X) target = self.labels[utt_id] - return idx, x_inp, target + return x_inp, target class Dataset_for_dev(Dataset_base): @@ -66,7 +67,7 @@ def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[], def __getitem__(self, index): utt_id = self.list_IDs[index] filepath = os.path.join(self.base_dir, utt_id) - X, fs = librosa.load(self.base_dir + "/" + utt_id, sr=16000) + X, fs = librosa.load(filepath, sr=16000) augmethod_index = random.choice(range(len(self.augmentation_methods))) if len( self.augmentation_methods) > 0 else -1 if augmethod_index >= 0: @@ -74,7 +75,7 @@ def __getitem__(self, index): audio_path=filepath) x_inp = Tensor(X) target = self.labels[utt_id] - return index, x_inp, target + return x_inp, target class Dataset_for_eval(Dataset_base): @@ -160,10 +161,6 @@ def __init__( num_workers: int = 0, pin_memory: bool = False, args: Optional[Dict[str, Any]] = None, - # list_IDs: list = [], labels: list = [], base_dir: str = '', algo: int = 5, vocoders: list = [], - # augmentation_methods: list = [], eval_augment: Optional[str] = None, num_additional_real: int = 2, num_additional_spoof: int = 2, - # trim_length: int = 64000, wav_samp_rate: int = 16000, noise_path: Optional[str] = None, rir_path: Optional[str] = None, - # aug_dir: Optional[str] = None, online_aug: bool = False, repeat_pad: bool = True, is_train: bool = True, enable_chunking: bool = False ) -> None: """Initialize a `ASVSpoofDataModule`. @@ -199,7 +196,7 @@ def __init__( self.top_k, self.min_duration, self.max_duration, - self.args.sample_rate, + self.args.data.wav_samp_rate, self.args.padding_type, self.args.random_start ) @@ -207,7 +204,7 @@ def __init__( self.collate_fn = lambda x: multi_view_collate_fn( x, self.args.views, - self.args.sample_rate, + self.args.data.wav_samp_rate, self.args.padding_type, self.args.random_start ) @@ -333,7 +330,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """ pass - def genList(self, dir_meta, is_train=False, is_eval=False, is_dev=False): + def genList(self, is_train=False, is_eval=False, is_dev=False): """ This function generates the list of files and their corresponding labels Specifically for the standard CNSL dataset