Skip to content

Commit

Permalink
updare code for normal multiview da
Browse files Browse the repository at this point in the history
  • Loading branch information
hungdinhxuan committed Dec 11, 2024
1 parent f45c6c6 commit e43863e
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 10 deletions.
139 changes: 139 additions & 0 deletions notebooks/prepare_dataset.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import os\n",
"import librosa\n",
"from tqdm import tqdm\n",
"from multiprocessing import Pool, cpu_count\n",
"from functools import partial\n",
"\n",
"\n",
"def get_audio_duration(row, base_dir):\n",
" \"\"\"Calculate duration for a single audio file\"\"\"\n",
" try:\n",
" file_path = os.path.join(base_dir, row['utt_id'])\n",
" duration = librosa.get_duration(path=file_path)\n",
" return {\n",
" 'utt_id': row['utt_id'],\n",
" 'subset': row['subset'],\n",
" 'label': row['label'],\n",
" 'duration': duration\n",
" }\n",
" except Exception as e:\n",
" print(f\"Error processing {row['utt_id']}: {str(e)}\")\n",
" return {\n",
" 'utt_id': row['utt_id'],\n",
" 'subset': row['subset'],\n",
" 'label': row['label'],\n",
" 'duration': -1 # Mark failed files with -1\n",
" }\n",
"\n",
"\n",
"def process_chunk(chunk, base_dir):\n",
" \"\"\"Process a chunk of the dataframe\"\"\"\n",
" return [get_audio_duration(row, base_dir) for row in chunk.to_dict('records')]\n",
"\n",
"\n",
"def calculate_durations(protocol_file, base_dir, output_file, n_workers=None):\n",
" \"\"\"\n",
" Calculate durations for all audio files in parallel\n",
" \n",
" Args:\n",
" protocol_file: Path to protocol file\n",
" base_dir: Base directory containing audio files\n",
" output_file: Path to output CSV file\n",
" n_workers: Number of worker processes (default: CPU count - 1)\n",
" \"\"\"\n",
" if n_workers is None:\n",
" n_workers = cpu_count() - 1\n",
"\n",
" print(\"Reading protocol file...\")\n",
" protocol = pd.read_csv(protocol_file, sep=\" \", header=None)\n",
" protocol.columns = [\"utt_id\", \"subset\", \"label\"]\n",
"\n",
" # Split dataframe into chunks for parallel processing\n",
" chunk_size = len(protocol) // n_workers + 1\n",
" chunks = np.array_split(protocol, n_workers)\n",
"\n",
" print(f\"Processing {len(protocol)} files using {n_workers} workers...\")\n",
"\n",
" # Process chunks in parallel\n",
" with Pool(n_workers) as pool:\n",
" partial_process = partial(process_chunk, base_dir=base_dir)\n",
" results = list(tqdm(\n",
" pool.imap(partial_process, chunks),\n",
" total=len(chunks),\n",
" desc=\"Calculating durations\"\n",
" ))\n",
"\n",
" # Flatten results and convert to dataframe\n",
" all_results = [item for sublist in results for item in sublist]\n",
" df_results = pd.DataFrame(all_results)\n",
"\n",
" # Calculate statistics\n",
" valid_durations = df_results[df_results['duration'] != -1]['duration']\n",
" stats = {\n",
" 'total_files': len(df_results),\n",
" 'failed_files': len(df_results[df_results['duration'] == -1]),\n",
" 'total_duration_hours': valid_durations.sum() / 3600,\n",
" 'mean_duration': valid_durations.mean(),\n",
" 'min_duration': valid_durations.min(),\n",
" 'max_duration': valid_durations.max()\n",
" }\n",
"\n",
" # Save results\n",
" print(\"\\nSaving results...\")\n",
" df_results.to_csv(output_file, index=False)\n",
"\n",
" # Print statistics\n",
" print(\"\\nProcessing Statistics:\")\n",
" print(f\"Total files processed: {stats['total_files']}\")\n",
" print(f\"Failed files: {stats['failed_files']}\")\n",
" print(f\"Total duration: {stats['total_duration_hours']:.2f} hours\")\n",
" print(f\"Mean duration: {stats['mean_duration']:.2f} seconds\")\n",
" print(f\"Min duration: {stats['min_duration']:.2f} seconds\")\n",
" print(f\"Max duration: {stats['max_duration']:.2f} seconds\")\n",
"\n",
" return df_results, stats\n",
"\n",
"\n",
"# Usage\n",
"BASE_DIR = \"/nvme2/hungdx/Lightning-hydra/data/0_large-corpus\"\n",
"protocol_file = \"/nvme2/hungdx/Lightning-hydra/notebooks/new_protocol_trim_vocoded.txt\"\n",
"output_file = \"audio_durations.csv\"\n",
"\n",
"# Run the processing\n",
"df_results, stats = calculate_durations(\n",
" protocol_file=protocol_file,\n",
" base_dir=BASE_DIR,\n",
" output_file=output_file,\n",
" n_workers=8 # Adjust based on your system\n",
")\n",
"\n",
"# Display first few rows of results\n",
"print(\"\\nFirst few rows of results:\")\n",
"print(df_results.head())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
17 changes: 7 additions & 10 deletions src/data/normal_multiview_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,13 @@ def __getitem__(self, idx):
augmethod_index = random.choice(range(len(self.augmentation_methods))) if len(
self.augmentation_methods) > 0 else -1
if augmethod_index >= 0:
# print("Augmenting with", self.augmentation_methods[augmethod_index])
X = globals()[self.augmentation_methods[augmethod_index]](X, self.args, self.sample_rate,
audio_path=filepath)

x_inp = Tensor(X)
target = self.labels[utt_id]
return idx, x_inp, target
return x_inp, target


class Dataset_for_dev(Dataset_base):
Expand All @@ -66,15 +67,15 @@ def __init__(self, args, list_IDs, labels, base_dir, algo=5, vocoders=[],
def __getitem__(self, index):
utt_id = self.list_IDs[index]
filepath = os.path.join(self.base_dir, utt_id)
X, fs = librosa.load(self.base_dir + "/" + utt_id, sr=16000)
X, fs = librosa.load(filepath, sr=16000)
augmethod_index = random.choice(range(len(self.augmentation_methods))) if len(
self.augmentation_methods) > 0 else -1
if augmethod_index >= 0:
X = globals()[self.augmentation_methods[augmethod_index]](X, self.args, self.sample_rate,
audio_path=filepath)
x_inp = Tensor(X)
target = self.labels[utt_id]
return index, x_inp, target
return x_inp, target


class Dataset_for_eval(Dataset_base):
Expand Down Expand Up @@ -160,10 +161,6 @@ def __init__(
num_workers: int = 0,
pin_memory: bool = False,
args: Optional[Dict[str, Any]] = None,
# list_IDs: list = [], labels: list = [], base_dir: str = '', algo: int = 5, vocoders: list = [],
# augmentation_methods: list = [], eval_augment: Optional[str] = None, num_additional_real: int = 2, num_additional_spoof: int = 2,
# trim_length: int = 64000, wav_samp_rate: int = 16000, noise_path: Optional[str] = None, rir_path: Optional[str] = None,
# aug_dir: Optional[str] = None, online_aug: bool = False, repeat_pad: bool = True, is_train: bool = True, enable_chunking: bool = False
) -> None:
"""Initialize a `ASVSpoofDataModule`.
Expand Down Expand Up @@ -199,15 +196,15 @@ def __init__(
self.top_k,
self.min_duration,
self.max_duration,
self.args.sample_rate,
self.args.data.wav_samp_rate,
self.args.padding_type,
self.args.random_start
)
else:
self.collate_fn = lambda x: multi_view_collate_fn(
x,
self.args.views,
self.args.sample_rate,
self.args.data.wav_samp_rate,
self.args.padding_type,
self.args.random_start
)
Expand Down Expand Up @@ -333,7 +330,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
pass

def genList(self, dir_meta, is_train=False, is_eval=False, is_dev=False):
def genList(self, is_train=False, is_eval=False, is_dev=False):
"""
This function generates the list of files and their corresponding labels
Specifically for the standard CNSL dataset
Expand Down

0 comments on commit e43863e

Please sign in to comment.