Skip to content

Commit

Permalink
[FIX] better handling of selecting random files
Browse files Browse the repository at this point in the history
  • Loading branch information
BenCretois committed Jan 7, 2025
1 parent 1f8db62 commit 2739e12
Showing 1 changed file with 42 additions and 18 deletions.
60 changes: 42 additions & 18 deletions src/extract.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import logging
import os
import pandas as pd

import fs
import pyarrow.parquet as pq
Expand All @@ -9,7 +10,6 @@

from utils import openAudioFile, openCachedFile, saveSignal


def setup_logging():
logging.basicConfig(
filename="audio_processing.log",
Expand All @@ -18,7 +18,6 @@ def setup_logging():
datefmt="%Y-%m-%d %H:%M:%S",
)


@retry(wait=wait_exponential(multiplier=1, min=4, max=120))
def do_connection(connection_string):
"""Establish a connection to the filesystem with retries."""
Expand All @@ -31,7 +30,6 @@ def do_connection(connection_string):
# logging.info("Retrying connection...")
raise


# @retry(wait=wait_exponential(multiplier=5, min=60, max=600))
def extract_segments(
item, sample_rate, out_path, filesystem, connection_string, seg_length=3
Expand All @@ -49,7 +47,6 @@ def extract_segments(
save_extracted_segments(signal, rate, segments, out_path, seg_length)
# logging.info(f"Segments extracted from {audio_file}")


def save_extracted_segments(signal, rate, segment, out_path, seg_length):
"""Save the extracted segments to the output path."""
# for segment in segments:
Expand All @@ -62,7 +59,6 @@ def save_extracted_segments(signal, rate, segment, out_path, seg_length):
segment_signal = signal[start:end]
save_segment(segment_signal, segment, out_path)


def save_segment(segment_signal, segment, out_path):
"""Save an individual segment."""
species_path = os.path.join(out_path, segment["species"])
Expand All @@ -73,7 +69,6 @@ def save_segment(segment_signal, segment, out_path):
print(f"Segment {segment_path} saved")
saveSignal(segment_signal, segment_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
Expand All @@ -90,15 +85,44 @@ def save_segment(segment_signal, segment, out_path):

myfs = do_connection(config["CONNECTION_STRING"])

items = pq.read_table(args.parquet_file, filters=[["audio", "=", args.audio_file]])

for item in items.to_pylist():
print(f"Extracting segments from {item}")
extract_segments(
item,
config["SAMPLE_RATE"],
config["OUT_PATH_SEGMENTS"],
myfs,
config["CONNECTION_STRING"],
seg_length=3,
)
# Read the Parquet file into a DataFrame
parquet_df = pd.read_parquet(args.parquet_file)

# Log the total number of detections in the Parquet file
print(f"Total number of detections in the Parquet file: {len(parquet_df)}")

# Limit the total number of segments per species globally
num_segments_per_species = config["NUM_SEGMENTS"]
sampled_items = parquet_df.groupby("species").apply(
lambda x: x.sample(min(len(x), num_segments_per_species), random_state=None)
).reset_index(drop=True)

# Log the number of sampled detections
print(f"Number of sampled detections across all species: {len(sampled_items)}")

# Filter the sampled items for the specific audio file
audio_file_basename = os.path.basename(args.audio_file)
filtered_items = sampled_items[sampled_items['audio'].str.contains(audio_file_basename, na=False, case=False)]

# Skip processing if there are no relevant segments for the file
if filtered_items.empty:
print(f"No detections found for {args.audio_file}. Skipping...")
exit(0)

# Log the number of detections for the specific audio file
print(f"Number of detections for {args.audio_file}: {len(filtered_items)}")

# Process each detection
for _, item in filtered_items.iterrows():
try:
logging.info(f"Processing item: {item}")
extract_segments(
item,
config["SAMPLE_RATE"],
config["OUT_PATH_SEGMENTS"],
myfs,
config["CONNECTION_STRING"],
seg_length=3,
)
except Exception as e:
logging.error(f"Error processing segment {item}: {e}")

0 comments on commit 2739e12

Please sign in to comment.