generated from ashleve/lightning-hydra-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f45c6c6
commit e43863e
Showing
2 changed files
with
146 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import pandas as pd\n", | ||
"import numpy as np\n", | ||
"import os\n", | ||
"import librosa\n", | ||
"from tqdm import tqdm\n", | ||
"from multiprocessing import Pool, cpu_count\n", | ||
"from functools import partial\n", | ||
"\n", | ||
"\n", | ||
"def get_audio_duration(row, base_dir):\n", | ||
" \"\"\"Calculate duration for a single audio file\"\"\"\n", | ||
" try:\n", | ||
" file_path = os.path.join(base_dir, row['utt_id'])\n", | ||
" duration = librosa.get_duration(path=file_path)\n", | ||
" return {\n", | ||
" 'utt_id': row['utt_id'],\n", | ||
" 'subset': row['subset'],\n", | ||
" 'label': row['label'],\n", | ||
" 'duration': duration\n", | ||
" }\n", | ||
" except Exception as e:\n", | ||
" print(f\"Error processing {row['utt_id']}: {str(e)}\")\n", | ||
" return {\n", | ||
" 'utt_id': row['utt_id'],\n", | ||
" 'subset': row['subset'],\n", | ||
" 'label': row['label'],\n", | ||
" 'duration': -1 # Mark failed files with -1\n", | ||
" }\n", | ||
"\n", | ||
"\n", | ||
"def process_chunk(chunk, base_dir):\n", | ||
" \"\"\"Process a chunk of the dataframe\"\"\"\n", | ||
" return [get_audio_duration(row, base_dir) for row in chunk.to_dict('records')]\n", | ||
"\n", | ||
"\n", | ||
"def calculate_durations(protocol_file, base_dir, output_file, n_workers=None):\n", | ||
" \"\"\"\n", | ||
" Calculate durations for all audio files in parallel\n", | ||
" \n", | ||
" Args:\n", | ||
" protocol_file: Path to protocol file\n", | ||
" base_dir: Base directory containing audio files\n", | ||
" output_file: Path to output CSV file\n", | ||
" n_workers: Number of worker processes (default: CPU count - 1)\n", | ||
" \"\"\"\n", | ||
" if n_workers is None:\n", | ||
" n_workers = cpu_count() - 1\n", | ||
"\n", | ||
" print(\"Reading protocol file...\")\n", | ||
" protocol = pd.read_csv(protocol_file, sep=\" \", header=None)\n", | ||
" protocol.columns = [\"utt_id\", \"subset\", \"label\"]\n", | ||
"\n", | ||
" # Split dataframe into chunks for parallel processing\n", | ||
" chunk_size = len(protocol) // n_workers + 1\n", | ||
" chunks = np.array_split(protocol, n_workers)\n", | ||
"\n", | ||
" print(f\"Processing {len(protocol)} files using {n_workers} workers...\")\n", | ||
"\n", | ||
" # Process chunks in parallel\n", | ||
" with Pool(n_workers) as pool:\n", | ||
" partial_process = partial(process_chunk, base_dir=base_dir)\n", | ||
" results = list(tqdm(\n", | ||
" pool.imap(partial_process, chunks),\n", | ||
" total=len(chunks),\n", | ||
" desc=\"Calculating durations\"\n", | ||
" ))\n", | ||
"\n", | ||
" # Flatten results and convert to dataframe\n", | ||
" all_results = [item for sublist in results for item in sublist]\n", | ||
" df_results = pd.DataFrame(all_results)\n", | ||
"\n", | ||
" # Calculate statistics\n", | ||
" valid_durations = df_results[df_results['duration'] != -1]['duration']\n", | ||
" stats = {\n", | ||
" 'total_files': len(df_results),\n", | ||
" 'failed_files': len(df_results[df_results['duration'] == -1]),\n", | ||
" 'total_duration_hours': valid_durations.sum() / 3600,\n", | ||
" 'mean_duration': valid_durations.mean(),\n", | ||
" 'min_duration': valid_durations.min(),\n", | ||
" 'max_duration': valid_durations.max()\n", | ||
" }\n", | ||
"\n", | ||
" # Save results\n", | ||
" print(\"\\nSaving results...\")\n", | ||
" df_results.to_csv(output_file, index=False)\n", | ||
"\n", | ||
" # Print statistics\n", | ||
" print(\"\\nProcessing Statistics:\")\n", | ||
" print(f\"Total files processed: {stats['total_files']}\")\n", | ||
" print(f\"Failed files: {stats['failed_files']}\")\n", | ||
" print(f\"Total duration: {stats['total_duration_hours']:.2f} hours\")\n", | ||
" print(f\"Mean duration: {stats['mean_duration']:.2f} seconds\")\n", | ||
" print(f\"Min duration: {stats['min_duration']:.2f} seconds\")\n", | ||
" print(f\"Max duration: {stats['max_duration']:.2f} seconds\")\n", | ||
"\n", | ||
" return df_results, stats\n", | ||
"\n", | ||
"\n", | ||
"# Usage\n", | ||
"BASE_DIR = \"/nvme2/hungdx/Lightning-hydra/data/0_large-corpus\"\n", | ||
"protocol_file = \"/nvme2/hungdx/Lightning-hydra/notebooks/new_protocol_trim_vocoded.txt\"\n", | ||
"output_file = \"audio_durations.csv\"\n", | ||
"\n", | ||
"# Run the processing\n", | ||
"df_results, stats = calculate_durations(\n", | ||
" protocol_file=protocol_file,\n", | ||
" base_dir=BASE_DIR,\n", | ||
" output_file=output_file,\n", | ||
" n_workers=8 # Adjust based on your system\n", | ||
")\n", | ||
"\n", | ||
"# Display first few rows of results\n", | ||
"print(\"\\nFirst few rows of results:\")\n", | ||
"print(df_results.head())" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "base", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"name": "python", | ||
"version": "3.11.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters