Skip to content

Commit

Permalink
Adding job_kwargs to all benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Jan 24, 2025
1 parent 664840b commit 63a58c1
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
from pathlib import Path
import os
from spikeinterface.core.job_tools import fix_job_kwargs

import numpy as np

Expand All @@ -21,7 +22,9 @@
ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS"))


def make_dataset():
def make_dataset(job_kwargs={}):

job_kwargs = fix_job_kwargs(job_kwargs)
recording, gt_sorting = generate_ground_truth_recording(
durations=[60.0],
sampling_frequency=30000.0,
Expand All @@ -39,10 +42,10 @@ def make_dataset():
seed=2205,
)

gt_analyzer = create_sorting_analyzer(gt_sorting, recording, sparse=True, format="memory")
gt_analyzer = create_sorting_analyzer(gt_sorting, recording, sparse=True, format="memory", **job_kwargs)
gt_analyzer.compute("random_spikes", method="uniform", max_spikes_per_unit=500)
# analyzer.compute("waveforms")
gt_analyzer.compute("templates")
gt_analyzer.compute("templates", **job_kwargs)
gt_analyzer.compute("noise_levels")

return recording, gt_sorting, gt_analyzer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_benchmark_clustering(create_cache_folder):
cache_folder = create_cache_folder
job_kwargs = dict(n_jobs=0.8, chunk_duration="1s")

recording, gt_sorting, gt_analyzer = make_dataset()
recording, gt_sorting, gt_analyzer = make_dataset(job_kwargs)

num_spikes = gt_sorting.to_spike_vector().size
spike_indices = np.arange(0, num_spikes, 5)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/benchmark/tests/test_benchmark_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ def test_benchmark_matching(create_cache_folder):
cache_folder = create_cache_folder
job_kwargs = dict(n_jobs=0.8, chunk_duration="100ms")

recording, gt_sorting, gt_analyzer = make_dataset()
recording, gt_sorting, gt_analyzer = make_dataset(job_kwargs)

# templates sparse
gt_templates = compute_gt_templates(
recording, gt_sorting, ms_before=2.0, ms_after=3.0, return_scaled=False, **job_kwargs
)
noise_levels = get_noise_levels(recording)
noise_levels = get_noise_levels(recording, **job_kwargs)
sparsity = compute_sparsity(gt_templates, noise_levels, method="snr", amplitude_mode="peak_to_peak", threshold=0.25)
gt_templates = gt_templates.to_sparse(sparsity)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_benchmark_merging(create_cache_folder):
cache_folder = create_cache_folder
job_kwargs = dict(n_jobs=0.8, chunk_duration="1s")

recording, gt_sorting, gt_analyzer = make_dataset()
recording, gt_sorting, gt_analyzer = make_dataset(job_kwargs)

# create study
study_folder = cache_folder / "study_clustering"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_benchmark_motion_interpolation(create_cache_folder):
cache_folder = create_cache_folder
job_kwargs = dict(n_jobs=0.8, chunk_duration="1s")

data = make_drifting_dataset()
data = make_drifting_dataset(job_kwargs)

datasets = {
"data_static": (data["static_rec"], data["sorting"]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_benchmark_peak_detection(create_cache_folder):
job_kwargs = dict(n_jobs=0.8, chunk_duration="100ms")

# recording, gt_sorting = make_dataset()
recording, gt_sorting, gt_analyzer = make_dataset()
recording, gt_sorting, gt_analyzer = make_dataset(job_kwargs)

# create study
study_folder = cache_folder / "study_peak_detection"
Expand All @@ -27,8 +27,9 @@ def test_benchmark_peak_detection(create_cache_folder):

recording, gt_sorting = datasets[dataset]

sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False)
sorting_analyzer.compute(["random_spikes", "templates"])
sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False, **job_kwargs)
sorting_analyzer.compute("random_spikes")
sorting_analyzer.compute("templates", **job_kwargs)
extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index")
spikes = gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds)
peaks[dataset] = spikes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_benchmark_peak_localization(create_cache_folder):
job_kwargs = dict(n_jobs=0.8, chunk_duration="100ms")

# recording, gt_sorting = make_dataset()
recording, gt_sorting, gt_analyzer = make_dataset()
recording, gt_sorting, gt_analyzer = make_dataset(job_kwargs)

# create study
study_folder = cache_folder / "study_peak_localization"
Expand Down

0 comments on commit 63a58c1

Please sign in to comment.