From 3dc57290dbde0aeaa5048f2301ee75015a93fe26 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 29 Dec 2025 15:43:44 +0100 Subject: [PATCH 01/70] Test IBL extractors tests failing for PI update --- src/spikeinterface/extractors/tests/test_iblextractors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index 972a8e7bb0..56d01e38cf 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -76,8 +76,8 @@ def test_offsets(self): def test_probe_representation(self): probe = self.recording.get_probe() - expected_probe_representation = "Probe - 384ch - 1shanks" - assert repr(probe) == expected_probe_representation + expected_probe_representation = "Probe - 384ch" + assert expected_probe_representation in repr(probe) def test_property_keys(self): expected_property_keys = [ From 79ca022883baeb27b98fc18c1a59543d08523a27 Mon Sep 17 00:00:00 2001 From: m-beau Date: Tue, 6 Jan 2026 17:54:42 +0100 Subject: [PATCH 02/70] original commit - good times --- .../postprocessing/good_periods_per_unit.py | 82 +++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 src/spikeinterface/postprocessing/good_periods_per_unit.py diff --git a/src/spikeinterface/postprocessing/good_periods_per_unit.py b/src/spikeinterface/postprocessing/good_periods_per_unit.py new file mode 100644 index 0000000000..b30c24e4f5 --- /dev/null +++ b/src/spikeinterface/postprocessing/good_periods_per_unit.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import importlib.util +import warnings + +import numpy as np +from itertools import chain + +from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension + +numba_spec = importlib.util.find_spec("numba") +if numba_spec is not None: + HAVE_NUMBA = True +else: + HAVE_NUMBA = False + + +class ComputeGoodTimeChunks(AnalyzerExtension): + """Compute good time chunks. + + Parameters + ---------- + method : "false_positives_and_negatives" | "user_defined" | "combined" + + + Returns + ------- + # dict or array depending on output mode + good_periods_per_unit : numpy.ndarray + (n_periods, 4) array with columns: segment_id, unit_id, start_time, end_time + """ + + extension_name = "good_periods_per_unit" + depend_on = [] + need_recording = False + use_nodepipeline = False + need_job_kwargs = False + + ## todo: add fp fn parameters (flat kwargs) + def _set_params(self, method: str = "false_positives_and_negatives", user_defined_periods=None): + if method in ["false_positives_and_negatives", "combined"]: + if not self.sorting_analyzer.has_extension("amplitude_scalings"): + raise ValueError( + "ComputeGoodTimeChunks with method 'false_positives_and_negatives' requires 'amplitude_scalings' extension." + ) + elif method == "user_defined": + assert user_defined_periods is not None, "user_defined_periods must be provided for method 'user_defined'" + if method == "combined": + warnings.warn("ComputeGoodTimeChunks was called with method 'combined', yet user_defined_periods are not passed. Falling back to using false positives and negatives only.") + method = "false_positives_and_negatives" + + params = dict(method=method, user_defined_periods=user_defined_periods) + + return params + + def _select_extension_data(self, unit_ids): + new_extension_data = self.data + return new_extension_data + + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, censor_ms=None, verbose=False, **job_kwargs + ): + new_extension_data = self.data + return new_extension_data + + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + new_extension_data = self.data + return new_extension_data + + def _run(self, verbose=False): + method = self.params["method"] + flat_args = 0 #TODO extract from method_kwargs + + self.data["isi_histograms"] = isi_histograms + self.data["bins"] = bins + + def _get_data(self): + return self.data["isi_histograms"], self.data["bins"] + + +register_result_extension(ComputeISIHistograms) +compute_isi_histograms = ComputeISIHistograms.function_factory() \ No newline at end of file From 22501da4764e10273586ce0f2abfdfaab3f27254 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Jan 2026 16:56:01 +0000 Subject: [PATCH 03/70] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../postprocessing/good_periods_per_unit.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/postprocessing/good_periods_per_unit.py b/src/spikeinterface/postprocessing/good_periods_per_unit.py index b30c24e4f5..a820db8a43 100644 --- a/src/spikeinterface/postprocessing/good_periods_per_unit.py +++ b/src/spikeinterface/postprocessing/good_periods_per_unit.py @@ -21,7 +21,7 @@ class ComputeGoodTimeChunks(AnalyzerExtension): Parameters ---------- method : "false_positives_and_negatives" | "user_defined" | "combined" - + Returns ------- @@ -46,7 +46,9 @@ def _set_params(self, method: str = "false_positives_and_negatives", user_define elif method == "user_defined": assert user_defined_periods is not None, "user_defined_periods must be provided for method 'user_defined'" if method == "combined": - warnings.warn("ComputeGoodTimeChunks was called with method 'combined', yet user_defined_periods are not passed. Falling back to using false positives and negatives only.") + warnings.warn( + "ComputeGoodTimeChunks was called with method 'combined', yet user_defined_periods are not passed. Falling back to using false positives and negatives only." + ) method = "false_positives_and_negatives" params = dict(method=method, user_defined_periods=user_defined_periods) @@ -69,7 +71,7 @@ def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, def _run(self, verbose=False): method = self.params["method"] - flat_args = 0 #TODO extract from method_kwargs + flat_args = 0 # TODO extract from method_kwargs self.data["isi_histograms"] = isi_histograms self.data["bins"] = bins @@ -79,4 +81,4 @@ def _get_data(self): register_result_extension(ComputeISIHistograms) -compute_isi_histograms = ComputeISIHistograms.function_factory() \ No newline at end of file +compute_isi_histograms = ComputeISIHistograms.function_factory() From 7ca3d35302929b8584e4827b6fc8b9dfc6454dae Mon Sep 17 00:00:00 2001 From: m-beau Date: Wed, 7 Jan 2026 16:00:32 +0100 Subject: [PATCH 04/70] good times - progress --- .../postprocessing/good_periods_per_unit.py | 221 +++++++++++++++--- 1 file changed, 192 insertions(+), 29 deletions(-) diff --git a/src/spikeinterface/postprocessing/good_periods_per_unit.py b/src/spikeinterface/postprocessing/good_periods_per_unit.py index b30c24e4f5..38cbc7a7c2 100644 --- a/src/spikeinterface/postprocessing/good_periods_per_unit.py +++ b/src/spikeinterface/postprocessing/good_periods_per_unit.py @@ -4,7 +4,7 @@ import warnings import numpy as np -from itertools import chain +from typing import Optional, Literal from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension @@ -14,20 +14,45 @@ else: HAVE_NUMBA = False +class ComputeGoodPeriodsPerUnit(AnalyzerExtension): + """Compute good time periods per unit based on quality metrics. -class ComputeGoodTimeChunks(AnalyzerExtension): - """Compute good time chunks. - - Parameters + Paraneters ---------- - method : "false_positives_and_negatives" | "user_defined" | "combined" - + method : {"false_positives_and_negatives", "user_defined", "combined"} + Strategy for identifying good periods for each unit. If "false_positives_and_negatives", uses + amplitude cutoff (false negative spike rate) and refractory period violations (false positive spike rate) + to estimate good periods (as periods with fn_rate 0 or subperiod_size_relative > 0,\ + "Either subperiod_size_absolute or subperiod_size_relative must be positive." + assert isinstance(subperiod_size_relative, (int)),\ + "subperiod_size_relative must be an integer." + + # user_defined_periods format + if user_defined_periods is not None: + try: + user_defined_periods = np.asarray(user_defined_periods) + except Exception as e: + raise ValueError(("user_defined_periods must be some (n_periods, 3) [unit, good_period_start, good_period_end] " + "or (n_periods, 4) [unit, segment_index, good_period_start, good_period_end] structure convertible to a numpy array")) from e + + if user_defined_periods.ndim != 2 or user_defined_periods.shape[1] not in (3, 4): + raise ValueError("user_defined_periods must be of shape (n_periods, 3) [unit, good_period_start, good_period_end] or (n_periods, 4) [unit, segment_index, good_period_start, good_period_end]") + + if not np.issubdtype(user_defined_periods.dtype, np.integer): + # Try converting to check if they're integer-valued floats + if not np.allclose(user_defined_periods, user_defined_periods.astype(int)): + raise ValueError("All values in user_defined_periods must be integers, in samples.") + user_defined_periods = user_defined_periods.astype(int) + + if user_defined_periods.shape[1] == 3: + # add segment index 0 as column 1 if missing + user_defined_periods = np.hstack((user_defined_periods[:, 0:1], + np.zeros((user_defined_periods.shape[0], 1), dtype=int), + user_defined_periods[:, 1:3])) + + params = dict(method=method, + subperiod_size_absolute=subperiod_size_absolute, + subperiod_size_relative=subperiod_size_relative, + subperiod_size_mode=subperiod_size_mode, + violations_ms=violations_ms, + fp_threshold=fp_threshold, + fn_threshold=fn_threshold, + minimum_n_spikes=minimum_n_spikes, + user_defined_periods=user_defined_periods) + return params def _select_extension_data(self, unit_ids): @@ -68,15 +145,101 @@ def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, return new_extension_data def _run(self, verbose=False): - method = self.params["method"] - flat_args = 0 #TODO extract from method_kwargs - self.data["isi_histograms"] = isi_histograms - self.data["bins"] = bins + if self.params["method"] == "user_defined": + # directly use user defined periods + self.data["good_periods_per_unit"] = self.params["user_defined_periods"] + + if self.params["method"] in ["false_positives_and_negatives", "combined"]: + # ndarray: (n_periods, 3) with columns: segment_id, start_sample, end_sample + period_bounds = compute_period_bounds(self, + self.params["subperiod_size_absolute"], + self.params["subperiod_size_relative"], + self.params["subperiod_size_mode"]) + + ## Compute fp and fn for all periods + + # fp computed from refractory period violations + # dict: unit_id -> array of shape (n_periods) + periods_fp_per_unit = compute_fp_rates(self, period_bounds, + self.params["violations_ms"]) + + # fn computed from amplitude clippings + # dict: unit_id -> array of shape (n_periods) + periods_fn_per_unit = compute_fn_rates(self, period_bounds) + + ## Combine fp and fn results with thresholds to define good periods + + + ## Eventually combine with user defined periods if provided + + self.data["period_bounds"] = period_bounds + self.data["periods_fp_per_unit"] = periods_fp_per_unit + self.data["periods_fn_per_unit"] = periods_fn_per_unit + self.data["good_periods_per_unit"] = None # (n_good_periods, 4) with (unit, segment, start, end) to be implemented def _get_data(self): return self.data["isi_histograms"], self.data["bins"] -register_result_extension(ComputeISIHistograms) -compute_isi_histograms = ComputeISIHistograms.function_factory() \ No newline at end of file +# register_result_extension(ComputeISIHistograms) +# compute_isi_histograms = ComputeISIHistograms.function_factory() + +def compute_period_bounds(self, + subperiod_size_absolute: float = 10, + subperiod_size_relative: int = 1000, + subperiod_size_mode: str = "absolute") -> np.ndarray: + + sorting = self.sorting_analyzer.sorting + fs = sorting.get_sampling_frequency() + + if subperiod_size_mode == "absolute": + period_size_samples = margin_size_samples = np.round(subperiod_size_absolute * fs).astype(int) + else: # relative + period_size_samples = margin_size_samples = 0 # to be implemented based on firing rates + + all_period_bounds = np.empty((0, 3)) + for segment_i in range(sorting.get_num_segments()): + n_samples = sorting.get_num_samples(segment_i) # int: samples + n_periods = n_samples // period_size_samples + 1 + + # list of sliding [start, end] in samples + # for period size of 10s and margin size of 10s: [0, 30], [10, 40], [20, 50], ... + period_bounds = [(segment_i, + i * period_size_samples, + i * period_size_samples + 2 * margin_size_samples, + ) + for i in range(n_periods)] + all_period_bounds = np.vstack(all_period_bounds, period_bounds) if len(all_period_bounds) > 0 else np.array(period_bounds) + + return all_period_bounds + +def compute_fp_rates(self, + period_bounds: list, + violations_ms: float = 0.8) -> dict: + units = self.sorting_analyzer.sorting.unit_ids + n_periods = period_bounds.shape[0] + + fp_violations = {} + for unit in units: + fp_violations[unit] = np.zeros((n_periods,), dtype=float) + for i, (segment_i, start, end) in enumerate(period_bounds): + fp_rate = 0 # refractory period violations for this period + fp_violations[unit][i] = fp_rate + pass + + return fp_violations + +def compute_fn_rates(self, period_bounds: list) -> dict: + units = self.sorting_analyzer.sorting.unit_ids + n_periods = period_bounds.shape[0] + + fn_violations = {} + for unit in units: + fn_violations[unit] = np.zeros((n_periods,), dtype=float) + for i, (segment_i, start, end) in enumerate(period_bounds): + fn_rate = 0 # clipped amplitude AUC ratio for this period + fn_violations[unit][i] = fn_rate + pass + + return fn_violations \ No newline at end of file From ab0e8dc2b14975b788fd02d1def133fca35009a8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 7 Jan 2026 15:07:58 +0000 Subject: [PATCH 05/70] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../postprocessing/good_periods_per_unit.py | 171 ++++++++++-------- 1 file changed, 97 insertions(+), 74 deletions(-) diff --git a/src/spikeinterface/postprocessing/good_periods_per_unit.py b/src/spikeinterface/postprocessing/good_periods_per_unit.py index 38cbc7a7c2..ddde15d010 100644 --- a/src/spikeinterface/postprocessing/good_periods_per_unit.py +++ b/src/spikeinterface/postprocessing/good_periods_per_unit.py @@ -14,6 +14,7 @@ else: HAVE_NUMBA = False + class ComputeGoodPeriodsPerUnit(AnalyzerExtension): """Compute good time periods per unit based on quality metrics. @@ -61,24 +62,25 @@ class ComputeGoodPeriodsPerUnit(AnalyzerExtension): use_nodepipeline = False need_job_kwargs = False - def _set_params(self, - method: str = "false_positives_and_negatives", - subperiod_size_absolute: float = 10.0, - subperiod_size_relative: int = 1000, - subperiod_size_mode: str = "absolute", - violations_ms: float = 0.8, - fp_threshold: float = 0.05, - fn_threshold: float = 0.05, - minimum_n_spikes: int = 100, - user_defined_periods: Optional[object] = None): - + def _set_params( + self, + method: str = "false_positives_and_negatives", + subperiod_size_absolute: float = 10.0, + subperiod_size_relative: int = 1000, + subperiod_size_mode: str = "absolute", + violations_ms: float = 0.8, + fp_threshold: float = 0.05, + fn_threshold: float = 0.05, + minimum_n_spikes: int = 100, + user_defined_periods: Optional[object] = None, + ): + # method - assert method in ("false_positives_and_negatives", "user_defined", "combined"),\ - f"Invalid method: {method}" - + assert method in ("false_positives_and_negatives", "user_defined", "combined"), f"Invalid method: {method}" + if method == "user_defined" and user_defined_periods is None: raise ValueError("user_defined_periods required for 'user_defined' method") - + if method == "combined" and user_defined_periods is None: warnings.warn("Combined method without user_defined_periods, falling back") method = "false_positives_and_negatives" @@ -86,48 +88,59 @@ def _set_params(self, if params.method in ["false_positives_and_negatives", "combined"]: if not self.sorting_analyzer.has_extension("amplitude_scalings"): raise ValueError("Requires 'amplitude_scalings' extension; please compute it first.") - + # subperiods - assert subperiod_size_mode in ("absolute", "relative"),\ - f"Invalid subperiod_size_mode: {subperiod_size_mode}" - assert subperiod_size_absolute > 0 or subperiod_size_relative > 0,\ - "Either subperiod_size_absolute or subperiod_size_relative must be positive." - assert isinstance(subperiod_size_relative, (int)),\ - "subperiod_size_relative must be an integer." + assert subperiod_size_mode in ("absolute", "relative"), f"Invalid subperiod_size_mode: {subperiod_size_mode}" + assert ( + subperiod_size_absolute > 0 or subperiod_size_relative > 0 + ), "Either subperiod_size_absolute or subperiod_size_relative must be positive." + assert isinstance(subperiod_size_relative, (int)), "subperiod_size_relative must be an integer." # user_defined_periods format if user_defined_periods is not None: try: user_defined_periods = np.asarray(user_defined_periods) except Exception as e: - raise ValueError(("user_defined_periods must be some (n_periods, 3) [unit, good_period_start, good_period_end] " - "or (n_periods, 4) [unit, segment_index, good_period_start, good_period_end] structure convertible to a numpy array")) from e - + raise ValueError( + ( + "user_defined_periods must be some (n_periods, 3) [unit, good_period_start, good_period_end] " + "or (n_periods, 4) [unit, segment_index, good_period_start, good_period_end] structure convertible to a numpy array" + ) + ) from e + if user_defined_periods.ndim != 2 or user_defined_periods.shape[1] not in (3, 4): - raise ValueError("user_defined_periods must be of shape (n_periods, 3) [unit, good_period_start, good_period_end] or (n_periods, 4) [unit, segment_index, good_period_start, good_period_end]") - + raise ValueError( + "user_defined_periods must be of shape (n_periods, 3) [unit, good_period_start, good_period_end] or (n_periods, 4) [unit, segment_index, good_period_start, good_period_end]" + ) + if not np.issubdtype(user_defined_periods.dtype, np.integer): # Try converting to check if they're integer-valued floats if not np.allclose(user_defined_periods, user_defined_periods.astype(int)): raise ValueError("All values in user_defined_periods must be integers, in samples.") user_defined_periods = user_defined_periods.astype(int) - + if user_defined_periods.shape[1] == 3: # add segment index 0 as column 1 if missing - user_defined_periods = np.hstack((user_defined_periods[:, 0:1], - np.zeros((user_defined_periods.shape[0], 1), dtype=int), - user_defined_periods[:, 1:3])) - - params = dict(method=method, - subperiod_size_absolute=subperiod_size_absolute, - subperiod_size_relative=subperiod_size_relative, - subperiod_size_mode=subperiod_size_mode, - violations_ms=violations_ms, - fp_threshold=fp_threshold, - fn_threshold=fn_threshold, - minimum_n_spikes=minimum_n_spikes, - user_defined_periods=user_defined_periods) - + user_defined_periods = np.hstack( + ( + user_defined_periods[:, 0:1], + np.zeros((user_defined_periods.shape[0], 1), dtype=int), + user_defined_periods[:, 1:3], + ) + ) + + params = dict( + method=method, + subperiod_size_absolute=subperiod_size_absolute, + subperiod_size_relative=subperiod_size_relative, + subperiod_size_mode=subperiod_size_mode, + violations_ms=violations_ms, + fp_threshold=fp_threshold, + fn_threshold=fn_threshold, + minimum_n_spikes=minimum_n_spikes, + user_defined_periods=user_defined_periods, + ) + return params def _select_extension_data(self, unit_ids): @@ -152,17 +165,18 @@ def _run(self, verbose=False): if self.params["method"] in ["false_positives_and_negatives", "combined"]: # ndarray: (n_periods, 3) with columns: segment_id, start_sample, end_sample - period_bounds = compute_period_bounds(self, - self.params["subperiod_size_absolute"], - self.params["subperiod_size_relative"], - self.params["subperiod_size_mode"]) - + period_bounds = compute_period_bounds( + self, + self.params["subperiod_size_absolute"], + self.params["subperiod_size_relative"], + self.params["subperiod_size_mode"], + ) + ## Compute fp and fn for all periods # fp computed from refractory period violations # dict: unit_id -> array of shape (n_periods) - periods_fp_per_unit = compute_fp_rates(self, period_bounds, - self.params["violations_ms"]) + periods_fp_per_unit = compute_fp_rates(self, period_bounds, self.params["violations_ms"]) # fn computed from amplitude clippings # dict: unit_id -> array of shape (n_periods) @@ -170,13 +184,14 @@ def _run(self, verbose=False): ## Combine fp and fn results with thresholds to define good periods - ## Eventually combine with user defined periods if provided self.data["period_bounds"] = period_bounds self.data["periods_fp_per_unit"] = periods_fp_per_unit self.data["periods_fn_per_unit"] = periods_fn_per_unit - self.data["good_periods_per_unit"] = None # (n_good_periods, 4) with (unit, segment, start, end) to be implemented + self.data["good_periods_per_unit"] = ( + None # (n_good_periods, 4) with (unit, segment, start, end) to be implemented + ) def _get_data(self): return self.data["isi_histograms"], self.data["bins"] @@ -185,11 +200,14 @@ def _get_data(self): # register_result_extension(ComputeISIHistograms) # compute_isi_histograms = ComputeISIHistograms.function_factory() -def compute_period_bounds(self, - subperiod_size_absolute: float = 10, - subperiod_size_relative: int = 1000, - subperiod_size_mode: str = "absolute") -> np.ndarray: - + +def compute_period_bounds( + self, + subperiod_size_absolute: float = 10, + subperiod_size_relative: int = 1000, + subperiod_size_mode: str = "absolute", +) -> np.ndarray: + sorting = self.sorting_analyzer.sorting fs = sorting.get_sampling_frequency() @@ -200,23 +218,27 @@ def compute_period_bounds(self, all_period_bounds = np.empty((0, 3)) for segment_i in range(sorting.get_num_segments()): - n_samples = sorting.get_num_samples(segment_i) # int: samples + n_samples = sorting.get_num_samples(segment_i) # int: samples n_periods = n_samples // period_size_samples + 1 # list of sliding [start, end] in samples # for period size of 10s and margin size of 10s: [0, 30], [10, 40], [20, 50], ... - period_bounds = [(segment_i, - i * period_size_samples, - i * period_size_samples + 2 * margin_size_samples, - ) - for i in range(n_periods)] - all_period_bounds = np.vstack(all_period_bounds, period_bounds) if len(all_period_bounds) > 0 else np.array(period_bounds) - + period_bounds = [ + ( + segment_i, + i * period_size_samples, + i * period_size_samples + 2 * margin_size_samples, + ) + for i in range(n_periods) + ] + all_period_bounds = ( + np.vstack(all_period_bounds, period_bounds) if len(all_period_bounds) > 0 else np.array(period_bounds) + ) + return all_period_bounds -def compute_fp_rates(self, - period_bounds: list, - violations_ms: float = 0.8) -> dict: + +def compute_fp_rates(self, period_bounds: list, violations_ms: float = 0.8) -> dict: units = self.sorting_analyzer.sorting.unit_ids n_periods = period_bounds.shape[0] @@ -224,12 +246,13 @@ def compute_fp_rates(self, for unit in units: fp_violations[unit] = np.zeros((n_periods,), dtype=float) for i, (segment_i, start, end) in enumerate(period_bounds): - fp_rate = 0 # refractory period violations for this period - fp_violations[unit][i] = fp_rate - pass + fp_rate = 0 # refractory period violations for this period + fp_violations[unit][i] = fp_rate + pass return fp_violations + def compute_fn_rates(self, period_bounds: list) -> dict: units = self.sorting_analyzer.sorting.unit_ids n_periods = period_bounds.shape[0] @@ -238,8 +261,8 @@ def compute_fn_rates(self, period_bounds: list) -> dict: for unit in units: fn_violations[unit] = np.zeros((n_periods,), dtype=float) for i, (segment_i, start, end) in enumerate(period_bounds): - fn_rate = 0 # clipped amplitude AUC ratio for this period - fn_violations[unit][i] = fn_rate - pass + fn_rate = 0 # clipped amplitude AUC ratio for this period + fn_violations[unit][i] = fn_rate + pass - return fn_violations \ No newline at end of file + return fn_violations From 7279b6753f30aff0bfe485b8ee884e56b3068822 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 7 Jan 2026 16:26:47 +0100 Subject: [PATCH 06/70] wip --- .../core/analyzer_extension_core.py | 34 +++++++- src/spikeinterface/core/basesorting.py | 20 +++++ src/spikeinterface/core/node_pipeline.py | 11 ++- src/spikeinterface/core/sorting_tools.py | 77 +++++++++++++++++ .../core/tests/test_basesorting.py | 64 ++++++++++++-- .../metrics/quality/misc_metrics.py | 85 +++++++++++++++---- .../tests/test_interpolate_bad_channels.py | 2 +- 7 files changed, 268 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 74ef52e258..5e46f20d22 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -19,6 +19,7 @@ from .template import Templates from .sorting_tools import random_spikes_selection from .job_tools import fix_job_kwargs, split_job_kwargs +from .node_pipeline import base_period_dtype class ComputeRandomSpikes(AnalyzerExtension): @@ -1331,6 +1332,21 @@ class BaseSpikeVectorExtension(AnalyzerExtension): need_backward_compatibility_on_load = False nodepipeline_variables = [] # to be defined in subclass + def __init__(self, sorting_analyzer): + super().__init__(sorting_analyzer) + self._segment_slices = None + + @property + def segment_slices(self): + if self._segment_slices is None: + segment_slices = [] + spikes = self.sorting_analyzer.sorting.to_spike_vector() + for segment_index in range(self.sorting_analyzer.get_num_segments()): + i0, i1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) + segment_slices.append(slice(i0, i1)) + self._segment_slices = segment_slices + return self._segment_slices + def _set_params(self, **kwargs): params = kwargs.copy() return params @@ -1369,7 +1385,7 @@ def _run(self, verbose=False, **job_kwargs): for d, name in zip(data, data_names): self.data[name] = d - def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, copy=True): + def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, periods=None, copy=True): """ Return extension data. If the extension computes more than one `nodepipeline_variables`, the `return_data_name` is used to specify which one to return. @@ -1383,13 +1399,15 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, return_data_name : str | None, default: None The name of the data to return. If None and multiple `nodepipeline_variables` are computed, the first one is returned. + periods : array of unit_period dtype, default: None + Optional periods (segment_index, start_sample_index, end_sample_index, unit_index) to slice output data copy : bool, default: True Whether to return a copy of the data (only for outputs="numpy") Returns ------- numpy.ndarray | dict - The + The requested data in numpy or by unit format. """ from spikeinterface.core.sorting_tools import spike_vector_to_indices @@ -1404,6 +1422,18 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, ), f"return_data_name {return_data_name} not in nodepipeline_variables {self.nodepipeline_variables}" all_data = self.data[return_data_name] + if periods is not None: + # TODO: slice this properly with unit_indices + required = np.dtype(base_period_dtype).names + if not required.issubset(periods.dtype.names): + raise ValueError(f"Period must have the following fields: {required}") + # slice data according to period + segment_slices = self.segment_slices + all_data_segment = all_data[segment_slices[periods["segment_index"]]] + start = periods["start_sample_index"] + end = periods["end_sample_index"] + all_data = all_data_segment[start:end] + if outputs == "numpy": if copy: return all_data.copy() # return a copy to avoid modification diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 98159fb646..b6440f8e2b 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -626,6 +626,26 @@ def time_slice(self, start_time: float | None, end_time: float | None) -> BaseSo return self.frame_slice(start_frame=start_frame, end_frame=end_frame) + def select_periods(self, periods): + """ + Returns a new sorting object, restricted to the given periods of dtype unit_period_dtype. + + Parameters + ---------- + periods : numpy.array of unit_period_dtype + Period (segment_index, start_sample_index, end_sample_index, unit_index) + on which to restrict the sorting. + + Returns + ------- + BaseSorting + A new sorting object with only samples between start_sample_index and end_sample_index + for the given segment_index. + """ + from spikeinterface.core.sorting_tools import select_sorting_periods + + return select_sorting_periods(self, periods) + def split_by(self, property="group", outputs="dict"): """ Splits object based on a certain property (e.g. "group") diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 71654a67b4..f6bf3cb31f 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -22,11 +22,20 @@ ("segment_index", "int64"), ] - spike_peak_dtype = base_peak_dtype + [ ("unit_index", "int64"), ] +base_period_dtype = [ + ("start_sample_index", "int64"), + ("end_sample_index", "int64"), + ("segment_index", "int64"), +] + +unit_period_dtype = base_period_dtype + [ + ("unit_index", "int64"), +] + class PipelineNode: diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 90c7e18a99..9a9a3670ef 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -228,6 +228,83 @@ def random_spikes_selection( return random_spikes_indices +def select_sorting_periods_mask(sorting: BaseSorting, periods): + """ + Returns a boolean mask for the spikes in the sorting object, restricted to the given periods of dtype unit_period_dtype. + + Parameters + ---------- + sorting : BaseSorting + The sorting object. + periods : numpy.array of unit_period_dtype + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to restrict the sorting. + + Returns + ------- + numpy.array + A boolean mask of the spikes in the sorting object, with True for spikes within the specified periods. + """ + spike_vector = sorting.to_spike_vector() + spike_vector_list = sorting.to_spike_vector(concatenated=False) + keep_mask = np.zeros(len(spike_vector), dtype=bool) + all_global_indices = spike_vector_to_indices(spike_vector_list, unit_ids=sorting.unit_ids, absolute_index=True) + for segment_index in range(sorting.get_num_segments()): + global_indices_segment = all_global_indices[segment_index] + # filter periods by segment + periods_in_segment = periods[periods["segment_index"] == segment_index] + for unit_index, unit_id in enumerate(sorting.unit_ids): + # filter by unit index + periods_for_unit = periods_in_segment[periods_in_segment["unit_index"] == unit_index] + global_indices = global_indices_segment[unit_id] + spiketrains = spike_vector[global_indices]["sample_index"] + if len(periods_for_unit) > 0: + for period in periods_for_unit: + mask = (spiketrains >= period["start_sample_index"]) & (spiketrains < period["end_sample_index"]) + keep_mask[global_indices[mask]] = True + return keep_mask + + +def select_sorting_periods(sorting: BaseSorting, periods): + """ + Returns a new sorting object, restricted to the given periods of dtype unit_period_dtype. + + Parameters + ---------- + S + periods : numpy.array of unit_period_dtype + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to restrict the sorting. + + Returns + ------- + BaseSorting + A new sorting object with only samples between start_sample_index and end_sample_index + for the given segment_index. + """ + from spikeinterface.core.numpyextractors import NumpySorting + from spikeinterface.core.node_pipeline import unit_period_dtype + + if periods is not None: + if not isinstance(periods, np.ndarray): + periods = np.array([periods], dtype=unit_period_dtype) + required = set(np.dtype(unit_period_dtype).names) + if not required.issubset(periods.dtype.names): + raise ValueError(f"Period must have the following fields: {required}") + + spike_vector = sorting.to_spike_vector() + keep_mask = select_sorting_periods_mask(sorting, periods) + sliced_spike_vector = spike_vector[keep_mask] + + sorting = NumpySorting( + sliced_spike_vector, sampling_frequency=sorting.sampling_frequency, unit_ids=sorting.unit_ids + ) + sorting.copy_metadata(sorting) + return sorting + else: + return sorting + + ### MERGING ZONE ### def apply_merges_to_sorting( sorting: BaseSorting, diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 54befd40ec..ada35a57e9 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -3,9 +3,7 @@ but check only for BaseRecording general methods. """ -import shutil -from pathlib import Path - +import time import numpy as np import pytest from numpy.testing import assert_raises @@ -17,15 +15,15 @@ SharedMemorySorting, NpzFolderSorting, NumpyFolderSorting, + generate_ground_truth_recording, + generate_sorting, create_sorting_npz, generate_sorting, load, ) from spikeinterface.core.base import BaseExtractor from spikeinterface.core.testing import check_sorted_arrays_equal, check_sortings_equal -from spikeinterface.core.generate import generate_sorting - -from spikeinterface.core import generate_recording, generate_ground_truth_recording +from spikeinterface.core.node_pipeline import unit_period_dtype def test_BaseSorting(create_cache_folder): @@ -226,7 +224,61 @@ def test_time_slice(): ) +def test_select_periods(): + sampling_frequency = 10_000.0 + duration = 1_000 + num_samples = int(sampling_frequency * duration) + num_units = 1000 + sorting = generate_sorting( + durations=[duration, duration], sampling_frequency=sampling_frequency, num_units=num_units + ) + + rng = np.random.default_rng() + + # number of random periods + n_periods = 10_000 + # generate random periods + segment_indices = rng.integers(0, sorting.get_num_segments(), n_periods) + start_samples = rng.integers(0, num_samples, n_periods) + durations = rng.integers(100, 100_000, n_periods) + end_samples = start_samples + durations + valid_periods = end_samples < num_samples + segment_indices = segment_indices[valid_periods] + start_samples = start_samples[valid_periods] + end_samples = end_samples[valid_periods] + unit_index = rng.integers(0, num_units - 1, len(segment_indices)) + + periods = np.zeros(len(segment_indices), dtype=unit_period_dtype) + periods["segment_index"] = segment_indices + periods["start_sample_index"] = start_samples + periods["end_sample_index"] = end_samples + periods["unit_index"] = unit_index + + t_start = time.perf_counter() + sliced_sorting = sorting.select_periods(periods=periods) + t_stop = time.perf_counter() + elapsed = t_stop - t_start + print(f"select_periods took {elapsed:.2f} seconds for {len(periods)} periods") + + # Check that all spikes in the sliced sorting are within the periods + for segment_index in range(sorting.get_num_segments()): + for unit_index, unit_id in enumerate(sorting.unit_ids): + spiketrain = sorting.get_unit_spike_train(segment_index=segment_index, unit_id=unit_id) + spiketrain_sliced = sliced_sorting.get_unit_spike_train(segment_index=segment_index, unit_id=unit_id) + spikes_in_periods = np.array([], dtype=spiketrain.dtype) + periods_in_segment = periods[periods["segment_index"] == segment_index] + periods_for_unit = periods_in_segment[periods_in_segment["unit_index"] == unit_index] + for period in periods_for_unit: + start_sample = period["start_sample_index"] + end_sample = period["end_sample_index"] + spikes_in_period = spiketrain[(spiketrain >= start_sample) & (spiketrain < end_sample)] + spikes_in_periods = np.concatenate((spikes_in_periods, spikes_in_period)) + if not len(spikes_in_periods) == len(spiketrain_sliced): + print(f"Mismatch in number of spikes!: {len(spikes_in_periods)} vs {len(spiketrain_sliced)}") + + if __name__ == "__main__": test_BaseSorting() test_npy_sorting() test_empty_sorting() + test_select_periods() diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index c6b07da52e..028b2eeca5 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -19,12 +19,13 @@ from spikeinterface.core.analyzer_extension_core import BaseMetric from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs from spikeinterface.postprocessing import correlogram_for_one_segment -from spikeinterface.core import SortingAnalyzer, get_noise_levels +from spikeinterface.core import SortingAnalyzer, get_noise_levels, select_segment_sorting from spikeinterface.core.template_tools import ( get_template_extremum_channel, get_template_extremum_amplitude, get_dense_templates_array, ) +from spikeinterface.core.node_pipeline import base_period_dtype from ..spiketrain.metrics import NumSpikes, FiringRate @@ -35,7 +36,9 @@ HAVE_NUMBA = False -def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0): +def compute_presence_ratios( + sorting_analyzer, unit_ids=None, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0, periods=None +): """ Calculate the presence ratio, the fraction of time the unit is firing above a certain threshold. @@ -51,6 +54,9 @@ def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0 mean_fr_ratio_thresh : float, default: 0 The unit is considered active in a bin if its firing rate during that bin. is strictly above `mean_fr_ratio_thresh` times its mean firing rate throughout the recording. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -63,6 +69,7 @@ def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0 To do so, spike trains across segments are concatenated to mimic a continuous segment. """ sorting = sorting_analyzer.sorting + sorting = sorting.select_period(periods=periods) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() @@ -182,7 +189,7 @@ class SNR(BaseMetric): depend_on = ["noise_levels", "templates"] -def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5, min_isi_ms=0): +def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5, min_isi_ms=0, periods=None): """ Calculate Inter-Spike Interval (ISI) violations. @@ -204,6 +211,9 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 Minimum possible inter-spike interval, in ms. This is the artificial refractory period enforced. by the data acquisition system or post-processing algorithms. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -235,6 +245,7 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 res = namedtuple("isi_violation", ["isi_violations_ratio", "isi_violations_count"]) sorting = sorting_analyzer.sorting + sorting = sorting.select_period(sorting, periods=periods) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() @@ -280,7 +291,7 @@ class ISIViolation(BaseMetric): def compute_refrac_period_violations( - sorting_analyzer, unit_ids=None, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0 + sorting_analyzer, unit_ids=None, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0, periods=None ): """ Calculate the number of refractory period violations. @@ -300,6 +311,9 @@ def compute_refrac_period_violations( censored_period_ms : float, default: 0.0 The period (in ms) where no 2 spikes can occur (because they are not detected, or because they were removed by another mean). + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -332,6 +346,8 @@ def compute_refrac_period_violations( return None sorting = sorting_analyzer.sorting + sorting = sorting.select_period(periods=periods) + fs = sorting_analyzer.sampling_frequency num_units = len(sorting_analyzer.unit_ids) num_segments = sorting_analyzer.get_num_segments() @@ -392,6 +408,7 @@ def compute_sliding_rp_violations( exclude_ref_period_below_ms=0.5, max_ref_period_ms=10, contamination_values=None, + periods=None, ): """ Compute sliding refractory period violations, a metric developed by IBL which computes @@ -417,6 +434,9 @@ def compute_sliding_rp_violations( Maximum refractory period to test in ms. contamination_values : 1d array or None, default: None The contamination values to test, If None, it is set to np.arange(0.5, 35, 0.5). + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -431,6 +451,8 @@ def compute_sliding_rp_violations( """ duration = sorting_analyzer.get_total_duration() sorting = sorting_analyzer.sorting + sorting = sorting.select_period(periods=periods) + if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() @@ -486,7 +508,7 @@ class SlidingRPViolation(BaseMetric): } -def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None): +def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None, periods=None): """ Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of spikes at the exact same sample index, with synchrony sizes 2, 4 and 8. @@ -504,6 +526,9 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N ------- sync_spike_{X} : dict The synchrony metric for synchrony size X. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. References ---------- @@ -520,6 +545,7 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes]) sorting = sorting_analyzer.sorting + sorting = sorting.select_period(periods=periods) if unit_ids is None: unit_ids = sorting.unit_ids @@ -556,7 +582,7 @@ class Synchrony(BaseMetric): } -def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percentiles=(5, 95)): +def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percentiles=(5, 95), periods=None): """ Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution computed in non-overlapping time bins. @@ -571,6 +597,9 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent The size of the bin in seconds. percentiles : tuple, default: (5, 95) The percentiles to compute. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -584,6 +613,8 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent sampling_frequency = sorting_analyzer.sampling_frequency bin_size_samples = int(bin_size_s * sampling_frequency) sorting = sorting_analyzer.sorting + sorting = sorting.select_period(periods=periods) + if unit_ids is None: unit_ids = sorting.unit_ids @@ -635,6 +666,7 @@ def compute_amplitude_cv_metrics( percentiles=(5, 95), min_num_bins=10, amplitude_extension="spike_amplitudes", + periods=None, ): """ Calculate coefficient of variation of spike amplitudes within defined temporal bins. @@ -658,6 +690,8 @@ def compute_amplitude_cv_metrics( the median and range are set to NaN. amplitude_extension : str, default: "spike_amplitudes" The name of the extension to load the amplitudes from. "spike_amplitudes" or "amplitude_scalings". + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) Returns ------- @@ -683,7 +717,7 @@ def compute_amplitude_cv_metrics( if unit_ids is None: unit_ids = sorting.unit_ids - amps = sorting_analyzer.get_extension(amplitude_extension).get_data() + amps = sorting_analyzer.get_extension(amplitude_extension).get_data(period=period) # precompute segment slice segment_slices = [] @@ -752,6 +786,7 @@ def compute_amplitude_cutoffs( num_histogram_bins=500, histogram_smoothing_value=3, amplitudes_bins_min_ratio=5, + periods=None, ): """ Calculate approximate fraction of spikes missing from a distribution of amplitudes. @@ -770,6 +805,9 @@ def compute_amplitude_cutoffs( The minimum ratio between number of amplitudes for a unit and the number of bins. If the ratio is less than this threshold, the amplitude_cutoff for the unit is set to NaN. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -805,7 +843,7 @@ def compute_amplitude_cutoffs( invert_amplitudes = True extension = sorting_analyzer.get_extension("amplitude_scalings") - amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True) + amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, period=period) for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] @@ -837,7 +875,7 @@ class AmplitudeCutoff(BaseMetric): depend_on = ["spike_amplitudes|amplitude_scalings"] -def compute_amplitude_medians(sorting_analyzer, unit_ids=None): +def compute_amplitude_medians(sorting_analyzer, unit_ids=None, periods=None): """ Compute median of the amplitude distributions (in absolute value). @@ -847,6 +885,9 @@ def compute_amplitude_medians(sorting_analyzer, unit_ids=None): A SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the amplitude medians. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -865,7 +906,7 @@ def compute_amplitude_medians(sorting_analyzer, unit_ids=None): all_amplitude_medians = {} amplitude_extension = sorting_analyzer.get_extension("spike_amplitudes") - amplitudes_by_units = amplitude_extension.get_data(outputs="by_unit", concatenated=True) + amplitudes_by_units = amplitude_extension.get_data(outputs="by_unit", concatenated=True, period=period) for unit_id in unit_ids: all_amplitude_medians[unit_id] = np.median(amplitudes_by_units[unit_id]) @@ -882,7 +923,9 @@ class AmplitudeMedian(BaseMetric): depend_on = ["spike_amplitudes"] -def compute_noise_cutoffs(sorting_analyzer, unit_ids=None, high_quantile=0.25, low_quantile=0.1, n_bins=100): +def compute_noise_cutoffs( + sorting_analyzer, unit_ids=None, high_quantile=0.25, low_quantile=0.1, n_bins=100, periods=None +): """ A metric to determine if a unit's amplitude distribution is cut off as it approaches zero, without assuming a Gaussian distribution. @@ -906,6 +949,9 @@ def compute_noise_cutoffs(sorting_analyzer, unit_ids=None, high_quantile=0.25, l Quantile of the amplitude range below which values are treated as "low" (e.g. 0.1 = lower 10%), the test region. n_bins: int, default: 100 The number of bins to use to compute the amplitude histogram. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -934,7 +980,7 @@ def compute_noise_cutoffs(sorting_analyzer, unit_ids=None, high_quantile=0.25, l invert_amplitudes = True extension = sorting_analyzer.get_extension("amplitude_scalings") - amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True) + amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, period=period) for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] @@ -972,6 +1018,7 @@ def compute_drift_metrics( min_fraction_valid_intervals=0.5, min_num_bins=2, return_positions=False, + periods=None, ): """ Compute drifts metrics using estimated spike locations. @@ -1006,6 +1053,9 @@ def compute_drift_metrics( min_num_bins : int, default: 2 Minimum number of bins required to return a valid metric value. In case there are less bins, the metric values are set to NaN. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. return_positions : bool, default: False If True, median positions are returned (for debugging). @@ -1032,8 +1082,7 @@ def compute_drift_metrics( unit_ids = sorting.unit_ids spike_locations_ext = sorting_analyzer.get_extension("spike_locations") - spike_locations = spike_locations_ext.get_data() - # spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit") + spike_locations = spike_locations_ext.get_data(period=period) spikes = sorting.to_spike_vector() spike_locations_by_unit = {} for unit_id in unit_ids: @@ -1145,12 +1194,14 @@ class Drift(BaseMetric): depend_on = ["spike_locations"] +# TODO def compute_sd_ratio( sorting_analyzer: SortingAnalyzer, unit_ids=None, censored_period_ms: float = 4.0, correct_for_drift: bool = True, correct_for_template_itself: bool = True, + periods=None, **kwargs, ): """ @@ -1173,6 +1224,9 @@ def compute_sd_ratio( correct_for_template_itself : bool, default: True If true, will take into account that the template itself impacts the standard deviation of the noise, and will make a rough estimation of what that impact is (and remove it). + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. **kwargs : dict, default: {} Keyword arguments for computing spike amplitudes and extremum channel. @@ -1189,6 +1243,7 @@ def compute_sd_ratio( job_kwargs = fix_job_kwargs(job_kwargs) sorting = sorting_analyzer.sorting + sorting = sorting.select_period(periods=periods) censored_period = int(round(censored_period_ms * 1e-3 * sorting_analyzer.sampling_frequency)) if unit_ids is None: @@ -1201,7 +1256,7 @@ def compute_sd_ratio( ) return {unit_id: np.nan for unit_id in unit_ids} - spike_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() + spike_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data(period=period) if not HAVE_NUMBA: warnings.warn( diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index ab7ae9e7b5..75e41620f4 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -130,7 +130,7 @@ def test_compare_input_argument_ranges_against_ibl(shanks, p, sigma_um, num_chan recording._properties["contact_vector"][idx][1] = x[idx] # generate random bad channel locations - bad_channel_indexes = rng.choice(num_channels, rng.randint(1, int(num_channels / 5)), replace=False) + bad_channel_indexes = rng.choice(num_channels, rng.integers(1, int(num_channels / 5)), replace=False) bad_channel_ids = recording.channel_ids[bad_channel_indexes] # Run SI and IBL interpolation and check against eachother From 1962f212f2dcd68a56275d123a554b7757558143 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 7 Jan 2026 17:22:37 +0100 Subject: [PATCH 07/70] Fix test for base sorting and propagate to basevector extension --- .../core/analyzer_extension_core.py | 18 ++++++------------ .../core/tests/test_basesorting.py | 19 ++++++++++++------- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 5e46f20d22..9b93807b8c 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -17,9 +17,8 @@ from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates_with_accumulator from .recording_tools import get_noise_levels from .template import Templates -from .sorting_tools import random_spikes_selection +from .sorting_tools import random_spikes_selection, select_sorting_periods_mask from .job_tools import fix_job_kwargs, split_job_kwargs -from .node_pipeline import base_period_dtype class ComputeRandomSpikes(AnalyzerExtension): @@ -1423,16 +1422,11 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, all_data = self.data[return_data_name] if periods is not None: - # TODO: slice this properly with unit_indices - required = np.dtype(base_period_dtype).names - if not required.issubset(periods.dtype.names): - raise ValueError(f"Period must have the following fields: {required}") - # slice data according to period - segment_slices = self.segment_slices - all_data_segment = all_data[segment_slices[periods["segment_index"]]] - start = periods["start_sample_index"] - end = periods["end_sample_index"] - all_data = all_data_segment[start:end] + keep_mask = select_sorting_periods_mask( + self.sorting_analyzer.sorting, + periods, + ) + all_data = all_data[keep_mask] if outputs == "numpy": if copy: diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index ada35a57e9..18f632ed34 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -253,6 +253,7 @@ def test_select_periods(): periods["start_sample_index"] = start_samples periods["end_sample_index"] = end_samples periods["unit_index"] = unit_index + periods = np.sort(periods, order=["segment_index", "start_sample_index"]) t_start = time.perf_counter() sliced_sorting = sorting.select_periods(periods=periods) @@ -262,19 +263,23 @@ def test_select_periods(): # Check that all spikes in the sliced sorting are within the periods for segment_index in range(sorting.get_num_segments()): + periods_in_segment = periods[periods["segment_index"] == segment_index] for unit_index, unit_id in enumerate(sorting.unit_ids): spiketrain = sorting.get_unit_spike_train(segment_index=segment_index, unit_id=unit_id) - spiketrain_sliced = sliced_sorting.get_unit_spike_train(segment_index=segment_index, unit_id=unit_id) - spikes_in_periods = np.array([], dtype=spiketrain.dtype) - periods_in_segment = periods[periods["segment_index"] == segment_index] + periods_for_unit = periods_in_segment[periods_in_segment["unit_index"] == unit_index] + spiketrain_in_periods = [] for period in periods_for_unit: start_sample = period["start_sample_index"] end_sample = period["end_sample_index"] - spikes_in_period = spiketrain[(spiketrain >= start_sample) & (spiketrain < end_sample)] - spikes_in_periods = np.concatenate((spikes_in_periods, spikes_in_period)) - if not len(spikes_in_periods) == len(spiketrain_sliced): - print(f"Mismatch in number of spikes!: {len(spikes_in_periods)} vs {len(spiketrain_sliced)}") + spiketrain_in_periods.append(spiketrain[(spiketrain >= start_sample) & (spiketrain < end_sample)]) + if len(spiketrain_in_periods) == 0: + spiketrain_in_periods = np.array([], dtype=spiketrain.dtype) + else: + spiketrain_in_periods = np.unique(np.concatenate(spiketrain_in_periods)) + + spiketrain_sliced = sliced_sorting.get_unit_spike_train(segment_index=segment_index, unit_id=unit_id) + assert len(spiketrain_in_periods) == len(spiketrain_sliced) if __name__ == "__main__": From 7fbe1604d96072a49844173f5e7ca7f3d598a1a1 Mon Sep 17 00:00:00 2001 From: m-beau Date: Wed, 7 Jan 2026 17:26:08 +0100 Subject: [PATCH 08/70] wip --- .../postprocessing/good_periods_per_unit.py | 57 ++++++++++--------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/src/spikeinterface/postprocessing/good_periods_per_unit.py b/src/spikeinterface/postprocessing/good_periods_per_unit.py index ddde15d010..47e09f2437 100644 --- a/src/spikeinterface/postprocessing/good_periods_per_unit.py +++ b/src/spikeinterface/postprocessing/good_periods_per_unit.py @@ -7,6 +7,8 @@ from typing import Optional, Literal from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension +from spikeinterface.core.node_pipeline import unit_period_dtype +from spikeinterface.metrics.quality import compute_refrac_period_violations, compute_amplitude_cutoffs numba_spec = importlib.util.find_spec("numba") if numba_spec is not None: @@ -164,7 +166,7 @@ def _run(self, verbose=False): self.data["good_periods_per_unit"] = self.params["user_defined_periods"] if self.params["method"] in ["false_positives_and_negatives", "combined"]: - # ndarray: (n_periods, 3) with columns: segment_id, start_sample, end_sample + # dict: unit_name -> (n_periods, 3) with columns: segment_id, start_sample, end_sample period_bounds = compute_period_bounds( self, self.params["subperiod_size_absolute"], @@ -210,47 +212,48 @@ def compute_period_bounds( sorting = self.sorting_analyzer.sorting fs = sorting.get_sampling_frequency() + units = sorting.unit_ids + n_units = len(units) if subperiod_size_mode == "absolute": period_size_samples = margin_size_samples = np.round(subperiod_size_absolute * fs).astype(int) else: # relative - period_size_samples = margin_size_samples = 0 # to be implemented based on firing rates + pass # to be implemented based on firing rates - all_period_bounds = np.empty((0, 3)) - for segment_i in range(sorting.get_num_segments()): - n_samples = sorting.get_num_samples(segment_i) # int: samples + all_period_bounds = np.array([], dtype=unit_period_dtype) + for segment_index in range(sorting.get_num_segments()): + n_samples = sorting.get_num_samples(segment_index) # int: samples n_periods = n_samples // period_size_samples + 1 + intervals = [ + [i * period_size_samples, i * period_size_samples + 2 * margin_size_samples] for i in range(n_periods) + ] # list of sliding [start, end] in samples # for period size of 10s and margin size of 10s: [0, 30], [10, 40], [20, 50], ... - period_bounds = [ - ( - segment_i, - i * period_size_samples, - i * period_size_samples + 2 * margin_size_samples, - ) - for i in range(n_periods) - ] - all_period_bounds = ( - np.vstack(all_period_bounds, period_bounds) if len(all_period_bounds) > 0 else np.array(period_bounds) - ) + period_bounds = np.zeros((n_periods * n_units,), dtype=unit_period_dtype) + for int_index, int in enumerate(intervals): + periods_per_units = np.zeros((n_units,), dtype=unit_period_dtype) + periods_per_units["segment_index"] = segment_index + periods_per_units["start_sample_index"] = int[0] + periods_per_units["end_sample_index"] = int[1] + periods_per_units["unit_index"] = np.arange(n_units) + period_bounds[int_index * n_units : (int_index + 1) * n_units] = periods_per_units + all_period_bounds = np.concatenate(all_period_bounds, period_bounds) return all_period_bounds -def compute_fp_rates(self, period_bounds: list, violations_ms: float = 0.8) -> dict: - units = self.sorting_analyzer.sorting.unit_ids - n_periods = period_bounds.shape[0] +def compute_fp_rates(self, period_bounds: np.ndarray, violations_ms: float = 0.8) -> dict: - fp_violations = {} - for unit in units: - fp_violations[unit] = np.zeros((n_periods,), dtype=float) - for i, (segment_i, start, end) in enumerate(period_bounds): - fp_rate = 0 # refractory period violations for this period - fp_violations[unit][i] = fp_rate - pass + isi_violations = compute_refrac_period_violations( + self.sorting_analyzer.sorting, + refractory_period_ms=violations_ms, + periods=period_bounds, + ) + + fp_rates = isi_violations["rp_contamination"] # dict: unit_id -> array of shape (n_subperiods) - return fp_violations + return fp_rates def compute_fn_rates(self, period_bounds: list) -> dict: From 528c82b7951db9d509030b7fc10e3796fb69347b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 8 Jan 2026 08:33:58 +0100 Subject: [PATCH 09/70] Fix tests in quailty metrics --- .../metrics/quality/misc_metrics.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 028b2eeca5..4a7ef04554 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -69,7 +69,7 @@ def compute_presence_ratios( To do so, spike trains across segments are concatenated to mimic a continuous segment. """ sorting = sorting_analyzer.sorting - sorting = sorting.select_period(periods=periods) + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() @@ -245,7 +245,7 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 res = namedtuple("isi_violation", ["isi_violations_ratio", "isi_violations_count"]) sorting = sorting_analyzer.sorting - sorting = sorting.select_period(sorting, periods=periods) + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() @@ -346,7 +346,7 @@ def compute_refrac_period_violations( return None sorting = sorting_analyzer.sorting - sorting = sorting.select_period(periods=periods) + sorting = sorting.select_periods(periods=periods) fs = sorting_analyzer.sampling_frequency num_units = len(sorting_analyzer.unit_ids) @@ -451,7 +451,7 @@ def compute_sliding_rp_violations( """ duration = sorting_analyzer.get_total_duration() sorting = sorting_analyzer.sorting - sorting = sorting.select_period(periods=periods) + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids @@ -545,7 +545,7 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes]) sorting = sorting_analyzer.sorting - sorting = sorting.select_period(periods=periods) + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting.unit_ids @@ -613,7 +613,7 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent sampling_frequency = sorting_analyzer.sampling_frequency bin_size_samples = int(bin_size_s * sampling_frequency) sorting = sorting_analyzer.sorting - sorting = sorting.select_period(periods=periods) + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting.unit_ids @@ -717,7 +717,7 @@ def compute_amplitude_cv_metrics( if unit_ids is None: unit_ids = sorting.unit_ids - amps = sorting_analyzer.get_extension(amplitude_extension).get_data(period=period) + amps = sorting_analyzer.get_extension(amplitude_extension).get_data(periods=periods) # precompute segment slice segment_slices = [] @@ -843,7 +843,7 @@ def compute_amplitude_cutoffs( invert_amplitudes = True extension = sorting_analyzer.get_extension("amplitude_scalings") - amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, period=period) + amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, periods=periods) for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] @@ -906,7 +906,7 @@ def compute_amplitude_medians(sorting_analyzer, unit_ids=None, periods=None): all_amplitude_medians = {} amplitude_extension = sorting_analyzer.get_extension("spike_amplitudes") - amplitudes_by_units = amplitude_extension.get_data(outputs="by_unit", concatenated=True, period=period) + amplitudes_by_units = amplitude_extension.get_data(outputs="by_unit", concatenated=True, periods=periods) for unit_id in unit_ids: all_amplitude_medians[unit_id] = np.median(amplitudes_by_units[unit_id]) @@ -980,7 +980,7 @@ def compute_noise_cutoffs( invert_amplitudes = True extension = sorting_analyzer.get_extension("amplitude_scalings") - amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, period=period) + amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, periods=periods) for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] @@ -1082,7 +1082,7 @@ def compute_drift_metrics( unit_ids = sorting.unit_ids spike_locations_ext = sorting_analyzer.get_extension("spike_locations") - spike_locations = spike_locations_ext.get_data(period=period) + spike_locations = spike_locations_ext.get_data(periods=periods) spikes = sorting.to_spike_vector() spike_locations_by_unit = {} for unit_id in unit_ids: @@ -1243,7 +1243,7 @@ def compute_sd_ratio( job_kwargs = fix_job_kwargs(job_kwargs) sorting = sorting_analyzer.sorting - sorting = sorting.select_period(periods=periods) + sorting = sorting.select_periods(periods=periods) censored_period = int(round(censored_period_ms * 1e-3 * sorting_analyzer.sampling_frequency)) if unit_ids is None: @@ -1256,7 +1256,7 @@ def compute_sd_ratio( ) return {unit_id: np.nan for unit_id in unit_ids} - spike_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data(period=period) + spike_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data(periods=periods) if not HAVE_NUMBA: warnings.warn( From fccdbe349979e7b0a0f9154879d41c67ad88f68c Mon Sep 17 00:00:00 2001 From: m-beau Date: Thu, 8 Jan 2026 12:05:09 +0100 Subject: [PATCH 10/70] finished implementing good periods --- .../postprocessing/good_periods_per_unit.py | 243 ++++++++++++++---- 1 file changed, 186 insertions(+), 57 deletions(-) diff --git a/src/spikeinterface/postprocessing/good_periods_per_unit.py b/src/spikeinterface/postprocessing/good_periods_per_unit.py index 47e09f2437..f75af10ddd 100644 --- a/src/spikeinterface/postprocessing/good_periods_per_unit.py +++ b/src/spikeinterface/postprocessing/good_periods_per_unit.py @@ -9,6 +9,7 @@ from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension from spikeinterface.core.node_pipeline import unit_period_dtype from spikeinterface.metrics.quality import compute_refrac_period_violations, compute_amplitude_cutoffs +from spikeinterface.metrics.spiketrain import compute_firing_rates numba_spec = importlib.util.find_spec("numba") if numba_spec is not None: @@ -46,6 +47,8 @@ class ComputeGoodPeriodsPerUnit(AnalyzerExtension): Maximum false negative rate to mark period as good. minimum_n_spikes : int, default=100 Minimum spikes required in period for analysis. + minimum_valid_period_duration : float, default=180 + Minimum duration that detected good periods must have to be kept, in seconds. user_defined_periods : array-like or None, default=None In SAMPLES, user-specified (unit, good_period_start, good_period_end) or (unit, segment_index, good_period_start, good_period_end) time pairs. Required if method="user_defined" or "combined". @@ -74,6 +77,7 @@ def _set_params( fp_threshold: float = 0.05, fn_threshold: float = 0.05, minimum_n_spikes: int = 100, + minimum_valid_period_duration: float = 180, user_defined_periods: Optional[object] = None, ): @@ -98,7 +102,7 @@ def _set_params( ), "Either subperiod_size_absolute or subperiod_size_relative must be positive." assert isinstance(subperiod_size_relative, (int)), "subperiod_size_relative must be an integer." - # user_defined_periods format + # user_defined_periods formatting if user_defined_periods is not None: try: user_defined_periods = np.asarray(user_defined_periods) @@ -130,6 +134,22 @@ def _set_params( user_defined_periods[:, 1:3], ) ) + # Cast user defined periods to unit_period_dtype + user_defined_periods_typed = np.zeros(user_defined_periods.shape[0], dtype=unit_period_dtype) + user_defined_periods_typed["unit_index"] = user_defined_periods[:, 0] + user_defined_periods_typed["segment_index"] = user_defined_periods[:, 1] + user_defined_periods_typed["start_sample_index"] = user_defined_periods[:, 2] + user_defined_periods_typed["end_sample_index"] = user_defined_periods[:, 3] + user_defined_periods = user_defined_periods_typed + + # assert that user-defined periods are not too short + fs = self.sorting_analyzer.sorting.get_sampling_frequency() + durations = user_defined_periods["end_sample_index"] - user_defined_periods["start_sample_index"] + min_duration_samples = int(minimum_valid_period_duration * fs) + if np.any(durations < min_duration_samples): + raise ValueError( + f"All user-defined periods must be at least {minimum_valid_period_duration} seconds long." + ) params = dict( method=method, @@ -140,6 +160,7 @@ def _set_params( fp_threshold=fp_threshold, fn_threshold=fn_threshold, minimum_n_spikes=minimum_n_spikes, + minimum_valid_period_duration=minimum_valid_period_duration, user_defined_periods=user_defined_periods, ) @@ -165,9 +186,9 @@ def _run(self, verbose=False): # directly use user defined periods self.data["good_periods_per_unit"] = self.params["user_defined_periods"] - if self.params["method"] in ["false_positives_and_negatives", "combined"]: - # dict: unit_name -> (n_periods, 3) with columns: segment_id, start_sample, end_sample - period_bounds = compute_period_bounds( + elif self.params["method"] in ["false_positives_and_negatives", "combined"]: + # dict: unit_name -> list of subperiod, each subperiod is an array of dtype unit_period_dtype with 4 fields + subperiods_per_unit = compute_subperiods( self, self.params["subperiod_size_absolute"], self.params["subperiod_size_relative"], @@ -177,22 +198,56 @@ def _run(self, verbose=False): ## Compute fp and fn for all periods # fp computed from refractory period violations - # dict: unit_id -> array of shape (n_periods) - periods_fp_per_unit = compute_fp_rates(self, period_bounds, self.params["violations_ms"]) + # dict: unit_id -> array of shape (n_subperiods) + periods_fp_per_unit = compute_fp_rates(self, subperiods_per_unit, self.params["violations_ms"]) # fn computed from amplitude clippings - # dict: unit_id -> array of shape (n_periods) - periods_fn_per_unit = compute_fn_rates(self, period_bounds) + # dict: unit_id -> array of shape (n_subperiods) + periods_fn_per_unit = compute_fn_rates(self, subperiods_per_unit) ## Combine fp and fn results with thresholds to define good periods - - ## Eventually combine with user defined periods if provided - - self.data["period_bounds"] = period_bounds + # get n spikes per unit to set the fp or fn rates to 1 if not enough spikes + minimum_valid_period_duration = self.params["minimum_valid_period_duration"] + fs = self.sorting_analyzer.sorting.get_sampling_frequency() + min_valid_period_samples = int(minimum_valid_period_duration * fs) + + n_spikes_per_unit = self.sorting_analyzer.count_num_spikes_per_unit() + good_periods_per_unit = np.array([], dtype=unit_period_dtype) + for unit_name, subperiods in subperiods_per_unit.items(): + n_spikes = n_spikes_per_unit[unit_name] + if n_spikes < self.params["minimum_n_spikes"]: + periods_fp_per_unit[unit_name] = np.ones_like(periods_fp_per_unit[unit_name]) + periods_fn_per_unit[unit_name] = np.ones_like(periods_fn_per_unit[unit_name]) + + fp_rates = periods_fp_per_unit[unit_name] + fn_rates = periods_fn_per_unit[unit_name] + + good_periods_mask = (fp_rates < self.params["fp_threshold"]) & (fn_rates < self.params["fn_threshold"]) + good_subperiods = subperiods[good_periods_mask] + good_segments = np.unique(good_subperiods["segment_index"]) + for segment_index in good_segments: + segment_mask = good_subperiods["segment_index"] == segment_index + good_segment_subperiods = good_subperiods[segment_mask] + good_segment_periods = merge_overlapping_periods(good_segment_subperiods) + good_periods_per_unit = np.concatenate((good_periods_per_unit, good_segment_periods), axis=0) + + ## Remove good periods that are too short + durations = good_periods_per_unit[:, 1] - good_periods_per_unit[:, 0] + valid_mask = durations >= min_valid_period_samples + good_periods_per_unit = good_periods_per_unit[valid_mask] + + ## Eventually combine with user-defined periods if provided + if self.params["method"] == "combined": + user_defined_periods = self.params["user_defined_periods"] + all_periods = np.concatenate((good_periods_per_unit, user_defined_periods), axis=0) + good_periods_per_unit = merge_overlapping_periods_across_units_and_segments(all_periods) + + ## Store data + self.data["subperiods_per_unit"] = subperiods_per_unit self.data["periods_fp_per_unit"] = periods_fp_per_unit self.data["periods_fn_per_unit"] = periods_fn_per_unit self.data["good_periods_per_unit"] = ( - None # (n_good_periods, 4) with (unit, segment, start, end) to be implemented + good_periods_per_unit # (n_good_periods, 4) with (unit, segment, start, end) to be implemented ) def _get_data(self): @@ -203,69 +258,143 @@ def _get_data(self): # compute_isi_histograms = ComputeISIHistograms.function_factory() -def compute_period_bounds( +def compute_subperiods( self, subperiod_size_absolute: float = 10, subperiod_size_relative: int = 1000, subperiod_size_mode: str = "absolute", -) -> np.ndarray: +) -> dict: sorting = self.sorting_analyzer.sorting fs = sorting.get_sampling_frequency() - units = sorting.unit_ids - n_units = len(units) + unit_names = sorting.unit_ids if subperiod_size_mode == "absolute": - period_size_samples = margin_size_samples = np.round(subperiod_size_absolute * fs).astype(int) + period_sizes_samples = {u: np.round(subperiod_size_absolute * fs).astype(int) for u in unit_names} else: # relative - pass # to be implemented based on firing rates + mean_firing_rates = compute_firing_rates(self.sorting_analyzer, unit_names) + period_sizes_samples = { + u: np.round((subperiod_size_relative / mean_firing_rates[u]) * fs).astype(int) for u in unit_names + } + margin_sizes_samples = period_sizes_samples + + all_subperiods = {} + for unit_name in unit_names: + period_size_samples = period_sizes_samples[unit_name] + margin_size_samples = margin_sizes_samples[unit_name] + + all_subperiods[unit_name] = [] + for segment_index in range(sorting.get_num_segments()): + n_samples = sorting.get_num_samples(segment_index) # int: samples + n_subperiods = n_samples // period_size_samples + 1 + starts_ends = np.array( + [ + [i * period_size_samples, i * period_size_samples + 2 * margin_size_samples] + for i in range(n_subperiods) + ] + ) + for start, end in starts_ends: + subperiod = np.zeros((1,), dtype=unit_period_dtype) + subperiod["segment_index"] = segment_index + subperiod["start_sample_index"] = start + subperiod["end_sample_index"] = end + subperiod["unit_index"] = unit_name + all_subperiods[unit_name].append(subperiod) + + return all_subperiods + + +def compute_fp_rates(self, subperiods_per_unit: dict, violations_ms: float = 0.8) -> dict: + + fp_rates = {} + for unit_name, subperiods in subperiods_per_unit.items(): + fp_rates[unit_name] = [] + for subperiod in subperiods: + isi_violations = compute_refrac_period_violations( + self.sorting_analyzer.sorting, + unit_ids=[unit_name], + refractory_period_ms=violations_ms, + periods=subperiod, + ) + fp_rates[unit_name].append( + isi_violations["rp_contamination"][unit_name] + ) # contamination for this subperiod + + return fp_rates - all_period_bounds = np.array([], dtype=unit_period_dtype) - for segment_index in range(sorting.get_num_segments()): - n_samples = sorting.get_num_samples(segment_index) # int: samples - n_periods = n_samples // period_size_samples + 1 - intervals = [ - [i * period_size_samples, i * period_size_samples + 2 * margin_size_samples] for i in range(n_periods) - ] - # list of sliding [start, end] in samples - # for period size of 10s and margin size of 10s: [0, 30], [10, 40], [20, 50], ... - period_bounds = np.zeros((n_periods * n_units,), dtype=unit_period_dtype) - for int_index, int in enumerate(intervals): - periods_per_units = np.zeros((n_units,), dtype=unit_period_dtype) - periods_per_units["segment_index"] = segment_index - periods_per_units["start_sample_index"] = int[0] - periods_per_units["end_sample_index"] = int[1] - periods_per_units["unit_index"] = np.arange(n_units) - period_bounds[int_index * n_units : (int_index + 1) * n_units] = periods_per_units - all_period_bounds = np.concatenate(all_period_bounds, period_bounds) +def compute_fn_rates(self, subperiods_per_unit: dict) -> dict: - return all_period_bounds + fn_rates = {} + for unit_name, subperiods in subperiods_per_unit.items(): + fn_rates[unit_name] = [] + for subperiod in subperiods: + all_fraction_missing = compute_amplitude_cutoffs( + self.sorting_analyzer.sorting, + unit_ids=[unit_name], + num_histogram_bins=500, + histogram_smoothing_value=3, + amplitudes_bins_min_ratio=5, + periods=subperiod, + ) + fn_rates[unit_name].append(all_fraction_missing[unit_name]) # missed spikes for this subperiod + return fn_rates -def compute_fp_rates(self, period_bounds: np.ndarray, violations_ms: float = 0.8) -> dict: - isi_violations = compute_refrac_period_violations( - self.sorting_analyzer.sorting, - refractory_period_ms=violations_ms, - periods=period_bounds, - ) +def merge_overlapping_periods(subperiods): - fp_rates = isi_violations["rp_contamination"] # dict: unit_id -> array of shape (n_subperiods) + segment_indices = np.unique(subperiods["segment_index"]) + assert len(segment_indices) == 1, "Subperiods must belong to the same segment to be merged." + segment_index = segment_indices[0] + unit_indices = np.unique(subperiods["unit_index"]) + assert len(unit_indices) == 1, "Subperiods must belong to the same unit to be merged." + unit_index = unit_indices[0] + + # Sort subperiods by start time for interval merging + sort_idx = np.argsort(subperiods["start_sample_index"]) + sorted_subperiods = subperiods[sort_idx] + + # Merge overlapping/adjacent intervals + merged_starts = [sorted_subperiods[0]["start_sample_index"]] + merged_ends = [sorted_subperiods[0]["end_sample_index"]] + + for i in range(1, len(sorted_subperiods)): + current_start = sorted_subperiods[i]["start_sample_index"] + current_end = sorted_subperiods[i]["end_sample_index"] + + # Merge if overlapping or contiguous (end >= start) + if current_start <= merged_ends[-1]: + merged_ends[-1] = max(merged_ends[-1], current_end) + else: + merged_starts.append(current_start) + merged_ends.append(current_end) + + # Construct output array + n_periods = len(merged_starts) + merged_periods = np.zeros(n_periods, dtype=unit_period_dtype) + merged_periods["segment_index"] = segment_index + merged_periods["start_sample_index"] = merged_starts + merged_periods["end_sample_index"] = merged_ends + merged_periods["unit_index"] = unit_index + + return merged_periods - return fp_rates +def merge_overlapping_periods_across_units_and_segments(periods): -def compute_fn_rates(self, period_bounds: list) -> dict: - units = self.sorting_analyzer.sorting.unit_ids - n_periods = period_bounds.shape[0] + units = np.unique(periods["unit_index"]) + segments = np.unique(periods["segment_index"]) - fn_violations = {} - for unit in units: - fn_violations[unit] = np.zeros((n_periods,), dtype=float) - for i, (segment_i, start, end) in enumerate(period_bounds): - fn_rate = 0 # clipped amplitude AUC ratio for this period - fn_violations[unit][i] = fn_rate - pass + merged_periods = np.array([], dtype=unit_period_dtype) + for unit_index in units: + for segment_index in segments: + masked_periods = periods[ + (periods["unit_index"] == unit_index) & (periods["segment_index"] == segment_index) + ] + if len(masked_periods) == 0: + continue + _merged_periods = merge_overlapping_periods(masked_periods) + merged_periods = np.concatenate((merged_periods, _merged_periods), axis=0) - return fn_violations + return merged_periods From f36c7fc9d30308f0ab254c2e7028c75d52eebb53 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 8 Jan 2026 12:44:55 +0100 Subject: [PATCH 11/70] Some fixes --- src/spikeinterface/metrics/quality/__init__.py | 1 + src/spikeinterface/postprocessing/__init__.py | 5 +++++ .../postprocessing/good_periods_per_unit.py | 16 +++++++--------- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/metrics/quality/__init__.py b/src/spikeinterface/metrics/quality/__init__.py index 1edcd9221f..f91ed6eefc 100644 --- a/src/spikeinterface/metrics/quality/__init__.py +++ b/src/spikeinterface/metrics/quality/__init__.py @@ -20,4 +20,5 @@ compute_sliding_rp_violations, compute_sd_ratio, compute_synchrony_metrics, + compute_refrac_period_violations, ) diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index dca9711ccd..078157188e 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -44,3 +44,8 @@ ComputeTemplateMetrics, compute_template_metrics, ) + +from .good_periods_per_unit import ( + ComputeGoodPeriodsPerUnit, + compute_good_periods_per_unit, +) diff --git a/src/spikeinterface/postprocessing/good_periods_per_unit.py b/src/spikeinterface/postprocessing/good_periods_per_unit.py index f75af10ddd..785ffa87ca 100644 --- a/src/spikeinterface/postprocessing/good_periods_per_unit.py +++ b/src/spikeinterface/postprocessing/good_periods_per_unit.py @@ -91,7 +91,7 @@ def _set_params( warnings.warn("Combined method without user_defined_periods, falling back") method = "false_positives_and_negatives" - if params.method in ["false_positives_and_negatives", "combined"]: + if method in ["false_positives_and_negatives", "combined"]: if not self.sorting_analyzer.has_extension("amplitude_scalings"): raise ValueError("Requires 'amplitude_scalings' extension; please compute it first.") @@ -254,8 +254,8 @@ def _get_data(self): return self.data["isi_histograms"], self.data["bins"] -# register_result_extension(ComputeISIHistograms) -# compute_isi_histograms = ComputeISIHistograms.function_factory() +register_result_extension(ComputeGoodPeriodsPerUnit) +compute_good_periods_per_unit = ComputeGoodPeriodsPerUnit.function_factory() def compute_subperiods( @@ -285,7 +285,7 @@ def compute_subperiods( all_subperiods[unit_name] = [] for segment_index in range(sorting.get_num_segments()): - n_samples = sorting.get_num_samples(segment_index) # int: samples + n_samples = self.sorting_analyzer.get_num_samples(segment_index) # int: samples n_subperiods = n_samples // period_size_samples + 1 starts_ends = np.array( [ @@ -311,14 +311,12 @@ def compute_fp_rates(self, subperiods_per_unit: dict, violations_ms: float = 0.8 fp_rates[unit_name] = [] for subperiod in subperiods: isi_violations = compute_refrac_period_violations( - self.sorting_analyzer.sorting, + self.sorting_analyzer, unit_ids=[unit_name], refractory_period_ms=violations_ms, periods=subperiod, ) - fp_rates[unit_name].append( - isi_violations["rp_contamination"][unit_name] - ) # contamination for this subperiod + fp_rates[unit_name].append(isi_violations.rp_contamination[unit_name]) # contamination for this subperiod return fp_rates @@ -330,7 +328,7 @@ def compute_fn_rates(self, subperiods_per_unit: dict) -> dict: fn_rates[unit_name] = [] for subperiod in subperiods: all_fraction_missing = compute_amplitude_cutoffs( - self.sorting_analyzer.sorting, + self.sorting_analyzer, unit_ids=[unit_name], num_histogram_bins=500, histogram_smoothing_value=3, From 775dda710adc4c9b4a7eddb7e5dea99d8d9df884 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 8 Jan 2026 12:46:23 +0100 Subject: [PATCH 12/70] Fix retrieval of spikevector features --- src/spikeinterface/core/analyzer_extension_core.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 9b93807b8c..804418a2ff 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -1421,6 +1421,7 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, ), f"return_data_name {return_data_name} not in nodepipeline_variables {self.nodepipeline_variables}" all_data = self.data[return_data_name] + keep_mask = None if periods is not None: keep_mask = select_sorting_periods_mask( self.sorting_analyzer.sorting, @@ -1436,6 +1437,8 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, elif outputs == "by_unit": unit_ids = self.sorting_analyzer.unit_ids spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) + if keep_mask is not None: + spike_vector = spike_vector[keep_mask] spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True) data_by_units = {} for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): From 15df754997c6df5bf4ed3ca15a4f0339edacedbb Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 8 Jan 2026 16:37:16 +0100 Subject: [PATCH 13/70] Fix tests, saving and loading --- .../core/analyzer_extension_core.py | 11 +- src/spikeinterface/core/sortinganalyzer.py | 1 + .../metrics/quality/misc_metrics.py | 3 +- .../postprocessing/good_periods_per_unit.py | 243 ++++++++++-------- .../tests/common_extension_tests.py | 14 +- 5 files changed, 147 insertions(+), 125 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 804418a2ff..71d0e1b5f0 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -1421,13 +1421,16 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, ), f"return_data_name {return_data_name} not in nodepipeline_variables {self.nodepipeline_variables}" all_data = self.data[return_data_name] - keep_mask = None if periods is not None: keep_mask = select_sorting_periods_mask( self.sorting_analyzer.sorting, periods, ) all_data = all_data[keep_mask] + sorting = self.sorting_analyzer.sorting.select_periods(periods) + else: + keep_mask = None + sorting = self.sorting_analyzer.sorting if outputs == "numpy": if copy: @@ -1436,12 +1439,10 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, return all_data elif outputs == "by_unit": unit_ids = self.sorting_analyzer.unit_ids - spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) - if keep_mask is not None: - spike_vector = spike_vector[keep_mask] + spike_vector = sorting.to_spike_vector(concatenated=False) spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True) data_by_units = {} - for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): + for segment_index in range(sorting.get_num_segments()): data_by_units[segment_index] = {} for unit_id in unit_ids: inds = spike_indices[segment_index][unit_id] diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 1870c24e7a..0c3b6b9615 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -2633,6 +2633,7 @@ def _save_data(self): extension_group.create_dataset( name=ext_data_name, data=np.array([ext_data], dtype=object), object_codec=numcodecs.JSON() ) + extension_group[ext_data_name].attrs["dict"] = True elif isinstance(ext_data, np.ndarray): extension_group.create_dataset(name=ext_data_name, data=ext_data, **saving_options) elif HAS_PANDAS and isinstance(ext_data, pd.DataFrame): diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 4a7ef04554..4d02274e80 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -350,7 +350,7 @@ def compute_refrac_period_violations( fs = sorting_analyzer.sampling_frequency num_units = len(sorting_analyzer.unit_ids) - num_segments = sorting_analyzer.get_num_segments() + num_segments = sorting.get_num_segments() spikes = sorting.to_spike_vector(concatenated=False) @@ -849,7 +849,6 @@ def compute_amplitude_cutoffs( amplitudes = amplitudes_by_units[unit_id] if invert_amplitudes: amplitudes = -amplitudes - all_fraction_missing[unit_id] = amplitude_cutoff( amplitudes, num_histogram_bins, histogram_smoothing_value, amplitudes_bins_min_ratio ) diff --git a/src/spikeinterface/postprocessing/good_periods_per_unit.py b/src/spikeinterface/postprocessing/good_periods_per_unit.py index 785ffa87ca..b9246f510d 100644 --- a/src/spikeinterface/postprocessing/good_periods_per_unit.py +++ b/src/spikeinterface/postprocessing/good_periods_per_unit.py @@ -143,7 +143,7 @@ def _set_params( user_defined_periods = user_defined_periods_typed # assert that user-defined periods are not too short - fs = self.sorting_analyzer.sorting.get_sampling_frequency() + fs = self.sorting_analyzer.sampling_frequency durations = user_defined_periods["end_sample_index"] - user_defined_periods["start_sample_index"] min_duration_samples = int(minimum_valid_period_duration * fs) if np.any(durations < min_duration_samples): @@ -187,43 +187,43 @@ def _run(self, verbose=False): self.data["good_periods_per_unit"] = self.params["user_defined_periods"] elif self.params["method"] in ["false_positives_and_negatives", "combined"]: - # dict: unit_name -> list of subperiod, each subperiod is an array of dtype unit_period_dtype with 4 fields - subperiods_per_unit = compute_subperiods( - self, + # dict: unit_id -> list of subperiod, each subperiod is an array of dtype unit_period_dtype with 4 fields + subperiods_per_unit = self.compute_subperiods( self.params["subperiod_size_absolute"], self.params["subperiod_size_relative"], self.params["subperiod_size_mode"], ) - ## Compute fp and fn for all periods + # Compute fp and fn for all periods # fp computed from refractory period violations # dict: unit_id -> array of shape (n_subperiods) - periods_fp_per_unit = compute_fp_rates(self, subperiods_per_unit, self.params["violations_ms"]) + periods_fp_per_unit = self.compute_fp_rates(subperiods_per_unit, self.params["violations_ms"]) # fn computed from amplitude clippings # dict: unit_id -> array of shape (n_subperiods) - periods_fn_per_unit = compute_fn_rates(self, subperiods_per_unit) - - ## Combine fp and fn results with thresholds to define good periods + periods_fn_per_unit = self.compute_fn_rates(subperiods_per_unit) + # Combine fp and fn results with thresholds to define good periods # get n spikes per unit to set the fp or fn rates to 1 if not enough spikes minimum_valid_period_duration = self.params["minimum_valid_period_duration"] - fs = self.sorting_analyzer.sorting.get_sampling_frequency() + fs = self.sorting_analyzer.sampling_frequency min_valid_period_samples = int(minimum_valid_period_duration * fs) - n_spikes_per_unit = self.sorting_analyzer.count_num_spikes_per_unit() + n_spikes_per_unit = self.sorting_analyzer.sorting.count_num_spikes_per_unit() good_periods_per_unit = np.array([], dtype=unit_period_dtype) - for unit_name, subperiods in subperiods_per_unit.items(): - n_spikes = n_spikes_per_unit[unit_name] + for unit_id, subperiods in subperiods_per_unit.items(): + n_spikes = n_spikes_per_unit[unit_id] if n_spikes < self.params["minimum_n_spikes"]: - periods_fp_per_unit[unit_name] = np.ones_like(periods_fp_per_unit[unit_name]) - periods_fn_per_unit[unit_name] = np.ones_like(periods_fn_per_unit[unit_name]) + periods_fp_per_unit[unit_id] = [1] * len(periods_fp_per_unit[unit_id]) + periods_fn_per_unit[unit_id] = [1] * len(periods_fn_per_unit[unit_id]) - fp_rates = periods_fp_per_unit[unit_name] - fn_rates = periods_fn_per_unit[unit_name] + fp_rates = periods_fp_per_unit[unit_id] + fn_rates = periods_fn_per_unit[unit_id] - good_periods_mask = (fp_rates < self.params["fp_threshold"]) & (fn_rates < self.params["fn_threshold"]) - good_subperiods = subperiods[good_periods_mask] + good_periods_mask = (np.array(fp_rates) < self.params["fp_threshold"]) & ( + np.array(fn_rates) < self.params["fn_threshold"] + ) + good_subperiods = np.array(subperiods)[good_periods_mask] good_segments = np.unique(good_subperiods["segment_index"]) for segment_index in good_segments: segment_mask = good_subperiods["segment_index"] == segment_index @@ -231,113 +231,124 @@ def _run(self, verbose=False): good_segment_periods = merge_overlapping_periods(good_segment_subperiods) good_periods_per_unit = np.concatenate((good_periods_per_unit, good_segment_periods), axis=0) - ## Remove good periods that are too short - durations = good_periods_per_unit[:, 1] - good_periods_per_unit[:, 0] - valid_mask = durations >= min_valid_period_samples + # Remove good periods that are too short + duration_samples = good_periods_per_unit["end_sample_index"] - good_periods_per_unit["start_sample_index"] + valid_mask = duration_samples >= min_valid_period_samples good_periods_per_unit = good_periods_per_unit[valid_mask] - ## Eventually combine with user-defined periods if provided + # Eventually combine with user-defined periods if provided if self.params["method"] == "combined": user_defined_periods = self.params["user_defined_periods"] all_periods = np.concatenate((good_periods_per_unit, user_defined_periods), axis=0) good_periods_per_unit = merge_overlapping_periods_across_units_and_segments(all_periods) - ## Store data - self.data["subperiods_per_unit"] = subperiods_per_unit + # Convert subperiods per unit in period_centers_s + period_centers_s = [] + for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): + period_centers_dict = {} + for unit_id in self.sorting_analyzer.unit_ids: + periods_unit = subperiods_per_unit[unit_id] + periods_segment = periods_unit[periods_unit["segment_index"] == segment_index] + centers = list(0.5 * (periods_segment["start_sample_index"] + periods_segment["end_sample_index"])) + period_centers_dict[unit_id] = centers + period_centers_s.append(period_centers_dict) + + # Store data: here we have to make sure every dict is JSON serializable, so everything is lists + self.data["period_centers_s"] = period_centers_s self.data["periods_fp_per_unit"] = periods_fp_per_unit self.data["periods_fn_per_unit"] = periods_fn_per_unit - self.data["good_periods_per_unit"] = ( - good_periods_per_unit # (n_good_periods, 4) with (unit, segment, start, end) to be implemented - ) + self.data["good_periods_per_unit"] = good_periods_per_unit def _get_data(self): - return self.data["isi_histograms"], self.data["bins"] - - -register_result_extension(ComputeGoodPeriodsPerUnit) -compute_good_periods_per_unit = ComputeGoodPeriodsPerUnit.function_factory() - - -def compute_subperiods( - self, - subperiod_size_absolute: float = 10, - subperiod_size_relative: int = 1000, - subperiod_size_mode: str = "absolute", -) -> dict: - - sorting = self.sorting_analyzer.sorting - fs = sorting.get_sampling_frequency() - unit_names = sorting.unit_ids - - if subperiod_size_mode == "absolute": - period_sizes_samples = {u: np.round(subperiod_size_absolute * fs).astype(int) for u in unit_names} - else: # relative - mean_firing_rates = compute_firing_rates(self.sorting_analyzer, unit_names) - period_sizes_samples = { - u: np.round((subperiod_size_relative / mean_firing_rates[u]) * fs).astype(int) for u in unit_names - } - margin_sizes_samples = period_sizes_samples - - all_subperiods = {} - for unit_name in unit_names: - period_size_samples = period_sizes_samples[unit_name] - margin_size_samples = margin_sizes_samples[unit_name] - - all_subperiods[unit_name] = [] - for segment_index in range(sorting.get_num_segments()): - n_samples = self.sorting_analyzer.get_num_samples(segment_index) # int: samples - n_subperiods = n_samples // period_size_samples + 1 - starts_ends = np.array( - [ - [i * period_size_samples, i * period_size_samples + 2 * margin_size_samples] - for i in range(n_subperiods) - ] - ) - for start, end in starts_ends: - subperiod = np.zeros((1,), dtype=unit_period_dtype) - subperiod["segment_index"] = segment_index - subperiod["start_sample_index"] = start - subperiod["end_sample_index"] = end - subperiod["unit_index"] = unit_name - all_subperiods[unit_name].append(subperiod) - - return all_subperiods - - -def compute_fp_rates(self, subperiods_per_unit: dict, violations_ms: float = 0.8) -> dict: - - fp_rates = {} - for unit_name, subperiods in subperiods_per_unit.items(): - fp_rates[unit_name] = [] - for subperiod in subperiods: - isi_violations = compute_refrac_period_violations( - self.sorting_analyzer, - unit_ids=[unit_name], - refractory_period_ms=violations_ms, - periods=subperiod, - ) - fp_rates[unit_name].append(isi_violations.rp_contamination[unit_name]) # contamination for this subperiod - - return fp_rates - + return self.data["good_periods_per_unit"] -def compute_fn_rates(self, subperiods_per_unit: dict) -> dict: - - fn_rates = {} - for unit_name, subperiods in subperiods_per_unit.items(): - fn_rates[unit_name] = [] - for subperiod in subperiods: - all_fraction_missing = compute_amplitude_cutoffs( - self.sorting_analyzer, - unit_ids=[unit_name], - num_histogram_bins=500, - histogram_smoothing_value=3, - amplitudes_bins_min_ratio=5, - periods=subperiod, - ) - fn_rates[unit_name].append(all_fraction_missing[unit_name]) # missed spikes for this subperiod - - return fn_rates + def compute_subperiods( + self, + subperiod_size_absolute: float = 10, + subperiod_size_relative: int = 1000, + subperiod_size_mode: str = "absolute", + ) -> dict: + """ + Computes subperiods per unit based on specified size mode. + + Returns + ------- + all_subperiods : dict + Dictionary mapping unit IDs to lists of subperiods (arrays of dtype unit_period_dtype). + """ + sorting = self.sorting_analyzer.sorting + fs = sorting.sampling_frequency + unit_ids = sorting.unit_ids + + if subperiod_size_mode == "absolute": + period_sizes_samples = {u: np.round(subperiod_size_absolute * fs).astype(int) for u in unit_ids} + else: # relative + mean_firing_rates = compute_firing_rates(self.sorting_analyzer, unit_ids) + period_sizes_samples = { + u: np.round((subperiod_size_relative / mean_firing_rates[u]) * fs).astype(int) for u in unit_ids + } + margin_sizes_samples = period_sizes_samples + + all_subperiods = {} + for unit_index, unit_id in enumerate(unit_ids): + period_size_samples = period_sizes_samples[unit_id] + margin_size_samples = margin_sizes_samples[unit_id] + + all_subperiods[unit_id] = [] + for segment_index in range(sorting.get_num_segments()): + n_samples = self.sorting_analyzer.get_num_samples(segment_index) # int: samples + n_subperiods = n_samples // period_size_samples + 1 + starts_ends = np.array( + [ + [i * period_size_samples, i * period_size_samples + 2 * margin_size_samples] + for i in range(n_subperiods) + ] + ) + for start, end in starts_ends: + subperiod = np.zeros((1,), dtype=unit_period_dtype) + subperiod["segment_index"] = segment_index + subperiod["start_sample_index"] = start + subperiod["end_sample_index"] = end + subperiod["unit_index"] = unit_index + all_subperiods[unit_id].append(subperiod) + all_subperiods[unit_id] = np.array(all_subperiods[unit_id]) + return all_subperiods + + def compute_fp_rates(self, subperiods_per_unit: dict, violations_ms: float = 0.8) -> dict: + """ + Computes false positive rates (RP violations) for each subperiod per unit. + """ + fp_rates = {} + for unit_id, subperiods in subperiods_per_unit.items(): + fp_rates[unit_id] = [] + for subperiod in subperiods: + isi_violations = compute_refrac_period_violations( + self.sorting_analyzer, + unit_ids=[unit_id], + refractory_period_ms=violations_ms, + periods=subperiod, + ) + fp_rates[unit_id].append(isi_violations.rp_contamination[unit_id]) # contamination for this subperiod + return fp_rates + + def compute_fn_rates(self, subperiods_per_unit: dict) -> dict: + """ + Computes false negative rates (amplitude cutoffs) for each subperiod per unit. + """ + fn_rates = {} + for unit_id, subperiods in subperiods_per_unit.items(): + fn_rates[unit_id] = [] + for subperiod in subperiods: + all_fraction_missing = compute_amplitude_cutoffs( + self.sorting_analyzer, + unit_ids=[unit_id], + num_histogram_bins=50, + histogram_smoothing_value=3, + amplitudes_bins_min_ratio=3, + periods=subperiod, + ) + fn_rates[unit_id].append(all_fraction_missing[unit_id]) # missed spikes for this subperiod + return fn_rates def merge_overlapping_periods(subperiods): @@ -396,3 +407,7 @@ def merge_overlapping_periods_across_units_and_segments(periods): merged_periods = np.concatenate((merged_periods, _merged_periods), axis=0) return merged_periods + + +register_result_extension(ComputeGoodPeriodsPerUnit) +compute_good_periods_per_unit = ComputeGoodPeriodsPerUnit.function_factory() diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 2207b98da6..46bf8e2235 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -95,7 +95,7 @@ def get_sorting_analyzer(self, recording, sorting, format="memory", sparsity=Non return sorting_analyzer - def _prepare_sorting_analyzer(self, format, sparse, extension_class): + def _prepare_sorting_analyzer(self, format, sparse, extension_class, extra_dependencies=None): # prepare a SortingAnalyzer object with depencies already computed sparsity_ = self.sparsity if sparse else None sorting_analyzer = self.get_sorting_analyzer( @@ -107,6 +107,10 @@ def _prepare_sorting_analyzer(self, format, sparse, extension_class): if "|" in dependency_name: dependency_name = dependency_name.split("|")[0] sorting_analyzer.compute(dependency_name) + if extra_dependencies is not None: + for dependency_name in extra_dependencies: + print("Computing extra dependency:", dependency_name) + sorting_analyzer.compute(dependency_name) return sorting_analyzer @@ -126,7 +130,7 @@ def _check_one(self, sorting_analyzer, extension_class, params): ext = sorting_analyzer.compute(extension_class.extension_name, **params, **job_kwargs) assert len(ext.data) > 0 main_data = ext.get_data() - assert len(main_data) > 0 + assert main_data is not None ext = sorting_analyzer.get_extension(extension_class.extension_name) assert ext is not None @@ -160,7 +164,7 @@ def _check_one(self, sorting_analyzer, extension_class, params): else: continue - def run_extension_tests(self, extension_class, params): + def run_extension_tests(self, extension_class, params, extra_dependencies=None): """ Convenience function to perform all checks on the extension of interest with the passed parameters. Will perform tests @@ -169,5 +173,7 @@ def run_extension_tests(self, extension_class, params): for sparse in (True, False): for format in ("memory", "binary_folder", "zarr"): print("sparse", sparse, format) - sorting_analyzer = self._prepare_sorting_analyzer(format, sparse, extension_class) + sorting_analyzer = self._prepare_sorting_analyzer( + format, sparse, extension_class, extra_dependencies=extra_dependencies + ) self._check_one(sorting_analyzer, extension_class, params) From 40e34176a2fe54d63b93b5fda16883040a81fd3f Mon Sep 17 00:00:00 2001 From: m-beau Date: Thu, 8 Jan 2026 16:37:20 +0100 Subject: [PATCH 14/70] started working on get_data method for good periods --- .../postprocessing/good_periods_per_unit.py | 47 ++++++++++++++++++- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/good_periods_per_unit.py b/src/spikeinterface/postprocessing/good_periods_per_unit.py index 785ffa87ca..d652fd90e3 100644 --- a/src/spikeinterface/postprocessing/good_periods_per_unit.py +++ b/src/spikeinterface/postprocessing/good_periods_per_unit.py @@ -250,8 +250,51 @@ def _run(self, verbose=False): good_periods_per_unit # (n_good_periods, 4) with (unit, segment, start, end) to be implemented ) - def _get_data(self): - return self.data["isi_histograms"], self.data["bins"] + def _get_data(self, outputs: str = "by_unit", return_subperiods_metadata: bool = False): + """ + Return extension data. If the extension computes more than one `nodepipeline_variables`, + the `return_data_name` is used to specify which one to return. + + Parameters + ---------- + outputs : "numpy" | "by_unit", default: "numpy" + How to return the data, by default "numpy" + return_subperiods_metadata: bool, default: False + Whether to also return metadata of subperiods used to compute the good periods + as dictionnaries per unit: + - subperiods_per_unit: unit_name -> list of n_subperiods subperiods (each subperiod is an array of dtype unit_period_dtype with 4 fields) + - periods_fp_per_unit: unit_name -> array of n_subperiods, false positive rates (refractory period violations) per subperiod + - periods_fn_per_unit: unit_name -> array of n_subperiods, false negative rates (amplitude cutoffs) per subperiod + + Returns + ------- + numpy.ndarray | dict | tuple + The periods in numpy or dictionnary by unit format, + or a tuple that contains the former as well as metadata of subperiods if return_subperiods_metadata is True. + """ + + good_periods = self.data["good_periods_per_unit"] + + # list of dictionnaries; one dictionnary per segment + if outputs == "by_unit": + unit_ids = self.sorting_analyzer.unit_ids + spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) + spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True) + data_by_units = {} + for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): + data_by_units[segment_index] = {} + for unit_id in unit_ids: + inds = spike_indices[segment_index][unit_id] + data_by_units[segment_index][unit_id] = all_data[inds] + + return data_by_units + + return ( + self.data["subperiods_per_unit"], + self.data["periods_fp_per_unit"], + self.data["periods_fn_per_unit"], + good_periods, + ) register_result_extension(ComputeGoodPeriodsPerUnit) From 81d745e72d27c390d98284c491672ab33d8fbc9c Mon Sep 17 00:00:00 2001 From: m-beau Date: Thu, 8 Jan 2026 17:46:34 +0100 Subject: [PATCH 15/70] done refactoring self.data serializable format and get_data method --- .../postprocessing/good_periods_per_unit.py | 182 +++++++++++------- 1 file changed, 115 insertions(+), 67 deletions(-) diff --git a/src/spikeinterface/postprocessing/good_periods_per_unit.py b/src/spikeinterface/postprocessing/good_periods_per_unit.py index e010e86662..33832f0ad7 100644 --- a/src/spikeinterface/postprocessing/good_periods_per_unit.py +++ b/src/spikeinterface/postprocessing/good_periods_per_unit.py @@ -187,6 +187,7 @@ def _run(self, verbose=False): self.data["good_periods_per_unit"] = self.params["user_defined_periods"] elif self.params["method"] in ["false_positives_and_negatives", "combined"]: + # dict: unit_id -> list of subperiod, each subperiod is an array of dtype unit_period_dtype with 4 fields subperiods_per_unit = self.compute_subperiods( self.params["subperiod_size_absolute"], @@ -194,47 +195,19 @@ def _run(self, verbose=False): self.params["subperiod_size_mode"], ) - # Compute fp and fn for all periods - - # fp computed from refractory period violations - # dict: unit_id -> array of shape (n_subperiods) - periods_fp_per_unit = self.compute_fp_rates(subperiods_per_unit, self.params["violations_ms"]) + # Compute fp and fn for all periods. + # fp computed from refractory period violations; fn computed from amplitude clippings + subperiods_fp_per_unit = self.compute_fp_rates(subperiods_per_unit, self.params["violations_ms"]) + subperiods_fn_per_unit = self.compute_fn_rates(subperiods_per_unit) - # fn computed from amplitude clippings - # dict: unit_id -> array of shape (n_subperiods) - periods_fn_per_unit = self.compute_fn_rates(subperiods_per_unit) # Combine fp and fn results with thresholds to define good periods # get n spikes per unit to set the fp or fn rates to 1 if not enough spikes - minimum_valid_period_duration = self.params["minimum_valid_period_duration"] - fs = self.sorting_analyzer.sampling_frequency - min_valid_period_samples = int(minimum_valid_period_duration * fs) - - n_spikes_per_unit = self.sorting_analyzer.sorting.count_num_spikes_per_unit() - good_periods_per_unit = np.array([], dtype=unit_period_dtype) - for unit_id, subperiods in subperiods_per_unit.items(): - n_spikes = n_spikes_per_unit[unit_id] - if n_spikes < self.params["minimum_n_spikes"]: - periods_fp_per_unit[unit_id] = [1] * len(periods_fp_per_unit[unit_id]) - periods_fn_per_unit[unit_id] = [1] * len(periods_fn_per_unit[unit_id]) - - fp_rates = periods_fp_per_unit[unit_id] - fn_rates = periods_fn_per_unit[unit_id] - - good_periods_mask = (np.array(fp_rates) < self.params["fp_threshold"]) & ( - np.array(fn_rates) < self.params["fn_threshold"] - ) - good_subperiods = np.array(subperiods)[good_periods_mask] - good_segments = np.unique(good_subperiods["segment_index"]) - for segment_index in good_segments: - segment_mask = good_subperiods["segment_index"] == segment_index - good_segment_subperiods = good_subperiods[segment_mask] - good_segment_periods = merge_overlapping_periods(good_segment_subperiods) - good_periods_per_unit = np.concatenate((good_periods_per_unit, good_segment_periods), axis=0) + good_periods_per_unit = self.compute_good_periods_from_fp_fn( + subperiods_per_unit, subperiods_fp_per_unit, subperiods_fn_per_unit + ) # Remove good periods that are too short - duration_samples = good_periods_per_unit["end_sample_index"] - good_periods_per_unit["start_sample_index"] - valid_mask = duration_samples >= min_valid_period_samples - good_periods_per_unit = good_periods_per_unit[valid_mask] + good_periods_per_unit = self.filter_out_short_periods(good_periods_per_unit) # Eventually combine with user-defined periods if provided if self.params["method"] == "combined": @@ -242,21 +215,18 @@ def _run(self, verbose=False): all_periods = np.concatenate((good_periods_per_unit, user_defined_periods), axis=0) good_periods_per_unit = merge_overlapping_periods_across_units_and_segments(all_periods) - # Convert subperiods per unit in period_centers_s - period_centers = [] - for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): - period_centers_dict = {} - for unit_id in self.sorting_analyzer.unit_ids: - periods_unit = subperiods_per_unit[unit_id] - periods_segment = periods_unit[periods_unit["segment_index"] == segment_index] - centers = list(0.5 * (periods_segment["start_sample_index"] + periods_segment["end_sample_index"])) - period_centers_dict[unit_id] = centers - period_centers.append(period_centers_dict) + ## Convert datastructures in spikeinterface-friendly serializable formats + # periods_fp_per_unit, periods_fn_per_unit: convert to (n_segments) list of unit -> values dicts + ( + subperiod_centers_per_segment_per_unit, + subperiods_fp_per_segment_per_unit, + subperiods_fn_per_segment_per_unit, + ) = self.reformat_subperiod_data(subperiods_per_unit, subperiods_fp_per_unit, subperiods_fn_per_unit) # Store data: here we have to make sure every dict is JSON serializable, so everything is lists - self.data["period_centers"] = period_centers - self.data["periods_fp_per_unit"] = periods_fp_per_unit - self.data["periods_fn_per_unit"] = periods_fn_per_unit + self.data["period_centers_per_unit"] = subperiod_centers_per_segment_per_unit + self.data["periods_fp_per_unit"] = subperiods_fp_per_segment_per_unit + self.data["periods_fn_per_unit"] = subperiods_fn_per_segment_per_unit self.data["good_periods_per_unit"] = good_periods_per_unit def _get_data(self, outputs: str = "by_unit", return_subperiods_metadata: bool = False): @@ -285,25 +255,33 @@ def _get_data(self, outputs: str = "by_unit", return_subperiods_metadata: bool = good_periods = self.data["good_periods_per_unit"] # list of dictionnaries; one dictionnary per segment - if outputs == "by_unit": - unit_ids = self.sorting_analyzer.unit_ids - spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) - spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True) - data_by_units = {} - for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): - data_by_units[segment_index] = {} - for unit_id in unit_ids: - inds = spike_indices[segment_index][unit_id] - data_by_units[segment_index][unit_id] = all_data[inds] - - return data_by_units + if outputs == "numpy": + good_periods = self.data["good_periods_per_unit"] + else: + # by_unit + unit_ids = np.unique(self.data["good_periods_per_unit"]["unit_index"]) + segments = np.unique(self.data["good_periods_per_unit"]["segment_index"]) + good_periods = [] + for segment_index in range(segments): + segment_mask = good_periods["segment_index"] == segment_index + periods_dict = {} + for unit_index in unit_ids: + periods_dict[unit_index] = [] + unit_mask = good_periods["unit_index"] == unit_index + good_periods_unit_segment = good_periods[segment_mask & unit_mask] + for start, end in good_periods_unit_segment[["start_sample_index", "end_sample_index"]]: + periods_dict[unit_index].append((start, end)) + good_periods.append(periods_dict) + + if return_subperiods_metadata: + return ( + self.data["period_centers_per_unit"], + self.data["periods_fp_per_unit"], + self.data["periods_fn_per_unit"], + good_periods, + ) - return ( - self.data["subperiods_per_unit"], - self.data["periods_fp_per_unit"], - self.data["periods_fn_per_unit"], - good_periods, - ) + return good_periods def compute_subperiods( self, @@ -393,6 +371,76 @@ def compute_fn_rates(self, subperiods_per_unit: dict) -> dict: fn_rates[unit_id].append(all_fraction_missing[unit_id]) # missed spikes for this subperiod return fn_rates + def compute_good_periods_from_fp_fn(self, subperiods_per_unit, subperiods_fp_per_unit, subperiods_fn_per_unit): + n_spikes_per_unit = self.sorting_analyzer.sorting.count_num_spikes_per_unit() + good_periods_per_unit = np.array([], dtype=unit_period_dtype) + for unit_id, subperiods in subperiods_per_unit.items(): + n_spikes = n_spikes_per_unit[unit_id] + if n_spikes < self.params["minimum_n_spikes"]: + subperiods_fp_per_unit[unit_id] = [1] * len(subperiods_fp_per_unit[unit_id]) + subperiods_fn_per_unit[unit_id] = [1] * len(subperiods_fn_per_unit[unit_id]) + + fp_rates = subperiods_fp_per_unit[unit_id] + fn_rates = subperiods_fn_per_unit[unit_id] + + good_periods_mask = (np.array(fp_rates) < self.params["fp_threshold"]) & ( + np.array(fn_rates) < self.params["fn_threshold"] + ) + good_subperiods = np.array(subperiods)[good_periods_mask] + good_segments = np.unique(good_subperiods["segment_index"]) + for segment_index in good_segments: + segment_mask = good_subperiods["segment_index"] == segment_index + good_segment_subperiods = good_subperiods[segment_mask] + good_segment_periods = merge_overlapping_periods(good_segment_subperiods) + good_periods_per_unit = np.concatenate((good_periods_per_unit, good_segment_periods), axis=0) + return good_periods_per_unit + + def filter_out_short_periods(self, good_periods_per_unit): + fs = self.sorting_analyzer.sampling_frequency + minimum_valid_period_duration = self.params["minimum_valid_period_duration"] + min_valid_period_samples = int(minimum_valid_period_duration * fs) + duration_samples = good_periods_per_unit["end_sample_index"] - good_periods_per_unit["start_sample_index"] + valid_mask = duration_samples >= min_valid_period_samples + return good_periods_per_unit[valid_mask] + + def reformat_subperiod_data(self, subperiods_per_unit, subperiods_fp_per_unit, subperiods_fn_per_unit): + n_segments = self.sorting_analyzer.sorting.get_num_segments() + subperiod_centers_per_segment_per_unit = [] + subperiods_fp_per_segment_per_unit = [] + subperiods_fn_per_segment_per_unit = [] + for segment_index in range(n_segments): + period_centers_dict = {} + fp_dict = {} + fn_dict = {} + for unit_id in self.sorting_analyzer.unit_ids: + periods_unit = subperiods_per_unit[unit_id] + periods_segment = periods_unit[periods_unit["segment_index"] == segment_index] + + centers = list(0.5 * (periods_segment["start_sample_index"] + periods_segment["end_sample_index"])) + fp_values = [ + subperiods_fp_per_unit[unit_id][i] + for i in range(len(periods_unit)) + if periods_unit[i]["segment_index"] == segment_index + ] + fn_values = [ + subperiods_fn_per_unit[unit_id][i] + for i in range(len(periods_unit)) + if periods_unit[i]["segment_index"] == segment_index + ] + + period_centers_dict[unit_id] = centers + fp_dict[unit_id] = fp_values + fn_dict[unit_id] = fn_values + + subperiod_centers_per_segment_per_unit.append(period_centers_dict) + subperiods_fp_per_segment_per_unit.append(fp_dict) + subperiods_fn_per_segment_per_unit.append(fn_dict) + return ( + subperiod_centers_per_segment_per_unit, + subperiods_fp_per_segment_per_unit, + subperiods_fn_per_segment_per_unit, + ) + def merge_overlapping_periods(subperiods): From 93a53cab72aae1d770bfad8e71a45b183e496fd2 Mon Sep 17 00:00:00 2001 From: m-beau Date: Thu, 8 Jan 2026 18:45:01 +0100 Subject: [PATCH 16/70] credits --- src/spikeinterface/postprocessing/good_periods_per_unit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/good_periods_per_unit.py b/src/spikeinterface/postprocessing/good_periods_per_unit.py index 33832f0ad7..6541bf3514 100644 --- a/src/spikeinterface/postprocessing/good_periods_per_unit.py +++ b/src/spikeinterface/postprocessing/good_periods_per_unit.py @@ -58,7 +58,7 @@ class ComputeGoodPeriodsPerUnit(AnalyzerExtension): good_periods_per_unit : numpy.ndarray, int (n_periods, 4) array with columns: unit_id, segment_id, start_time, end_time (times in samples) - Implementation: Maxime Beau + Implementation: Maxime Beau. Derived from NeuroPyxelos, inspired by bommbcell. """ extension_name = "good_periods_per_unit" From 493d215185bbeafa31b30207e0f9bc12a6268337 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 9 Jan 2026 11:38:03 +0100 Subject: [PATCH 17/70] Make good_periods blazing fast! --- .../metrics/quality/misc_metrics.py | 78 +++- .../postprocessing/good_periods_per_unit.py | 429 +++++++++++------- 2 files changed, 324 insertions(+), 183 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 4d02274e80..a434119b54 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -19,7 +19,7 @@ from spikeinterface.core.analyzer_extension_core import BaseMetric from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs from spikeinterface.postprocessing import correlogram_for_one_segment -from spikeinterface.core import SortingAnalyzer, get_noise_levels, select_segment_sorting +from spikeinterface.core import SortingAnalyzer, BaseSorting, get_noise_levels from spikeinterface.core.template_tools import ( get_template_extremum_channel, get_template_extremum_amplitude, @@ -336,8 +336,6 @@ def compute_refrac_period_violations( ---------- Based on metrics described in [Llobet]_ """ - from spikeinterface.metrics.spiketrain.metrics import compute_num_spikes - res = namedtuple("rp_violations", ["rp_contamination", "rp_violations"]) if not HAVE_NUMBA: @@ -348,16 +346,21 @@ def compute_refrac_period_violations( sorting = sorting_analyzer.sorting sorting = sorting.select_periods(periods=periods) - fs = sorting_analyzer.sampling_frequency - num_units = len(sorting_analyzer.unit_ids) + # TODO: in case of periods, should we use total samples only in provided periods? + total_samples = sorting_analyzer.get_total_samples() + + from spikeinterface.metrics.spiketrain.metrics import compute_num_spikes + + fs = sorting.sampling_frequency + num_units = len(sorting.unit_ids) num_segments = sorting.get_num_segments() spikes = sorting.to_spike_vector(concatenated=False) if unit_ids is None: - unit_ids = sorting_analyzer.unit_ids + unit_ids = sorting.unit_ids - num_spikes = compute_num_spikes(sorting_analyzer) + num_spikes = compute_num_spikes(sorting) t_c = int(round(censored_period_ms * fs * 1e-3)) t_r = int(round(refractory_period_ms * fs * 1e-3)) @@ -368,22 +371,25 @@ def compute_refrac_period_violations( spike_labels = spikes[seg_index]["unit_index"].astype(np.int32) _compute_rp_violations_numba(nb_rp_violations, spike_times, spike_labels, t_c, t_r) - T = sorting_analyzer.get_total_samples() - nb_violations = {} rp_contamination = {} for unit_index, unit_id in enumerate(sorting.unit_ids): if unit_id not in unit_ids: continue - - nb_violations[unit_id] = n_v = nb_rp_violations[unit_index] - N = num_spikes[unit_id] - if N == 0: - rp_contamination[unit_id] = np.nan + if isinstance(total_samples, dict): + total_samples_unit = total_samples[unit_id] else: - D = 1 - n_v * (T - 2 * N * t_c) / (N**2 * (t_r - t_c)) - rp_contamination[unit_id] = 1 - math.sqrt(D) if D >= 0 else 1.0 + total_samples_unit = total_samples + + nb_violations[unit_id] = nb_rp_violations[unit_index] + rp_contamination[unit_id] = _compute_rp_contamination_one_unit( + nb_rp_violations[unit_index], + num_spikes[unit_id], + total_samples_unit, + t_c, + t_r, + ) return res(rp_contamination, nb_violations) @@ -1623,6 +1629,46 @@ def slidingRP_violations( return min_cont_with_90_confidence +def _compute_rp_contamination_one_unit( + n_v, + n_spikes, + total_samples, + t_c, + t_r, +): + """ + Compute the refractory period contamination for one unit. + + Parameters + ---------- + n_v : int + Number of refractory period violations. + n_spikes : int + Number of spikes for the unit. + total_samples : int + Total number of samples in the recording. + t_c : int + Censored period in samples. + t_r : int + Refractory period in samples. + + Returns + ------- + rp_contamination : float + The refractory period contamination for the unit. + """ + if n_spikes <= 1: + return np.nan + + denom = 1 - n_v * (total_samples - 2 * n_spikes * t_c) / (n_spikes**2 * (t_r - t_c)) + if denom < 0: + return 1.0 + + rp_contamination = 1 - math.sqrt(denom) + + return rp_contamination + + def _compute_violations(obs_viol, firing_rate, spike_count, ref_period_dur, contamination_prop): contamination_rate = firing_rate * contamination_prop expected_viol = contamination_rate * ref_period_dur * 2 * spike_count diff --git a/src/spikeinterface/postprocessing/good_periods_per_unit.py b/src/spikeinterface/postprocessing/good_periods_per_unit.py index e010e86662..07e134acb1 100644 --- a/src/spikeinterface/postprocessing/good_periods_per_unit.py +++ b/src/spikeinterface/postprocessing/good_periods_per_unit.py @@ -4,11 +4,21 @@ import warnings import numpy as np -from typing import Optional, Literal +from typing import Optional + +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor +import multiprocessing as mp +from threadpoolctl import threadpool_limits +from tqdm.auto import tqdm from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension from spikeinterface.core.node_pipeline import unit_period_dtype -from spikeinterface.metrics.quality import compute_refrac_period_violations, compute_amplitude_cutoffs +from spikeinterface.core.job_tools import process_worker_initializer, process_function_wrapper +from spikeinterface.metrics.quality.misc_metrics import ( + amplitude_cutoff, + _compute_nb_violations_numba, + _compute_rp_contamination_one_unit, +) from spikeinterface.metrics.spiketrain import compute_firing_rates numba_spec = importlib.util.find_spec("numba") @@ -29,29 +39,42 @@ class ComputeGoodPeriodsPerUnit(AnalyzerExtension): to estimate good periods (as periods with fn_rate 0 or subperiod_size_relative > 0 - ), "Either subperiod_size_absolute or subperiod_size_relative must be positive." - assert isinstance(subperiod_size_relative, (int)), "subperiod_size_relative must be an integer." + period_duration_s_absolute > 0 or period_target_num_spikes > 0 + ), "Either period_duration_s_absolute or period_target_num_spikes must be positive." + assert isinstance(period_target_num_spikes, (int)), "period_target_num_spikes must be an integer." # user_defined_periods formatting if user_defined_periods is not None: @@ -153,15 +179,19 @@ def _set_params( params = dict( method=method, - subperiod_size_absolute=subperiod_size_absolute, - subperiod_size_relative=subperiod_size_relative, - subperiod_size_mode=subperiod_size_mode, - violations_ms=violations_ms, + period_duration_s_absolute=period_duration_s_absolute, + period_target_num_spikes=period_target_num_spikes, + period_mode=period_mode, fp_threshold=fp_threshold, fn_threshold=fn_threshold, minimum_n_spikes=minimum_n_spikes, minimum_valid_period_duration=minimum_valid_period_duration, user_defined_periods=user_defined_periods, + refractory_period_ms=refractory_period_ms, + censored_period_ms=censored_period_ms, + num_histogram_bins=num_histogram_bins, + histogram_smoothing_value=histogram_smoothing_value, + amplitudes_bins_min_ratio=amplitudes_bins_min_ratio, ) return params @@ -181,6 +211,7 @@ def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, return new_extension_data def _run(self, verbose=False): + from spikeinterface import get_global_job_kwargs if self.params["method"] == "user_defined": # directly use user defined periods @@ -188,128 +219,160 @@ def _run(self, verbose=False): elif self.params["method"] in ["false_positives_and_negatives", "combined"]: # dict: unit_id -> list of subperiod, each subperiod is an array of dtype unit_period_dtype with 4 fields - subperiods_per_unit = self.compute_subperiods( - self.params["subperiod_size_absolute"], - self.params["subperiod_size_relative"], - self.params["subperiod_size_mode"], + all_periods = self.compute_subperiods( + self.params["period_duration_s_absolute"], + self.params["period_target_num_spikes"], + self.params["period_mode"], ) - # Compute fp and fn for all periods - - # fp computed from refractory period violations - # dict: unit_id -> array of shape (n_subperiods) - periods_fp_per_unit = self.compute_fp_rates(subperiods_per_unit, self.params["violations_ms"]) + job_kwargs = get_global_job_kwargs() + n_jobs = job_kwargs["n_jobs"] + progress_bar = job_kwargs["progress_bar"] + max_threads_per_worker = job_kwargs["max_threads_per_worker"] + mp_context = job_kwargs["mp_context"] - # fn computed from amplitude clippings - # dict: unit_id -> array of shape (n_subperiods) - periods_fn_per_unit = self.compute_fn_rates(subperiods_per_unit) - # Combine fp and fn results with thresholds to define good periods - # get n spikes per unit to set the fp or fn rates to 1 if not enough spikes - minimum_valid_period_duration = self.params["minimum_valid_period_duration"] - fs = self.sorting_analyzer.sampling_frequency - min_valid_period_samples = int(minimum_valid_period_duration * fs) - - n_spikes_per_unit = self.sorting_analyzer.sorting.count_num_spikes_per_unit() - good_periods_per_unit = np.array([], dtype=unit_period_dtype) - for unit_id, subperiods in subperiods_per_unit.items(): - n_spikes = n_spikes_per_unit[unit_id] - if n_spikes < self.params["minimum_n_spikes"]: - periods_fp_per_unit[unit_id] = [1] * len(periods_fp_per_unit[unit_id]) - periods_fn_per_unit[unit_id] = [1] * len(periods_fn_per_unit[unit_id]) - - fp_rates = periods_fp_per_unit[unit_id] - fn_rates = periods_fn_per_unit[unit_id] + # Compute fp and fn for all periods + # Process units in parallel + amp_scalings = self.sorting_analyzer.get_extension("amplitude_scalings") + all_amplitudes_by_unit = amp_scalings.get_data(outputs="by_unit", concatenated=False) + + init_args = (self.sorting_analyzer.sorting, all_amplitudes_by_unit, self.params, max_threads_per_worker) + + # Each item is one computation of fp and fn for one period and one unit + items = [(period,) for period in all_periods] + job_name = f"computing false positives and negatives" + + # parallel + with ProcessPoolExecutor( + max_workers=n_jobs, + initializer=process_worker_initializer, + mp_context=mp.get_context(mp_context), + initargs=( + fp_fn_worker_func, + fp_fn_worker_init, + init_args, + max_threads_per_worker, + False, # no worker index + None, + None, + ), + ) as executor: + results = executor.map(fp_fn_worker_func_wrapper, items) + + if progress_bar: + results = tqdm(results, desc=f"{job_name} (workers: {n_jobs} processes)", total=len(items)) + + all_fps = np.zeros(len(all_periods)) + all_fns = np.zeros(len(all_periods)) + for i, (fp, fn) in enumerate(results): + all_fps[i] = fp + all_fns[i] = fn + + # set NaNs to 1 (they will be exluded anyways) + all_fps[np.isnan(all_fps)] = 1.0 + all_fns[np.isnan(all_fns)] = 1.0 + + # split values by segment and units + # fps and fns are lists of segments with dicts unit_id -> array of shape (n_subperiods) + fps = [] + fns = [] + for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): + fp_in_segment = {} + fn_in_segment = {} + segment_mask = all_periods["segment_index"] == segment_index + periods_segment = all_periods[segment_mask] + fps_segment = all_fps[segment_mask] + fns_segment = all_fns[segment_mask] + for unit_index, unit_id in enumerate(self.sorting_analyzer.unit_ids): + unit_mask = periods_segment["unit_index"] == unit_index + fp_in_segment[unit_id] = fps_segment[unit_mask] + fn_in_segment[unit_id] = fns_segment[unit_mask] + fps.append(fp_in_segment) + fns.append(fn_in_segment) + + good_period_mask = (all_fps < self.params["fp_threshold"]) & (all_fns < self.params["fn_threshold"]) + good_periods = all_periods[good_period_mask] + + # Sort good periods on segment_index, unit_index, start_sample_index + sort_idx = np.lexsort( + (good_periods["start_sample_index"], good_periods["unit_index"], good_periods["segment_index"]) + ) + good_periods_per_unit = good_periods[sort_idx] - good_periods_mask = (np.array(fp_rates) < self.params["fp_threshold"]) & ( - np.array(fn_rates) < self.params["fn_threshold"] - ) - good_subperiods = np.array(subperiods)[good_periods_mask] - good_segments = np.unique(good_subperiods["segment_index"]) - for segment_index in good_segments: - segment_mask = good_subperiods["segment_index"] == segment_index - good_segment_subperiods = good_subperiods[segment_mask] - good_segment_periods = merge_overlapping_periods(good_segment_subperiods) - good_periods_per_unit = np.concatenate((good_periods_per_unit, good_segment_periods), axis=0) + # Combine with user-defined periods if provided + if self.params["method"] == "combined": + user_defined_periods = self.params["user_defined_periods"] + all_periods = np.concatenate((good_periods_per_unit, user_defined_periods), axis=0) + good_periods_per_unit = merge_overlapping_periods_across_units_and_segments(good_periods_per_unit) # Remove good periods that are too short + minimum_valid_period_duration = self.params["minimum_valid_period_duration"] + min_valid_period_samples = int(minimum_valid_period_duration * self.sorting_analyzer.sampling_frequency) duration_samples = good_periods_per_unit["end_sample_index"] - good_periods_per_unit["start_sample_index"] valid_mask = duration_samples >= min_valid_period_samples good_periods_per_unit = good_periods_per_unit[valid_mask] - # Eventually combine with user-defined periods if provided - if self.params["method"] == "combined": - user_defined_periods = self.params["user_defined_periods"] - all_periods = np.concatenate((good_periods_per_unit, user_defined_periods), axis=0) - good_periods_per_unit = merge_overlapping_periods_across_units_and_segments(all_periods) - # Convert subperiods per unit in period_centers_s period_centers = [] for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): + periods_segment = all_periods[all_periods["segment_index"] == segment_index] period_centers_dict = {} - for unit_id in self.sorting_analyzer.unit_ids: - periods_unit = subperiods_per_unit[unit_id] - periods_segment = periods_unit[periods_unit["segment_index"] == segment_index] - centers = list(0.5 * (periods_segment["start_sample_index"] + periods_segment["end_sample_index"])) + for unit_index, unit_id in enumerate(self.sorting_analyzer.unit_ids): + periods_unit = periods_segment[periods_segment["unit_index"] == unit_index] + centers = list(0.5 * (periods_unit["start_sample_index"] + periods_unit["end_sample_index"])) period_centers_dict[unit_id] = centers period_centers.append(period_centers_dict) # Store data: here we have to make sure every dict is JSON serializable, so everything is lists self.data["period_centers"] = period_centers - self.data["periods_fp_per_unit"] = periods_fp_per_unit - self.data["periods_fn_per_unit"] = periods_fn_per_unit + self.data["periods_fp_per_unit"] = fps + self.data["periods_fn_per_unit"] = fns self.data["good_periods_per_unit"] = good_periods_per_unit - def _get_data(self, outputs: str = "by_unit", return_subperiods_metadata: bool = False): + def _get_data(self, outputs: str = "by_unit"): """ Return extension data. If the extension computes more than one `nodepipeline_variables`, the `return_data_name` is used to specify which one to return. Parameters ---------- - outputs : "numpy" | "by_unit", default: "numpy" - How to return the data, by default "numpy" - return_subperiods_metadata: bool, default: False - Whether to also return metadata of subperiods used to compute the good periods - as dictionnaries per unit: - - subperiods_per_unit: unit_name -> list of n_subperiods subperiods (each subperiod is an array of dtype unit_period_dtype with 4 fields) - - periods_fp_per_unit: unit_name -> array of n_subperiods, false positive rates (refractory period violations) per subperiod - - periods_fn_per_unit: unit_name -> array of n_subperiods, false negative rates (amplitude cutoffs) per subperiod + outputs : "numpy" | "by_unit", default: "by_unit" + How to return the data. Returns ------- - numpy.ndarray | dict | tuple - The periods in numpy or dictionnary by unit format, - or a tuple that contains the former as well as metadata of subperiods if return_subperiods_metadata is True. + numpy.ndarray | list + The periods in numpy or dictionary by unit format, depending on `outputs`. + If "numpy", returns an array of dtype unit_period_dtype with columns: + unit_index, segment_index, start_sample_index, end_sample_index. + If "by_unit", returns a list (per segment) of dictionaries mapping unit IDs to lists of + (start_sample_index, end_sample_index) tuples. """ - - good_periods = self.data["good_periods_per_unit"] - - # list of dictionnaries; one dictionnary per segment - if outputs == "by_unit": + if outputs == "numpy": + good_periods = self.data["good_periods_per_unit"] + else: + # by_unit unit_ids = self.sorting_analyzer.unit_ids - spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) - spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True) - data_by_units = {} - for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): - data_by_units[segment_index] = {} - for unit_id in unit_ids: - inds = spike_indices[segment_index][unit_id] - data_by_units[segment_index][unit_id] = all_data[inds] - - return data_by_units - - return ( - self.data["subperiods_per_unit"], - self.data["periods_fp_per_unit"], - self.data["periods_fn_per_unit"], - good_periods, - ) + good_periods = [] + good_periods_array = self.data["good_periods_per_unit"] + for segment_index in range(self.sorting_analyzer.get_num_segments()): + segment_mask = good_periods_array["segment_index"] == segment_index + periods_dict = {} + for unit_index in unit_ids: + periods_dict[unit_index] = [] + unit_mask = good_periods_array["unit_index"] == unit_index + good_periods_unit_segment = good_periods_array[segment_mask & unit_mask] + for start, end in good_periods_unit_segment[["start_sample_index", "end_sample_index"]]: + periods_dict[unit_index].append((start, end)) + good_periods.append(periods_dict) + + return good_periods def compute_subperiods( self, - subperiod_size_absolute: float = 10, - subperiod_size_relative: int = 1000, - subperiod_size_mode: str = "absolute", + period_duration_s_absolute: float = 10, + period_target_num_spikes: int = 1000, + period_mode: str = "absolute", ) -> dict: """ Computes subperiods per unit based on specified size mode. @@ -323,21 +386,20 @@ def compute_subperiods( fs = sorting.sampling_frequency unit_ids = sorting.unit_ids - if subperiod_size_mode == "absolute": - period_sizes_samples = {u: np.round(subperiod_size_absolute * fs).astype(int) for u in unit_ids} + if period_mode == "absolute": + period_sizes_samples = {u: np.round(period_duration_s_absolute * fs).astype(int) for u in unit_ids} else: # relative mean_firing_rates = compute_firing_rates(self.sorting_analyzer, unit_ids) period_sizes_samples = { - u: np.round((subperiod_size_relative / mean_firing_rates[u]) * fs).astype(int) for u in unit_ids + u: np.round((period_target_num_spikes / mean_firing_rates[u]) * fs).astype(int) for u in unit_ids } margin_sizes_samples = period_sizes_samples - all_subperiods = {} + all_subperiods = np.array([], dtype=unit_period_dtype) for unit_index, unit_id in enumerate(unit_ids): period_size_samples = period_sizes_samples[unit_id] margin_size_samples = margin_sizes_samples[unit_id] - all_subperiods[unit_id] = [] for segment_index in range(sorting.get_num_segments()): n_samples = self.sorting_analyzer.get_num_samples(segment_index) # int: samples n_subperiods = n_samples // period_size_samples + 1 @@ -347,52 +409,17 @@ def compute_subperiods( for i in range(n_subperiods) ] ) - for start, end in starts_ends: + periods_for_unit = np.zeros(len(starts_ends), dtype=unit_period_dtype) + for i, (start, end) in enumerate(starts_ends): subperiod = np.zeros((1,), dtype=unit_period_dtype) subperiod["segment_index"] = segment_index subperiod["start_sample_index"] = start subperiod["end_sample_index"] = end subperiod["unit_index"] = unit_index - all_subperiods[unit_id].append(subperiod) - all_subperiods[unit_id] = np.array(all_subperiods[unit_id]) + periods_for_unit[i] = subperiod + all_subperiods = np.concatenate((all_subperiods, periods_for_unit), axis=0) return all_subperiods - def compute_fp_rates(self, subperiods_per_unit: dict, violations_ms: float = 0.8) -> dict: - """ - Computes false positive rates (RP violations) for each subperiod per unit. - """ - fp_rates = {} - for unit_id, subperiods in subperiods_per_unit.items(): - fp_rates[unit_id] = [] - for subperiod in subperiods: - isi_violations = compute_refrac_period_violations( - self.sorting_analyzer, - unit_ids=[unit_id], - refractory_period_ms=violations_ms, - periods=subperiod, - ) - fp_rates[unit_id].append(isi_violations.rp_contamination[unit_id]) # contamination for this subperiod - return fp_rates - - def compute_fn_rates(self, subperiods_per_unit: dict) -> dict: - """ - Computes false negative rates (amplitude cutoffs) for each subperiod per unit. - """ - fn_rates = {} - for unit_id, subperiods in subperiods_per_unit.items(): - fn_rates[unit_id] = [] - for subperiod in subperiods: - all_fraction_missing = compute_amplitude_cutoffs( - self.sorting_analyzer, - unit_ids=[unit_id], - num_histogram_bins=50, - histogram_smoothing_value=3, - amplitudes_bins_min_ratio=3, - periods=subperiod, - ) - fn_rates[unit_id].append(all_fraction_missing[unit_id]) # missed spikes for this subperiod - return fn_rates - def merge_overlapping_periods(subperiods): @@ -434,16 +461,13 @@ def merge_overlapping_periods(subperiods): def merge_overlapping_periods_across_units_and_segments(periods): - - units = np.unique(periods["unit_index"]) segments = np.unique(periods["segment_index"]) - + units = np.unique(periods["unit_index"]) merged_periods = np.array([], dtype=unit_period_dtype) - for unit_index in units: - for segment_index in segments: - masked_periods = periods[ - (periods["unit_index"] == unit_index) & (periods["segment_index"] == segment_index) - ] + for segment_index in segments: + periods_per_segment = periods[periods["segment_index"] == segment_index] + for unit_index in units: + masked_periods = periods_per_segment[(periods_per_segment["unit_index"] == unit_index)] if len(masked_periods) == 0: continue _merged_periods = merge_overlapping_periods(masked_periods) @@ -454,3 +478,74 @@ def merge_overlapping_periods_across_units_and_segments(periods): register_result_extension(ComputeGoodPeriodsPerUnit) compute_good_periods_per_unit = ComputeGoodPeriodsPerUnit.function_factory() + + +global worker_ctx + + +def fp_fn_worker_init(sorting, all_amplitudes_by_unit, params, max_threads_per_worker): + global worker_ctx + worker_ctx = {} + + # cache spike vector and spiketrains + sorting.precompute_spike_trains() + + worker_ctx["sorting"] = sorting + worker_ctx["all_amplitudes_by_unit"] = all_amplitudes_by_unit + worker_ctx["params"] = params + worker_ctx["max_threads_per_worker"] = max_threads_per_worker + + +def fp_fn_worker_func(period, sorting, all_amplitudes_by_unit, params): + """ + Low level computation of false positives and false negatives for one period and one unit. + """ + # period is of dtype unit_period_dtype: 0: segment_index, 1: start_sample_index, 2: end_sample_index, 3: unit_index + period_sample = period[0] + segment_index = period_sample["segment_index"] + start_sample_index = period_sample["start_sample_index"] + end_sample_index = period_sample["end_sample_index"] + unit_index = period_sample["unit_index"] + unit_id = sorting.unit_ids[unit_index] + + amplitudes_unit = all_amplitudes_by_unit[segment_index][unit_id] + spiketrain = sorting.get_unit_spike_train(unit_id, segment_index=segment_index) + + mask = (spiketrain >= start_sample_index) & (spiketrain < end_sample_index) + total_samples_in_period = end_sample_index - start_sample_index + spiketrain_period = spiketrain[mask] + amplitudes_period = amplitudes_unit[mask] + + # compute fp (rp_violations). See _compute_refrac_period_violations in quality metrics + fs = sorting.sampling_frequency + t_c = int(round(params["censored_period_ms"] * fs * 1e-3)) + t_r = int(round(params["refractory_period_ms"] * fs * 1e-3)) + n_v = _compute_nb_violations_numba(spiketrain_period, t_r) + fp = _compute_rp_contamination_one_unit( + n_v, + len(spiketrain_period), + total_samples_in_period, + t_c, + t_r, + ) + + # compute fn (amplitude_cutoffs) + fn = amplitude_cutoff( + amplitudes_period, + params["num_histogram_bins"], + params["histogram_smoothing_value"], + params["amplitudes_bins_min_ratio"], + ) + return fp, fn + + +def fp_fn_worker_func_wrapper(period): + global worker_ctx + with threadpool_limits(limits=worker_ctx["max_threads_per_worker"]): + fp, fn = fp_fn_worker_func( + period, + worker_ctx["sorting"], + worker_ctx["all_amplitudes_by_unit"], + worker_ctx["params"], + ) + return fp, fn From a1fb16724b8525e17a40dd203351675bd97db9fa Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 9 Jan 2026 11:39:45 +0100 Subject: [PATCH 18/70] Add credits --- src/spikeinterface/postprocessing/good_periods_per_unit.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/good_periods_per_unit.py b/src/spikeinterface/postprocessing/good_periods_per_unit.py index 07e134acb1..93f1a82912 100644 --- a/src/spikeinterface/postprocessing/good_periods_per_unit.py +++ b/src/spikeinterface/postprocessing/good_periods_per_unit.py @@ -81,7 +81,9 @@ class ComputeGoodPeriodsPerUnit(AnalyzerExtension): good_periods_per_unit : numpy.ndarray, int (n_periods, 4) array with columns: unit_id, segment_id, start_time, end_time (times in samples) - Implementation: Maxime Beau + Notes + ----- + Implementation by Maxime Beau and Alessio Buccino, inspired by NeuroPyxels and Bombcell. """ extension_name = "good_periods_per_unit" From f6752ac85777798f8f11b68e8e86e3c8c46adfe7 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 9 Jan 2026 11:57:28 +0100 Subject: [PATCH 19/70] Fix tests --- .../postprocessing/good_periods_per_unit.py | 24 +++++++------------ 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/postprocessing/good_periods_per_unit.py b/src/spikeinterface/postprocessing/good_periods_per_unit.py index fd09b47741..74c9a39722 100644 --- a/src/spikeinterface/postprocessing/good_periods_per_unit.py +++ b/src/spikeinterface/postprocessing/good_periods_per_unit.py @@ -13,12 +13,6 @@ from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension from spikeinterface.core.node_pipeline import unit_period_dtype -from spikeinterface.core.job_tools import process_worker_initializer, process_function_wrapper -from spikeinterface.metrics.quality.misc_metrics import ( - amplitude_cutoff, - _compute_nb_violations_numba, - _compute_rp_contamination_one_unit, -) from spikeinterface.metrics.spiketrain import compute_firing_rates numba_spec = importlib.util.find_spec("numba") @@ -248,17 +242,9 @@ def _run(self, verbose=False): # parallel with ProcessPoolExecutor( max_workers=n_jobs, - initializer=process_worker_initializer, + initializer=fp_fn_worker_init, mp_context=mp.get_context(mp_context), - initargs=( - fp_fn_worker_func, - fp_fn_worker_init, - init_args, - max_threads_per_worker, - False, # no worker index - None, - None, - ), + initargs=init_args, ) as executor: results = executor.map(fp_fn_worker_func_wrapper, items) @@ -503,6 +489,12 @@ def fp_fn_worker_func(period, sorting, all_amplitudes_by_unit, params): """ Low level computation of false positives and false negatives for one period and one unit. """ + from spikeinterface.metrics.quality.misc_metrics import ( + amplitude_cutoff, + _compute_nb_violations_numba, + _compute_rp_contamination_one_unit, + ) + # period is of dtype unit_period_dtype: 0: segment_index, 1: start_sample_index, 2: end_sample_index, 3: unit_index period_sample = period[0] segment_index = period_sample["segment_index"] From a25182651864d610b0e49ba6a8aeb77c4296ab7b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 9 Jan 2026 12:39:19 +0100 Subject: [PATCH 20/70] oups --- src/spikeinterface/metrics/quality/misc_metrics.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index a434119b54..4f1b113866 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -349,8 +349,6 @@ def compute_refrac_period_violations( # TODO: in case of periods, should we use total samples only in provided periods? total_samples = sorting_analyzer.get_total_samples() - from spikeinterface.metrics.spiketrain.metrics import compute_num_spikes - fs = sorting.sampling_frequency num_units = len(sorting.unit_ids) num_segments = sorting.get_num_segments() @@ -360,7 +358,7 @@ def compute_refrac_period_violations( if unit_ids is None: unit_ids = sorting.unit_ids - num_spikes = compute_num_spikes(sorting) + num_spikes = sorting.count_num_spikes_per_unit() t_c = int(round(censored_period_ms * fs * 1e-3)) t_r = int(round(refractory_period_ms * fs * 1e-3)) From 983d255b8b6fcbc20ddce94694f215f85447b359 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 9 Jan 2026 16:26:19 +0100 Subject: [PATCH 21/70] Sam's review + implement select/merge/split data --- .../metrics/quality/misc_metrics.py | 11 +- .../postprocessing/good_periods_per_unit.py | 364 +++++++++++++----- 2 files changed, 268 insertions(+), 107 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 4f1b113866..93f46a2cbb 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -346,8 +346,15 @@ def compute_refrac_period_violations( sorting = sorting_analyzer.sorting sorting = sorting.select_periods(periods=periods) - # TODO: in case of periods, should we use total samples only in provided periods? - total_samples = sorting_analyzer.get_total_samples() + if periods is None: + total_samples = sorting_analyzer.get_total_samples() + else: + total_samples = {} + for unit_id in sorting.unit_ids: + total_samples[unit_id] = 0 + for period in periods: + unit_id = sorting.unit_ids[period["unit_index"]] + total_samples[unit_id] += period["end_sample_index"] - period["start_sample_index"] fs = sorting.sampling_frequency num_units = len(sorting.unit_ids) diff --git a/src/spikeinterface/postprocessing/good_periods_per_unit.py b/src/spikeinterface/postprocessing/good_periods_per_unit.py index 74c9a39722..334161afa6 100644 --- a/src/spikeinterface/postprocessing/good_periods_per_unit.py +++ b/src/spikeinterface/postprocessing/good_periods_per_unit.py @@ -5,8 +5,9 @@ import numpy as np from typing import Optional +from copy import deepcopy -from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor +from concurrent.futures import ProcessPoolExecutor import multiprocessing as mp from threadpoolctl import threadpool_limits from tqdm.auto import tqdm @@ -51,9 +52,11 @@ class ComputeGoodPeriodsPerUnit(AnalyzerExtension): Minimum spikes required in period for analysis. minimum_valid_period_duration : float, default: 180 Minimum duration that detected good periods must have to be kept, in seconds. - user_defined_periods : array-like or None, default: None - In SAMPLES, user-specified (unit, good_period_start, good_period_end) or (unit, segment_index, good_period_start, good_period_end) time pairs. - Required if method="user_defined" or "combined". + user_defined_periods : array of unit_period_dtype or shape (num_periods, 3) or (num_periods, 4) or None, default: None + Periods of unit_period_dtype (segment_index, start_sample_index, end_sample_index, unit_index) + or numpy array of shape (num_periods, 3) [unit_index, start_sample, end_sample] + or (num_periods, 4) [unit_index, segment_index, start_sample, end_sample] + in samples, over which to compute the metric. refractory_period_ms : float, default: 0.8 Refractory period duration for violation detection (ms). censored_period_ms : float, default: 0.0 @@ -66,9 +69,6 @@ class ComputeGoodPeriodsPerUnit(AnalyzerExtension): The minimum ratio between number of amplitudes for a unit and the number of bins. If the ratio is less than this threshold, the amplitude_cutoff for the unit is set to NaN. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -125,6 +125,7 @@ def _set_params( assert isinstance(period_target_num_spikes, (int)), "period_target_num_spikes must be an integer." # user_defined_periods formatting + self.user_defined_periods = None if user_defined_periods is not None: try: user_defined_periods = np.asarray(user_defined_periods) @@ -134,35 +135,31 @@ def _set_params( "user_defined_periods must be some (n_periods, 3) [unit, good_period_start, good_period_end] " "or (n_periods, 4) [unit, segment_index, good_period_start, good_period_end] structure convertible to a numpy array" ) - ) from e - - if user_defined_periods.ndim != 2 or user_defined_periods.shape[1] not in (3, 4): - raise ValueError( - "user_defined_periods must be of shape (n_periods, 3) [unit, good_period_start, good_period_end] or (n_periods, 4) [unit, segment_index, good_period_start, good_period_end]" ) - if not np.issubdtype(user_defined_periods.dtype, np.integer): - # Try converting to check if they're integer-valued floats - if not np.allclose(user_defined_periods, user_defined_periods.astype(int)): - raise ValueError("All values in user_defined_periods must be integers, in samples.") - user_defined_periods = user_defined_periods.astype(int) + if user_defined_periods.dtype != np.dtype(unit_period_dtype): + if user_defined_periods.ndim != 2 or user_defined_periods.shape[1] not in (3, 4): + raise ValueError( + "user_defined_periods must be of shape (n_periods, 3) [unit_index, good_period_start, good_period_end] or (n_periods, 4) [unit_index, segment_index, good_period_start, good_period_end]" + ) - if user_defined_periods.shape[1] == 3: - # add segment index 0 as column 1 if missing - user_defined_periods = np.hstack( - ( - user_defined_periods[:, 0:1], - np.zeros((user_defined_periods.shape[0], 1), dtype=int), - user_defined_periods[:, 1:3], + if not np.issubdtype(user_defined_periods.dtype, np.integer): + # Try converting to check if they're integer-valued floats + if not np.allclose(user_defined_periods, user_defined_periods.astype(int)): + raise ValueError("All values in user_defined_periods must be integers, in samples.") + user_defined_periods = user_defined_periods.astype(int) + + if user_defined_periods.shape[1] == 3: + # add segment index 0 as column 1 if missing + user_defined_periods = np.hstack( + ( + user_defined_periods[:, 0:1], + np.zeros((user_defined_periods.shape[0], 1), dtype=int), + user_defined_periods[:, 1:3], + ) ) - ) - # Cast user defined periods to unit_period_dtype - user_defined_periods_typed = np.zeros(user_defined_periods.shape[0], dtype=unit_period_dtype) - user_defined_periods_typed["unit_index"] = user_defined_periods[:, 0] - user_defined_periods_typed["segment_index"] = user_defined_periods[:, 1] - user_defined_periods_typed["start_sample_index"] = user_defined_periods[:, 2] - user_defined_periods_typed["end_sample_index"] = user_defined_periods[:, 3] - user_defined_periods = user_defined_periods_typed + # Cast user defined periods to unit_period_dtype + user_defined_periods = np.frombuffer(user_defined_periods, dtype=unit_period_dtype) # assert that user-defined periods are not too short fs = self.sorting_analyzer.sampling_frequency @@ -172,6 +169,7 @@ def _set_params( raise ValueError( f"All user-defined periods must be at least {minimum_valid_period_duration} seconds long." ) + self.user_defined_periods = user_defined_periods params = dict( method=method, @@ -182,7 +180,6 @@ def _set_params( fn_threshold=fn_threshold, minimum_n_spikes=minimum_n_spikes, minimum_valid_period_duration=minimum_valid_period_duration, - user_defined_periods=user_defined_periods, refractory_period_ms=refractory_period_ms, censored_period_ms=censored_period_ms, num_histogram_bins=num_histogram_bins, @@ -193,33 +190,186 @@ def _set_params( return params def _select_extension_data(self, unit_ids): - new_extension_data = self.data + new_extension_data = {} + good_periods = self.data["good_periods_per_unit"] + unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids) + mask = np.isin(good_periods["unit_index"], unit_indices) + new_extension_data["good_periods_per_unit"] = good_periods[mask] return new_extension_data def _merge_extension_data( self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, censor_ms=None, verbose=False, **job_kwargs ): - new_extension_data = self.data + new_extension_data = {} + good_periods = self.data["good_periods_per_unit"] + if self.params["method"] in ("false_positives_and_negatives", "combined"): + # need to recompute for merged units + recompute = True + else: + # in case of user-defined periods, just merge periods + recompute = False + + if recompute: + new_periods_centers = deepcopy(self.data.get("period_centers")) + new_periods_fp_per_unit = deepcopy(self.data.get("periods_fp_per_unit")) + new_periods_fn_per_unit = deepcopy(self.data.get("periods_fn_per_unit")) + # remove data of merged units + merged_unit_indices = [] + for unit_ids in merge_unit_groups: + unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids) + merged_unit_indices.append(unit_indices) + for unit_id in unit_ids: + for segment_index in range(self.sorting_analyzer.get_num_segments()): + new_periods_centers[segment_index].pop(unit_id, None) + new_periods_fp_per_unit[segment_index].pop(unit_id, None) + new_periods_fn_per_unit[segment_index].pop(unit_id, None) + + # remove periods of merged units + good_periods_valid = good_periods[~np.isin(good_periods["unit_index"], np.array(merged_unit_indices))] + # recompute for merged units + good_periods_merged, period_centers, fps, fns = self._compute_periods( + new_sorting_analyzer, + unit_ids=new_unit_ids, + ) + new_good_periods = np.concatenate((good_periods_valid, good_periods_merged), axis=0) + + # update period centers, fps, fns + for segment_index in range(new_sorting_analyzer.get_num_segments()): + new_periods_centers[segment_index].update(period_centers[segment_index]) + new_periods_fp_per_unit[segment_index].update(fps[segment_index]) + new_periods_fn_per_unit[segment_index].update(fns[segment_index]) + + new_extension_data["good_periods_per_unit"] = self._sort_periods(new_good_periods) + new_extension_data["period_centers"] = period_centers + new_extension_data["periods_fp_per_unit"] = fps + new_extension_data["periods_fn_per_unit"] = fns + else: + # just merge periods + merged_periods = np.array([], dtype=unit_period_dtype) + merged_unit_indices = [] + for unit_ids in merge_unit_groups: + unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids) + merged_unit_indices.append(unit_indices) + # get periods of all units to be merged + masked_periods = good_periods[np.isin(good_periods["unit_index"], unit_indices)] + if len(masked_periods) == 0: + continue + # merge periods + _merged_periods = merge_overlapping_periods_across_units_and_segments(masked_periods) + merged_periods = np.concatenate((merged_periods, _merged_periods)) + + # get periods of unmerged units + unmerged_mask = ~np.isin(good_periods["unit_index"], np.concatenate(merged_unit_indices)) + unmerged_periods = good_periods[unmerged_mask] + + new_good_periods = np.concatenate((unmerged_periods, merged_periods)) + new_extension_data["good_periods_per_unit"] = self._sort_periods(new_good_periods) + return new_extension_data def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): - new_extension_data = self.data + new_extension_data = {} + good_periods = self.data["good_periods_per_unit"] + if self.params["method"] in ("false_positives_and_negatives", "combined"): + # need to recompute for split units + recompute = True + else: + # in case of user-defined periods, we can only duplicate valid periods for the split + recompute = False + + if recompute: + new_periods_centers = deepcopy(self.data.get("period_centers")) + new_periods_fp_per_unit = deepcopy(self.data.get("periods_fp_per_unit")) + new_periods_fn_per_unit = deepcopy(self.data.get("periods_fn_per_unit")) + # remove data of split units + split_unit_indices = self.sorting_analyzer.sorting.ids_to_indices(split_units) + for unit_id in split_units: + for segment_index in range(self.sorting_analyzer.get_num_segments()): + new_periods_centers[segment_index].pop(unit_id, None) + new_periods_fp_per_unit[segment_index].pop(unit_id, None) + new_periods_fn_per_unit[segment_index].pop(unit_id, None) + + # remove periods of split units + good_periods_valid = good_periods[~np.isin(good_periods["unit_index"], split_unit_indices)] + # recompute for split units + good_periods_split, period_centers, fps, fns = self._compute_periods( + new_sorting_analyzer, + unit_ids=new_unit_ids, + ) + new_good_periods = np.concatenate((good_periods_valid, good_periods_split)) + # update period centers, fps, fns + for segment_index in range(new_sorting_analyzer.get_num_segments()): + new_periods_centers[segment_index].update(period_centers[segment_index]) + new_periods_fp_per_unit[segment_index].update(fps[segment_index]) + new_periods_fn_per_unit[segment_index].update(fns[segment_index]) + + new_extension_data["good_periods_per_unit"] = self._sort_periods(new_good_periods) + new_extension_data["period_centers"] = period_centers + new_extension_data["periods_fp_per_unit"] = fps + new_extension_data["periods_fn_per_unit"] = fns + else: + # just duplicate periods to the split units + split_periods = [] + split_unit_indices = self.sorting_analyzer.sorting.ids_to_indices(split_units) + for split_unit_id, new_unit_ids in zip(split_units, new_unit_ids): + unit_index = self.sorting_analyzer.sorting.id_to_index(split_unit_id) + new_unit_indices = new_sorting_analyzer.sorting.ids_to_indices(new_unit_ids) + split_unit_indices.append(unit_index) + # get periods of all units to be merged + masked_periods = good_periods[good_periods["unit_index"] == unit_index] + for new_unit_index in new_unit_indices: + _split_periods = masked_periods.copy() + _split_periods["unit_index"] = new_unit_index + split_periods = np.concatenate((split_periods, _split_periods), axis=0) + if len(masked_periods) == 0: + continue + # merge periods + _split_periods = merge_overlapping_periods_across_units_and_segments(masked_periods) + split_periods.append(_split_periods) + split_periods = np.concatenate(split_periods, axis=0) + # get periods of unmerged units + unsplit_mask = ~np.isin(good_periods["unit_index"], np.array(split_unit_indices)) + unsplit_periods = good_periods[unsplit_mask] + + new_good_periods = np.concatenate((unsplit_periods, split_periods), axis=0) + new_extension_data["good_periods_per_unit"] = self._sort_periods(new_good_periods) + return new_extension_data - def _run(self, verbose=False): + def _run(self, unit_ids=None, verbose=False): + good_periods_per_unit, period_centers, fps, fns = self._compute_periods( + self.sorting_analyzer, + unit_ids=unit_ids, + ) + self.data["good_periods_per_unit"] = good_periods_per_unit + if period_centers is not None: + self.data["period_centers"] = period_centers + if fps is not None: + self.data["periods_fp_per_unit"] = fps + if fns is not None: + self.data["periods_fn_per_unit"] = fns + + def _compute_periods( + self, + sorting_analyzer, + unit_ids=None, + ): from spikeinterface import get_global_job_kwargs if self.params["method"] == "user_defined": + # directly use user defined periods - self.data["good_periods_per_unit"] = self.params["user_defined_periods"] + return self.user_defined_periods, None, None, None elif self.params["method"] in ["false_positives_and_negatives", "combined"]: # dict: unit_id -> list of subperiod, each subperiod is an array of dtype unit_period_dtype with 4 fields - all_periods = self.compute_subperiods( + all_periods = compute_subperiods( + sorting_analyzer, self.params["period_duration_s_absolute"], self.params["period_target_num_spikes"], self.params["period_mode"], + unit_ids=unit_ids, ) job_kwargs = get_global_job_kwargs() @@ -230,10 +380,10 @@ def _run(self, verbose=False): # Compute fp and fn for all periods # Process units in parallel - amp_scalings = self.sorting_analyzer.get_extension("amplitude_scalings") + amp_scalings = sorting_analyzer.get_extension("amplitude_scalings") all_amplitudes_by_unit = amp_scalings.get_data(outputs="by_unit", concatenated=False) - init_args = (self.sorting_analyzer.sorting, all_amplitudes_by_unit, self.params, max_threads_per_worker) + init_args = (sorting_analyzer.sorting, all_amplitudes_by_unit, self.params, max_threads_per_worker) # Each item is one computation of fp and fn for one period and one unit items = [(period,) for period in all_periods] @@ -265,14 +415,14 @@ def _run(self, verbose=False): # fps and fns are lists of segments with dicts unit_id -> array of shape (n_subperiods) fps = [] fns = [] - for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): + for segment_index in range(sorting_analyzer.sorting.get_num_segments()): fp_in_segment = {} fn_in_segment = {} segment_mask = all_periods["segment_index"] == segment_index periods_segment = all_periods[segment_mask] fps_segment = all_fps[segment_mask] fns_segment = all_fns[segment_mask] - for unit_index, unit_id in enumerate(self.sorting_analyzer.unit_ids): + for unit_index, unit_id in enumerate(sorting_analyzer.unit_ids): unit_mask = periods_segment["unit_index"] == unit_index fp_in_segment[unit_id] = fps_segment[unit_mask] fn_in_segment[unit_id] = fns_segment[unit_mask] @@ -283,10 +433,7 @@ def _run(self, verbose=False): good_periods = all_periods[good_period_mask] # Sort good periods on segment_index, unit_index, start_sample_index - sort_idx = np.lexsort( - (good_periods["start_sample_index"], good_periods["unit_index"], good_periods["segment_index"]) - ) - good_periods_per_unit = good_periods[sort_idx] + good_periods_per_unit = self._sort_periods(good_periods) # Combine with user-defined periods if provided if self.params["method"] == "combined": @@ -296,27 +443,24 @@ def _run(self, verbose=False): # Remove good periods that are too short minimum_valid_period_duration = self.params["minimum_valid_period_duration"] - min_valid_period_samples = int(minimum_valid_period_duration * self.sorting_analyzer.sampling_frequency) + min_valid_period_samples = int(minimum_valid_period_duration * sorting_analyzer.sampling_frequency) duration_samples = good_periods_per_unit["end_sample_index"] - good_periods_per_unit["start_sample_index"] valid_mask = duration_samples >= min_valid_period_samples good_periods_per_unit = good_periods_per_unit[valid_mask] # Convert subperiods per unit in period_centers_s period_centers = [] - for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): + for segment_index in range(sorting_analyzer.sorting.get_num_segments()): periods_segment = all_periods[all_periods["segment_index"] == segment_index] period_centers_dict = {} - for unit_index, unit_id in enumerate(self.sorting_analyzer.unit_ids): + for unit_index, unit_id in enumerate(sorting_analyzer.unit_ids): periods_unit = periods_segment[periods_segment["unit_index"] == unit_index] centers = list(0.5 * (periods_unit["start_sample_index"] + periods_unit["end_sample_index"])) period_centers_dict[unit_id] = centers period_centers.append(period_centers_dict) # Store data: here we have to make sure every dict is JSON serializable, so everything is lists - self.data["period_centers"] = period_centers - self.data["periods_fp_per_unit"] = fps - self.data["periods_fn_per_unit"] = fns - self.data["good_periods_per_unit"] = good_periods_per_unit + return good_periods_per_unit, period_centers, fps, fns def _get_data(self, outputs: str = "by_unit"): """ @@ -338,7 +482,7 @@ def _get_data(self, outputs: str = "by_unit"): (start_sample_index, end_sample_index) tuples. """ if outputs == "numpy": - good_periods = self.data["good_periods_per_unit"] + good_periods = self.data["good_periods_per_unit"].copy() else: # by_unit unit_ids = self.sorting_analyzer.unit_ids @@ -357,57 +501,65 @@ def _get_data(self, outputs: str = "by_unit"): return good_periods - def compute_subperiods( - self, - period_duration_s_absolute: float = 10, - period_target_num_spikes: int = 1000, - period_mode: str = "absolute", - ) -> dict: - """ - Computes subperiods per unit based on specified size mode. + def _sort_periods(self, periods): + sort_idx = np.lexsort((periods["start_sample_index"], periods["unit_index"], periods["segment_index"])) + sorted_periods = periods[sort_idx] + return sorted_periods - Returns - ------- - all_subperiods : dict - Dictionary mapping unit IDs to lists of subperiods (arrays of dtype unit_period_dtype). - """ - sorting = self.sorting_analyzer.sorting - fs = sorting.sampling_frequency + +def compute_subperiods( + sorting_analyzer, + period_duration_s_absolute: float = 10, + period_target_num_spikes: int = 1000, + period_mode: str = "absolute", + unit_ids: Optional[list] = None, +) -> dict: + """ + Computes subperiods per unit based on specified size mode. + + Returns + ------- + all_subperiods : dict + Dictionary mapping unit IDs to lists of subperiods (arrays of dtype unit_period_dtype). + """ + sorting = sorting_analyzer.sorting + fs = sorting.sampling_frequency + if unit_ids is None: unit_ids = sorting.unit_ids - if period_mode == "absolute": - period_sizes_samples = {u: np.round(period_duration_s_absolute * fs).astype(int) for u in unit_ids} - else: # relative - mean_firing_rates = compute_firing_rates(self.sorting_analyzer, unit_ids) - period_sizes_samples = { - u: np.round((period_target_num_spikes / mean_firing_rates[u]) * fs).astype(int) for u in unit_ids - } - margin_sizes_samples = period_sizes_samples - - all_subperiods = np.array([], dtype=unit_period_dtype) - for unit_index, unit_id in enumerate(unit_ids): - period_size_samples = period_sizes_samples[unit_id] - margin_size_samples = margin_sizes_samples[unit_id] - - for segment_index in range(sorting.get_num_segments()): - n_samples = self.sorting_analyzer.get_num_samples(segment_index) # int: samples - n_subperiods = n_samples // period_size_samples + 1 - starts_ends = np.array( - [ - [i * period_size_samples, i * period_size_samples + 2 * margin_size_samples] - for i in range(n_subperiods) - ] - ) - periods_for_unit = np.zeros(len(starts_ends), dtype=unit_period_dtype) - for i, (start, end) in enumerate(starts_ends): - subperiod = np.zeros((1,), dtype=unit_period_dtype) - subperiod["segment_index"] = segment_index - subperiod["start_sample_index"] = start - subperiod["end_sample_index"] = end - subperiod["unit_index"] = unit_index - periods_for_unit[i] = subperiod - all_subperiods = np.concatenate((all_subperiods, periods_for_unit), axis=0) - return all_subperiods + if period_mode == "absolute": + period_sizes_samples = {u: np.round(period_duration_s_absolute * fs).astype(int) for u in unit_ids} + else: # relative + mean_firing_rates = compute_firing_rates(sorting_analyzer, unit_ids) + period_sizes_samples = { + u: np.round((period_target_num_spikes / mean_firing_rates[u]) * fs).astype(int) for u in unit_ids + } + margin_sizes_samples = period_sizes_samples + + all_subperiods = [] + for unit_index, unit_id in enumerate(unit_ids): + period_size_samples = period_sizes_samples[unit_id] + margin_size_samples = margin_sizes_samples[unit_id] + + for segment_index in range(sorting.get_num_segments()): + n_samples = sorting_analyzer.get_num_samples(segment_index) # int: samples + n_subperiods = n_samples // period_size_samples + 1 + starts_ends = np.array( + [ + [i * period_size_samples, i * period_size_samples + 2 * margin_size_samples] + for i in range(n_subperiods) + ] + ) + periods_for_unit = np.zeros(len(starts_ends), dtype=unit_period_dtype) + for i, (start, end) in enumerate(starts_ends): + subperiod = np.zeros((1,), dtype=unit_period_dtype) + subperiod["segment_index"] = segment_index + subperiod["start_sample_index"] = start + subperiod["end_sample_index"] = end + subperiod["unit_index"] = unit_index + periods_for_unit[i] = subperiod + all_subperiods.append(periods_for_unit) + return np.concatenate(all_subperiods) def merge_overlapping_periods(subperiods): @@ -452,7 +604,7 @@ def merge_overlapping_periods(subperiods): def merge_overlapping_periods_across_units_and_segments(periods): segments = np.unique(periods["segment_index"]) units = np.unique(periods["unit_index"]) - merged_periods = np.array([], dtype=unit_period_dtype) + merged_periods = [] for segment_index in segments: periods_per_segment = periods[periods["segment_index"] == segment_index] for unit_index in units: @@ -460,7 +612,9 @@ def merge_overlapping_periods_across_units_and_segments(periods): if len(masked_periods) == 0: continue _merged_periods = merge_overlapping_periods(masked_periods) - merged_periods = np.concatenate((merged_periods, _merged_periods), axis=0) + merged_periods.append(_merged_periods) + if len(merged_periods) == 0: + merged_periods = np.array([], dtype=unit_period_dtype) return merged_periods From c5dbb9369f1583b1e0cab21fd9dc4df9ebd2c7e3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 12 Jan 2026 12:12:57 +0100 Subject: [PATCH 22/70] Rename to valid_unit_periods and wip widgets --- src/spikeinterface/postprocessing/__init__.py | 6 +- .../tests/test_valid_unit_periods.py | 19 ++ ...iods_per_unit.py => valid_unit_periods.py} | 48 ++--- .../widgets/unit_valid_periods.py | 186 ++++++++++++++++++ src/spikeinterface/widgets/widget_list.py | 3 + 5 files changed, 236 insertions(+), 26 deletions(-) create mode 100644 src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py rename src/spikeinterface/postprocessing/{good_periods_per_unit.py => valid_unit_periods.py} (94%) create mode 100644 src/spikeinterface/widgets/unit_valid_periods.py diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index 078157188e..555c9a5d3b 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -45,7 +45,7 @@ compute_template_metrics, ) -from .good_periods_per_unit import ( - ComputeGoodPeriodsPerUnit, - compute_good_periods_per_unit, +from .valid_unit_periods import ( + ComputeValidUnitPeriods, + compute_valid_unit_periods, ) diff --git a/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py b/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py new file mode 100644 index 0000000000..bf80bc2d30 --- /dev/null +++ b/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py @@ -0,0 +1,19 @@ +import pytest + +from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite +from spikeinterface.postprocessing import ComputeValidUnitPeriods + + +class TestComputeValidUnitPeriods(AnalyzerExtensionCommonTestSuite): + + @pytest.mark.parametrize( + "params", + [ + dict(period_mode="absolute"), + dict(period_mode="relative"), + ], + ) + def test_extension(self, params): + self.run_extension_tests( + ComputeValidUnitPeriods, params, extra_dependencies=["templates", "amplitude_scalings"] + ) diff --git a/src/spikeinterface/postprocessing/good_periods_per_unit.py b/src/spikeinterface/postprocessing/valid_unit_periods.py similarity index 94% rename from src/spikeinterface/postprocessing/good_periods_per_unit.py rename to src/spikeinterface/postprocessing/valid_unit_periods.py index 334161afa6..4cdfc0b3b6 100644 --- a/src/spikeinterface/postprocessing/good_periods_per_unit.py +++ b/src/spikeinterface/postprocessing/valid_unit_periods.py @@ -23,7 +23,7 @@ HAVE_NUMBA = False -class ComputeGoodPeriodsPerUnit(AnalyzerExtension): +class ComputeValidUnitPeriods(AnalyzerExtension): """Compute good time periods per unit based on quality metrics. Paraneters @@ -72,7 +72,7 @@ class ComputeGoodPeriodsPerUnit(AnalyzerExtension): Returns ------- - good_periods_per_unit : numpy.ndarray, int + valid_unit_periods : numpy.ndarray, int (n_periods, 4) array with columns: unit_id, segment_id, start_time, end_time (times in samples) Notes @@ -80,7 +80,7 @@ class ComputeGoodPeriodsPerUnit(AnalyzerExtension): Implementation by Maxime Beau and Alessio Buccino, inspired by NeuroPyxels and Bombcell. """ - extension_name = "good_periods_per_unit" + extension_name = "valid_unit_periods" depend_on = [] need_recording = False use_nodepipeline = False @@ -191,17 +191,17 @@ def _set_params( def _select_extension_data(self, unit_ids): new_extension_data = {} - good_periods = self.data["good_periods_per_unit"] + good_periods = self.data["valid_unit_periods"] unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids) mask = np.isin(good_periods["unit_index"], unit_indices) - new_extension_data["good_periods_per_unit"] = good_periods[mask] + new_extension_data["valid_unit_periods"] = good_periods[mask] return new_extension_data def _merge_extension_data( self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, censor_ms=None, verbose=False, **job_kwargs ): new_extension_data = {} - good_periods = self.data["good_periods_per_unit"] + good_periods = self.data["valid_unit_periods"] if self.params["method"] in ("false_positives_and_negatives", "combined"): # need to recompute for merged units recompute = True @@ -239,7 +239,7 @@ def _merge_extension_data( new_periods_fp_per_unit[segment_index].update(fps[segment_index]) new_periods_fn_per_unit[segment_index].update(fns[segment_index]) - new_extension_data["good_periods_per_unit"] = self._sort_periods(new_good_periods) + new_extension_data["valid_unit_periods"] = self._sort_periods(new_good_periods) new_extension_data["period_centers"] = period_centers new_extension_data["periods_fp_per_unit"] = fps new_extension_data["periods_fn_per_unit"] = fns @@ -263,13 +263,13 @@ def _merge_extension_data( unmerged_periods = good_periods[unmerged_mask] new_good_periods = np.concatenate((unmerged_periods, merged_periods)) - new_extension_data["good_periods_per_unit"] = self._sort_periods(new_good_periods) + new_extension_data["valid_unit_periods"] = self._sort_periods(new_good_periods) return new_extension_data def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): new_extension_data = {} - good_periods = self.data["good_periods_per_unit"] + good_periods = self.data["valid_unit_periods"] if self.params["method"] in ("false_positives_and_negatives", "combined"): # need to recompute for split units recompute = True @@ -303,7 +303,7 @@ def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, new_periods_fp_per_unit[segment_index].update(fps[segment_index]) new_periods_fn_per_unit[segment_index].update(fns[segment_index]) - new_extension_data["good_periods_per_unit"] = self._sort_periods(new_good_periods) + new_extension_data["valid_unit_periods"] = self._sort_periods(new_good_periods) new_extension_data["period_centers"] = period_centers new_extension_data["periods_fp_per_unit"] = fps new_extension_data["periods_fn_per_unit"] = fns @@ -332,16 +332,16 @@ def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, unsplit_periods = good_periods[unsplit_mask] new_good_periods = np.concatenate((unsplit_periods, split_periods), axis=0) - new_extension_data["good_periods_per_unit"] = self._sort_periods(new_good_periods) + new_extension_data["valid_unit_periods"] = self._sort_periods(new_good_periods) return new_extension_data def _run(self, unit_ids=None, verbose=False): - good_periods_per_unit, period_centers, fps, fns = self._compute_periods( + valid_unit_periods, period_centers, fps, fns = self._compute_periods( self.sorting_analyzer, unit_ids=unit_ids, ) - self.data["good_periods_per_unit"] = good_periods_per_unit + self.data["valid_unit_periods"] = valid_unit_periods if period_centers is not None: self.data["period_centers"] = period_centers if fps is not None: @@ -433,20 +433,20 @@ def _compute_periods( good_periods = all_periods[good_period_mask] # Sort good periods on segment_index, unit_index, start_sample_index - good_periods_per_unit = self._sort_periods(good_periods) + valid_unit_periods = self._sort_periods(good_periods) # Combine with user-defined periods if provided if self.params["method"] == "combined": user_defined_periods = self.params["user_defined_periods"] - all_periods = np.concatenate((good_periods_per_unit, user_defined_periods), axis=0) - good_periods_per_unit = merge_overlapping_periods_across_units_and_segments(good_periods_per_unit) + all_periods = np.concatenate((valid_unit_periods, user_defined_periods), axis=0) + valid_unit_periods = merge_overlapping_periods_across_units_and_segments(valid_unit_periods) # Remove good periods that are too short minimum_valid_period_duration = self.params["minimum_valid_period_duration"] min_valid_period_samples = int(minimum_valid_period_duration * sorting_analyzer.sampling_frequency) - duration_samples = good_periods_per_unit["end_sample_index"] - good_periods_per_unit["start_sample_index"] + duration_samples = valid_unit_periods["end_sample_index"] - valid_unit_periods["start_sample_index"] valid_mask = duration_samples >= min_valid_period_samples - good_periods_per_unit = good_periods_per_unit[valid_mask] + valid_unit_periods = valid_unit_periods[valid_mask] # Convert subperiods per unit in period_centers_s period_centers = [] @@ -460,7 +460,7 @@ def _compute_periods( period_centers.append(period_centers_dict) # Store data: here we have to make sure every dict is JSON serializable, so everything is lists - return good_periods_per_unit, period_centers, fps, fns + return valid_unit_periods, period_centers, fps, fns def _get_data(self, outputs: str = "by_unit"): """ @@ -482,12 +482,12 @@ def _get_data(self, outputs: str = "by_unit"): (start_sample_index, end_sample_index) tuples. """ if outputs == "numpy": - good_periods = self.data["good_periods_per_unit"].copy() + good_periods = self.data["valid_unit_periods"].copy() else: # by_unit unit_ids = self.sorting_analyzer.unit_ids good_periods = [] - good_periods_array = self.data["good_periods_per_unit"] + good_periods_array = self.data["valid_unit_periods"] for segment_index in range(self.sorting_analyzer.get_num_segments()): segment_mask = good_periods_array["segment_index"] == segment_index periods_dict = {} @@ -615,12 +615,14 @@ def merge_overlapping_periods_across_units_and_segments(periods): merged_periods.append(_merged_periods) if len(merged_periods) == 0: merged_periods = np.array([], dtype=unit_period_dtype) + else: + merged_periods = np.concatenate(merged_periods, axis=0) return merged_periods -register_result_extension(ComputeGoodPeriodsPerUnit) -compute_good_periods_per_unit = ComputeGoodPeriodsPerUnit.function_factory() +register_result_extension(ComputeValidUnitPeriods) +compute_valid_unit_periods = ComputeValidUnitPeriods.function_factory() global worker_ctx diff --git a/src/spikeinterface/widgets/unit_valid_periods.py b/src/spikeinterface/widgets/unit_valid_periods.py new file mode 100644 index 0000000000..f911fab3e7 --- /dev/null +++ b/src/spikeinterface/widgets/unit_valid_periods.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +import numpy as np +from warnings import warn + +from spikeinterface.core import SortingAnalyzer +from .base import BaseWidget, to_attr + + +class UnitValidPeriodsWidget(BaseWidget): + """ + Plots the valid periods for units based on valid periods extension. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer | None, default: None + The sorting analyzer + segment_index : None or int, default: None + The segment index. If None, uses first segment. + unit_ids : list | None, default: None + List of unit ids to plot. If None, all units are plotted. + show_only_units_with_good_periods : bool, default: True + If True, only units with good periods are shown. + """ + + def __init__( + self, + sorting_analyzer: SortingAnalyzer | None = None, + segment_index: int | None = None, + unit_ids: list | None = None, + show_only_units_with_good_periods: bool = True, + backend: str | None = None, + **backend_kwargs, + ): + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) + self.check_extensions(sorting_analyzer, "valid_unit_periods") + good_periods_ext = sorting_analyzer.get_extension("valid_unit_periods") + if good_periods_ext.params["method"] == "user_defined": + raise ValueError("UnitValidPeriodsWidget cannot be used with 'user_defined' good periods.") + + if segment_index is None: + nseg = sorting_analyzer.get_num_segments() + if nseg != 1: + raise ValueError("You must provide segment_index=...") + else: + segment_index = 0 + + good_periods = good_periods_ext.get_data(outputs="numpy") + if show_only_units_with_good_periods: + good_unit_ids = sorting_analyzer.unit_ids[np.unique(good_periods["unit_index"])] + else: + good_unit_ids = sorting_analyzer.unit_ids + if unit_ids is not None: + good_unit_ids = [u for u in unit_ids if u in good_unit_ids] + + data_plot = dict( + sorting_analyzer=sorting_analyzer, + segment_index=segment_index, + unit_ids=good_unit_ids, + ) + + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + num_units = len(dp.unit_ids) + + if backend_kwargs["axes"] is not None: + axes = backend_kwargs["axes"] + assert np.asarray(axes).shape == (3, num_units), "Axes shape does not match number of units" + else: + if "figsize" not in backend_kwargs: + backend_kwargs["figsize"] = (2 * num_units, 2 * 3) + backend_kwargs["num_axes"] = num_units * 3 + backend_kwargs["ncols"] = num_units + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + sorting_analyzer = dp.sorting_analyzer + sampling_frequency = sorting_analyzer.sampling_frequency + segment_index = dp.segment_index + good_periods_ext = sorting_analyzer.get_extension("valid_unit_periods") + fp_threshold = good_periods_ext.params["fp_threshold"] + fn_threshold = good_periods_ext.params["fn_threshold"] + fp_per_unit = good_periods_ext.data["periods_fp_per_unit"][segment_index] + fn_per_unit = good_periods_ext.data["periods_fn_per_unit"][segment_index] + period_centers = good_periods_ext.data["period_centers"][segment_index] + good_periods = good_periods_ext.get_data(outputs="numpy") + good_periods = good_periods[good_periods["segment_index"] == segment_index] + + amp_scalings_ext = sorting_analyzer.get_extension("amplitude_scalings") + amp_scalings_by_unit = amp_scalings_ext.get_data(outputs="by_unit")[segment_index] + + for unit_id in dp.unit_ids: + fp = fp_per_unit[unit_id] + fn = fn_per_unit[unit_id] + unit_index = list(sorting_analyzer.unit_ids).index(unit_id) + + axs = self.axes[:, unit_index] + # for simplicity we don't use timestamps here + spiketrain = ( + sorting_analyzer.sorting.get_unit_spike_train(unit_id, segment_index=segment_index) / sampling_frequency + ) + center_bins_s = np.array(period_centers[unit_id]) / sampling_frequency + + axs[0].plot(center_bins_s, fp, ls="", marker="o", color="r") + axs[0].axhline(fp_threshold, color="gray", ls="--") + axs[1].plot(center_bins_s, fn, ls="", marker="o") + axs[1].axhline(fn_threshold, color="gray", ls="--") + axs[2].plot(spiketrain, amp_scalings_by_unit[unit_id], ls="", marker="o", color="gray", alpha=0.5) + # plot valid periods + valid_period_for_units = good_periods[good_periods["unit_index"] == unit_index] + for valid_period in valid_period_for_units: + start_time = valid_period["start_sample_index"] / sorting_analyzer.sampling_frequency + end_time = valid_period["end_sample_index"] / sorting_analyzer.sampling_frequency + axs[2].axvspan(start_time, end_time, alpha=0.3, color="g") + + axs[0].set_xlabel("") + axs[1].set_xlabel("") + axs[2].set_xlabel("Time (s)") + axs[0].set_ylabel("FP Rate") + axs[1].set_ylabel("FN Rate") + axs[2].set_ylabel("Amplitude Scaling") + axs[0].set_title(f"Unit {unit_id}") + + # TODO: fix update + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import ipywidgets.widgets as widgets + from IPython.display import display + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + + cm = 1 / 2.54 + + width_cm = backend_kwargs["width_cm"] + height_cm = backend_kwargs["height_cm"] + + ratios = [0.15, 0.85] + + with plt.ioff(): + output = widgets.Output() + with output: + # Create figure without axes - let plot_matplotlib create them + self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) + plt.show() + + self.unit_selector = UnitSelector(data_plot["unit_ids"]) + self.unit_selector.value = list(data_plot["unit_ids"])[:1] + + self.widget = widgets.AppLayout( + center=self.figure.canvas, + left_sidebar=self.unit_selector, + pane_widths=ratios + [0], + ) + + # a first update + self.axes = None + self._update_ipywidget() + + self.unit_selector.observe(self._update_ipywidget, names="value", type="change") + + if backend_kwargs["display"]: + display(self.widget) + + def _update_ipywidget(self, change=None): + if self.axes is None: + self.figure.clear() + else: + for ax in self.axes.flatten(): + ax.clear() + + data_plot = self.next_data_plot + data_plot["unit_ids"] = self.unit_selector.value + + backend_kwargs = dict(figure=self.figure, axes=self.axes, ax=None) + self.plot_matplotlib(data_plot, **backend_kwargs) + + self.figure.canvas.draw() + self.figure.canvas.flush_events() diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 6edba67c96..5078843006 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -37,6 +37,7 @@ from .comparison import AgreementMatrixWidget, ConfusionMatrixWidget from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, StudyAgreementMatrix, StudySummary from .collision import ComparisonCollisionBySimilarityWidget, StudyComparisonCollisionBySimilarityWidget +from .unit_valid_periods import UnitValidPeriodsWidget widget_list = [ AgreementMatrixWidget, @@ -48,6 +49,7 @@ CrossCorrelogramsWidget, DriftingTemplatesWidget, DriftRasterMapWidget, + UnitGoodPeriodsWidget, ISIDistributionWidget, LocationsWidget, MotionWidget, @@ -128,6 +130,7 @@ plot_crosscorrelograms = CrossCorrelogramsWidget plot_drifting_templates = DriftingTemplatesWidget plot_drift_raster_map = DriftRasterMapWidget +plot_valid_unit_periods = UnitValidPeriodsWidget plot_isi_distribution = ISIDistributionWidget plot_locations = LocationsWidget plot_motion = MotionWidget From f382f89877215c24bbff4b15d122cc1803529242 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 12 Jan 2026 12:14:41 +0100 Subject: [PATCH 23/70] Fix imports --- src/spikeinterface/widgets/widget_list.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 5078843006..8281666dd8 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -49,7 +49,7 @@ CrossCorrelogramsWidget, DriftingTemplatesWidget, DriftRasterMapWidget, - UnitGoodPeriodsWidget, + UnitValidPeriodsWidget, ISIDistributionWidget, LocationsWidget, MotionWidget, From ad50845a72af9cc7146439f585427f4b971ec0a8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 12 Jan 2026 15:56:31 +0100 Subject: [PATCH 24/70] Add widget and extend params --- .../postprocessing/amplitude_scalings.py | 6 +- .../postprocessing/valid_unit_periods.py | 58 ++++++++++------ src/spikeinterface/widgets/rasters.py | 4 -- .../widgets/tests/test_widgets.py | 21 ++++++ .../widgets/unit_valid_periods.py | 67 ++++++++++++------- src/spikeinterface/widgets/widget_list.py | 6 +- 6 files changed, 109 insertions(+), 53 deletions(-) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 8f3ffe0617..473798fe7c 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -127,7 +127,11 @@ def _get_pipeline_nodes(self): sparsity = self.params["sparsity"] else: if self.params["max_dense_channels"] is not None: - assert recording.get_num_channels() <= self.params["max_dense_channels"], "" + assert recording.get_num_channels() <= self.params["max_dense_channels"], ( + "Sparsity must be provided when the number of channels is " + f"greater than {self.params['max_dense_channels']}. Alternatively, set max_dense_channels to None " + "to compute amplitude scalings using dense waveforms." + ) sparsity = ChannelSparsity.create_dense(self.sorting_analyzer) sparsity_mask = sparsity.mask diff --git a/src/spikeinterface/postprocessing/valid_unit_periods.py b/src/spikeinterface/postprocessing/valid_unit_periods.py index 4cdfc0b3b6..4f0ea701af 100644 --- a/src/spikeinterface/postprocessing/valid_unit_periods.py +++ b/src/spikeinterface/postprocessing/valid_unit_periods.py @@ -12,6 +12,7 @@ from threadpoolctl import threadpool_limits from tqdm.auto import tqdm +from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension from spikeinterface.core.node_pipeline import unit_period_dtype from spikeinterface.metrics.spiketrain import compute_firing_rates @@ -42,8 +43,14 @@ class ComputeValidUnitPeriods(AnalyzerExtension): period_target_num_spikes : int | None, default: 300 Alternative to period_size_absolute, different for each unit: mean number of spikes that should be present in each estimation period. For neurons firing at 10 Hz, this would correspond to periods of 10s (100 spikes / 10 Hz = 10s). - period_size_mode: {"absolute", "relative"}, default: "absolute" - Whether to use absolute (in seconds) or relative (in mean number of spikes) period sizes. + period_mode: {"absolute", "relative"}, default: "absolute" + Whether to use absolute (in seconds) or relative (in target number of spikes) period sizes. + relative_margin_size : float, default: 1.0 + The margin to the left and the right for each period, expressed as a multiple of the period size. + For example, a value of 1.0 means that the margin size is equal to the period size: for a period of 10s, + each value will be computed using 30s of data (10s + 10s margin on each side). + min_num_periods_relative : int, default: 5 + Minimum number of periods per unit, when using period_mode "relative". fp_threshold : float, default: 0.05 Maximum false positive rate to mark period as good. fn_threshold : float, default: 0.05 @@ -70,10 +77,6 @@ class ComputeValidUnitPeriods(AnalyzerExtension): If the ratio is less than this threshold, the amplitude_cutoff for the unit is set to NaN. - Returns - ------- - valid_unit_periods : numpy.ndarray, int - (n_periods, 4) array with columns: unit_id, segment_id, start_time, end_time (times in samples) Notes ----- @@ -83,6 +86,7 @@ class ComputeValidUnitPeriods(AnalyzerExtension): extension_name = "valid_unit_periods" depend_on = [] need_recording = False + need_job_kwargs = True use_nodepipeline = False need_job_kwargs = False @@ -92,10 +96,12 @@ def _set_params( period_duration_s_absolute: float = 30.0, period_target_num_spikes: int = 300, period_mode: str = "absolute", + relative_margin_size: float = 1.0, fp_threshold: float = 0.1, fn_threshold: float = 0.1, minimum_n_spikes: int = 100, minimum_valid_period_duration: float = 180, + min_num_periods_relative: int = 5, user_defined_periods: Optional[object] = None, refractory_period_ms: float = 0.8, censored_period_ms: float = 0.0, @@ -176,6 +182,8 @@ def _set_params( period_duration_s_absolute=period_duration_s_absolute, period_target_num_spikes=period_target_num_spikes, period_mode=period_mode, + relative_margin_size=relative_margin_size, + min_num_periods_relative=min_num_periods_relative, fp_threshold=fp_threshold, fn_threshold=fn_threshold, minimum_n_spikes=minimum_n_spikes, @@ -227,7 +235,7 @@ def _merge_extension_data( # remove periods of merged units good_periods_valid = good_periods[~np.isin(good_periods["unit_index"], np.array(merged_unit_indices))] # recompute for merged units - good_periods_merged, period_centers, fps, fns = self._compute_periods( + good_periods_merged, period_centers, fps, fns = self._compute_valid_periods( new_sorting_analyzer, unit_ids=new_unit_ids, ) @@ -292,7 +300,7 @@ def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, # remove periods of split units good_periods_valid = good_periods[~np.isin(good_periods["unit_index"], split_unit_indices)] # recompute for split units - good_periods_split, period_centers, fps, fns = self._compute_periods( + good_periods_split, period_centers, fps, fns = self._compute_valid_periods( new_sorting_analyzer, unit_ids=new_unit_ids, ) @@ -336,10 +344,11 @@ def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, return new_extension_data - def _run(self, unit_ids=None, verbose=False): - valid_unit_periods, period_centers, fps, fns = self._compute_periods( + def _run(self, unit_ids=None, verbose=False, **job_kwargs): + valid_unit_periods, period_centers, fps, fns = self._compute_valid_periods( self.sorting_analyzer, unit_ids=unit_ids, + **job_kwargs, ) self.data["valid_unit_periods"] = valid_unit_periods if period_centers is not None: @@ -349,11 +358,7 @@ def _run(self, unit_ids=None, verbose=False): if fns is not None: self.data["periods_fn_per_unit"] = fns - def _compute_periods( - self, - sorting_analyzer, - unit_ids=None, - ): + def _compute_valid_periods(self, sorting_analyzer, unit_ids=None, **job_kwargs): from spikeinterface import get_global_job_kwargs if self.params["method"] == "user_defined": @@ -369,10 +374,12 @@ def _compute_periods( self.params["period_duration_s_absolute"], self.params["period_target_num_spikes"], self.params["period_mode"], + self.params["relative_margin_size"], + self.params["min_num_periods_relative"], unit_ids=unit_ids, ) - job_kwargs = get_global_job_kwargs() + job_kwargs = fix_job_kwargs(job_kwargs) n_jobs = job_kwargs["n_jobs"] progress_bar = job_kwargs["progress_bar"] max_threads_per_worker = job_kwargs["max_threads_per_worker"] @@ -512,6 +519,8 @@ def compute_subperiods( period_duration_s_absolute: float = 10, period_target_num_spikes: int = 1000, period_mode: str = "absolute", + relative_margin_size: float = 1.0, + min_num_periods_relative: int = 5, unit_ids: Optional[list] = None, ) -> dict: """ @@ -534,7 +543,7 @@ def compute_subperiods( period_sizes_samples = { u: np.round((period_target_num_spikes / mean_firing_rates[u]) * fs).astype(int) for u in unit_ids } - margin_sizes_samples = period_sizes_samples + margin_sizes_samples = {u: np.round(relative_margin_size * period_sizes_samples[u]).astype(int) for u in unit_ids} all_subperiods = [] for unit_index, unit_id in enumerate(unit_ids): @@ -543,13 +552,24 @@ def compute_subperiods( for segment_index in range(sorting.get_num_segments()): n_samples = sorting_analyzer.get_num_samples(segment_index) # int: samples - n_subperiods = n_samples // period_size_samples + 1 + # We round the number of subperiods to ensure coverage of the entire recording + # the end of the last period is then clipped or extended to the end of the recording + n_subperiods = round(n_samples / period_size_samples) + if period_mode == "relative" and n_subperiods < min_num_periods_relative: + n_subperiods = min_num_periods_relative # at least min_num_periods_relative subperiods + period_size_samples = n_samples // n_subperiods + margin_size_samples = int(relative_margin_size * period_size_samples) starts_ends = np.array( [ - [i * period_size_samples, i * period_size_samples + 2 * margin_size_samples] + [i * period_size_samples, i * period_size_samples + period_size_samples + 2 * margin_size_samples] for i in range(n_subperiods) ] ) + # remove periods whose end is above the expected number of samples + starts_ends = starts_ends[starts_ends[:, 1] <= n_subperiods * period_size_samples] + # extend last period to the end of the recording + starts_ends[-1][1] = n_samples + periods_for_unit = np.zeros(len(starts_ends), dtype=unit_period_dtype) for i, (start, end) in enumerate(starts_ends): subperiod = np.zeros((1,), dtype=unit_period_dtype) diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index 757401d77c..d59193ed8b 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -327,7 +327,6 @@ def _full_update_plot(self, change=None): backend_kwargs = dict(figure=self.figure, axes=None, ax=None) self.plot_matplotlib(data_plot, **backend_kwargs) - self._update_plot() def _update_plot(self, change=None): for ax in self.axes.flatten(): @@ -346,9 +345,6 @@ def _update_plot(self, change=None): self.figure.canvas.flush_events() -import numpy as np - - class RasterWidget(BaseRasterWidget): """ Plots spike train rasters. diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index e71b8b7d68..b46dc01d61 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -67,6 +67,7 @@ def setUpClass(cls): templates=dict(), noise_levels=dict(), spike_amplitudes=dict(), + amplitude_scalings=dict(max_dense_channels=None), # required by valid unit periods unit_locations=dict(), spike_locations=dict(), quality_metrics=dict( @@ -82,6 +83,18 @@ def setUpClass(cls): cls.sorting_analyzer_dense = create_sorting_analyzer(cls.sorting, cls.recording, format="memory", sparse=False) cls.sorting_analyzer_dense.compute("random_spikes") cls.sorting_analyzer_dense.compute(extensions_to_compute, **job_kwargs) + # compute valid periods later, since it depends on amplitude_scalings + cls.sorting_analyzer_dense.compute( + dict( + valid_unit_periods=dict( + period_mode="relative", + period_target_num_spikes=200, + relative_margin_size=0.5, + min_num_periods_relative=5, + ) + ), + **job_kwargs, + ) sw.set_default_plotter_backend("matplotlib") @@ -688,6 +701,14 @@ def test_plot_motion_info(self): if backend not in self.skip_backends: sw.plot_motion_info(motion_info, recording=self.recording, backend=backend) + def test_plot_valid_unit_periods(self): + possible_backends = list(sw.ValidUnitPeriodsWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_valid_unit_periods( + self.sorting_analyzer_dense, backend=backend, show_only_units_with_valid_periods=False + ) + if __name__ == "__main__": # unittest.main() diff --git a/src/spikeinterface/widgets/unit_valid_periods.py b/src/spikeinterface/widgets/unit_valid_periods.py index f911fab3e7..ffe9bc555f 100644 --- a/src/spikeinterface/widgets/unit_valid_periods.py +++ b/src/spikeinterface/widgets/unit_valid_periods.py @@ -7,7 +7,7 @@ from .base import BaseWidget, to_attr -class UnitValidPeriodsWidget(BaseWidget): +class ValidUnitPeriodsWidget(BaseWidget): """ Plots the valid periods for units based on valid periods extension. @@ -19,8 +19,8 @@ class UnitValidPeriodsWidget(BaseWidget): The segment index. If None, uses first segment. unit_ids : list | None, default: None List of unit ids to plot. If None, all units are plotted. - show_only_units_with_good_periods : bool, default: True - If True, only units with good periods are shown. + show_only_units_with_valid_periods : bool, default: True + If True, only units with valid periods are shown. """ def __init__( @@ -28,15 +28,15 @@ def __init__( sorting_analyzer: SortingAnalyzer | None = None, segment_index: int | None = None, unit_ids: list | None = None, - show_only_units_with_good_periods: bool = True, + show_only_units_with_valid_periods: bool = True, backend: str | None = None, **backend_kwargs, ): sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) self.check_extensions(sorting_analyzer, "valid_unit_periods") - good_periods_ext = sorting_analyzer.get_extension("valid_unit_periods") - if good_periods_ext.params["method"] == "user_defined": - raise ValueError("UnitValidPeriodsWidget cannot be used with 'user_defined' good periods.") + valid_periods_ext = sorting_analyzer.get_extension("valid_unit_periods") + if valid_periods_ext.params["method"] == "user_defined": + raise ValueError("UnitValidPeriodsWidget cannot be used with 'user_defined' valid periods.") if segment_index is None: nseg = sorting_analyzer.get_num_segments() @@ -45,18 +45,18 @@ def __init__( else: segment_index = 0 - good_periods = good_periods_ext.get_data(outputs="numpy") - if show_only_units_with_good_periods: - good_unit_ids = sorting_analyzer.unit_ids[np.unique(good_periods["unit_index"])] + valid_periods = valid_periods_ext.get_data(outputs="numpy") + if show_only_units_with_valid_periods: + valid_unit_ids = sorting_analyzer.unit_ids[np.unique(valid_periods["unit_index"])] else: - good_unit_ids = sorting_analyzer.unit_ids + valid_unit_ids = sorting_analyzer.unit_ids if unit_ids is not None: - good_unit_ids = [u for u in unit_ids if u in good_unit_ids] + valid_unit_ids = [u for u in unit_ids if u in valid_unit_ids] data_plot = dict( sorting_analyzer=sorting_analyzer, segment_index=segment_index, - unit_ids=good_unit_ids, + unit_ids=valid_unit_ids, ) BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) @@ -70,6 +70,8 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if backend_kwargs["axes"] is not None: axes = backend_kwargs["axes"] + if axes.ndim == 1: + axes = axes[:, None] assert np.asarray(axes).shape == (3, num_units), "Axes shape does not match number of units" else: if "figsize" not in backend_kwargs: @@ -94,12 +96,12 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): amp_scalings_ext = sorting_analyzer.get_extension("amplitude_scalings") amp_scalings_by_unit = amp_scalings_ext.get_data(outputs="by_unit")[segment_index] - for unit_id in dp.unit_ids: + for ui, unit_id in enumerate(dp.unit_ids): fp = fp_per_unit[unit_id] fn = fn_per_unit[unit_id] unit_index = list(sorting_analyzer.unit_ids).index(unit_id) - axs = self.axes[:, unit_index] + axs = self.axes[:, ui] # for simplicity we don't use timestamps here spiketrain = ( sorting_analyzer.sorting.get_unit_spike_train(unit_id, segment_index=segment_index) / sampling_frequency @@ -126,7 +128,15 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): axs[2].set_ylabel("Amplitude Scaling") axs[0].set_title(f"Unit {unit_id}") - # TODO: fix update + axs[1].sharex(axs[0]) + axs[2].sharex(axs[0]) + + for ax in self.axes.flatten(): + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + self.figure.subplots_adjust(hspace=0.4) + def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets @@ -161,25 +171,30 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): ) # a first update - self.axes = None - self._update_ipywidget() + self._full_update_plot() - self.unit_selector.observe(self._update_ipywidget, names="value", type="change") + self.unit_selector.observe(self._update_plot, names=["value"], type="change") if backend_kwargs["display"]: display(self.widget) - def _update_ipywidget(self, change=None): - if self.axes is None: - self.figure.clear() - else: - for ax in self.axes.flatten(): - ax.clear() + def _full_update_plot(self, change=None): + self.figure.clear() + data_plot = self.next_data_plot + data_plot["unit_ids"] = self.unit_selector.value + backend_kwargs = dict(figure=self.figure, axes=None, ax=None) + self.plot_matplotlib(data_plot, **backend_kwargs) + + def _update_plot(self, change=None): + print(f"_update_plot called! change={change}", flush=True) + + for ax in self.axes.flatten(): + ax.clear() data_plot = self.next_data_plot data_plot["unit_ids"] = self.unit_selector.value - backend_kwargs = dict(figure=self.figure, axes=self.axes, ax=None) + backend_kwargs = dict(figure=None, axes=self.axes, ax=None) self.plot_matplotlib(data_plot, **backend_kwargs) self.figure.canvas.draw() diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 8281666dd8..e74ad38053 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -37,7 +37,7 @@ from .comparison import AgreementMatrixWidget, ConfusionMatrixWidget from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, StudyAgreementMatrix, StudySummary from .collision import ComparisonCollisionBySimilarityWidget, StudyComparisonCollisionBySimilarityWidget -from .unit_valid_periods import UnitValidPeriodsWidget +from .unit_valid_periods import ValidUnitPeriodsWidget widget_list = [ AgreementMatrixWidget, @@ -49,7 +49,7 @@ CrossCorrelogramsWidget, DriftingTemplatesWidget, DriftRasterMapWidget, - UnitValidPeriodsWidget, + ValidUnitPeriodsWidget, ISIDistributionWidget, LocationsWidget, MotionWidget, @@ -130,7 +130,7 @@ plot_crosscorrelograms = CrossCorrelogramsWidget plot_drifting_templates = DriftingTemplatesWidget plot_drift_raster_map = DriftRasterMapWidget -plot_valid_unit_periods = UnitValidPeriodsWidget +plot_valid_unit_periods = ValidUnitPeriodsWidget plot_isi_distribution = ISIDistributionWidget plot_locations = LocationsWidget plot_motion = MotionWidget From bb46f27ad9f719bfcd0db25fae1e55e4c2cbbe8d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 13 Jan 2026 11:59:57 +0100 Subject: [PATCH 25/70] Update src/spikeinterface/core/sorting_tools.py Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> --- src/spikeinterface/core/sorting_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 9a9a3670ef..4695f9b289 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -265,7 +265,7 @@ def select_sorting_periods_mask(sorting: BaseSorting, periods): return keep_mask -def select_sorting_periods(sorting: BaseSorting, periods): +def select_sorting_periods(sorting: BaseSorting, periods) -> BaseSorting: """ Returns a new sorting object, restricted to the given periods of dtype unit_period_dtype. From 121a0b19c3c435fa3a3f7bd64508eb440b371393 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 13 Jan 2026 12:00:58 +0100 Subject: [PATCH 26/70] Apply suggestion from @chrishalcrow Co-authored-by: Chris Halcrow <57948917+chrishalcrow@users.noreply.github.com> --- src/spikeinterface/core/sorting_tools.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 4695f9b289..75e25115ae 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -271,7 +271,6 @@ def select_sorting_periods(sorting: BaseSorting, periods) -> BaseSorting: Parameters ---------- - S periods : numpy.array of unit_period_dtype Periods (segment_index, start_sample_index, end_sample_index, unit_index) on which to restrict the sorting. From cbf3213a4c3769eae38f5203a42025531e80f0dd Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 13 Jan 2026 13:06:43 +0100 Subject: [PATCH 27/70] refactor presence ratio and drift metrics to use periods properly --- .../metrics/quality/misc_metrics.py | 166 ++++++++++-------- .../quality/tests/test_metrics_functions.py | 5 +- src/spikeinterface/metrics/quality/utils.py | 47 ----- .../metrics/spiketrain/metrics.py | 10 +- src/spikeinterface/metrics/utils.py | 121 +++++++++++++ 5 files changed, 220 insertions(+), 129 deletions(-) delete mode 100644 src/spikeinterface/metrics/quality/utils.py create mode 100644 src/spikeinterface/metrics/utils.py diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 4a7ef04554..e720477ee6 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -19,15 +19,14 @@ from spikeinterface.core.analyzer_extension_core import BaseMetric from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs from spikeinterface.postprocessing import correlogram_for_one_segment -from spikeinterface.core import SortingAnalyzer, get_noise_levels, select_segment_sorting +from spikeinterface.core import SortingAnalyzer, get_noise_levels from spikeinterface.core.template_tools import ( get_template_extremum_channel, get_template_extremum_amplitude, get_dense_templates_array, ) -from spikeinterface.core.node_pipeline import base_period_dtype - -from ..spiketrain.metrics import NumSpikes, FiringRate +from spikeinterface.metrics.spiketrain.metrics import NumSpikes, FiringRate +from spikeinterface.metrics.utils import compute_bin_edges_per_unit, compute_total_durations_per_unit numba_spec = importlib.util.find_spec("numba") if numba_spec is not None: @@ -74,12 +73,16 @@ def compute_presence_ratios( unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() - seg_lengths = [sorting_analyzer.get_num_samples(i) for i in range(num_segs)] - total_length = sorting_analyzer.get_total_samples() - total_duration = sorting_analyzer.get_total_duration() + segment_samples = [sorting_analyzer.get_num_samples(i) for i in range(num_segs)] + total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) + total_samples = np.sum(segment_samples) bin_duration_samples = int((bin_duration_s * sorting_analyzer.sampling_frequency)) - num_bin_edges = total_length // bin_duration_samples + 1 - bin_edges = np.arange(num_bin_edges) * bin_duration_samples + bin_edges_per_unit = compute_bin_edges_per_unit( + sorting, + segment_samples=segment_samples, + periods=periods, + bin_duration_s=bin_duration_s, + ) mean_fr_ratio_thresh = float(mean_fr_ratio_thresh) if mean_fr_ratio_thresh < 0: @@ -90,7 +93,7 @@ def compute_presence_ratios( warnings.warn("`mean_fr_ratio_thres` parameter above 1 might lead to low presence ratios.") presence_ratios = {} - if total_length < bin_duration_samples: + if total_samples < bin_duration_samples: warnings.warn( f"Bin duration of {bin_duration_s}s is larger than recording duration. " f"Presence ratios are set to NaN." ) @@ -98,9 +101,15 @@ def compute_presence_ratios( else: for unit_id in unit_ids: spike_train = [] + bin_edges = bin_edges_per_unit[unit_id] + if len(bin_edges) < 2: + presence_ratios[unit_id] = 0.0 + continue + total_duration = total_durations[unit_id] + for segment_index in range(num_segs): st = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - st = st + np.sum(seg_lengths[:segment_index]) + st = st + np.sum(segment_samples[:segment_index]) spike_train.append(st) spike_train = np.concatenate(spike_train) @@ -109,7 +118,6 @@ def compute_presence_ratios( presence_ratios[unit_id] = presence_ratio( spike_train, - total_length, bin_edges=bin_edges, bin_n_spikes_thres=bin_n_spikes_thres, ) @@ -250,7 +258,7 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() - total_duration_s = sorting_analyzer.get_total_duration() + total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) fs = sorting_analyzer.sampling_frequency isi_threshold_s = isi_threshold_ms / 1000 @@ -271,7 +279,8 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 if not any([len(train) > 0 for train in spike_train_list]): continue - ratio, _, count = isi_violations(spike_train_list, total_duration_s, isi_threshold_s, min_isi_s) + total_duration = total_durations[unit_id] + ratio, _, count = isi_violations(spike_train_list, total_duration, isi_threshold_s, min_isi_s) isi_violations_ratio[unit_id] = ratio isi_violations_count[unit_id] = count @@ -449,7 +458,7 @@ def compute_sliding_rp_violations( This code was adapted from: https://github.com/SteinmetzLab/slidingRefractory/blob/1.0.0/python/slidingRP/metrics.py """ - duration = sorting_analyzer.get_total_duration() + total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) sorting = sorting_analyzer.sorting sorting = sorting.select_periods(periods=periods) @@ -477,6 +486,7 @@ def compute_sliding_rp_violations( contamination[unit_id] = np.nan continue + duration = total_durations[unit_id] contamination[unit_id] = slidingRP_violations( spike_train_list, fs, @@ -582,6 +592,7 @@ class Synchrony(BaseMetric): } +# TODO: refactor for periods def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percentiles=(5, 95), periods=None): """ Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution @@ -659,6 +670,7 @@ class FiringRange(BaseMetric): } +# TODO: refactor for periods def compute_amplitude_cv_metrics( sorting_analyzer, unit_ids=None, @@ -710,13 +722,14 @@ def compute_amplitude_cv_metrics( "spike_amplitudes", "amplitude_scalings", ), "Invalid amplitude_extension. It can be either 'spike_amplitudes' or 'amplitude_scalings'" - sorting = sorting_analyzer.sorting - total_duration = sorting_analyzer.get_total_duration() - spikes = sorting.to_spike_vector() - num_spikes = sorting.count_num_spikes_per_unit(outputs="dict") if unit_ids is None: unit_ids = sorting.unit_ids + sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) + total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) + spikes = sorting.to_spike_vector() + num_spikes = sorting.count_num_spikes_per_unit(outputs="dict") amps = sorting_analyzer.get_extension(amplitude_extension).get_data(periods=periods) # precompute segment slice @@ -729,6 +742,7 @@ def compute_amplitude_cv_metrics( all_unit_ids = list(sorting.unit_ids) amplitude_cv_medians, amplitude_cv_ranges = {}, {} for unit_id in unit_ids: + total_duration = total_durations[unit_id] firing_rate = num_spikes[unit_id] / total_duration temporal_bin_size_samples = int( (average_num_spikes_per_bin / firing_rate) * sorting_analyzer.sampling_frequency @@ -1078,34 +1092,30 @@ def compute_drift_metrics( check_has_required_extensions("drift", sorting_analyzer) res = namedtuple("drift_metrics", ["drift_ptp", "drift_std", "drift_mad"]) sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting.unit_ids spike_locations_ext = sorting_analyzer.get_extension("spike_locations") spike_locations = spike_locations_ext.get_data(periods=periods) - spikes = sorting.to_spike_vector() - spike_locations_by_unit = {} - for unit_id in unit_ids: - unit_index = sorting.id_to_index(unit_id) - # TODO @alessio this is very slow this sjould be done with spike_vector_to_indices() in code - spike_mask = spikes["unit_index"] == unit_index - spike_locations_by_unit[unit_id] = spike_locations[spike_mask] + spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit", concatenated=True, periods=periods) + segment_samples = [sorting_analyzer.get_num_samples(i) for i in range(sorting_analyzer.get_num_segments())] interval_samples = int(interval_s * sorting_analyzer.sampling_frequency) assert direction in spike_locations.dtype.names, ( f"Direction {direction} is invalid. Available directions: " f"{spike_locations.dtype.names}" ) - total_duration = sorting_analyzer.get_total_duration() - if total_duration < min_num_bins * interval_s: - warnings.warn( - "The recording is too short given the specified 'interval_s' and " - "'min_num_bins'. Drift metrics will be set to NaN" - ) - empty_dict = {unit_id: np.nan for unit_id in unit_ids} - if return_positions: - return res(empty_dict, empty_dict, empty_dict), np.nan - else: - return res(empty_dict, empty_dict, empty_dict) + # total_duration = sorting_analyzer.get_total_duration() + # if total_duration < min_num_bins * interval_s: + # warnings.warn( + # "The recording is too short given the specified 'interval_s' and " + # "'min_num_bins'. Drift metrics will be set to NaN" + # ) + # empty_dict = {unit_id: np.nan for unit_id in unit_ids} + # if return_positions: + # return res(empty_dict, empty_dict, empty_dict), np.nan + # else: + # return res(empty_dict, empty_dict, empty_dict) # we need drift_ptps = {} @@ -1113,45 +1123,50 @@ def compute_drift_metrics( drift_mads = {} # reference positions are the medians across segments - reference_positions = np.zeros(len(unit_ids)) - for i, unit_id in enumerate(unit_ids): - unit_ind = sorting.id_to_index(unit_id) - reference_positions[i] = np.median(spike_locations_by_unit[unit_id][direction]) + reference_positions = {} + for unit_id in unit_ids: + reference_positions[unit_id] = np.median(spike_locations_by_unit[unit_id][direction]) # now compute median positions and concatenate them over segments median_position_segments = None - for segment_index in range(sorting_analyzer.get_num_segments()): - seg_length = sorting_analyzer.get_num_samples(segment_index) - num_bin_edges = seg_length // interval_samples + 1 - bins = np.arange(num_bin_edges) * interval_samples - spike_vector = sorting.to_spike_vector() - - # retrieve spikes in segment - i0, i1 = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1]) - spikes_in_segment = spike_vector[i0:i1] - spike_locations_in_segment = spike_locations[i0:i1] - - # compute median positions (if less than min_spikes_per_interval, median position is 0) - median_positions = np.nan * np.zeros((len(unit_ids), num_bin_edges - 1)) - for bin_index, (start_frame, end_frame) in enumerate(zip(bins[:-1], bins[1:])): - i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame]) - spikes_in_bin = spikes_in_segment[i0:i1] - spike_locations_in_bin = spike_locations_in_segment[i0:i1][direction] - - for i, unit_id in enumerate(unit_ids): - unit_ind = sorting.id_to_index(unit_id) - mask = spikes_in_bin["unit_index"] == unit_ind - if np.sum(mask) >= min_spikes_per_interval: - median_positions[i, bin_index] = np.median(spike_locations_in_bin[mask]) - if median_position_segments is None: - median_position_segments = median_positions - else: - median_position_segments = np.hstack((median_position_segments, median_positions)) + spike_vector = sorting.to_spike_vector() + bin_edges_for_units = compute_bin_edges_per_unit( + sorting, + segment_samples=segment_samples, + periods=periods, + bin_duration_s=interval_s, + ) + + median_positions_per_unit = {} + for i, unit in enumerate(unit_ids): + bins = bin_edges_for_units[unit] + num_bins = len(bins) - 1 + if num_bins < min_num_bins: + warnings.warn( + f"Unit {unit} has only {num_bins} bins given the specified 'interval_s' and " + f"'min_num_bins'. Drift metrics will be set to NaN" + ) + drift_ptps[unit] = np.nan + drift_stds[unit] = np.nan + drift_mads[unit] = np.nan + continue + + bin_spike_indices = np.searchsorted(spike_vector["sample_index"], bins) + median_positions = np.nan * np.zeros(num_bins) + for bin_index, (i0, i1) in enumerate(zip(bin_spike_indices[:-1], bin_spike_indices[1:])): + spikes_in_bin = spike_vector[i0:i1] + spike_locations_in_bin = spike_locations[i0:i1][direction] - # finally, compute deviations and drifts - position_diffs = median_position_segments - reference_positions[:, None] - for i, unit_id in enumerate(unit_ids): - position_diff = position_diffs[i] + unit_index = sorting_analyzer.sorting.id_to_index(unit) + mask = spikes_in_bin["unit_index"] == unit_index + if np.sum(mask) >= min_spikes_per_interval: + median_positions[bin_index] = np.median(spike_locations_in_bin[mask]) + else: + median_positions[bin_index] = np.nan + median_positions_per_unit[unit] = median_positions + + # now compute deviations and drifts for this unit + position_diff = median_positions - reference_positions[unit_id] if np.any(np.isnan(position_diff)): # deal with nans: if more than 50% nans --> set to nan if np.sum(np.isnan(position_diff)) > min_fraction_valid_intervals * len(position_diff): @@ -1169,8 +1184,9 @@ def compute_drift_metrics( drift_ptps[unit_id] = ptp_drift drift_stds[unit_id] = std_drift drift_mads[unit_id] = mad_drift + if return_positions: - outs = res(drift_ptps, drift_stds, drift_mads), median_positions + outs = res(drift_ptps, drift_stds, drift_mads), median_positions_per_unit else: outs = res(drift_ptps, drift_stds, drift_mads) return outs @@ -1385,7 +1401,7 @@ def check_has_required_extensions(metric_name, sorting_analyzer): ### LOW-LEVEL FUNCTIONS ### -def presence_ratio(spike_train, total_length, bin_edges=None, num_bin_edges=None, bin_n_spikes_thres=0): +def presence_ratio(spike_train, bin_edges=None, num_bin_edges=None, bin_n_spikes_thres=0): """ Calculate the presence ratio for a single unit. @@ -1393,8 +1409,6 @@ def presence_ratio(spike_train, total_length, bin_edges=None, num_bin_edges=None ---------- spike_train : np.ndarray Spike times for this unit, in samples. - total_length : int - Total length of the recording in samples. bin_edges : np.array, optional Pre-computed bin edges (mutually exclusive with num_bin_edges). num_bin_edges : int, optional diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index c0dd6c6033..57516d6bc3 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -12,11 +12,8 @@ synthesize_random_firings, ) -from spikeinterface.metrics.quality.utils import create_ground_truth_pc_distributions +from spikeinterface.metrics.utils import create_ground_truth_pc_distributions -# from spikeinterface.metrics.quality_metric_list import ( -# _misc_metric_name_to_func, -# ) from spikeinterface.metrics.quality import ( get_quality_metric_list, diff --git a/src/spikeinterface/metrics/quality/utils.py b/src/spikeinterface/metrics/quality/utils.py deleted file mode 100644 index 844a7da7f5..0000000000 --- a/src/spikeinterface/metrics/quality/utils.py +++ /dev/null @@ -1,47 +0,0 @@ -from __future__ import annotations - -import numpy as np - - -def create_ground_truth_pc_distributions(center_locations, total_points): - """ - Simulate PCs as multivariate Gaussians, for testing PC-based quality metrics - Values are created for only one channel and vary along one dimension. - - Parameters - ---------- - center_locations : array-like (units, ) or (channels, units) - Mean of the multivariate gaussian at each channel for each unit. - total_points : array-like - Number of points in each unit distribution. - - Returns - ------- - all_pcs : numpy.ndarray - PC scores for each point. - all_labels : numpy.array - Labels for each point. - """ - from scipy.stats import multivariate_normal - - np.random.seed(0) - - if len(np.array(center_locations).shape) == 1: - distributions = [ - multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) - for center, size in zip(center_locations, total_points) - ] - all_pcs = np.concatenate(distributions, axis=0) - - else: - all_pcs = np.empty((np.sum(total_points), 3, center_locations.shape[0])) - for channel in range(center_locations.shape[0]): - distributions = [ - multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) - for center, size in zip(center_locations[channel], total_points) - ] - all_pcs[:, :, channel] = np.concatenate(distributions, axis=0) - - all_labels = np.concatenate([np.ones((total_points[i],), dtype="int") * i for i in range(len(total_points))]) - - return all_pcs, all_labels diff --git a/src/spikeinterface/metrics/spiketrain/metrics.py b/src/spikeinterface/metrics/spiketrain/metrics.py index ba66d0671c..0ddb5fabe7 100644 --- a/src/spikeinterface/metrics/spiketrain/metrics.py +++ b/src/spikeinterface/metrics/spiketrain/metrics.py @@ -2,7 +2,7 @@ from spikeinterface.core.analyzer_extension_core import BaseMetric -def compute_num_spikes(sorting_analyzer, unit_ids=None, **kwargs): +def compute_num_spikes(sorting_analyzer, unit_ids=None, periods=None, **kwargs): """ Compute the number of spike across segments. @@ -12,6 +12,8 @@ def compute_num_spikes(sorting_analyzer, unit_ids=None, **kwargs): A SortingAnalyzer object. unit_ids : list or None The list of unit ids to compute the number of spikes. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) Returns ------- @@ -20,6 +22,7 @@ def compute_num_spikes(sorting_analyzer, unit_ids=None, **kwargs): """ sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods) if unit_ids is None: unit_ids = sorting.unit_ids num_segs = sorting.get_num_segments() @@ -43,7 +46,7 @@ class NumSpikes(BaseMetric): metric_columns = {"num_spikes": int} -def compute_firing_rates(sorting_analyzer, unit_ids=None): +def compute_firing_rates(sorting_analyzer, unit_ids=None, periods=None): """ Compute the firing rate across segments. @@ -53,6 +56,8 @@ def compute_firing_rates(sorting_analyzer, unit_ids=None): A SortingAnalyzer object. unit_ids : list or None The list of unit ids to compute the firing rate. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) Returns ------- @@ -61,6 +66,7 @@ def compute_firing_rates(sorting_analyzer, unit_ids=None): """ sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods) if unit_ids is None: unit_ids = sorting.unit_ids total_duration = sorting_analyzer.get_total_duration() diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py new file mode 100644 index 0000000000..beb9b505ff --- /dev/null +++ b/src/spikeinterface/metrics/utils.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import numpy as np + + +def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, periods=None): + """ + Compute bin edges for units, optionally taking into account periods. + + Parameters + ---------- + sorting : Sorting + Sorting object containing unit information. + segment_samples : list or array-like + Number of samples in each segment. + bin_duration_s : float, default: 1 + Duration of each bin in seconds + periods : array of unit_period_dtype, default: None + Periods to consider for each unit + """ + bin_edges_for_units = {} + num_segments = len(segment_samples) + bin_duration_samples = int(bin_duration_s * sorting.sampling_frequency) + + if periods is not None: + for unit_id in sorting.unit_ids: + unit_index = sorting.id_to_index(unit_id) + periods_unit = periods[periods["unit_index"] == unit_index] + bin_edges = [] + for seg_index in range(num_segments): + seg_periods = periods_unit[periods_unit["segment_index"] == seg_index] + if len(seg_periods) == 0: + continue + seg_start = np.sum(segment_samples[:seg_index]) + for period in seg_periods: + start_sample = seg_start + period["start_sample_index"] + end_sample = seg_start + period["end_sample_index"] + bin_edges.extend(np.arange(start_sample, end_sample, bin_duration_samples)) + bin_edges_for_units[unit_id] = np.array(bin_edges) + else: + total_length = np.sum(segment_samples) + for unit_id in sorting.unit_ids: + bin_edges_for_units[unit_id] = np.arange(0, total_length, bin_duration_samples) * bin_duration_samples + return bin_edges_for_units + + +def compute_total_durations_per_unit(sorting_analyzer, periods=None): + """ + Compute total duration for each unit, optionally taking into account periods. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object. + periods : array of unit_period_dtype, default: None + Periods to consider for each unit. + + Returns + ------- + dict + Total duration for each unit. + """ + if periods is not None: + total_durations = {} + sorting = sorting_analyzer.sorting + for unit_id in sorting.unit_ids: + unit_index = sorting.id_to_index(unit_id) + periods_unit = periods[periods["unit_index"] == unit_index] + total_duration = 0 + for period in periods_unit: + total_duration += period["end_sample_index"] - period["start_sample_index"] + total_durations[unit_id] = total_duration / sorting.sampling_frequency + else: + total_durations = { + unit_id: sorting_analyzer.get_total_duration_per_unit() for unit_id in sorting_analyzer.unit_ids + } + return total_durations + + +def create_ground_truth_pc_distributions(center_locations, total_points): + """ + Simulate PCs as multivariate Gaussians, for testing PC-based quality metrics + Values are created for only one channel and vary along one dimension. + + Parameters + ---------- + center_locations : array-like (units, ) or (channels, units) + Mean of the multivariate gaussian at each channel for each unit. + total_points : array-like + Number of points in each unit distribution. + + Returns + ------- + all_pcs : numpy.ndarray + PC scores for each point. + all_labels : numpy.array + Labels for each point. + """ + from scipy.stats import multivariate_normal + + np.random.seed(0) + + if len(np.array(center_locations).shape) == 1: + distributions = [ + multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) + for center, size in zip(center_locations, total_points) + ] + all_pcs = np.concatenate(distributions, axis=0) + + else: + all_pcs = np.empty((np.sum(total_points), 3, center_locations.shape[0])) + for channel in range(center_locations.shape[0]): + distributions = [ + multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) + for center, size in zip(center_locations[channel], total_points) + ] + all_pcs[:, :, channel] = np.concatenate(distributions, axis=0) + + all_labels = np.concatenate([np.ones((total_points[i],), dtype="int") * i for i in range(len(total_points))]) + + return all_pcs, all_labels From 4409aa5fc1dcd1d5fb93d3a075a598df0c18113f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 13 Jan 2026 13:16:15 +0100 Subject: [PATCH 28/70] Fix rp_violations --- .../metrics/quality/misc_metrics.py | 11 +++-- src/spikeinterface/metrics/utils.py | 41 ++++++++++++++----- 2 files changed, 39 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index e720477ee6..74aca85dce 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -26,7 +26,11 @@ get_dense_templates_array, ) from spikeinterface.metrics.spiketrain.metrics import NumSpikes, FiringRate -from spikeinterface.metrics.utils import compute_bin_edges_per_unit, compute_total_durations_per_unit +from spikeinterface.metrics.utils import ( + compute_bin_edges_per_unit, + compute_total_durations_per_unit, + compute_total_samples_per_unit, +) numba_spec = importlib.util.find_spec("numba") if numba_spec is not None: @@ -366,7 +370,7 @@ def compute_refrac_period_violations( if unit_ids is None: unit_ids = sorting_analyzer.unit_ids - num_spikes = compute_num_spikes(sorting_analyzer) + num_spikes = sorting.count_num_spikes_per_unit() t_c = int(round(censored_period_ms * fs * 1e-3)) t_r = int(round(refractory_period_ms * fs * 1e-3)) @@ -377,7 +381,7 @@ def compute_refrac_period_violations( spike_labels = spikes[seg_index]["unit_index"].astype(np.int32) _compute_rp_violations_numba(nb_rp_violations, spike_times, spike_labels, t_c, t_r) - T = sorting_analyzer.get_total_samples() + total_samples = compute_total_samples_per_unit(sorting_analyzer, periods=periods) nb_violations = {} rp_contamination = {} @@ -388,6 +392,7 @@ def compute_refrac_period_violations( nb_violations[unit_id] = n_v = nb_rp_violations[unit_index] N = num_spikes[unit_id] + T = total_samples[unit_id] if N == 0: rp_contamination[unit_id] = np.nan else: diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py index beb9b505ff..446f9ce471 100644 --- a/src/spikeinterface/metrics/utils.py +++ b/src/spikeinterface/metrics/utils.py @@ -44,9 +44,9 @@ def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, per return bin_edges_for_units -def compute_total_durations_per_unit(sorting_analyzer, periods=None): +def get_total_samples_per_unit(sorting_analyzer, periods=None): """ - Compute total duration for each unit, optionally taking into account periods. + Get total number of samples for each unit, optionally taking into account periods. Parameters ---------- @@ -58,22 +58,43 @@ def compute_total_durations_per_unit(sorting_analyzer, periods=None): Returns ------- dict - Total duration for each unit. + Total number of samples for each unit. """ if periods is not None: - total_durations = {} + total_samples = {} sorting = sorting_analyzer.sorting for unit_id in sorting.unit_ids: unit_index = sorting.id_to_index(unit_id) periods_unit = periods[periods["unit_index"] == unit_index] - total_duration = 0 + num_samples_in_period = 0 for period in periods_unit: - total_duration += period["end_sample_index"] - period["start_sample_index"] - total_durations[unit_id] = total_duration / sorting.sampling_frequency + num_samples_in_period += period["end_sample_index"] - period["start_sample_index"] + total_samples[unit_id] = num_samples_in_period else: - total_durations = { - unit_id: sorting_analyzer.get_total_duration_per_unit() for unit_id in sorting_analyzer.unit_ids - } + total_samples = {unit_id: sorting_analyzer.get_total_samples() for unit_id in sorting_analyzer.unit_ids} + return total_samples + + +def compute_total_durations_per_unit(sorting_analyzer, periods=None): + """ + Compute total duration for each unit, optionally taking into account periods. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object. + periods : array of unit_period_dtype, default: None + Periods to consider for each unit. + + Returns + ------- + dict + Total duration for each unit. + """ + total_samples = get_total_samples_per_unit(sorting_analyzer, periods=periods) + total_durations = { + unit_id: samples / sorting_analyzer.sorting.sampling_frequency for unit_id, samples in total_samples.items() + } return total_durations From 71f8668c5f8eb25d50c53a614f81552947ab1ede Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 13 Jan 2026 13:27:25 +0100 Subject: [PATCH 29/70] implement firing range and fix drift --- .../metrics/quality/misc_metrics.py | 81 ++++++++----------- 1 file changed, 33 insertions(+), 48 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 74aca85dce..b30ab068fe 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -597,7 +597,6 @@ class Synchrony(BaseMetric): } -# TODO: refactor for periods def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percentiles=(5, 95), periods=None): """ Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution @@ -630,6 +629,9 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent bin_size_samples = int(bin_size_s * sampling_frequency) sorting = sorting_analyzer.sorting sorting = sorting.select_periods(periods=periods) + segment_samples = [ + sorting_analyzer.get_num_samples(segment_index) for segment_index in range(sorting_analyzer.get_num_segments()) + ] if unit_ids is None: unit_ids = sorting.unit_ids @@ -645,15 +647,25 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent # for each segment, we compute the firing rate histogram and we concatenate them firing_rate_histograms = {unit_id: np.array([], dtype=float) for unit_id in sorting.unit_ids} - for segment_index in range(sorting_analyzer.get_num_segments()): - num_samples = sorting_analyzer.get_num_samples(segment_index) - edges = np.arange(0, num_samples + 1, bin_size_samples) + bin_edges_per_unit = compute_bin_edges_per_unit( + sorting, + segment_samples=segment_samples, + periods=periods, + bin_duration_s=bin_size_s, + ) + for unit_id in unit_ids: + bin_edges = bin_edges_per_unit[unit_id] - for unit_id in unit_ids: - spike_times = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - spike_counts, _ = np.histogram(spike_times, bins=edges) - firing_rates = spike_counts / bin_size_s - firing_rate_histograms[unit_id] = np.concatenate((firing_rate_histograms[unit_id], firing_rates)) + # we can concatenate spike trains across segments adding the cumulative number of samples + # as offset, since bin edges are already cumulative + for segment_index in range(sorting_analyzer.get_num_segments()): + st = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) + st = st + np.sum(segment_samples[:segment_index]) + spike_train.append(st) + spike_train = np.concatenate(spike_train) + + spike_counts, _ = np.histogram(spike_train, bins=bin_edges) + firing_rate_histograms[unit_id] = spike_counts / bin_size_s # finally we compute the percentiles firing_ranges = {} @@ -731,9 +743,9 @@ def compute_amplitude_cv_metrics( unit_ids = sorting.unit_ids sorting = sorting_analyzer.sorting sorting = sorting.select_periods(periods=periods) - total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) spikes = sorting.to_spike_vector() + total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) num_spikes = sorting.count_num_spikes_per_unit(outputs="dict") amps = sorting_analyzer.get_extension(amplitude_extension).get_data(periods=periods) @@ -1106,21 +1118,9 @@ def compute_drift_metrics( spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit", concatenated=True, periods=periods) segment_samples = [sorting_analyzer.get_num_samples(i) for i in range(sorting_analyzer.get_num_segments())] - interval_samples = int(interval_s * sorting_analyzer.sampling_frequency) assert direction in spike_locations.dtype.names, ( f"Direction {direction} is invalid. Available directions: " f"{spike_locations.dtype.names}" ) - # total_duration = sorting_analyzer.get_total_duration() - # if total_duration < min_num_bins * interval_s: - # warnings.warn( - # "The recording is too short given the specified 'interval_s' and " - # "'min_num_bins'. Drift metrics will be set to NaN" - # ) - # empty_dict = {unit_id: np.nan for unit_id in unit_ids} - # if return_positions: - # return res(empty_dict, empty_dict, empty_dict), np.nan - # else: - # return res(empty_dict, empty_dict, empty_dict) # we need drift_ptps = {} @@ -1133,8 +1133,14 @@ def compute_drift_metrics( reference_positions[unit_id] = np.median(spike_locations_by_unit[unit_id][direction]) # now compute median positions and concatenate them over segments - median_position_segments = None spike_vector = sorting.to_spike_vector() + spike_sample_indices = spike_vector["sample_index"] + # we need to add the cumulative sum of segment samples to have global sample indices + cumulative_segment_samples = np.cumsum([0] + segment_samples[:-1]) + for segment_index in range(sorting_analyzer.get_num_segments()): + seg_mask = spike_vector["segment_index"] == segment_index + spike_sample_indices[seg_mask] += cumulative_segment_samples[segment_index] + bin_edges_for_units = compute_bin_edges_per_unit( sorting, segment_samples=segment_samples, @@ -1143,7 +1149,7 @@ def compute_drift_metrics( ) median_positions_per_unit = {} - for i, unit in enumerate(unit_ids): + for unit in unit_ids: bins = bin_edges_for_units[unit] num_bins = len(bins) - 1 if num_bins < min_num_bins: @@ -1156,7 +1162,9 @@ def compute_drift_metrics( drift_mads[unit] = np.nan continue - bin_spike_indices = np.searchsorted(spike_vector["sample_index"], bins) + # bin_edges are global across segments, so we have to use spike_sample_indices, + # since we offseted them to be global + bin_spike_indices = np.searchsorted(spike_sample_indices, bins) median_positions = np.nan * np.zeros(num_bins) for bin_index, (i0, i1) in enumerate(zip(bin_spike_indices[:-1], bin_spike_indices[1:])): spikes_in_bin = spike_vector[i0:i1] @@ -1783,29 +1791,6 @@ def _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): return synchrony_counts -def _get_amplitudes_by_units(sorting_analyzer, unit_ids, peak_sign): - # used by compute_amplitude_cutoffs and compute_amplitude_medians - - if (spike_amplitudes_extension := sorting_analyzer.get_extension("spike_amplitudes")) is not None: - return spike_amplitudes_extension.get_data(outputs="by_unit", concatenated=True) - - elif sorting_analyzer.has_extension("waveforms"): - amplitudes_by_units = {} - waveforms_ext = sorting_analyzer.get_extension("waveforms") - before = waveforms_ext.nbefore - extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign) - for unit_id in unit_ids: - waveforms = waveforms_ext.get_waveforms_one_unit(unit_id, force_dense=False) - chan_id = extremum_channels_ids[unit_id] - if sorting_analyzer.is_sparse(): - chan_ind = np.where(sorting_analyzer.sparsity.unit_id_to_channel_ids[unit_id] == chan_id)[0] - else: - chan_ind = sorting_analyzer.channel_ids_to_indices([chan_id])[0] - amplitudes_by_units[unit_id] = waveforms[:, before, chan_ind] - - return amplitudes_by_units - - if HAVE_NUMBA: import numba From 1ea0d68074a0925d08811d31a8951728346e4c03 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 13 Jan 2026 14:24:19 +0100 Subject: [PATCH 30/70] fix naming issue --- src/spikeinterface/metrics/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py index 446f9ce471..16058e521f 100644 --- a/src/spikeinterface/metrics/utils.py +++ b/src/spikeinterface/metrics/utils.py @@ -44,7 +44,7 @@ def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, per return bin_edges_for_units -def get_total_samples_per_unit(sorting_analyzer, periods=None): +def compute_total_samples_per_unit(sorting_analyzer, periods=None): """ Get total number of samples for each unit, optionally taking into account periods. @@ -91,7 +91,7 @@ def compute_total_durations_per_unit(sorting_analyzer, periods=None): dict Total duration for each unit. """ - total_samples = get_total_samples_per_unit(sorting_analyzer, periods=periods) + total_samples = compute_total_samples_per_unit(sorting_analyzer, periods=periods) total_durations = { unit_id: samples / sorting_analyzer.sorting.sampling_frequency for unit_id, samples in total_samples.items() } From a86c2d36c6ddd823e65f999a1f76f9a8607f0938 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 13 Jan 2026 14:31:11 +0100 Subject: [PATCH 31/70] remove solved todos --- src/spikeinterface/metrics/quality/misc_metrics.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index b30ab068fe..004f0ee56c 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -687,7 +687,6 @@ class FiringRange(BaseMetric): } -# TODO: refactor for periods def compute_amplitude_cv_metrics( sorting_analyzer, unit_ids=None, @@ -1223,7 +1222,6 @@ class Drift(BaseMetric): depend_on = ["spike_locations"] -# TODO def compute_sd_ratio( sorting_analyzer: SortingAnalyzer, unit_ids=None, From 84da1a2fced1c1f1eba6c2b90ff5ae7eae482f5c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 13 Jan 2026 14:45:19 +0100 Subject: [PATCH 32/70] wip: test user defined --- .../tests/test_valid_unit_periods.py | 27 ++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py b/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py index bf80bc2d30..7b62706af3 100644 --- a/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py +++ b/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py @@ -1,5 +1,6 @@ import pytest - +import numpy as np +from spikeinterface.core.node_pipeline import unit_period_dtype from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite from spikeinterface.postprocessing import ComputeValidUnitPeriods @@ -17,3 +18,27 @@ def test_extension(self, params): self.run_extension_tests( ComputeValidUnitPeriods, params, extra_dependencies=["templates", "amplitude_scalings"] ) + + def test_user_defined_periods(self): + unit_ids = self.sorting.unit_ids + num_segments = self.sorting.get_num_segments() + + # unit periods of unit_period_dtypes + periods = np.zeros(len(unit_ids) * num_segments, dtype=unit_period_dtype) + + # for each unit we 1 valid period per segment + for i, unit_id in enumerate(unit_ids): + unit_index = self.sorting.id_to_index(unit_id) + for segment_index in range(num_segments): + num_samples = self.recording.get_num_samples(segment_index=segment_index) + idx = i * num_segments + segment_index + periods[idx]["unit_index"] = unit_index + period_start = num_samples // 4 + period_duration = num_samples // 2 + periods[idx]["start_sample_index"] = period_start + periods[idx]["end_sample_index"] = period_start + period_duration + periods[idx]["segment_index"] = segment_index + + sorting_analyzer = self._prepare_sorting_analyzer( + "memory", sparse=False, extension_class=ComputeValidUnitPeriods + ) From c539f582994a7a5d02af534dd43b3e2b759f911b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 13 Jan 2026 16:37:37 +0100 Subject: [PATCH 33/70] wip: tests --- .../tests/common_extension_tests.py | 6 +++- .../tests/test_valid_unit_periods.py | 13 +++++-- .../postprocessing/valid_unit_periods.py | 35 ++++++++++++------- 3 files changed, 38 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 46bf8e2235..fe71d0eb5a 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -150,7 +150,11 @@ def _check_one(self, sorting_analyzer, extension_class, params): ext_loaded = sorting_analyzer_loaded.get_extension(extension_class.extension_name) for ext_data_name, ext_data_loaded in ext_loaded.data.items(): if isinstance(ext_data_loaded, np.ndarray): - assert np.array_equal(ext.data[ext_data_name], ext_data_loaded) + if len(ext_data_loaded) > 0 and isinstance(ext_data_loaded[0], dict): + for i in range(len(ext_data_loaded)): + assert np.array_equal(np.array(ext.data[ext_data_name][i]), np.array(ext_data_loaded[i])) + else: + assert np.array_equal(ext.data[ext_data_name], ext_data_loaded) elif isinstance(ext_data_loaded, pd.DataFrame): # skip nan values for col in ext_data_loaded.columns: diff --git a/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py b/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py index 7b62706af3..1171f62e94 100644 --- a/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py +++ b/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py @@ -10,8 +10,8 @@ class TestComputeValidUnitPeriods(AnalyzerExtensionCommonTestSuite): @pytest.mark.parametrize( "params", [ - dict(period_mode="absolute"), - dict(period_mode="relative"), + dict(period_mode="absolute", period_duration_s_absolute=1.1, minimum_valid_period_duration=1.0), + dict(period_mode="relative", period_target_num_spikes=30, minimum_valid_period_duration=1.0), ], ) def test_extension(self, params): @@ -42,3 +42,12 @@ def test_user_defined_periods(self): sorting_analyzer = self._prepare_sorting_analyzer( "memory", sparse=False, extension_class=ComputeValidUnitPeriods ) + ext = sorting_analyzer.compute( + ComputeValidUnitPeriods.extension_name, + method="user_defined", + user_defined_periods=periods, + minimum_valid_period_duration=1, + ) + # check that valid periods correspond to user defined periods + ext_periods = ext.get_data(outputs="numpy") + np.testing.assert_array_equal(ext_periods, periods) diff --git a/src/spikeinterface/postprocessing/valid_unit_periods.py b/src/spikeinterface/postprocessing/valid_unit_periods.py index 4f0ea701af..808e252c9c 100644 --- a/src/spikeinterface/postprocessing/valid_unit_periods.py +++ b/src/spikeinterface/postprocessing/valid_unit_periods.py @@ -552,6 +552,9 @@ def compute_subperiods( for segment_index in range(sorting.get_num_segments()): n_samples = sorting_analyzer.get_num_samples(segment_index) # int: samples + print( + f"Num samples segment {segment_index}: {n_samples} - period size: {period_size_samples} - margin size: {margin_size_samples}" + ) # We round the number of subperiods to ensure coverage of the entire recording # the end of the last period is then clipped or extended to the end of the recording n_subperiods = round(n_samples / period_size_samples) @@ -566,19 +569,25 @@ def compute_subperiods( ] ) # remove periods whose end is above the expected number of samples - starts_ends = starts_ends[starts_ends[:, 1] <= n_subperiods * period_size_samples] - # extend last period to the end of the recording - starts_ends[-1][1] = n_samples - - periods_for_unit = np.zeros(len(starts_ends), dtype=unit_period_dtype) - for i, (start, end) in enumerate(starts_ends): - subperiod = np.zeros((1,), dtype=unit_period_dtype) - subperiod["segment_index"] = segment_index - subperiod["start_sample_index"] = start - subperiod["end_sample_index"] = end - subperiod["unit_index"] = unit_index - periods_for_unit[i] = subperiod - all_subperiods.append(periods_for_unit) + if len(starts_ends) > 0: + beyond_samples_mask = starts_ends[:, 1] > n_samples + if sum(beyond_samples_mask) == len(starts_ends): + # all periods end beyond n_samples: keep only first period + starts_ends = starts_ends[:1].copy() + else: + starts_ends = starts_ends[starts_ends[:, 1] <= n_subperiods * period_size_samples] + # set last period to the end of the recording + starts_ends[-1][1] = n_samples + + periods_for_unit = np.zeros(len(starts_ends), dtype=unit_period_dtype) + for i, (start, end) in enumerate(starts_ends): + subperiod = np.zeros((1,), dtype=unit_period_dtype) + subperiod["segment_index"] = segment_index + subperiod["start_sample_index"] = start + subperiod["end_sample_index"] = end + subperiod["unit_index"] = unit_index + periods_for_unit[i] = subperiod + all_subperiods.append(periods_for_unit) return np.concatenate(all_subperiods) From 3f93f97618930203aadb724dda7a50a53e57b4b6 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 13 Jan 2026 16:42:16 +0100 Subject: [PATCH 34/70] Implement select_segment_periods in core --- .../core/analyzer_extension_core.py | 12 - .../metrics/quality/misc_metrics.py | 301 +++++++----------- .../quality/tests/test_metrics_functions.py | 5 +- src/spikeinterface/metrics/quality/utils.py | 47 +++ .../metrics/spiketrain/metrics.py | 10 +- 5 files changed, 175 insertions(+), 200 deletions(-) create mode 100644 src/spikeinterface/metrics/quality/utils.py diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 804418a2ff..7ac037a8cd 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -1333,18 +1333,6 @@ class BaseSpikeVectorExtension(AnalyzerExtension): def __init__(self, sorting_analyzer): super().__init__(sorting_analyzer) - self._segment_slices = None - - @property - def segment_slices(self): - if self._segment_slices is None: - segment_slices = [] - spikes = self.sorting_analyzer.sorting.to_spike_vector() - for segment_index in range(self.sorting_analyzer.get_num_segments()): - i0, i1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) - segment_slices.append(slice(i0, i1)) - self._segment_slices = segment_slices - return self._segment_slices def _set_params(self, **kwargs): params = kwargs.copy() diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 004f0ee56c..c6b07da52e 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -25,12 +25,8 @@ get_template_extremum_amplitude, get_dense_templates_array, ) -from spikeinterface.metrics.spiketrain.metrics import NumSpikes, FiringRate -from spikeinterface.metrics.utils import ( - compute_bin_edges_per_unit, - compute_total_durations_per_unit, - compute_total_samples_per_unit, -) + +from ..spiketrain.metrics import NumSpikes, FiringRate numba_spec = importlib.util.find_spec("numba") if numba_spec is not None: @@ -39,9 +35,7 @@ HAVE_NUMBA = False -def compute_presence_ratios( - sorting_analyzer, unit_ids=None, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0, periods=None -): +def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0): """ Calculate the presence ratio, the fraction of time the unit is firing above a certain threshold. @@ -57,9 +51,6 @@ def compute_presence_ratios( mean_fr_ratio_thresh : float, default: 0 The unit is considered active in a bin if its firing rate during that bin. is strictly above `mean_fr_ratio_thresh` times its mean firing rate throughout the recording. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -72,21 +63,16 @@ def compute_presence_ratios( To do so, spike trains across segments are concatenated to mimic a continuous segment. """ sorting = sorting_analyzer.sorting - sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() - segment_samples = [sorting_analyzer.get_num_samples(i) for i in range(num_segs)] - total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) - total_samples = np.sum(segment_samples) + seg_lengths = [sorting_analyzer.get_num_samples(i) for i in range(num_segs)] + total_length = sorting_analyzer.get_total_samples() + total_duration = sorting_analyzer.get_total_duration() bin_duration_samples = int((bin_duration_s * sorting_analyzer.sampling_frequency)) - bin_edges_per_unit = compute_bin_edges_per_unit( - sorting, - segment_samples=segment_samples, - periods=periods, - bin_duration_s=bin_duration_s, - ) + num_bin_edges = total_length // bin_duration_samples + 1 + bin_edges = np.arange(num_bin_edges) * bin_duration_samples mean_fr_ratio_thresh = float(mean_fr_ratio_thresh) if mean_fr_ratio_thresh < 0: @@ -97,7 +83,7 @@ def compute_presence_ratios( warnings.warn("`mean_fr_ratio_thres` parameter above 1 might lead to low presence ratios.") presence_ratios = {} - if total_samples < bin_duration_samples: + if total_length < bin_duration_samples: warnings.warn( f"Bin duration of {bin_duration_s}s is larger than recording duration. " f"Presence ratios are set to NaN." ) @@ -105,15 +91,9 @@ def compute_presence_ratios( else: for unit_id in unit_ids: spike_train = [] - bin_edges = bin_edges_per_unit[unit_id] - if len(bin_edges) < 2: - presence_ratios[unit_id] = 0.0 - continue - total_duration = total_durations[unit_id] - for segment_index in range(num_segs): st = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - st = st + np.sum(segment_samples[:segment_index]) + st = st + np.sum(seg_lengths[:segment_index]) spike_train.append(st) spike_train = np.concatenate(spike_train) @@ -122,6 +102,7 @@ def compute_presence_ratios( presence_ratios[unit_id] = presence_ratio( spike_train, + total_length, bin_edges=bin_edges, bin_n_spikes_thres=bin_n_spikes_thres, ) @@ -201,7 +182,7 @@ class SNR(BaseMetric): depend_on = ["noise_levels", "templates"] -def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5, min_isi_ms=0, periods=None): +def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5, min_isi_ms=0): """ Calculate Inter-Spike Interval (ISI) violations. @@ -223,9 +204,6 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 Minimum possible inter-spike interval, in ms. This is the artificial refractory period enforced. by the data acquisition system or post-processing algorithms. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -257,12 +235,11 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 res = namedtuple("isi_violation", ["isi_violations_ratio", "isi_violations_count"]) sorting = sorting_analyzer.sorting - sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() - total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) + total_duration_s = sorting_analyzer.get_total_duration() fs = sorting_analyzer.sampling_frequency isi_threshold_s = isi_threshold_ms / 1000 @@ -283,8 +260,7 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 if not any([len(train) > 0 for train in spike_train_list]): continue - total_duration = total_durations[unit_id] - ratio, _, count = isi_violations(spike_train_list, total_duration, isi_threshold_s, min_isi_s) + ratio, _, count = isi_violations(spike_train_list, total_duration_s, isi_threshold_s, min_isi_s) isi_violations_ratio[unit_id] = ratio isi_violations_count[unit_id] = count @@ -304,7 +280,7 @@ class ISIViolation(BaseMetric): def compute_refrac_period_violations( - sorting_analyzer, unit_ids=None, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0, periods=None + sorting_analyzer, unit_ids=None, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0 ): """ Calculate the number of refractory period violations. @@ -324,9 +300,6 @@ def compute_refrac_period_violations( censored_period_ms : float, default: 0.0 The period (in ms) where no 2 spikes can occur (because they are not detected, or because they were removed by another mean). - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -359,8 +332,6 @@ def compute_refrac_period_violations( return None sorting = sorting_analyzer.sorting - sorting = sorting.select_periods(periods=periods) - fs = sorting_analyzer.sampling_frequency num_units = len(sorting_analyzer.unit_ids) num_segments = sorting_analyzer.get_num_segments() @@ -370,7 +341,7 @@ def compute_refrac_period_violations( if unit_ids is None: unit_ids = sorting_analyzer.unit_ids - num_spikes = sorting.count_num_spikes_per_unit() + num_spikes = compute_num_spikes(sorting_analyzer) t_c = int(round(censored_period_ms * fs * 1e-3)) t_r = int(round(refractory_period_ms * fs * 1e-3)) @@ -381,7 +352,7 @@ def compute_refrac_period_violations( spike_labels = spikes[seg_index]["unit_index"].astype(np.int32) _compute_rp_violations_numba(nb_rp_violations, spike_times, spike_labels, t_c, t_r) - total_samples = compute_total_samples_per_unit(sorting_analyzer, periods=periods) + T = sorting_analyzer.get_total_samples() nb_violations = {} rp_contamination = {} @@ -392,7 +363,6 @@ def compute_refrac_period_violations( nb_violations[unit_id] = n_v = nb_rp_violations[unit_index] N = num_spikes[unit_id] - T = total_samples[unit_id] if N == 0: rp_contamination[unit_id] = np.nan else: @@ -422,7 +392,6 @@ def compute_sliding_rp_violations( exclude_ref_period_below_ms=0.5, max_ref_period_ms=10, contamination_values=None, - periods=None, ): """ Compute sliding refractory period violations, a metric developed by IBL which computes @@ -448,9 +417,6 @@ def compute_sliding_rp_violations( Maximum refractory period to test in ms. contamination_values : 1d array or None, default: None The contamination values to test, If None, it is set to np.arange(0.5, 35, 0.5). - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -463,10 +429,8 @@ def compute_sliding_rp_violations( This code was adapted from: https://github.com/SteinmetzLab/slidingRefractory/blob/1.0.0/python/slidingRP/metrics.py """ - total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) + duration = sorting_analyzer.get_total_duration() sorting = sorting_analyzer.sorting - sorting = sorting.select_periods(periods=periods) - if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() @@ -491,7 +455,6 @@ def compute_sliding_rp_violations( contamination[unit_id] = np.nan continue - duration = total_durations[unit_id] contamination[unit_id] = slidingRP_violations( spike_train_list, fs, @@ -523,7 +486,7 @@ class SlidingRPViolation(BaseMetric): } -def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None, periods=None): +def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None): """ Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of spikes at the exact same sample index, with synchrony sizes 2, 4 and 8. @@ -541,9 +504,6 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N ------- sync_spike_{X} : dict The synchrony metric for synchrony size X. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. References ---------- @@ -560,7 +520,6 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes]) sorting = sorting_analyzer.sorting - sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting.unit_ids @@ -597,7 +556,7 @@ class Synchrony(BaseMetric): } -def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percentiles=(5, 95), periods=None): +def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percentiles=(5, 95)): """ Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution computed in non-overlapping time bins. @@ -612,9 +571,6 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent The size of the bin in seconds. percentiles : tuple, default: (5, 95) The percentiles to compute. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -628,11 +584,6 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent sampling_frequency = sorting_analyzer.sampling_frequency bin_size_samples = int(bin_size_s * sampling_frequency) sorting = sorting_analyzer.sorting - sorting = sorting.select_periods(periods=periods) - segment_samples = [ - sorting_analyzer.get_num_samples(segment_index) for segment_index in range(sorting_analyzer.get_num_segments()) - ] - if unit_ids is None: unit_ids = sorting.unit_ids @@ -647,25 +598,15 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent # for each segment, we compute the firing rate histogram and we concatenate them firing_rate_histograms = {unit_id: np.array([], dtype=float) for unit_id in sorting.unit_ids} - bin_edges_per_unit = compute_bin_edges_per_unit( - sorting, - segment_samples=segment_samples, - periods=periods, - bin_duration_s=bin_size_s, - ) - for unit_id in unit_ids: - bin_edges = bin_edges_per_unit[unit_id] - - # we can concatenate spike trains across segments adding the cumulative number of samples - # as offset, since bin edges are already cumulative - for segment_index in range(sorting_analyzer.get_num_segments()): - st = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - st = st + np.sum(segment_samples[:segment_index]) - spike_train.append(st) - spike_train = np.concatenate(spike_train) + for segment_index in range(sorting_analyzer.get_num_segments()): + num_samples = sorting_analyzer.get_num_samples(segment_index) + edges = np.arange(0, num_samples + 1, bin_size_samples) - spike_counts, _ = np.histogram(spike_train, bins=bin_edges) - firing_rate_histograms[unit_id] = spike_counts / bin_size_s + for unit_id in unit_ids: + spike_times = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) + spike_counts, _ = np.histogram(spike_times, bins=edges) + firing_rates = spike_counts / bin_size_s + firing_rate_histograms[unit_id] = np.concatenate((firing_rate_histograms[unit_id], firing_rates)) # finally we compute the percentiles firing_ranges = {} @@ -694,7 +635,6 @@ def compute_amplitude_cv_metrics( percentiles=(5, 95), min_num_bins=10, amplitude_extension="spike_amplitudes", - periods=None, ): """ Calculate coefficient of variation of spike amplitudes within defined temporal bins. @@ -718,8 +658,6 @@ def compute_amplitude_cv_metrics( the median and range are set to NaN. amplitude_extension : str, default: "spike_amplitudes" The name of the extension to load the amplitudes from. "spike_amplitudes" or "amplitude_scalings". - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) Returns ------- @@ -738,15 +676,14 @@ def compute_amplitude_cv_metrics( "spike_amplitudes", "amplitude_scalings", ), "Invalid amplitude_extension. It can be either 'spike_amplitudes' or 'amplitude_scalings'" - if unit_ids is None: - unit_ids = sorting.unit_ids sorting = sorting_analyzer.sorting - sorting = sorting.select_periods(periods=periods) - + total_duration = sorting_analyzer.get_total_duration() spikes = sorting.to_spike_vector() - total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) num_spikes = sorting.count_num_spikes_per_unit(outputs="dict") - amps = sorting_analyzer.get_extension(amplitude_extension).get_data(periods=periods) + if unit_ids is None: + unit_ids = sorting.unit_ids + + amps = sorting_analyzer.get_extension(amplitude_extension).get_data() # precompute segment slice segment_slices = [] @@ -758,7 +695,6 @@ def compute_amplitude_cv_metrics( all_unit_ids = list(sorting.unit_ids) amplitude_cv_medians, amplitude_cv_ranges = {}, {} for unit_id in unit_ids: - total_duration = total_durations[unit_id] firing_rate = num_spikes[unit_id] / total_duration temporal_bin_size_samples = int( (average_num_spikes_per_bin / firing_rate) * sorting_analyzer.sampling_frequency @@ -816,7 +752,6 @@ def compute_amplitude_cutoffs( num_histogram_bins=500, histogram_smoothing_value=3, amplitudes_bins_min_ratio=5, - periods=None, ): """ Calculate approximate fraction of spikes missing from a distribution of amplitudes. @@ -835,9 +770,6 @@ def compute_amplitude_cutoffs( The minimum ratio between number of amplitudes for a unit and the number of bins. If the ratio is less than this threshold, the amplitude_cutoff for the unit is set to NaN. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -873,7 +805,7 @@ def compute_amplitude_cutoffs( invert_amplitudes = True extension = sorting_analyzer.get_extension("amplitude_scalings") - amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, periods=periods) + amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True) for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] @@ -905,7 +837,7 @@ class AmplitudeCutoff(BaseMetric): depend_on = ["spike_amplitudes|amplitude_scalings"] -def compute_amplitude_medians(sorting_analyzer, unit_ids=None, periods=None): +def compute_amplitude_medians(sorting_analyzer, unit_ids=None): """ Compute median of the amplitude distributions (in absolute value). @@ -915,9 +847,6 @@ def compute_amplitude_medians(sorting_analyzer, unit_ids=None, periods=None): A SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the amplitude medians. If None, all units are used. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -936,7 +865,7 @@ def compute_amplitude_medians(sorting_analyzer, unit_ids=None, periods=None): all_amplitude_medians = {} amplitude_extension = sorting_analyzer.get_extension("spike_amplitudes") - amplitudes_by_units = amplitude_extension.get_data(outputs="by_unit", concatenated=True, periods=periods) + amplitudes_by_units = amplitude_extension.get_data(outputs="by_unit", concatenated=True) for unit_id in unit_ids: all_amplitude_medians[unit_id] = np.median(amplitudes_by_units[unit_id]) @@ -953,9 +882,7 @@ class AmplitudeMedian(BaseMetric): depend_on = ["spike_amplitudes"] -def compute_noise_cutoffs( - sorting_analyzer, unit_ids=None, high_quantile=0.25, low_quantile=0.1, n_bins=100, periods=None -): +def compute_noise_cutoffs(sorting_analyzer, unit_ids=None, high_quantile=0.25, low_quantile=0.1, n_bins=100): """ A metric to determine if a unit's amplitude distribution is cut off as it approaches zero, without assuming a Gaussian distribution. @@ -979,9 +906,6 @@ def compute_noise_cutoffs( Quantile of the amplitude range below which values are treated as "low" (e.g. 0.1 = lower 10%), the test region. n_bins: int, default: 100 The number of bins to use to compute the amplitude histogram. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -1010,7 +934,7 @@ def compute_noise_cutoffs( invert_amplitudes = True extension = sorting_analyzer.get_extension("amplitude_scalings") - amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, periods=periods) + amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True) for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] @@ -1048,7 +972,6 @@ def compute_drift_metrics( min_fraction_valid_intervals=0.5, min_num_bins=2, return_positions=False, - periods=None, ): """ Compute drifts metrics using estimated spike locations. @@ -1083,9 +1006,6 @@ def compute_drift_metrics( min_num_bins : int, default: 2 Minimum number of bins required to return a valid metric value. In case there are less bins, the metric values are set to NaN. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. return_positions : bool, default: False If True, median positions are returned (for debugging). @@ -1108,18 +1028,35 @@ def compute_drift_metrics( check_has_required_extensions("drift", sorting_analyzer) res = namedtuple("drift_metrics", ["drift_ptp", "drift_std", "drift_mad"]) sorting = sorting_analyzer.sorting - sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting.unit_ids spike_locations_ext = sorting_analyzer.get_extension("spike_locations") - spike_locations = spike_locations_ext.get_data(periods=periods) - spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit", concatenated=True, periods=periods) + spike_locations = spike_locations_ext.get_data() + # spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit") + spikes = sorting.to_spike_vector() + spike_locations_by_unit = {} + for unit_id in unit_ids: + unit_index = sorting.id_to_index(unit_id) + # TODO @alessio this is very slow this sjould be done with spike_vector_to_indices() in code + spike_mask = spikes["unit_index"] == unit_index + spike_locations_by_unit[unit_id] = spike_locations[spike_mask] - segment_samples = [sorting_analyzer.get_num_samples(i) for i in range(sorting_analyzer.get_num_segments())] + interval_samples = int(interval_s * sorting_analyzer.sampling_frequency) assert direction in spike_locations.dtype.names, ( f"Direction {direction} is invalid. Available directions: " f"{spike_locations.dtype.names}" ) + total_duration = sorting_analyzer.get_total_duration() + if total_duration < min_num_bins * interval_s: + warnings.warn( + "The recording is too short given the specified 'interval_s' and " + "'min_num_bins'. Drift metrics will be set to NaN" + ) + empty_dict = {unit_id: np.nan for unit_id in unit_ids} + if return_positions: + return res(empty_dict, empty_dict, empty_dict), np.nan + else: + return res(empty_dict, empty_dict, empty_dict) # we need drift_ptps = {} @@ -1127,58 +1064,45 @@ def compute_drift_metrics( drift_mads = {} # reference positions are the medians across segments - reference_positions = {} - for unit_id in unit_ids: - reference_positions[unit_id] = np.median(spike_locations_by_unit[unit_id][direction]) + reference_positions = np.zeros(len(unit_ids)) + for i, unit_id in enumerate(unit_ids): + unit_ind = sorting.id_to_index(unit_id) + reference_positions[i] = np.median(spike_locations_by_unit[unit_id][direction]) # now compute median positions and concatenate them over segments - spike_vector = sorting.to_spike_vector() - spike_sample_indices = spike_vector["sample_index"] - # we need to add the cumulative sum of segment samples to have global sample indices - cumulative_segment_samples = np.cumsum([0] + segment_samples[:-1]) + median_position_segments = None for segment_index in range(sorting_analyzer.get_num_segments()): - seg_mask = spike_vector["segment_index"] == segment_index - spike_sample_indices[seg_mask] += cumulative_segment_samples[segment_index] - - bin_edges_for_units = compute_bin_edges_per_unit( - sorting, - segment_samples=segment_samples, - periods=periods, - bin_duration_s=interval_s, - ) - - median_positions_per_unit = {} - for unit in unit_ids: - bins = bin_edges_for_units[unit] - num_bins = len(bins) - 1 - if num_bins < min_num_bins: - warnings.warn( - f"Unit {unit} has only {num_bins} bins given the specified 'interval_s' and " - f"'min_num_bins'. Drift metrics will be set to NaN" - ) - drift_ptps[unit] = np.nan - drift_stds[unit] = np.nan - drift_mads[unit] = np.nan - continue - - # bin_edges are global across segments, so we have to use spike_sample_indices, - # since we offseted them to be global - bin_spike_indices = np.searchsorted(spike_sample_indices, bins) - median_positions = np.nan * np.zeros(num_bins) - for bin_index, (i0, i1) in enumerate(zip(bin_spike_indices[:-1], bin_spike_indices[1:])): - spikes_in_bin = spike_vector[i0:i1] - spike_locations_in_bin = spike_locations[i0:i1][direction] - - unit_index = sorting_analyzer.sorting.id_to_index(unit) - mask = spikes_in_bin["unit_index"] == unit_index - if np.sum(mask) >= min_spikes_per_interval: - median_positions[bin_index] = np.median(spike_locations_in_bin[mask]) - else: - median_positions[bin_index] = np.nan - median_positions_per_unit[unit] = median_positions + seg_length = sorting_analyzer.get_num_samples(segment_index) + num_bin_edges = seg_length // interval_samples + 1 + bins = np.arange(num_bin_edges) * interval_samples + spike_vector = sorting.to_spike_vector() + + # retrieve spikes in segment + i0, i1 = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1]) + spikes_in_segment = spike_vector[i0:i1] + spike_locations_in_segment = spike_locations[i0:i1] + + # compute median positions (if less than min_spikes_per_interval, median position is 0) + median_positions = np.nan * np.zeros((len(unit_ids), num_bin_edges - 1)) + for bin_index, (start_frame, end_frame) in enumerate(zip(bins[:-1], bins[1:])): + i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame]) + spikes_in_bin = spikes_in_segment[i0:i1] + spike_locations_in_bin = spike_locations_in_segment[i0:i1][direction] + + for i, unit_id in enumerate(unit_ids): + unit_ind = sorting.id_to_index(unit_id) + mask = spikes_in_bin["unit_index"] == unit_ind + if np.sum(mask) >= min_spikes_per_interval: + median_positions[i, bin_index] = np.median(spike_locations_in_bin[mask]) + if median_position_segments is None: + median_position_segments = median_positions + else: + median_position_segments = np.hstack((median_position_segments, median_positions)) - # now compute deviations and drifts for this unit - position_diff = median_positions - reference_positions[unit_id] + # finally, compute deviations and drifts + position_diffs = median_position_segments - reference_positions[:, None] + for i, unit_id in enumerate(unit_ids): + position_diff = position_diffs[i] if np.any(np.isnan(position_diff)): # deal with nans: if more than 50% nans --> set to nan if np.sum(np.isnan(position_diff)) > min_fraction_valid_intervals * len(position_diff): @@ -1196,9 +1120,8 @@ def compute_drift_metrics( drift_ptps[unit_id] = ptp_drift drift_stds[unit_id] = std_drift drift_mads[unit_id] = mad_drift - if return_positions: - outs = res(drift_ptps, drift_stds, drift_mads), median_positions_per_unit + outs = res(drift_ptps, drift_stds, drift_mads), median_positions else: outs = res(drift_ptps, drift_stds, drift_mads) return outs @@ -1228,7 +1151,6 @@ def compute_sd_ratio( censored_period_ms: float = 4.0, correct_for_drift: bool = True, correct_for_template_itself: bool = True, - periods=None, **kwargs, ): """ @@ -1251,9 +1173,6 @@ def compute_sd_ratio( correct_for_template_itself : bool, default: True If true, will take into account that the template itself impacts the standard deviation of the noise, and will make a rough estimation of what that impact is (and remove it). - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. **kwargs : dict, default: {} Keyword arguments for computing spike amplitudes and extremum channel. @@ -1270,7 +1189,6 @@ def compute_sd_ratio( job_kwargs = fix_job_kwargs(job_kwargs) sorting = sorting_analyzer.sorting - sorting = sorting.select_periods(periods=periods) censored_period = int(round(censored_period_ms * 1e-3 * sorting_analyzer.sampling_frequency)) if unit_ids is None: @@ -1283,7 +1201,7 @@ def compute_sd_ratio( ) return {unit_id: np.nan for unit_id in unit_ids} - spike_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data(periods=periods) + spike_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() if not HAVE_NUMBA: warnings.warn( @@ -1412,7 +1330,7 @@ def check_has_required_extensions(metric_name, sorting_analyzer): ### LOW-LEVEL FUNCTIONS ### -def presence_ratio(spike_train, bin_edges=None, num_bin_edges=None, bin_n_spikes_thres=0): +def presence_ratio(spike_train, total_length, bin_edges=None, num_bin_edges=None, bin_n_spikes_thres=0): """ Calculate the presence ratio for a single unit. @@ -1420,6 +1338,8 @@ def presence_ratio(spike_train, bin_edges=None, num_bin_edges=None, bin_n_spikes ---------- spike_train : np.ndarray Spike times for this unit, in samples. + total_length : int + Total length of the recording in samples. bin_edges : np.array, optional Pre-computed bin edges (mutually exclusive with num_bin_edges). num_bin_edges : int, optional @@ -1789,6 +1709,29 @@ def _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): return synchrony_counts +def _get_amplitudes_by_units(sorting_analyzer, unit_ids, peak_sign): + # used by compute_amplitude_cutoffs and compute_amplitude_medians + + if (spike_amplitudes_extension := sorting_analyzer.get_extension("spike_amplitudes")) is not None: + return spike_amplitudes_extension.get_data(outputs="by_unit", concatenated=True) + + elif sorting_analyzer.has_extension("waveforms"): + amplitudes_by_units = {} + waveforms_ext = sorting_analyzer.get_extension("waveforms") + before = waveforms_ext.nbefore + extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign) + for unit_id in unit_ids: + waveforms = waveforms_ext.get_waveforms_one_unit(unit_id, force_dense=False) + chan_id = extremum_channels_ids[unit_id] + if sorting_analyzer.is_sparse(): + chan_ind = np.where(sorting_analyzer.sparsity.unit_id_to_channel_ids[unit_id] == chan_id)[0] + else: + chan_ind = sorting_analyzer.channel_ids_to_indices([chan_id])[0] + amplitudes_by_units[unit_id] = waveforms[:, before, chan_ind] + + return amplitudes_by_units + + if HAVE_NUMBA: import numba diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index 57516d6bc3..c0dd6c6033 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -12,8 +12,11 @@ synthesize_random_firings, ) -from spikeinterface.metrics.utils import create_ground_truth_pc_distributions +from spikeinterface.metrics.quality.utils import create_ground_truth_pc_distributions +# from spikeinterface.metrics.quality_metric_list import ( +# _misc_metric_name_to_func, +# ) from spikeinterface.metrics.quality import ( get_quality_metric_list, diff --git a/src/spikeinterface/metrics/quality/utils.py b/src/spikeinterface/metrics/quality/utils.py new file mode 100644 index 0000000000..844a7da7f5 --- /dev/null +++ b/src/spikeinterface/metrics/quality/utils.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import numpy as np + + +def create_ground_truth_pc_distributions(center_locations, total_points): + """ + Simulate PCs as multivariate Gaussians, for testing PC-based quality metrics + Values are created for only one channel and vary along one dimension. + + Parameters + ---------- + center_locations : array-like (units, ) or (channels, units) + Mean of the multivariate gaussian at each channel for each unit. + total_points : array-like + Number of points in each unit distribution. + + Returns + ------- + all_pcs : numpy.ndarray + PC scores for each point. + all_labels : numpy.array + Labels for each point. + """ + from scipy.stats import multivariate_normal + + np.random.seed(0) + + if len(np.array(center_locations).shape) == 1: + distributions = [ + multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) + for center, size in zip(center_locations, total_points) + ] + all_pcs = np.concatenate(distributions, axis=0) + + else: + all_pcs = np.empty((np.sum(total_points), 3, center_locations.shape[0])) + for channel in range(center_locations.shape[0]): + distributions = [ + multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) + for center, size in zip(center_locations[channel], total_points) + ] + all_pcs[:, :, channel] = np.concatenate(distributions, axis=0) + + all_labels = np.concatenate([np.ones((total_points[i],), dtype="int") * i for i in range(len(total_points))]) + + return all_pcs, all_labels diff --git a/src/spikeinterface/metrics/spiketrain/metrics.py b/src/spikeinterface/metrics/spiketrain/metrics.py index 0ddb5fabe7..ba66d0671c 100644 --- a/src/spikeinterface/metrics/spiketrain/metrics.py +++ b/src/spikeinterface/metrics/spiketrain/metrics.py @@ -2,7 +2,7 @@ from spikeinterface.core.analyzer_extension_core import BaseMetric -def compute_num_spikes(sorting_analyzer, unit_ids=None, periods=None, **kwargs): +def compute_num_spikes(sorting_analyzer, unit_ids=None, **kwargs): """ Compute the number of spike across segments. @@ -12,8 +12,6 @@ def compute_num_spikes(sorting_analyzer, unit_ids=None, periods=None, **kwargs): A SortingAnalyzer object. unit_ids : list or None The list of unit ids to compute the number of spikes. If None, all units are used. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) Returns ------- @@ -22,7 +20,6 @@ def compute_num_spikes(sorting_analyzer, unit_ids=None, periods=None, **kwargs): """ sorting = sorting_analyzer.sorting - sorting = sorting.select_periods(periods) if unit_ids is None: unit_ids = sorting.unit_ids num_segs = sorting.get_num_segments() @@ -46,7 +43,7 @@ class NumSpikes(BaseMetric): metric_columns = {"num_spikes": int} -def compute_firing_rates(sorting_analyzer, unit_ids=None, periods=None): +def compute_firing_rates(sorting_analyzer, unit_ids=None): """ Compute the firing rate across segments. @@ -56,8 +53,6 @@ def compute_firing_rates(sorting_analyzer, unit_ids=None, periods=None): A SortingAnalyzer object. unit_ids : list or None The list of unit ids to compute the firing rate. If None, all units are used. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) Returns ------- @@ -66,7 +61,6 @@ def compute_firing_rates(sorting_analyzer, unit_ids=None, periods=None): """ sorting = sorting_analyzer.sorting - sorting = sorting.select_periods(periods) if unit_ids is None: unit_ids = sorting.unit_ids total_duration = sorting_analyzer.get_total_duration() From cd854567b45d0eddc16275431ef5c6044188c0ec Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 13 Jan 2026 16:43:53 +0100 Subject: [PATCH 35/70] remove utils --- src/spikeinterface/metrics/utils.py | 142 ---------------------------- 1 file changed, 142 deletions(-) delete mode 100644 src/spikeinterface/metrics/utils.py diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py deleted file mode 100644 index 16058e521f..0000000000 --- a/src/spikeinterface/metrics/utils.py +++ /dev/null @@ -1,142 +0,0 @@ -from __future__ import annotations - -import numpy as np - - -def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, periods=None): - """ - Compute bin edges for units, optionally taking into account periods. - - Parameters - ---------- - sorting : Sorting - Sorting object containing unit information. - segment_samples : list or array-like - Number of samples in each segment. - bin_duration_s : float, default: 1 - Duration of each bin in seconds - periods : array of unit_period_dtype, default: None - Periods to consider for each unit - """ - bin_edges_for_units = {} - num_segments = len(segment_samples) - bin_duration_samples = int(bin_duration_s * sorting.sampling_frequency) - - if periods is not None: - for unit_id in sorting.unit_ids: - unit_index = sorting.id_to_index(unit_id) - periods_unit = periods[periods["unit_index"] == unit_index] - bin_edges = [] - for seg_index in range(num_segments): - seg_periods = periods_unit[periods_unit["segment_index"] == seg_index] - if len(seg_periods) == 0: - continue - seg_start = np.sum(segment_samples[:seg_index]) - for period in seg_periods: - start_sample = seg_start + period["start_sample_index"] - end_sample = seg_start + period["end_sample_index"] - bin_edges.extend(np.arange(start_sample, end_sample, bin_duration_samples)) - bin_edges_for_units[unit_id] = np.array(bin_edges) - else: - total_length = np.sum(segment_samples) - for unit_id in sorting.unit_ids: - bin_edges_for_units[unit_id] = np.arange(0, total_length, bin_duration_samples) * bin_duration_samples - return bin_edges_for_units - - -def compute_total_samples_per_unit(sorting_analyzer, periods=None): - """ - Get total number of samples for each unit, optionally taking into account periods. - - Parameters - ---------- - sorting_analyzer : SortingAnalyzer - The sorting analyzer object. - periods : array of unit_period_dtype, default: None - Periods to consider for each unit. - - Returns - ------- - dict - Total number of samples for each unit. - """ - if periods is not None: - total_samples = {} - sorting = sorting_analyzer.sorting - for unit_id in sorting.unit_ids: - unit_index = sorting.id_to_index(unit_id) - periods_unit = periods[periods["unit_index"] == unit_index] - num_samples_in_period = 0 - for period in periods_unit: - num_samples_in_period += period["end_sample_index"] - period["start_sample_index"] - total_samples[unit_id] = num_samples_in_period - else: - total_samples = {unit_id: sorting_analyzer.get_total_samples() for unit_id in sorting_analyzer.unit_ids} - return total_samples - - -def compute_total_durations_per_unit(sorting_analyzer, periods=None): - """ - Compute total duration for each unit, optionally taking into account periods. - - Parameters - ---------- - sorting_analyzer : SortingAnalyzer - The sorting analyzer object. - periods : array of unit_period_dtype, default: None - Periods to consider for each unit. - - Returns - ------- - dict - Total duration for each unit. - """ - total_samples = compute_total_samples_per_unit(sorting_analyzer, periods=periods) - total_durations = { - unit_id: samples / sorting_analyzer.sorting.sampling_frequency for unit_id, samples in total_samples.items() - } - return total_durations - - -def create_ground_truth_pc_distributions(center_locations, total_points): - """ - Simulate PCs as multivariate Gaussians, for testing PC-based quality metrics - Values are created for only one channel and vary along one dimension. - - Parameters - ---------- - center_locations : array-like (units, ) or (channels, units) - Mean of the multivariate gaussian at each channel for each unit. - total_points : array-like - Number of points in each unit distribution. - - Returns - ------- - all_pcs : numpy.ndarray - PC scores for each point. - all_labels : numpy.array - Labels for each point. - """ - from scipy.stats import multivariate_normal - - np.random.seed(0) - - if len(np.array(center_locations).shape) == 1: - distributions = [ - multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) - for center, size in zip(center_locations, total_points) - ] - all_pcs = np.concatenate(distributions, axis=0) - - else: - all_pcs = np.empty((np.sum(total_points), 3, center_locations.shape[0])) - for channel in range(center_locations.shape[0]): - distributions = [ - multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) - for center, size in zip(center_locations[channel], total_points) - ] - all_pcs[:, :, channel] = np.concatenate(distributions, axis=0) - - all_labels = np.concatenate([np.ones((total_points[i],), dtype="int") * i for i in range(len(total_points))]) - - return all_pcs, all_labels From 7a42fe32354b1ffd024a32ec327d2353de2196a0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 13 Jan 2026 16:46:14 +0100 Subject: [PATCH 36/70] rebase on #4316 --- src/spikeinterface/metrics/quality/utils.py | 47 ------- src/spikeinterface/metrics/utils.py | 142 ++++++++++++++++++++ 2 files changed, 142 insertions(+), 47 deletions(-) delete mode 100644 src/spikeinterface/metrics/quality/utils.py create mode 100644 src/spikeinterface/metrics/utils.py diff --git a/src/spikeinterface/metrics/quality/utils.py b/src/spikeinterface/metrics/quality/utils.py deleted file mode 100644 index 844a7da7f5..0000000000 --- a/src/spikeinterface/metrics/quality/utils.py +++ /dev/null @@ -1,47 +0,0 @@ -from __future__ import annotations - -import numpy as np - - -def create_ground_truth_pc_distributions(center_locations, total_points): - """ - Simulate PCs as multivariate Gaussians, for testing PC-based quality metrics - Values are created for only one channel and vary along one dimension. - - Parameters - ---------- - center_locations : array-like (units, ) or (channels, units) - Mean of the multivariate gaussian at each channel for each unit. - total_points : array-like - Number of points in each unit distribution. - - Returns - ------- - all_pcs : numpy.ndarray - PC scores for each point. - all_labels : numpy.array - Labels for each point. - """ - from scipy.stats import multivariate_normal - - np.random.seed(0) - - if len(np.array(center_locations).shape) == 1: - distributions = [ - multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) - for center, size in zip(center_locations, total_points) - ] - all_pcs = np.concatenate(distributions, axis=0) - - else: - all_pcs = np.empty((np.sum(total_points), 3, center_locations.shape[0])) - for channel in range(center_locations.shape[0]): - distributions = [ - multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) - for center, size in zip(center_locations[channel], total_points) - ] - all_pcs[:, :, channel] = np.concatenate(distributions, axis=0) - - all_labels = np.concatenate([np.ones((total_points[i],), dtype="int") * i for i in range(len(total_points))]) - - return all_pcs, all_labels diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py new file mode 100644 index 0000000000..16058e521f --- /dev/null +++ b/src/spikeinterface/metrics/utils.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import numpy as np + + +def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, periods=None): + """ + Compute bin edges for units, optionally taking into account periods. + + Parameters + ---------- + sorting : Sorting + Sorting object containing unit information. + segment_samples : list or array-like + Number of samples in each segment. + bin_duration_s : float, default: 1 + Duration of each bin in seconds + periods : array of unit_period_dtype, default: None + Periods to consider for each unit + """ + bin_edges_for_units = {} + num_segments = len(segment_samples) + bin_duration_samples = int(bin_duration_s * sorting.sampling_frequency) + + if periods is not None: + for unit_id in sorting.unit_ids: + unit_index = sorting.id_to_index(unit_id) + periods_unit = periods[periods["unit_index"] == unit_index] + bin_edges = [] + for seg_index in range(num_segments): + seg_periods = periods_unit[periods_unit["segment_index"] == seg_index] + if len(seg_periods) == 0: + continue + seg_start = np.sum(segment_samples[:seg_index]) + for period in seg_periods: + start_sample = seg_start + period["start_sample_index"] + end_sample = seg_start + period["end_sample_index"] + bin_edges.extend(np.arange(start_sample, end_sample, bin_duration_samples)) + bin_edges_for_units[unit_id] = np.array(bin_edges) + else: + total_length = np.sum(segment_samples) + for unit_id in sorting.unit_ids: + bin_edges_for_units[unit_id] = np.arange(0, total_length, bin_duration_samples) * bin_duration_samples + return bin_edges_for_units + + +def compute_total_samples_per_unit(sorting_analyzer, periods=None): + """ + Get total number of samples for each unit, optionally taking into account periods. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object. + periods : array of unit_period_dtype, default: None + Periods to consider for each unit. + + Returns + ------- + dict + Total number of samples for each unit. + """ + if periods is not None: + total_samples = {} + sorting = sorting_analyzer.sorting + for unit_id in sorting.unit_ids: + unit_index = sorting.id_to_index(unit_id) + periods_unit = periods[periods["unit_index"] == unit_index] + num_samples_in_period = 0 + for period in periods_unit: + num_samples_in_period += period["end_sample_index"] - period["start_sample_index"] + total_samples[unit_id] = num_samples_in_period + else: + total_samples = {unit_id: sorting_analyzer.get_total_samples() for unit_id in sorting_analyzer.unit_ids} + return total_samples + + +def compute_total_durations_per_unit(sorting_analyzer, periods=None): + """ + Compute total duration for each unit, optionally taking into account periods. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer object. + periods : array of unit_period_dtype, default: None + Periods to consider for each unit. + + Returns + ------- + dict + Total duration for each unit. + """ + total_samples = compute_total_samples_per_unit(sorting_analyzer, periods=periods) + total_durations = { + unit_id: samples / sorting_analyzer.sorting.sampling_frequency for unit_id, samples in total_samples.items() + } + return total_durations + + +def create_ground_truth_pc_distributions(center_locations, total_points): + """ + Simulate PCs as multivariate Gaussians, for testing PC-based quality metrics + Values are created for only one channel and vary along one dimension. + + Parameters + ---------- + center_locations : array-like (units, ) or (channels, units) + Mean of the multivariate gaussian at each channel for each unit. + total_points : array-like + Number of points in each unit distribution. + + Returns + ------- + all_pcs : numpy.ndarray + PC scores for each point. + all_labels : numpy.array + Labels for each point. + """ + from scipy.stats import multivariate_normal + + np.random.seed(0) + + if len(np.array(center_locations).shape) == 1: + distributions = [ + multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) + for center, size in zip(center_locations, total_points) + ] + all_pcs = np.concatenate(distributions, axis=0) + + else: + all_pcs = np.empty((np.sum(total_points), 3, center_locations.shape[0])) + for channel in range(center_locations.shape[0]): + distributions = [ + multivariate_normal.rvs(mean=[center, 0.0, 0.0], cov=[1.0, 1.0, 1.0], size=size) + for center, size in zip(center_locations[channel], total_points) + ] + all_pcs[:, :, channel] = np.concatenate(distributions, axis=0) + + all_labels = np.concatenate([np.ones((total_points[i],), dtype="int") * i for i in range(len(total_points))]) + + return all_pcs, all_labels From cbc0986cdfe485c7177e4e84673a3597d0c6dafd Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 14 Jan 2026 09:34:49 +0100 Subject: [PATCH 37/70] Fix import --- src/spikeinterface/core/sorting_tools.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 75e25115ae..f5cf82c76f 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -281,8 +281,8 @@ def select_sorting_periods(sorting: BaseSorting, periods) -> BaseSorting: A new sorting object with only samples between start_sample_index and end_sample_index for the given segment_index. """ + from spikeinterface.core.base import unit_period_dtype from spikeinterface.core.numpyextractors import NumpySorting - from spikeinterface.core.node_pipeline import unit_period_dtype if periods is not None: if not isinstance(periods, np.ndarray): @@ -295,6 +295,7 @@ def select_sorting_periods(sorting: BaseSorting, periods) -> BaseSorting: keep_mask = select_sorting_periods_mask(sorting, periods) sliced_spike_vector = spike_vector[keep_mask] + # important: we keep the original unit ids so the unit_index field in spike vector is still valid sorting = NumpySorting( sliced_spike_vector, sampling_frequency=sorting.sampling_frequency, unit_ids=sorting.unit_ids ) From 046430e13bd48c099e5ed3e756b32df20ba5fd43 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 14 Jan 2026 09:42:45 +0100 Subject: [PATCH 38/70] fix import --- .../metrics/quality/tests/test_metrics_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index fb764dac78..b4f956e6a7 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -12,7 +12,7 @@ synthesize_random_firings, ) -from spikeinterface.metrics.quality.utils import create_ground_truth_pc_distributions +from spikeinterface.metrics.utils import create_ground_truth_pc_distributions # from spikeinterface.metrics.quality_metric_list import ( # _misc_metric_name_to_func, From bb8625358e22bdeeecf07872579dbe56a36a25f3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 14 Jan 2026 09:48:47 +0100 Subject: [PATCH 39/70] Add misc_metric changes --- .../metrics/quality/misc_metrics.py | 368 +++++++++++------- 1 file changed, 231 insertions(+), 137 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index c6b07da52e..2d90493756 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -25,8 +25,12 @@ get_template_extremum_amplitude, get_dense_templates_array, ) - -from ..spiketrain.metrics import NumSpikes, FiringRate +from spikeinterface.metrics.spiketrain.metrics import NumSpikes, FiringRate +from spikeinterface.metrics.utils import ( + compute_bin_edges_per_unit, + compute_total_durations_per_unit, + compute_total_samples_per_unit, +) numba_spec = importlib.util.find_spec("numba") if numba_spec is not None: @@ -35,7 +39,9 @@ HAVE_NUMBA = False -def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0): +def compute_presence_ratios( + sorting_analyzer, unit_ids=None, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0, periods=None +): """ Calculate the presence ratio, the fraction of time the unit is firing above a certain threshold. @@ -51,6 +57,9 @@ def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0 mean_fr_ratio_thresh : float, default: 0 The unit is considered active in a bin if its firing rate during that bin. is strictly above `mean_fr_ratio_thresh` times its mean firing rate throughout the recording. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -63,16 +72,21 @@ def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0 To do so, spike trains across segments are concatenated to mimic a continuous segment. """ sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() - seg_lengths = [sorting_analyzer.get_num_samples(i) for i in range(num_segs)] - total_length = sorting_analyzer.get_total_samples() - total_duration = sorting_analyzer.get_total_duration() + segment_samples = [sorting_analyzer.get_num_samples(i) for i in range(num_segs)] + total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) + total_samples = np.sum(segment_samples) bin_duration_samples = int((bin_duration_s * sorting_analyzer.sampling_frequency)) - num_bin_edges = total_length // bin_duration_samples + 1 - bin_edges = np.arange(num_bin_edges) * bin_duration_samples + bin_edges_per_unit = compute_bin_edges_per_unit( + sorting, + segment_samples=segment_samples, + periods=periods, + bin_duration_s=bin_duration_s, + ) mean_fr_ratio_thresh = float(mean_fr_ratio_thresh) if mean_fr_ratio_thresh < 0: @@ -83,7 +97,7 @@ def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0 warnings.warn("`mean_fr_ratio_thres` parameter above 1 might lead to low presence ratios.") presence_ratios = {} - if total_length < bin_duration_samples: + if total_samples < bin_duration_samples: warnings.warn( f"Bin duration of {bin_duration_s}s is larger than recording duration. " f"Presence ratios are set to NaN." ) @@ -91,9 +105,15 @@ def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0 else: for unit_id in unit_ids: spike_train = [] + bin_edges = bin_edges_per_unit[unit_id] + if len(bin_edges) < 2: + presence_ratios[unit_id] = 0.0 + continue + total_duration = total_durations[unit_id] + for segment_index in range(num_segs): st = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - st = st + np.sum(seg_lengths[:segment_index]) + st = st + np.sum(segment_samples[:segment_index]) spike_train.append(st) spike_train = np.concatenate(spike_train) @@ -102,7 +122,6 @@ def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0 presence_ratios[unit_id] = presence_ratio( spike_train, - total_length, bin_edges=bin_edges, bin_n_spikes_thres=bin_n_spikes_thres, ) @@ -182,7 +201,7 @@ class SNR(BaseMetric): depend_on = ["noise_levels", "templates"] -def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5, min_isi_ms=0): +def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5, min_isi_ms=0, periods=None): """ Calculate Inter-Spike Interval (ISI) violations. @@ -204,6 +223,9 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 Minimum possible inter-spike interval, in ms. This is the artificial refractory period enforced. by the data acquisition system or post-processing algorithms. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -235,11 +257,12 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 res = namedtuple("isi_violation", ["isi_violations_ratio", "isi_violations_count"]) sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() - total_duration_s = sorting_analyzer.get_total_duration() + total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) fs = sorting_analyzer.sampling_frequency isi_threshold_s = isi_threshold_ms / 1000 @@ -260,7 +283,8 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 if not any([len(train) > 0 for train in spike_train_list]): continue - ratio, _, count = isi_violations(spike_train_list, total_duration_s, isi_threshold_s, min_isi_s) + total_duration = total_durations[unit_id] + ratio, _, count = isi_violations(spike_train_list, total_duration, isi_threshold_s, min_isi_s) isi_violations_ratio[unit_id] = ratio isi_violations_count[unit_id] = count @@ -280,7 +304,7 @@ class ISIViolation(BaseMetric): def compute_refrac_period_violations( - sorting_analyzer, unit_ids=None, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0 + sorting_analyzer, unit_ids=None, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0, periods=None ): """ Calculate the number of refractory period violations. @@ -300,6 +324,9 @@ def compute_refrac_period_violations( censored_period_ms : float, default: 0.0 The period (in ms) where no 2 spikes can occur (because they are not detected, or because they were removed by another mean). + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -322,8 +349,6 @@ def compute_refrac_period_violations( ---------- Based on metrics described in [Llobet]_ """ - from spikeinterface.metrics.spiketrain.metrics import compute_num_spikes - res = namedtuple("rp_violations", ["rp_contamination", "rp_violations"]) if not HAVE_NUMBA: @@ -332,16 +357,18 @@ def compute_refrac_period_violations( return None sorting = sorting_analyzer.sorting - fs = sorting_analyzer.sampling_frequency - num_units = len(sorting_analyzer.unit_ids) - num_segments = sorting_analyzer.get_num_segments() + sorting = sorting.select_periods(periods=periods) + + fs = sorting.sampling_frequency + num_units = len(sorting.unit_ids) + num_segments = sorting.get_num_segments() spikes = sorting.to_spike_vector(concatenated=False) if unit_ids is None: - unit_ids = sorting_analyzer.unit_ids + unit_ids = sorting.unit_ids - num_spikes = compute_num_spikes(sorting_analyzer) + num_spikes = sorting.count_num_spikes_per_unit() t_c = int(round(censored_period_ms * fs * 1e-3)) t_r = int(round(refractory_period_ms * fs * 1e-3)) @@ -352,7 +379,7 @@ def compute_refrac_period_violations( spike_labels = spikes[seg_index]["unit_index"].astype(np.int32) _compute_rp_violations_numba(nb_rp_violations, spike_times, spike_labels, t_c, t_r) - T = sorting_analyzer.get_total_samples() + total_samples = compute_total_samples_per_unit(sorting_analyzer, periods=periods) nb_violations = {} rp_contamination = {} @@ -360,14 +387,15 @@ def compute_refrac_period_violations( for unit_index, unit_id in enumerate(sorting.unit_ids): if unit_id not in unit_ids: continue - - nb_violations[unit_id] = n_v = nb_rp_violations[unit_index] - N = num_spikes[unit_id] - if N == 0: - rp_contamination[unit_id] = np.nan - else: - D = 1 - n_v * (T - 2 * N * t_c) / (N**2 * (t_r - t_c)) - rp_contamination[unit_id] = 1 - math.sqrt(D) if D >= 0 else 1.0 + total_samples_unit = total_samples[unit_id] + nb_violations[unit_id] = nb_rp_violations[unit_index] + rp_contamination[unit_id] = _compute_rp_contamination_one_unit( + nb_rp_violations[unit_index], + num_spikes[unit_id], + total_samples_unit, + t_c, + t_r, + ) return res(rp_contamination, nb_violations) @@ -392,6 +420,7 @@ def compute_sliding_rp_violations( exclude_ref_period_below_ms=0.5, max_ref_period_ms=10, contamination_values=None, + periods=None, ): """ Compute sliding refractory period violations, a metric developed by IBL which computes @@ -417,6 +446,9 @@ def compute_sliding_rp_violations( Maximum refractory period to test in ms. contamination_values : 1d array or None, default: None The contamination values to test, If None, it is set to np.arange(0.5, 35, 0.5). + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -429,8 +461,10 @@ def compute_sliding_rp_violations( This code was adapted from: https://github.com/SteinmetzLab/slidingRefractory/blob/1.0.0/python/slidingRP/metrics.py """ - duration = sorting_analyzer.get_total_duration() + total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) + if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() @@ -455,6 +489,7 @@ def compute_sliding_rp_violations( contamination[unit_id] = np.nan continue + duration = total_durations[unit_id] contamination[unit_id] = slidingRP_violations( spike_train_list, fs, @@ -486,7 +521,7 @@ class SlidingRPViolation(BaseMetric): } -def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None): +def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None, periods=None): """ Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of spikes at the exact same sample index, with synchrony sizes 2, 4 and 8. @@ -504,6 +539,9 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N ------- sync_spike_{X} : dict The synchrony metric for synchrony size X. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. References ---------- @@ -520,6 +558,7 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes]) sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting.unit_ids @@ -556,7 +595,7 @@ class Synchrony(BaseMetric): } -def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percentiles=(5, 95)): +def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percentiles=(5, 95), periods=None): """ Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution computed in non-overlapping time bins. @@ -571,6 +610,9 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent The size of the bin in seconds. percentiles : tuple, default: (5, 95) The percentiles to compute. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -584,6 +626,11 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent sampling_frequency = sorting_analyzer.sampling_frequency bin_size_samples = int(bin_size_s * sampling_frequency) sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) + segment_samples = [ + sorting_analyzer.get_num_samples(segment_index) for segment_index in range(sorting_analyzer.get_num_segments()) + ] + if unit_ids is None: unit_ids = sorting.unit_ids @@ -598,15 +645,25 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent # for each segment, we compute the firing rate histogram and we concatenate them firing_rate_histograms = {unit_id: np.array([], dtype=float) for unit_id in sorting.unit_ids} - for segment_index in range(sorting_analyzer.get_num_segments()): - num_samples = sorting_analyzer.get_num_samples(segment_index) - edges = np.arange(0, num_samples + 1, bin_size_samples) + bin_edges_per_unit = compute_bin_edges_per_unit( + sorting, + segment_samples=segment_samples, + periods=periods, + bin_duration_s=bin_size_s, + ) + for unit_id in unit_ids: + bin_edges = bin_edges_per_unit[unit_id] - for unit_id in unit_ids: - spike_times = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - spike_counts, _ = np.histogram(spike_times, bins=edges) - firing_rates = spike_counts / bin_size_s - firing_rate_histograms[unit_id] = np.concatenate((firing_rate_histograms[unit_id], firing_rates)) + # we can concatenate spike trains across segments adding the cumulative number of samples + # as offset, since bin edges are already cumulative + for segment_index in range(sorting_analyzer.get_num_segments()): + st = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) + st = st + np.sum(segment_samples[:segment_index]) + spike_train.append(st) + spike_train = np.concatenate(spike_train) + + spike_counts, _ = np.histogram(spike_train, bins=bin_edges) + firing_rate_histograms[unit_id] = spike_counts / bin_size_s # finally we compute the percentiles firing_ranges = {} @@ -635,6 +692,7 @@ def compute_amplitude_cv_metrics( percentiles=(5, 95), min_num_bins=10, amplitude_extension="spike_amplitudes", + periods=None, ): """ Calculate coefficient of variation of spike amplitudes within defined temporal bins. @@ -658,6 +716,8 @@ def compute_amplitude_cv_metrics( the median and range are set to NaN. amplitude_extension : str, default: "spike_amplitudes" The name of the extension to load the amplitudes from. "spike_amplitudes" or "amplitude_scalings". + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) Returns ------- @@ -676,14 +736,15 @@ def compute_amplitude_cv_metrics( "spike_amplitudes", "amplitude_scalings", ), "Invalid amplitude_extension. It can be either 'spike_amplitudes' or 'amplitude_scalings'" - sorting = sorting_analyzer.sorting - total_duration = sorting_analyzer.get_total_duration() - spikes = sorting.to_spike_vector() - num_spikes = sorting.count_num_spikes_per_unit(outputs="dict") if unit_ids is None: unit_ids = sorting.unit_ids + sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) - amps = sorting_analyzer.get_extension(amplitude_extension).get_data() + spikes = sorting.to_spike_vector() + total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) + num_spikes = sorting.count_num_spikes_per_unit(outputs="dict") + amps = sorting_analyzer.get_extension(amplitude_extension).get_data(periods=periods) # precompute segment slice segment_slices = [] @@ -695,6 +756,7 @@ def compute_amplitude_cv_metrics( all_unit_ids = list(sorting.unit_ids) amplitude_cv_medians, amplitude_cv_ranges = {}, {} for unit_id in unit_ids: + total_duration = total_durations[unit_id] firing_rate = num_spikes[unit_id] / total_duration temporal_bin_size_samples = int( (average_num_spikes_per_bin / firing_rate) * sorting_analyzer.sampling_frequency @@ -752,6 +814,7 @@ def compute_amplitude_cutoffs( num_histogram_bins=500, histogram_smoothing_value=3, amplitudes_bins_min_ratio=5, + periods=None, ): """ Calculate approximate fraction of spikes missing from a distribution of amplitudes. @@ -770,6 +833,9 @@ def compute_amplitude_cutoffs( The minimum ratio between number of amplitudes for a unit and the number of bins. If the ratio is less than this threshold, the amplitude_cutoff for the unit is set to NaN. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -805,13 +871,12 @@ def compute_amplitude_cutoffs( invert_amplitudes = True extension = sorting_analyzer.get_extension("amplitude_scalings") - amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True) + amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, periods=periods) for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] if invert_amplitudes: amplitudes = -amplitudes - all_fraction_missing[unit_id] = amplitude_cutoff( amplitudes, num_histogram_bins, histogram_smoothing_value, amplitudes_bins_min_ratio ) @@ -837,7 +902,7 @@ class AmplitudeCutoff(BaseMetric): depend_on = ["spike_amplitudes|amplitude_scalings"] -def compute_amplitude_medians(sorting_analyzer, unit_ids=None): +def compute_amplitude_medians(sorting_analyzer, unit_ids=None, periods=None): """ Compute median of the amplitude distributions (in absolute value). @@ -847,6 +912,9 @@ def compute_amplitude_medians(sorting_analyzer, unit_ids=None): A SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the amplitude medians. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -865,7 +933,7 @@ def compute_amplitude_medians(sorting_analyzer, unit_ids=None): all_amplitude_medians = {} amplitude_extension = sorting_analyzer.get_extension("spike_amplitudes") - amplitudes_by_units = amplitude_extension.get_data(outputs="by_unit", concatenated=True) + amplitudes_by_units = amplitude_extension.get_data(outputs="by_unit", concatenated=True, periods=periods) for unit_id in unit_ids: all_amplitude_medians[unit_id] = np.median(amplitudes_by_units[unit_id]) @@ -882,7 +950,9 @@ class AmplitudeMedian(BaseMetric): depend_on = ["spike_amplitudes"] -def compute_noise_cutoffs(sorting_analyzer, unit_ids=None, high_quantile=0.25, low_quantile=0.1, n_bins=100): +def compute_noise_cutoffs( + sorting_analyzer, unit_ids=None, high_quantile=0.25, low_quantile=0.1, n_bins=100, periods=None +): """ A metric to determine if a unit's amplitude distribution is cut off as it approaches zero, without assuming a Gaussian distribution. @@ -906,6 +976,9 @@ def compute_noise_cutoffs(sorting_analyzer, unit_ids=None, high_quantile=0.25, l Quantile of the amplitude range below which values are treated as "low" (e.g. 0.1 = lower 10%), the test region. n_bins: int, default: 100 The number of bins to use to compute the amplitude histogram. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -934,7 +1007,7 @@ def compute_noise_cutoffs(sorting_analyzer, unit_ids=None, high_quantile=0.25, l invert_amplitudes = True extension = sorting_analyzer.get_extension("amplitude_scalings") - amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True) + amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, periods=periods) for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] @@ -972,6 +1045,7 @@ def compute_drift_metrics( min_fraction_valid_intervals=0.5, min_num_bins=2, return_positions=False, + periods=None, ): """ Compute drifts metrics using estimated spike locations. @@ -1006,6 +1080,9 @@ def compute_drift_metrics( min_num_bins : int, default: 2 Minimum number of bins required to return a valid metric value. In case there are less bins, the metric values are set to NaN. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. return_positions : bool, default: False If True, median positions are returned (for debugging). @@ -1028,35 +1105,18 @@ def compute_drift_metrics( check_has_required_extensions("drift", sorting_analyzer) res = namedtuple("drift_metrics", ["drift_ptp", "drift_std", "drift_mad"]) sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting.unit_ids spike_locations_ext = sorting_analyzer.get_extension("spike_locations") - spike_locations = spike_locations_ext.get_data() - # spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit") - spikes = sorting.to_spike_vector() - spike_locations_by_unit = {} - for unit_id in unit_ids: - unit_index = sorting.id_to_index(unit_id) - # TODO @alessio this is very slow this sjould be done with spike_vector_to_indices() in code - spike_mask = spikes["unit_index"] == unit_index - spike_locations_by_unit[unit_id] = spike_locations[spike_mask] + spike_locations = spike_locations_ext.get_data(periods=periods) + spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit", concatenated=True, periods=periods) - interval_samples = int(interval_s * sorting_analyzer.sampling_frequency) + segment_samples = [sorting_analyzer.get_num_samples(i) for i in range(sorting_analyzer.get_num_segments())] assert direction in spike_locations.dtype.names, ( f"Direction {direction} is invalid. Available directions: " f"{spike_locations.dtype.names}" ) - total_duration = sorting_analyzer.get_total_duration() - if total_duration < min_num_bins * interval_s: - warnings.warn( - "The recording is too short given the specified 'interval_s' and " - "'min_num_bins'. Drift metrics will be set to NaN" - ) - empty_dict = {unit_id: np.nan for unit_id in unit_ids} - if return_positions: - return res(empty_dict, empty_dict, empty_dict), np.nan - else: - return res(empty_dict, empty_dict, empty_dict) # we need drift_ptps = {} @@ -1064,45 +1124,58 @@ def compute_drift_metrics( drift_mads = {} # reference positions are the medians across segments - reference_positions = np.zeros(len(unit_ids)) - for i, unit_id in enumerate(unit_ids): - unit_ind = sorting.id_to_index(unit_id) - reference_positions[i] = np.median(spike_locations_by_unit[unit_id][direction]) + reference_positions = {} + for unit_id in unit_ids: + reference_positions[unit_id] = np.median(spike_locations_by_unit[unit_id][direction]) # now compute median positions and concatenate them over segments - median_position_segments = None + spike_vector = sorting.to_spike_vector() + spike_sample_indices = spike_vector["sample_index"] + # we need to add the cumulative sum of segment samples to have global sample indices + cumulative_segment_samples = np.cumsum([0] + segment_samples[:-1]) for segment_index in range(sorting_analyzer.get_num_segments()): - seg_length = sorting_analyzer.get_num_samples(segment_index) - num_bin_edges = seg_length // interval_samples + 1 - bins = np.arange(num_bin_edges) * interval_samples - spike_vector = sorting.to_spike_vector() - - # retrieve spikes in segment - i0, i1 = np.searchsorted(spike_vector["segment_index"], [segment_index, segment_index + 1]) - spikes_in_segment = spike_vector[i0:i1] - spike_locations_in_segment = spike_locations[i0:i1] - - # compute median positions (if less than min_spikes_per_interval, median position is 0) - median_positions = np.nan * np.zeros((len(unit_ids), num_bin_edges - 1)) - for bin_index, (start_frame, end_frame) in enumerate(zip(bins[:-1], bins[1:])): - i0, i1 = np.searchsorted(spikes_in_segment["sample_index"], [start_frame, end_frame]) - spikes_in_bin = spikes_in_segment[i0:i1] - spike_locations_in_bin = spike_locations_in_segment[i0:i1][direction] - - for i, unit_id in enumerate(unit_ids): - unit_ind = sorting.id_to_index(unit_id) - mask = spikes_in_bin["unit_index"] == unit_ind - if np.sum(mask) >= min_spikes_per_interval: - median_positions[i, bin_index] = np.median(spike_locations_in_bin[mask]) - if median_position_segments is None: - median_position_segments = median_positions - else: - median_position_segments = np.hstack((median_position_segments, median_positions)) + seg_mask = spike_vector["segment_index"] == segment_index + spike_sample_indices[seg_mask] += cumulative_segment_samples[segment_index] + + bin_edges_for_units = compute_bin_edges_per_unit( + sorting, + segment_samples=segment_samples, + periods=periods, + bin_duration_s=interval_s, + ) - # finally, compute deviations and drifts - position_diffs = median_position_segments - reference_positions[:, None] - for i, unit_id in enumerate(unit_ids): - position_diff = position_diffs[i] + median_positions_per_unit = {} + for unit in unit_ids: + bins = bin_edges_for_units[unit] + num_bins = len(bins) - 1 + if num_bins < min_num_bins: + warnings.warn( + f"Unit {unit} has only {num_bins} bins given the specified 'interval_s' and " + f"'min_num_bins'. Drift metrics will be set to NaN" + ) + drift_ptps[unit] = np.nan + drift_stds[unit] = np.nan + drift_mads[unit] = np.nan + continue + + # bin_edges are global across segments, so we have to use spike_sample_indices, + # since we offseted them to be global + bin_spike_indices = np.searchsorted(spike_sample_indices, bins) + median_positions = np.nan * np.zeros(num_bins) + for bin_index, (i0, i1) in enumerate(zip(bin_spike_indices[:-1], bin_spike_indices[1:])): + spikes_in_bin = spike_vector[i0:i1] + spike_locations_in_bin = spike_locations[i0:i1][direction] + + unit_index = sorting_analyzer.sorting.id_to_index(unit) + mask = spikes_in_bin["unit_index"] == unit_index + if np.sum(mask) >= min_spikes_per_interval: + median_positions[bin_index] = np.median(spike_locations_in_bin[mask]) + else: + median_positions[bin_index] = np.nan + median_positions_per_unit[unit] = median_positions + + # now compute deviations and drifts for this unit + position_diff = median_positions - reference_positions[unit_id] if np.any(np.isnan(position_diff)): # deal with nans: if more than 50% nans --> set to nan if np.sum(np.isnan(position_diff)) > min_fraction_valid_intervals * len(position_diff): @@ -1120,8 +1193,9 @@ def compute_drift_metrics( drift_ptps[unit_id] = ptp_drift drift_stds[unit_id] = std_drift drift_mads[unit_id] = mad_drift + if return_positions: - outs = res(drift_ptps, drift_stds, drift_mads), median_positions + outs = res(drift_ptps, drift_stds, drift_mads), median_positions_per_unit else: outs = res(drift_ptps, drift_stds, drift_mads) return outs @@ -1151,6 +1225,7 @@ def compute_sd_ratio( censored_period_ms: float = 4.0, correct_for_drift: bool = True, correct_for_template_itself: bool = True, + periods=None, **kwargs, ): """ @@ -1173,6 +1248,9 @@ def compute_sd_ratio( correct_for_template_itself : bool, default: True If true, will take into account that the template itself impacts the standard deviation of the noise, and will make a rough estimation of what that impact is (and remove it). + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. **kwargs : dict, default: {} Keyword arguments for computing spike amplitudes and extremum channel. @@ -1189,6 +1267,7 @@ def compute_sd_ratio( job_kwargs = fix_job_kwargs(job_kwargs) sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) censored_period = int(round(censored_period_ms * 1e-3 * sorting_analyzer.sampling_frequency)) if unit_ids is None: @@ -1201,7 +1280,7 @@ def compute_sd_ratio( ) return {unit_id: np.nan for unit_id in unit_ids} - spike_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() + spike_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data(periods=periods) if not HAVE_NUMBA: warnings.warn( @@ -1330,7 +1409,7 @@ def check_has_required_extensions(metric_name, sorting_analyzer): ### LOW-LEVEL FUNCTIONS ### -def presence_ratio(spike_train, total_length, bin_edges=None, num_bin_edges=None, bin_n_spikes_thres=0): +def presence_ratio(spike_train, bin_edges=None, num_bin_edges=None, bin_n_spikes_thres=0): """ Calculate the presence ratio for a single unit. @@ -1338,8 +1417,6 @@ def presence_ratio(spike_train, total_length, bin_edges=None, num_bin_edges=None ---------- spike_train : np.ndarray Spike times for this unit, in samples. - total_length : int - Total length of the recording in samples. bin_edges : np.array, optional Pre-computed bin edges (mutually exclusive with num_bin_edges). num_bin_edges : int, optional @@ -1569,6 +1646,46 @@ def slidingRP_violations( return min_cont_with_90_confidence +def _compute_rp_contamination_one_unit( + n_v, + n_spikes, + total_samples, + t_c, + t_r, +): + """ + Compute the refractory period contamination for one unit. + + Parameters + ---------- + n_v : int + Number of refractory period violations. + n_spikes : int + Number of spikes for the unit. + total_samples : int + Total number of samples in the recording. + t_c : int + Censored period in samples. + t_r : int + Refractory period in samples. + + Returns + ------- + rp_contamination : float + The refractory period contamination for the unit. + """ + if n_spikes <= 1: + return np.nan + + denom = 1 - n_v * (total_samples - 2 * n_spikes * t_c) / (n_spikes**2 * (t_r - t_c)) + if denom < 0: + return 1.0 + + rp_contamination = 1 - math.sqrt(denom) + + return rp_contamination + + def _compute_violations(obs_viol, firing_rate, spike_count, ref_period_dur, contamination_prop): contamination_rate = firing_rate * contamination_prop expected_viol = contamination_rate * ref_period_dur * 2 * spike_count @@ -1709,29 +1826,6 @@ def _get_synchrony_counts(spikes, synchrony_sizes, all_unit_ids): return synchrony_counts -def _get_amplitudes_by_units(sorting_analyzer, unit_ids, peak_sign): - # used by compute_amplitude_cutoffs and compute_amplitude_medians - - if (spike_amplitudes_extension := sorting_analyzer.get_extension("spike_amplitudes")) is not None: - return spike_amplitudes_extension.get_data(outputs="by_unit", concatenated=True) - - elif sorting_analyzer.has_extension("waveforms"): - amplitudes_by_units = {} - waveforms_ext = sorting_analyzer.get_extension("waveforms") - before = waveforms_ext.nbefore - extremum_channels_ids = get_template_extremum_channel(sorting_analyzer, peak_sign=peak_sign) - for unit_id in unit_ids: - waveforms = waveforms_ext.get_waveforms_one_unit(unit_id, force_dense=False) - chan_id = extremum_channels_ids[unit_id] - if sorting_analyzer.is_sparse(): - chan_ind = np.where(sorting_analyzer.sparsity.unit_id_to_channel_ids[unit_id] == chan_id)[0] - else: - chan_ind = sorting_analyzer.channel_ids_to_indices([chan_id])[0] - amplitudes_by_units[unit_id] = waveforms[:, before, chan_ind] - - return amplitudes_by_units - - if HAVE_NUMBA: import numba From 807f5c61115505eb9f2648cb94cc411f6aafad38 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 14 Jan 2026 13:32:01 +0100 Subject: [PATCH 40/70] Add tests for user defined and combined --- .../tests/test_valid_unit_periods.py | 84 ++++++++++++++++++- .../postprocessing/valid_unit_periods.py | 33 ++++++-- 2 files changed, 108 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py b/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py index 1171f62e94..6f4c25f415 100644 --- a/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py +++ b/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py @@ -1,6 +1,7 @@ import pytest import numpy as np -from spikeinterface.core.node_pipeline import unit_period_dtype + +from spikeinterface.core.base import unit_period_dtype from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite from spikeinterface.postprocessing import ComputeValidUnitPeriods @@ -51,3 +52,84 @@ def test_user_defined_periods(self): # check that valid periods correspond to user defined periods ext_periods = ext.get_data(outputs="numpy") np.testing.assert_array_equal(ext_periods, periods) + + def test_user_defined_periods_as_arrays(self): + unit_ids = self.sorting.unit_ids + num_segments = self.sorting.get_num_segments() + + # unit periods of unit_period_dtypes + periods_array = np.zeros((len(unit_ids) * num_segments, 4), dtype=int) + + # for each unit we 1 valid period per segment + for i, unit_id in enumerate(unit_ids): + unit_index = self.sorting.id_to_index(unit_id) + for segment_index in range(num_segments): + num_samples = self.recording.get_num_samples(segment_index=segment_index) + idx = i * num_segments + segment_index + period_start = num_samples // 4 + period_duration = num_samples // 2 + periods_array[idx, 0] = segment_index + periods_array[idx, 1] = period_start + periods_array[idx, 2] = period_start + period_duration + periods_array[idx, 3] = unit_index + + sorting_analyzer = self._prepare_sorting_analyzer( + "memory", sparse=False, extension_class=ComputeValidUnitPeriods + ) + ext = sorting_analyzer.compute( + ComputeValidUnitPeriods.extension_name, + method="user_defined", + user_defined_periods=periods_array, + minimum_valid_period_duration=1, + ) + # check that valid periods correspond to user defined periods + ext_periods = ext.get_data(outputs="numpy") + ext_periods = np.column_stack([ext_periods[field] for field in ext_periods.dtype.names]) + np.testing.assert_array_equal(ext_periods, periods_array) + + # test that dropping segment_index raises because multi-segment + with pytest.raises(ValueError): + ext = sorting_analyzer.compute( + ComputeValidUnitPeriods.extension_name, + method="user_defined", + user_defined_periods=periods_array[:, 1:4], # drop segment_index + minimum_valid_period_duration=1, + ) + + def test_combined_periods(self): + unit_ids = self.sorting.unit_ids + num_segments = self.sorting.get_num_segments() + + # unit periods of unit_period_dtypes + periods = np.zeros(len(unit_ids) * num_segments, dtype=unit_period_dtype) + + # for each unit we 1 valid period per segment + for i, unit_id in enumerate(unit_ids): + unit_index = self.sorting.id_to_index(unit_id) + for segment_index in range(num_segments): + num_samples = self.recording.get_num_samples(segment_index=segment_index) + idx = i * num_segments + segment_index + periods[idx]["unit_index"] = unit_index + period_start = num_samples // 4 + period_duration = num_samples // 2 + periods[idx]["start_sample_index"] = period_start + periods[idx]["end_sample_index"] = period_start + period_duration + periods[idx]["segment_index"] = segment_index + + sorting_analyzer = self._prepare_sorting_analyzer( + "memory", + sparse=False, + extension_class=ComputeValidUnitPeriods, + extra_dependencies=["templates", "amplitude_scalings"], + ) + ext = sorting_analyzer.compute( + ComputeValidUnitPeriods.extension_name, + method="combined", + user_defined_periods=periods, + period_mode="absolute", + period_duration_s_absolute=1.0, + minimum_valid_period_duration=1, + ) + # check that valid periods correspond to intersection of auto-computed and user defined periods + ext_periods = ext.get_data(outputs="numpy") + assert len(ext_periods) <= len(periods) # should be less or equal than user defined ones diff --git a/src/spikeinterface/postprocessing/valid_unit_periods.py b/src/spikeinterface/postprocessing/valid_unit_periods.py index 808e252c9c..f1f6717b65 100644 --- a/src/spikeinterface/postprocessing/valid_unit_periods.py +++ b/src/spikeinterface/postprocessing/valid_unit_periods.py @@ -12,9 +12,9 @@ from threadpoolctl import threadpool_limits from tqdm.auto import tqdm +from spikeinterface.core.base import unit_period_dtype from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension -from spikeinterface.core.node_pipeline import unit_period_dtype from spikeinterface.metrics.spiketrain import compute_firing_rates numba_spec = importlib.util.find_spec("numba") @@ -61,8 +61,8 @@ class ComputeValidUnitPeriods(AnalyzerExtension): Minimum duration that detected good periods must have to be kept, in seconds. user_defined_periods : array of unit_period_dtype or shape (num_periods, 3) or (num_periods, 4) or None, default: None Periods of unit_period_dtype (segment_index, start_sample_index, end_sample_index, unit_index) - or numpy array of shape (num_periods, 3) [unit_index, start_sample, end_sample] - or (num_periods, 4) [unit_index, segment_index, start_sample, end_sample] + or numpy array of shape (num_periods, 3) [start_sample, end_sample, unit_index] + or (num_periods, 4) [segment_index, start_sample, end_sample, unit_index] in samples, over which to compute the metric. refractory_period_ms : float, default: 0.8 Refractory period duration for violation detection (ms). @@ -156,12 +156,15 @@ def _set_params( user_defined_periods = user_defined_periods.astype(int) if user_defined_periods.shape[1] == 3: - # add segment index 0 as column 1 if missing + if self.sorting_analyzer.get_num_segments() > 1: + raise ValueError( + "For multi-segment recordings, user_defined_periods must include segment_index as column 1." + ) + # add segment index 0 as column 0 if missing user_defined_periods = np.hstack( ( - user_defined_periods[:, 0:1], np.zeros((user_defined_periods.shape[0], 1), dtype=int), - user_defined_periods[:, 1:3], + user_defined_periods, ) ) # Cast user defined periods to unit_period_dtype @@ -444,8 +447,10 @@ def _compute_valid_periods(self, sorting_analyzer, unit_ids=None, **job_kwargs): # Combine with user-defined periods if provided if self.params["method"] == "combined": - user_defined_periods = self.params["user_defined_periods"] - all_periods = np.concatenate((valid_unit_periods, user_defined_periods), axis=0) + user_defined_periods = self.user_defined_periods + valid_unit_periods = self._sort_periods( + np.concatenate((valid_unit_periods, user_defined_periods), axis=0) + ) valid_unit_periods = merge_overlapping_periods_across_units_and_segments(valid_unit_periods) # Remove good periods that are too short @@ -513,6 +518,18 @@ def _sort_periods(self, periods): sorted_periods = periods[sort_idx] return sorted_periods + def set_data(self, ext_data_name, ext_data): + # cast back lists of dicts (required for dumping) back to arrays + if ext_data_name in ("period_centers", "periods_fp_per_unit", "periods_fn_per_unit"): + ext_data = [] + # lists of dicts to lists of dicts with arrays + for segment_dict in ext_data: + segment_dict_arrays = {} + for unit_id, values in segment_dict.items(): + segment_dict_arrays[unit_id] = np.array(values) + ext_data.append(segment_dict_arrays) + self.data[ext_data_name] = ext_data + def compute_subperiods( sorting_analyzer, From 89d563b999af030bb308ee8aaf0dbc517ab6046c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 14 Jan 2026 14:39:14 +0100 Subject: [PATCH 41/70] Add to built_in extensions --- src/spikeinterface/core/sortinganalyzer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 0c3b6b9615..61fcb3d0f2 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -2816,6 +2816,7 @@ def set_data(self, ext_data_name, ext_data): "spike_locations": "spikeinterface.postprocessing", "template_similarity": "spikeinterface.postprocessing", "unit_locations": "spikeinterface.postprocessing", + "valid_unit_periods": "spikeinterface.postprocessing", # from metrics "quality_metrics": "spikeinterface.metrics", "template_metrics": "spikeinterface.metrics", From 50f33f0c8a1e94f7e5fcf7b3aa14058e4a836d4e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 14 Jan 2026 15:44:53 +0100 Subject: [PATCH 42/70] fix tests --- .../metrics/quality/misc_metrics.py | 20 ++++++++++--------- .../metrics/quality/tests/conftest.py | 2 +- .../quality/tests/test_metrics_functions.py | 2 +- src/spikeinterface/metrics/utils.py | 2 +- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 2d90493756..b0791d00d7 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -111,6 +111,7 @@ def compute_presence_ratios( continue total_duration = total_durations[unit_id] + spike_train = [] for segment_index in range(num_segs): st = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) st = st + np.sum(segment_samples[:segment_index]) @@ -656,6 +657,7 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent # we can concatenate spike trains across segments adding the cumulative number of samples # as offset, since bin edges are already cumulative + spike_train = [] for segment_index in range(sorting_analyzer.get_num_segments()): st = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) st = st + np.sum(segment_samples[:segment_index]) @@ -737,7 +739,7 @@ def compute_amplitude_cv_metrics( "amplitude_scalings", ), "Invalid amplitude_extension. It can be either 'spike_amplitudes' or 'amplitude_scalings'" if unit_ids is None: - unit_ids = sorting.unit_ids + unit_ids = sorting_analyzer.unit_ids sorting = sorting_analyzer.sorting sorting = sorting.select_periods(periods=periods) @@ -1145,17 +1147,17 @@ def compute_drift_metrics( ) median_positions_per_unit = {} - for unit in unit_ids: - bins = bin_edges_for_units[unit] + for unit_id in unit_ids: + bins = bin_edges_for_units[unit_id] num_bins = len(bins) - 1 if num_bins < min_num_bins: warnings.warn( - f"Unit {unit} has only {num_bins} bins given the specified 'interval_s' and " + f"Unit {unit_id} has only {num_bins} bins given the specified 'interval_s' and " f"'min_num_bins'. Drift metrics will be set to NaN" ) - drift_ptps[unit] = np.nan - drift_stds[unit] = np.nan - drift_mads[unit] = np.nan + drift_ptps[unit_id] = np.nan + drift_stds[unit_id] = np.nan + drift_mads[unit_id] = np.nan continue # bin_edges are global across segments, so we have to use spike_sample_indices, @@ -1166,13 +1168,13 @@ def compute_drift_metrics( spikes_in_bin = spike_vector[i0:i1] spike_locations_in_bin = spike_locations[i0:i1][direction] - unit_index = sorting_analyzer.sorting.id_to_index(unit) + unit_index = sorting_analyzer.sorting.id_to_index(unit_id) mask = spikes_in_bin["unit_index"] == unit_index if np.sum(mask) >= min_spikes_per_interval: median_positions[bin_index] = np.median(spike_locations_in_bin[mask]) else: median_positions[bin_index] = np.nan - median_positions_per_unit[unit] = median_positions + median_positions_per_unit[unit_id] = median_positions # now compute deviations and drifts for this unit position_diff = median_positions - reference_positions[unit_id] diff --git a/src/spikeinterface/metrics/quality/tests/conftest.py b/src/spikeinterface/metrics/quality/tests/conftest.py index c2a6c6fe82..5313e763c1 100644 --- a/src/spikeinterface/metrics/quality/tests/conftest.py +++ b/src/spikeinterface/metrics/quality/tests/conftest.py @@ -10,7 +10,7 @@ def make_small_analyzer(): recording, sorting = generate_ground_truth_recording( - durations=[2.0], + durations=[10.0], num_units=10, seed=1205, ) diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index b4f956e6a7..8b6e67d119 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -223,7 +223,7 @@ def test_unit_structure_in_output(small_sorting_analyzer): "amplitude_cv": {"average_num_spikes_per_bin": 7, "min_num_bins": 3}, "firing_range": {"bin_size_s": 1}, "isi_violation": {"isi_threshold_ms": 10}, - "drift": {"interval_s": 1, "min_spikes_per_interval": 5}, + "drift": {"interval_s": 1, "min_spikes_per_interval": 5, "min_fraction_valid_intervals": 0.2}, "sliding_rp_violation": {"max_ref_period_ms": 50, "bin_size_ms": 0.15}, "rp_violation": {"refractory_period_ms": 10.0, "censored_period_ms": 0.0}, } diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py index 16058e521f..91538498aa 100644 --- a/src/spikeinterface/metrics/utils.py +++ b/src/spikeinterface/metrics/utils.py @@ -40,7 +40,7 @@ def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, per else: total_length = np.sum(segment_samples) for unit_id in sorting.unit_ids: - bin_edges_for_units[unit_id] = np.arange(0, total_length, bin_duration_samples) * bin_duration_samples + bin_edges_for_units[unit_id] = np.arange(0, total_length, bin_duration_samples) return bin_edges_for_units From f2d48bab694ec0c4b8ae2c7a420e274065875f86 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 14 Jan 2026 17:38:05 +0100 Subject: [PATCH 43/70] Remove debug print --- src/spikeinterface/postprocessing/valid_unit_periods.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spikeinterface/postprocessing/valid_unit_periods.py b/src/spikeinterface/postprocessing/valid_unit_periods.py index f1f6717b65..5da7c3559d 100644 --- a/src/spikeinterface/postprocessing/valid_unit_periods.py +++ b/src/spikeinterface/postprocessing/valid_unit_periods.py @@ -531,6 +531,7 @@ def set_data(self, ext_data_name, ext_data): self.data[ext_data_name] = ext_data +# TODO: deal with margin when returning periods def compute_subperiods( sorting_analyzer, period_duration_s_absolute: float = 10, @@ -569,9 +570,6 @@ def compute_subperiods( for segment_index in range(sorting.get_num_segments()): n_samples = sorting_analyzer.get_num_samples(segment_index) # int: samples - print( - f"Num samples segment {segment_index}: {n_samples} - period size: {period_size_samples} - margin size: {margin_size_samples}" - ) # We round the number of subperiods to ensure coverage of the entire recording # the end of the last period is then clipped or extended to the end of the recording n_subperiods = round(n_samples / period_size_samples) From e173a63a4d7843be2fcea0609041757f73c926d1 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 15 Jan 2026 14:56:04 +0100 Subject: [PATCH 44/70] wip: fix intervals --- .../postprocessing/valid_unit_periods.py | 101 +++++++++++------- 1 file changed, 62 insertions(+), 39 deletions(-) diff --git a/src/spikeinterface/postprocessing/valid_unit_periods.py b/src/spikeinterface/postprocessing/valid_unit_periods.py index 5da7c3559d..b7e69a9291 100644 --- a/src/spikeinterface/postprocessing/valid_unit_periods.py +++ b/src/spikeinterface/postprocessing/valid_unit_periods.py @@ -372,7 +372,7 @@ def _compute_valid_periods(self, sorting_analyzer, unit_ids=None, **job_kwargs): elif self.params["method"] in ["false_positives_and_negatives", "combined"]: # dict: unit_id -> list of subperiod, each subperiod is an array of dtype unit_period_dtype with 4 fields - all_periods = compute_subperiods( + all_periods, all_periods_w_margins, period_centers = compute_subperiods( sorting_analyzer, self.params["period_duration_s_absolute"], self.params["period_target_num_spikes"], @@ -396,7 +396,7 @@ def _compute_valid_periods(self, sorting_analyzer, unit_ids=None, **job_kwargs): init_args = (sorting_analyzer.sorting, all_amplitudes_by_unit, self.params, max_threads_per_worker) # Each item is one computation of fp and fn for one period and one unit - items = [(period,) for period in all_periods] + items = [(period,) for period in all_periods_w_margins] job_name = f"computing false positives and negatives" # parallel @@ -460,16 +460,16 @@ def _compute_valid_periods(self, sorting_analyzer, unit_ids=None, **job_kwargs): valid_mask = duration_samples >= min_valid_period_samples valid_unit_periods = valid_unit_periods[valid_mask] - # Convert subperiods per unit in period_centers_s - period_centers = [] - for segment_index in range(sorting_analyzer.sorting.get_num_segments()): - periods_segment = all_periods[all_periods["segment_index"] == segment_index] - period_centers_dict = {} - for unit_index, unit_id in enumerate(sorting_analyzer.unit_ids): - periods_unit = periods_segment[periods_segment["unit_index"] == unit_index] - centers = list(0.5 * (periods_unit["start_sample_index"] + periods_unit["end_sample_index"])) - period_centers_dict[unit_id] = centers - period_centers.append(period_centers_dict) + # # Convert subperiods per unit in period_centers_s + # period_centers = [] + # for segment_index in range(sorting_analyzer.sorting.get_num_segments()): + # periods_segment = all_periods[all_periods["segment_index"] == segment_index] + # period_centers_dict = {} + # for unit_index, unit_id in enumerate(sorting_analyzer.unit_ids): + # periods_unit = periods_segment[periods_segment["unit_index"] == unit_index] + # centers = list(0.5 * (periods_unit["start_sample_index"] + periods_unit["end_sample_index"])) + # period_centers_dict[unit_id] = centers + # period_centers.append(period_centers_dict) # Store data: here we have to make sure every dict is JSON serializable, so everything is lists return valid_unit_periods, period_centers, fps, fns @@ -564,12 +564,15 @@ def compute_subperiods( margin_sizes_samples = {u: np.round(relative_margin_size * period_sizes_samples[u]).astype(int) for u in unit_ids} all_subperiods = [] - for unit_index, unit_id in enumerate(unit_ids): - period_size_samples = period_sizes_samples[unit_id] - margin_size_samples = margin_sizes_samples[unit_id] - - for segment_index in range(sorting.get_num_segments()): - n_samples = sorting_analyzer.get_num_samples(segment_index) # int: samples + all_subperiods_w_margins = [] + all_period_centers = [] + for segment_index in range(sorting.get_num_segments()): + n_samples = sorting_analyzer.get_num_samples(segment_index) # int: samples + period_centers = {} + for unit_index, unit_id in enumerate(unit_ids): + period_centers[unit_id] = [] + period_size_samples = period_sizes_samples[unit_id] + margin_size_samples = margin_sizes_samples[unit_id] # We round the number of subperiods to ensure coverage of the entire recording # the end of the last period is then clipped or extended to the end of the recording n_subperiods = round(n_samples / period_size_samples) @@ -583,27 +586,47 @@ def compute_subperiods( for i in range(n_subperiods) ] ) - # remove periods whose end is above the expected number of samples - if len(starts_ends) > 0: - beyond_samples_mask = starts_ends[:, 1] > n_samples - if sum(beyond_samples_mask) == len(starts_ends): - # all periods end beyond n_samples: keep only first period - starts_ends = starts_ends[:1].copy() - else: - starts_ends = starts_ends[starts_ends[:, 1] <= n_subperiods * period_size_samples] - # set last period to the end of the recording - starts_ends[-1][1] = n_samples - - periods_for_unit = np.zeros(len(starts_ends), dtype=unit_period_dtype) - for i, (start, end) in enumerate(starts_ends): - subperiod = np.zeros((1,), dtype=unit_period_dtype) - subperiod["segment_index"] = segment_index - subperiod["start_sample_index"] = start - subperiod["end_sample_index"] = end - subperiod["unit_index"] = unit_index - periods_for_unit[i] = subperiod - all_subperiods.append(periods_for_unit) - return np.concatenate(all_subperiods) + starts = np.arange(0, n_samples, period_size_samples) + periods_for_unit = np.zeros(len(starts_ends), dtype=unit_period_dtype) + periods_for_unit_w_margins = np.zeros(len(starts_ends), dtype=unit_period_dtype) + for i, start in enumerate(starts): + end = start + period_size_samples + ext_start = max(0, start - margin_size_samples) + ext_end = min(n_samples, end + margin_size_samples) + center = start + period_size_samples // 2 + period_centers[unit_id].append((segment_index, unit_id, center)) + periods_for_unit[i]["segment_index"] = segment_index + periods_for_unit[i]["start_sample_index"] = start + periods_for_unit[i]["end_sample_index"] = end + periods_for_unit[i]["unit_index"] = unit_index + periods_for_unit_w_margins[i]["segment_index"] = segment_index + periods_for_unit_w_margins[i]["start_sample_index"] = ext_start + periods_for_unit_w_margins[i]["end_sample_index"] = ext_end + periods_for_unit_w_margins[i]["unit_index"] = unit_index + + # # remove periods whose end is above the expected number of samples + # if len(starts_ends) > 0: + # beyond_samples_mask = starts_ends[:, 1] > n_samples + # if sum(beyond_samples_mask) == len(starts_ends): + # # all periods end beyond n_samples: keep only first period + # starts_ends = starts_ends[:1].copy() + # else: + # starts_ends = starts_ends[starts_ends[:, 1] <= n_subperiods * period_size_samples] + # # set last period to the end of the recording + # starts_ends[-1][1] = n_samples + + # periods_for_unit = np.zeros(len(starts_ends), dtype=unit_period_dtype) + # for i, (start, end) in enumerate(starts_ends): + # subperiod = np.zeros((1,), dtype=unit_period_dtype) + # subperiod["segment_index"] = segment_index + # subperiod["start_sample_index"] = start + # subperiod["end_sample_index"] = end + # subperiod["unit_index"] = unit_index + # periods_for_unit[i] = subperiod + all_subperiods.append(periods_for_unit) + all_subperiods_w_margins.append(periods_for_unit_w_margins) + all_period_centers.append(period_centers) + return np.concatenate(all_subperiods), np.concatenate(all_subperiods_w_margins), all_period_centers def merge_overlapping_periods(subperiods): From 80bc50fa59071ac9640b73818f8708e0fae41757 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 15 Jan 2026 15:05:37 +0100 Subject: [PATCH 45/70] Change base_period_dtype order and fix select_sorting_periods array input --- src/spikeinterface/core/base.py | 2 +- src/spikeinterface/core/sorting_tools.py | 26 ++++++++++++++++--- .../core/tests/test_basesorting.py | 14 ++++++++-- 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 3505853835..4520c19819 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -42,9 +42,9 @@ minimum_spike_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] base_period_dtype = [ + ("segment_index", "int64"), ("start_sample_index", "int64"), ("end_sample_index", "int64"), - ("segment_index", "int64"), ] unit_period_dtype = base_period_dtype + [ diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index f5cf82c76f..d05cc33869 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -271,9 +271,11 @@ def select_sorting_periods(sorting: BaseSorting, periods) -> BaseSorting: Parameters ---------- - periods : numpy.array of unit_period_dtype + periods : numpy.ndarray Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to restrict the sorting. + on which to restrict the sorting. Periods can be either a numpy array of unit_period_dtype + or an array with (num_periods, 4) shape. In the latter case, the fields are assumed to be + in the order: segment_index, start_sample_index, end_sample_index, unit_index. Returns ------- @@ -286,7 +288,25 @@ def select_sorting_periods(sorting: BaseSorting, periods) -> BaseSorting: if periods is not None: if not isinstance(periods, np.ndarray): - periods = np.array([periods], dtype=unit_period_dtype) + raise ValueError("periods must be a numpy array") + if not periods.dtype == unit_period_dtype: + if periods.ndim != 2 or periods.shape[1] != 4: + raise ValueError( + "If periods is not of dtype unit_period_dtype, it must be a 2D array with shape (num_periods, 4)" + ) + warnings.warn( + "periods is not of dtype unit_period_dtype. Assuming fields are in order: " + "(segment_index, start_sample_index, end_sample_index, unit_index).", + UserWarning, + ) + # convert to structured array + periods_converted = np.empty(periods.shape[0], dtype=unit_period_dtype) + periods_converted["segment_index"] = periods[:, 0] + periods_converted["start_sample_index"] = periods[:, 1] + periods_converted["end_sample_index"] = periods[:, 2] + periods_converted["unit_index"] = periods[:, 3] + periods = periods_converted + required = set(np.dtype(unit_period_dtype).names) if not required.issubset(periods.dtype.names): raise ValueError(f"Period must have the following fields: {required}") diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 963320c2a1..ed1931e87a 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -225,7 +225,7 @@ def test_time_slice(): def test_select_periods(): sampling_frequency = 10_000.0 - duration = 1_000 + duration = 100 num_samples = int(sampling_frequency * duration) num_units = 1000 sorting = generate_sorting( @@ -235,7 +235,7 @@ def test_select_periods(): rng = np.random.default_rng() # number of random periods - n_periods = 10_000 + n_periods = 1_000 # generate random periods segment_indices = rng.integers(0, sorting.get_num_segments(), n_periods) start_samples = rng.integers(0, num_samples, n_periods) @@ -280,6 +280,16 @@ def test_select_periods(): spiketrain_sliced = sliced_sorting.get_unit_spike_train(segment_index=segment_index, unit_id=unit_id) assert len(spiketrain_in_periods) == len(spiketrain_sliced) + # now test with input as numpy array with shape (n_periods, 4) + periods_array = np.zeros((len(periods), 4), dtype="int64") + periods_array[:, 0] = periods["segment_index"] + periods_array[:, 1] = periods["start_sample_index"] + periods_array[:, 2] = periods["end_sample_index"] + periods_array[:, 3] = periods["unit_index"] + + sliced_sorting_array = sorting.select_periods(periods=periods_array) + np.testing.assert_array_equal(sliced_sorting.to_spike_vector(), sliced_sorting_array.to_spike_vector()) + if __name__ == "__main__": test_BaseSorting() From 96e6a5317e1e8a76b77d37e9a56f5e399ad5e1b1 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 15 Jan 2026 17:02:02 +0100 Subject: [PATCH 46/70] fix tests --- src/spikeinterface/metrics/quality/misc_metrics.py | 9 ++++----- src/spikeinterface/metrics/spiketrain/metrics.py | 7 ++++++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 9942c2f707..835465a4c1 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -382,10 +382,10 @@ def compute_refrac_period_violations( for segment_index in range(sorting_analyzer.get_num_segments()): spike_times = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - nb_rp_violations[unit_id] += _compute_rp_violations_numba(spike_times, t_c, t_r) + nb_violations[unit_id] += _compute_rp_violations_numba(spike_times, t_c, t_r) rp_contamination[unit_id] = _compute_rp_contamination_one_unit( - nb_rp_violations[unit_id], + nb_violations[unit_id], num_spikes[unit_id], total_samples_unit, t_c, @@ -1122,9 +1122,8 @@ def compute_drift_metrics( # we need to add the cumulative sum of segment samples to have global sample indices cumulative_segment_samples = np.cumsum([0] + segment_samples[:-1]) for segment_index in range(sorting_analyzer.get_num_segments()): - spike_sample_indices[sorting._get_spike_vector_segment_slices()[segment_index]] += cumulative_segment_samples[ - segment_index - ] + segment_slice = sorting._get_spike_vector_segment_slices()[segment_index] + spike_sample_indices[segment_slice[0] : segment_slice[1]] += cumulative_segment_samples[segment_index] bin_edges_for_units = compute_bin_edges_per_unit( sorting, diff --git a/src/spikeinterface/metrics/spiketrain/metrics.py b/src/spikeinterface/metrics/spiketrain/metrics.py index 600ae2e406..669733f47a 100644 --- a/src/spikeinterface/metrics/spiketrain/metrics.py +++ b/src/spikeinterface/metrics/spiketrain/metrics.py @@ -26,7 +26,12 @@ def compute_num_spikes(sorting_analyzer, unit_ids=None, periods=None): sorting = sorting.select_periods(periods) if unit_ids is None: unit_ids = sorting.unit_ids - return sorting.count_num_spikes_per_unit(unit_ids=unit_ids) + # re-order dict to match unit_ids order + count_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) + num_spikes = {} + for unit_id in unit_ids: + num_spikes[unit_id] = count_spikes[unit_id] + return num_spikes class NumSpikes(BaseMetric): From 319891137e9d395862aa26d7b78e54774626922c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 15 Jan 2026 17:38:59 +0100 Subject: [PATCH 47/70] Fix generation of bins --- src/spikeinterface/metrics/quality/misc_metrics.py | 5 +++-- src/spikeinterface/metrics/utils.py | 14 ++++++++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 835465a4c1..85bd507c2a 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -648,6 +648,7 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent periods=periods, bin_duration_s=bin_size_s, ) + cumulative_segment_samples = np.cumsum([0] + segment_samples[:-1]) for unit_id in unit_ids: bin_edges = bin_edges_per_unit[unit_id] @@ -656,9 +657,9 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent spike_trains = [] for segment_index in range(sorting_analyzer.get_num_segments()): spike_times = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) - spike_times = spike_times + np.sum(segment_samples[:segment_index]) + spike_times = spike_times + cumulative_segment_samples[segment_index] spike_trains.append(spike_times) - spike_train = np.concatenate(spike_trains) + spike_train = np.concatenate(spike_trains, dtype="int64") spike_counts, _ = np.histogram(spike_train, bins=bin_edges) firing_rate_histograms[unit_id] = spike_counts / bin_size_s diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py index 222503a730..83ddfcf90b 100644 --- a/src/spikeinterface/metrics/utils.py +++ b/src/spikeinterface/metrics/utils.py @@ -35,12 +35,22 @@ def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, per for period in seg_periods: start_sample = seg_start + period["start_sample_index"] end_sample = seg_start + period["end_sample_index"] + end_sample = end_sample // bin_duration_samples * bin_duration_samples + 1 # align to bin bin_edges.extend(np.arange(start_sample, end_sample, bin_duration_samples)) bin_edges_for_units[unit_id] = np.array(bin_edges) else: - total_length = np.sum(segment_samples) for unit_id in sorting.unit_ids: - bin_edges_for_units[unit_id] = np.arange(0, total_length, bin_duration_samples) + bin_edges = [] + for seg_index in range(num_segments): + seg_start = np.sum(segment_samples[:seg_index]) + seg_end = seg_start + segment_samples[seg_index] + # for segments which are not the last, we don't need to correct the end + # since the first index of the next segment will be the end of the current segment + if seg_index == num_segments - 1: + seg_end = seg_end // bin_duration_samples * bin_duration_samples + 1 # align to bin + bins = np.arange(seg_start, seg_end, bin_duration_samples) + bin_edges.extend(bins) + bin_edges_for_units[unit_id] = np.array(bin_edges) return bin_edges_for_units From bbc28c523f3abd29e86ab47aab41064905124902 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 15 Jan 2026 17:47:10 +0100 Subject: [PATCH 48/70] Refactor generation of subperiods --- src/spikeinterface/core/base.py | 2 +- .../postprocessing/valid_unit_periods.py | 33 +++---------------- 2 files changed, 6 insertions(+), 29 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 3505853835..4520c19819 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -42,9 +42,9 @@ minimum_spike_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] base_period_dtype = [ + ("segment_index", "int64"), ("start_sample_index", "int64"), ("end_sample_index", "int64"), - ("segment_index", "int64"), ] unit_period_dtype = base_period_dtype + [ diff --git a/src/spikeinterface/postprocessing/valid_unit_periods.py b/src/spikeinterface/postprocessing/valid_unit_periods.py index b7e69a9291..539d587b5b 100644 --- a/src/spikeinterface/postprocessing/valid_unit_periods.py +++ b/src/spikeinterface/postprocessing/valid_unit_periods.py @@ -580,21 +580,17 @@ def compute_subperiods( n_subperiods = min_num_periods_relative # at least min_num_periods_relative subperiods period_size_samples = n_samples // n_subperiods margin_size_samples = int(relative_margin_size * period_size_samples) - starts_ends = np.array( - [ - [i * period_size_samples, i * period_size_samples + period_size_samples + 2 * margin_size_samples] - for i in range(n_subperiods) - ] - ) + + # we generate periods starting from 0 up to n_samples, with and without margins, and period centers starts = np.arange(0, n_samples, period_size_samples) - periods_for_unit = np.zeros(len(starts_ends), dtype=unit_period_dtype) - periods_for_unit_w_margins = np.zeros(len(starts_ends), dtype=unit_period_dtype) + periods_for_unit = np.zeros(len(starts), dtype=unit_period_dtype) + periods_for_unit_w_margins = np.zeros(len(starts), dtype=unit_period_dtype) for i, start in enumerate(starts): end = start + period_size_samples ext_start = max(0, start - margin_size_samples) ext_end = min(n_samples, end + margin_size_samples) center = start + period_size_samples // 2 - period_centers[unit_id].append((segment_index, unit_id, center)) + period_centers[unit_id].append(center) periods_for_unit[i]["segment_index"] = segment_index periods_for_unit[i]["start_sample_index"] = start periods_for_unit[i]["end_sample_index"] = end @@ -604,25 +600,6 @@ def compute_subperiods( periods_for_unit_w_margins[i]["end_sample_index"] = ext_end periods_for_unit_w_margins[i]["unit_index"] = unit_index - # # remove periods whose end is above the expected number of samples - # if len(starts_ends) > 0: - # beyond_samples_mask = starts_ends[:, 1] > n_samples - # if sum(beyond_samples_mask) == len(starts_ends): - # # all periods end beyond n_samples: keep only first period - # starts_ends = starts_ends[:1].copy() - # else: - # starts_ends = starts_ends[starts_ends[:, 1] <= n_subperiods * period_size_samples] - # # set last period to the end of the recording - # starts_ends[-1][1] = n_samples - - # periods_for_unit = np.zeros(len(starts_ends), dtype=unit_period_dtype) - # for i, (start, end) in enumerate(starts_ends): - # subperiod = np.zeros((1,), dtype=unit_period_dtype) - # subperiod["segment_index"] = segment_index - # subperiod["start_sample_index"] = start - # subperiod["end_sample_index"] = end - # subperiod["unit_index"] = unit_index - # periods_for_unit[i] = subperiod all_subperiods.append(periods_for_unit) all_subperiods_w_margins.append(periods_for_unit_w_margins) all_period_centers.append(period_centers) From 8312db2c0a7ec7da816f5989a51ae5913db4a766 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 15 Jan 2026 17:50:16 +0100 Subject: [PATCH 49/70] fix conflicts2 --- src/spikeinterface/core/analyzer_extension_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index edf7244347..16f0551087 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -1446,7 +1446,7 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, # use the cache of indices spike_indices = self.sorting_analyzer.sorting.get_spike_vector_to_indices() data_by_units = {} - for segment_index in range(sorting.get_num_segments()): + for segment_index in range(self.sorting_analyzer.get_num_segments()): data_by_units[segment_index] = {} for unit_id in unit_ids: inds = spike_indices[segment_index][unit_id] From 7446a43187f4434a466a8ae72153d57585833cb5 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 16 Jan 2026 14:48:12 +0100 Subject: [PATCH 50/70] Use cached get_spike_vector_to_indices --- src/spikeinterface/core/sorting_tools.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index d05cc33869..bc0a1871af 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -246,9 +246,8 @@ def select_sorting_periods_mask(sorting: BaseSorting, periods): A boolean mask of the spikes in the sorting object, with True for spikes within the specified periods. """ spike_vector = sorting.to_spike_vector() - spike_vector_list = sorting.to_spike_vector(concatenated=False) keep_mask = np.zeros(len(spike_vector), dtype=bool) - all_global_indices = spike_vector_to_indices(spike_vector_list, unit_ids=sorting.unit_ids, absolute_index=True) + all_global_indices = sorting.get_spike_vector_to_indices() for segment_index in range(sorting.get_num_segments()): global_indices_segment = all_global_indices[segment_index] # filter periods by segment From 51e906a5ef93b4f13b32f7a2e5c273f5b7073ae0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 16 Jan 2026 17:43:09 +0100 Subject: [PATCH 51/70] Fix error in merging --- src/spikeinterface/metrics/quality/misc_metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 6b217d197d..8176e07628 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -349,6 +349,8 @@ def compute_refrac_period_violations( """ res = namedtuple("rp_violations", ["rp_contamination", "rp_violations"]) + sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting.unit_ids @@ -357,8 +359,6 @@ def compute_refrac_period_violations( warnings.warn("compute_refrac_period_violations cannot run without numba.") return {unit_id: np.nan for unit_id in unit_ids} - sorting = sorting_analyzer.sorting - sorting = sorting.select_periods(periods=periods) num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) fs = sorting_analyzer.sampling_frequency From 220951425b4e1f608c0a97fcab8c6002404da343 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 20 Jan 2026 16:06:34 +0100 Subject: [PATCH 52/70] Add supports_periods in BaseMetric/Extension --- .../core/analyzer_extension_core.py | 21 +++- .../metrics/quality/misc_metrics.py | 100 ++++++++++-------- 2 files changed, 73 insertions(+), 48 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 0fe7fc81c1..bc81f063d1 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -823,10 +823,9 @@ class BaseMetric: metric_columns = {} # column names and their dtypes of the dataframe metric_descriptions = {} # descriptions of each metric column needs_recording = False # whether the metric needs recording - needs_tmp_data = ( - False # whether the metric needs temporary data comoputed with _prepare_data at the MetricExtension level - ) - needs_job_kwargs = False + needs_tmp_data = False # whether the metric needs temporary data computed with MetricExtension._prepare_data + needs_job_kwargs = False # whether the metric needs job_kwargs + supports_periods = False # whether the metric function supports periods depend_on = [] # extensions the metric depends on # the metric function must have the signature: @@ -839,7 +838,7 @@ class BaseMetric: metric_function = None # to be defined in subclass @classmethod - def compute(cls, sorting_analyzer, unit_ids, metric_params, tmp_data, job_kwargs): + def compute(cls, sorting_analyzer, unit_ids, metric_params, tmp_data, job_kwargs, periods=None): """Compute the metric. Parameters @@ -854,6 +853,8 @@ def compute(cls, sorting_analyzer, unit_ids, metric_params, tmp_data, job_kwargs Temporary data to pass to the metric function job_kwargs : dict Job keyword arguments to control parallelization + periods : np.ndarray | None + Numpy array of unit periods of unit_period_dtype if supports_periods is True Returns ------- @@ -865,6 +866,8 @@ def compute(cls, sorting_analyzer, unit_ids, metric_params, tmp_data, job_kwargs args += (tmp_data,) if cls.needs_job_kwargs: args += (job_kwargs,) + if cls.supports_periods: + args += (periods,) results = cls.metric_function(*args, **metric_params) @@ -988,6 +991,7 @@ def _set_params( metric_params: dict | None = None, delete_existing_metrics: bool = False, metrics_to_compute: list[str] | None = None, + periods: np.ndarray | None = None, **other_params, ): """ @@ -1004,6 +1008,8 @@ def _set_params( If True, existing metrics in the extension will be deleted before computing new ones. metrics_to_compute : list[str] | None List of metric names to compute. If None, all metrics in `metric_names` are computed. + periods : np.ndarray | None + Numpy array of unit_period_dtype defining periods to compute metrics over. other_params : dict Additional parameters for metric computation. @@ -1079,6 +1085,7 @@ def _set_params( metrics_to_compute=metrics_to_compute, delete_existing_metrics=delete_existing_metrics, metric_params=metric_params, + periods=periods, **other_params, ) return params @@ -1129,6 +1136,8 @@ def _compute_metrics( if metric_names is None: metric_names = self.params["metric_names"] + periods = self.params.get("periods", None) + column_names_dtypes = {} for metric_name in metric_names: metric = [m for m in self.metric_list if m.metric_name == metric_name][0] @@ -1153,6 +1162,7 @@ def _compute_metrics( metric_params=metric_params, tmp_data=tmp_data, job_kwargs=job_kwargs, + periods=periods, ) except Exception as e: warnings.warn(f"Error computing metric {metric_name}: {e}") @@ -1179,6 +1189,7 @@ def _run(self, **job_kwargs): metrics_to_compute = self.params["metrics_to_compute"] delete_existing_metrics = self.params["delete_existing_metrics"] + periods = self.params.get("periods", None) _, job_kwargs = split_job_kwargs(job_kwargs) job_kwargs = fix_job_kwargs(job_kwargs) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 8176e07628..ec0caf8138 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -39,7 +39,7 @@ def compute_presence_ratios( - sorting_analyzer, unit_ids=None, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0, periods=None + sorting_analyzer, unit_ids=None, periods=None, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0 ): """ Calculate the presence ratio, the fraction of time the unit is firing above a certain threshold. @@ -50,15 +50,15 @@ def compute_presence_ratios( A SortingAnalyzer object. unit_ids : list or None The list of unit ids to compute the presence ratio. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. bin_duration_s : float, default: 60 The duration of each bin in seconds. If the duration is less than this value, presence_ratio is set to NaN. mean_fr_ratio_thresh : float, default: 0 The unit is considered active in a bin if its firing rate during that bin. is strictly above `mean_fr_ratio_thresh` times its mean firing rate throughout the recording. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -136,6 +136,7 @@ class PresenceRatio(BaseMetric): metric_params = {"bin_duration_s": 60, "mean_fr_ratio_thresh": 0.0} metric_columns = {"presence_ratio": float} metric_descriptions = {"presence_ratio": "Fraction of time the unit is active."} + supports_periods = True def compute_snrs( @@ -199,10 +200,11 @@ class SNR(BaseMetric): metric_params = {"peak_sign": "neg", "peak_mode": "extremum"} metric_columns = {"snr": float} metric_descriptions = {"snr": "Signal to noise ratio for each unit."} + supports_periods = True depend_on = ["noise_levels", "templates"] -def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5, min_isi_ms=0, periods=None): +def compute_isi_violations(sorting_analyzer, unit_ids=None, periods=None, isi_threshold_ms=1.5, min_isi_ms=0): """ Calculate Inter-Spike Interval (ISI) violations. @@ -217,6 +219,9 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 The SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the ISI violations. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. isi_threshold_ms : float, default: 1.5 Threshold for classifying adjacent spikes as an ISI violation, in ms. This is the biophysical refractory period. @@ -224,9 +229,6 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 Minimum possible inter-spike interval, in ms. This is the artificial refractory period enforced. by the data acquisition system or post-processing algorithms. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -299,10 +301,11 @@ class ISIViolation(BaseMetric): "isi_violations_ratio": "Ratio of ISI violations for each unit.", "isi_violations_count": "Count of ISI violations for each unit.", } + supports_periods = True def compute_refrac_period_violations( - sorting_analyzer, unit_ids=None, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0, periods=None + sorting_analyzer, unit_ids=None, periods=None, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0 ): """ Calculate the number of refractory period violations. @@ -317,14 +320,14 @@ def compute_refrac_period_violations( The SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the refractory period violations. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. refractory_period_ms : float, default: 1.0 The period (in ms) where no 2 good spikes can occur. censored_period_ms : float, default: 0.0 The period (in ms) where no 2 spikes can occur (because they are not detected, or because they were removed by another mean). - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -397,18 +400,19 @@ class RPViolation(BaseMetric): "rp_contamination": "Refractory period contamination described in Llobet & Wyngaard 2022.", "rp_violations": "Number of refractory period violations.", } + supports_periods = True def compute_sliding_rp_violations( sorting_analyzer, unit_ids=None, + periods=None, min_spikes=0, bin_size_ms=0.25, window_size_s=1, exclude_ref_period_below_ms=0.5, max_ref_period_ms=10, contamination_values=None, - periods=None, ): """ Compute sliding refractory period violations, a metric developed by IBL which computes @@ -421,6 +425,9 @@ def compute_sliding_rp_violations( A SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the sliding RP violations. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. min_spikes : int, default: 0 Contamination is set to np.nan if the unit has less than this many spikes across all segments. @@ -434,9 +441,6 @@ def compute_sliding_rp_violations( Maximum refractory period to test in ms. contamination_values : 1d array or None, default: None The contamination values to test, If None, it is set to np.arange(0.5, 35, 0.5). - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -508,9 +512,10 @@ class SlidingRPViolation(BaseMetric): metric_descriptions = { "sliding_rp_violation": "Minimum contamination at 90% confidence using sliding refractory period method." } + supports_periods = True -def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None, periods=None): +def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, periods=None, synchrony_sizes=None): """ Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of spikes at the exact same sample index, with synchrony sizes 2, 4 and 8. @@ -521,6 +526,9 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N A SortingAnalyzer object. unit_ids : list or None, default: None List of unit ids to compute the synchrony metrics. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. synchrony_sizes: None, default: None Deprecated argument. Please use private `_get_synchrony_counts` if you need finer control over number of synchronous spikes. @@ -528,9 +536,6 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N ------- sync_spike_{X} : dict The synchrony metric for synchrony size X. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. References ---------- @@ -583,9 +588,10 @@ class Synchrony(BaseMetric): "sync_spike_4": "Fraction of spikes that are synchronous with at least three other spikes.", "sync_spike_8": "Fraction of spikes that are synchronous with at least seven other spikes.", } + supports_periods = True -def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percentiles=(5, 95), periods=None): +def compute_firing_ranges(sorting_analyzer, unit_ids=None, periods=None, bin_size_s=5, percentiles=(5, 95)): """ Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution computed in non-overlapping time bins. @@ -596,13 +602,13 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent A SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the firing range. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. bin_size_s : float, default: 5 The size of the bin in seconds. percentiles : tuple, default: (5, 95) The percentiles to compute. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -675,16 +681,17 @@ class FiringRange(BaseMetric): metric_descriptions = { "firing_range": "Range between the percentiles (default: 5th and 95th) of the firing rates distribution." } + supports_periods = True def compute_amplitude_cv_metrics( sorting_analyzer, unit_ids=None, + periods=None, average_num_spikes_per_bin=50, percentiles=(5, 95), min_num_bins=10, amplitude_extension="spike_amplitudes", - periods=None, ): """ Calculate coefficient of variation of spike amplitudes within defined temporal bins. @@ -697,6 +704,9 @@ def compute_amplitude_cv_metrics( A SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the amplitude spread. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. average_num_spikes_per_bin : int, default: 50 The average number of spikes per bin. This is used to estimate a temporal bin size using the firing rate of each unit. For example, if a unit has a firing rate of 10 Hz, amd the average number of spikes per bin is @@ -708,8 +718,6 @@ def compute_amplitude_cv_metrics( the median and range are set to NaN. amplitude_extension : str, default: "spike_amplitudes" The name of the extension to load the amplitudes from. "spike_amplitudes" or "amplitude_scalings". - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) Returns ------- @@ -786,16 +794,17 @@ class AmplitudeCV(BaseMetric): "amplitude_cv_median": "Median of the coefficient of variation of spike amplitudes within temporal bins.", "amplitude_cv_range": "Range of the coefficient of variation of spike amplitudes within temporal bins.", } + supports_periods = True depend_on = ["spike_amplitudes|amplitude_scalings"] def compute_amplitude_cutoffs( sorting_analyzer, unit_ids=None, + periods=None, num_histogram_bins=500, histogram_smoothing_value=3, amplitudes_bins_min_ratio=5, - periods=None, ): """ Calculate approximate fraction of spikes missing from a distribution of amplitudes. @@ -806,6 +815,9 @@ def compute_amplitude_cutoffs( A SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the amplitude cutoffs. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. num_histogram_bins : int, default: 100 The number of bins to use to compute the amplitude histogram. histogram_smoothing_value : int, default: 3 @@ -814,9 +826,6 @@ def compute_amplitude_cutoffs( The minimum ratio between number of amplitudes for a unit and the number of bins. If the ratio is less than this threshold, the amplitude_cutoff for the unit is set to NaN. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -881,6 +890,7 @@ class AmplitudeCutoff(BaseMetric): metric_descriptions = { "amplitude_cutoff": "Estimated fraction of missing spikes, based on the amplitude distribution." } + supports_periods = True depend_on = ["spike_amplitudes|amplitude_scalings"] @@ -929,11 +939,12 @@ class AmplitudeMedian(BaseMetric): metric_descriptions = { "amplitude_median": "Median of the amplitude distributions (in absolute value) for each unit in uV." } + supports_periods = True depend_on = ["spike_amplitudes"] def compute_noise_cutoffs( - sorting_analyzer, unit_ids=None, high_quantile=0.25, low_quantile=0.1, n_bins=100, periods=None + sorting_analyzer, unit_ids=None, periods=None, high_quantile=0.25, low_quantile=0.1, n_bins=100 ): """ A metric to determine if a unit's amplitude distribution is cut off as it approaches zero, without assuming a Gaussian distribution. @@ -952,15 +963,15 @@ def compute_noise_cutoffs( A SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the amplitude cutoffs. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. high_quantile : float, default: 0.25 Quantile of the amplitude range above which values are treated as "high" (e.g. 0.25 = top 25%), the reference region. low_quantile : int, default: 0.1 Quantile of the amplitude range below which values are treated as "low" (e.g. 0.1 = lower 10%), the test region. n_bins: int, default: 100 The number of bins to use to compute the amplitude histogram. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -1015,19 +1026,20 @@ class NoiseCutoff(BaseMetric): ), "noise_ratio": "Ratio of counts in the lower-amplitude bins to the count in the highest bin.", } + supports_periods = True depend_on = ["spike_amplitudes|amplitude_scalings"] def compute_drift_metrics( sorting_analyzer, unit_ids=None, + periods=None, interval_s=60, min_spikes_per_interval=100, direction="y", min_fraction_valid_intervals=0.5, min_num_bins=2, return_positions=False, - periods=None, ): """ Compute drifts metrics using estimated spike locations. @@ -1049,6 +1061,9 @@ def compute_drift_metrics( A SortingAnalyzer object. unit_ids : list or None, default: None List of unit ids to compute the drift metrics. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. interval_s : int, default: 60 Interval length is seconds for computing spike depth. min_spikes_per_interval : int, default: 100 @@ -1062,9 +1077,6 @@ def compute_drift_metrics( min_num_bins : int, default: 2 Minimum number of bins required to return a valid metric value. In case there are less bins, the metric values are set to NaN. - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. return_positions : bool, default: False If True, median positions are returned (for debugging). @@ -1198,16 +1210,17 @@ class Drift(BaseMetric): "drift_std": "Standard deviation of the drift signal in um.", "drift_mad": "Median absolute deviation of the drift signal in um.", } + supports_periods = True depend_on = ["spike_locations"] def compute_sd_ratio( sorting_analyzer: SortingAnalyzer, unit_ids=None, + periods=None, censored_period_ms: float = 4.0, correct_for_drift: bool = True, correct_for_template_itself: bool = True, - periods=None, **kwargs, ): """ @@ -1223,6 +1236,9 @@ def compute_sd_ratio( A SortingAnalyzer object. unit_ids : list or None, default: None The list of unit ids to compute this metric. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. censored_period_ms : float, default: 4.0 The censored period in milliseconds. This is to remove any potential bursts that could affect the SD. correct_for_drift : bool, default: True @@ -1230,9 +1246,6 @@ def compute_sd_ratio( correct_for_template_itself : bool, default: True If true, will take into account that the template itself impacts the standard deviation of the noise, and will make a rough estimation of what that impact is (and remove it). - periods : array of unit_period_dtype | None, default: None - Periods (segment_index, start_sample_index, end_sample_index, unit_index) - on which to compute the metric. If None, the entire recording duration is used. **kwargs : dict, default: {} Keyword arguments for computing spike amplitudes and extremum channel. @@ -1346,6 +1359,7 @@ class SDRatio(BaseMetric): "sd_ratio": "Ratio between the standard deviation of spike amplitudes and the standard deviation of noise." } needs_recording = True + supports_periods = True depend_on = ["templates", "spike_amplitudes"] From b23c431bc595650b336821811bb87bd3b8026424 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 20 Jan 2026 16:56:16 +0100 Subject: [PATCH 53/70] wip: test metrics with periods --- .../core/analyzer_extension_core.py | 14 +- .../metrics/quality/misc_metrics.py | 1 - .../quality/tests/test_metrics_functions.py | 687 ++++++++++-------- 3 files changed, 397 insertions(+), 305 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index bc81f063d1..a21404e58f 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -13,6 +13,7 @@ import numpy as np from collections import namedtuple +from .numpyextractors import NumpySorting from .sortinganalyzer import SortingAnalyzer, AnalyzerExtension, register_result_extension from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates_with_accumulator from .recording_tools import get_noise_levels @@ -1463,6 +1464,16 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, periods, ) all_data = all_data[keep_mask] + # since we have the mask already, we can use it directly to avoid double computation + spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=True) + sliced_spike_vector = spike_vector[keep_mask] + sorting = NumpySorting( + sliced_spike_vector, + sampling_frequency=self.sorting_analyzer.sampling_frequency, + unit_ids=self.sorting_analyzer.unit_ids, + ) + else: + sorting = self.sorting_analyzer.sorting if outputs == "numpy": if copy: @@ -1474,8 +1485,7 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, if keep_mask is not None: # since we are filtering spikes, we need to recompute the spike indices - spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) - spike_vector = spike_vector[keep_mask] + spike_vector = sorting.to_spike_vector(concatenated=False) spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True) else: # use the cache of indices diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index ec0caf8138..04d451202d 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -200,7 +200,6 @@ class SNR(BaseMetric): metric_params = {"peak_sign": "neg", "peak_mode": "extremum"} metric_columns = {"snr": float} metric_descriptions = {"snr": "Signal to noise ratio for each unit."} - supports_periods = True depend_on = ["noise_levels", "templates"] diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index 99e2e5606a..2e31c53135 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -1,5 +1,4 @@ import pytest -from pathlib import Path import numpy as np from copy import deepcopy import csv @@ -14,13 +13,8 @@ from spikeinterface.metrics.utils import create_ground_truth_pc_distributions -# from spikeinterface.metrics.quality_metric_list import ( -# _misc_metric_name_to_func, -# ) - from spikeinterface.metrics.quality import ( get_quality_metric_list, - get_quality_pca_metric_list, compute_quality_metrics, ) from spikeinterface.metrics.quality.misc_metrics import ( @@ -28,8 +22,6 @@ compute_amplitude_cutoffs, compute_presence_ratios, compute_isi_violations, - # compute_firing_rates, - # compute_num_spikes, compute_snrs, compute_refrac_period_violations, compute_sliding_rp_violations, @@ -44,7 +36,6 @@ ) from spikeinterface.metrics.quality.pca_metrics import ( - pca_metrics_list, mahalanobis_metrics, d_prime_metric, nearest_neighbors_metrics, @@ -53,258 +44,10 @@ ) -from spikeinterface.core.base import minimum_spike_dtype - - -job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") - - -def test_noise_cutoff(): - """ - Generate two artifical gaussian, one truncated and one not. Check the metrics are higher for the truncated one. - """ - np.random.seed(1) - amps = np.random.normal(0, 1, 1000) - amps_trunc = amps[amps > -1] - - cutoff1, ratio1 = _noise_cutoff(amps=amps) - cutoff2, ratio2 = _noise_cutoff(amps=amps_trunc) - - assert cutoff1 <= cutoff2 - assert ratio1 <= ratio2 - - -def test_compute_new_quality_metrics(small_sorting_analyzer): - """ - Computes quality metrics then computes a subset of quality metrics, and checks - that the old quality metrics are not deleted. - """ - - qm_params = { - "presence_ratio": {"bin_duration_s": 0.1}, - "amplitude_cutoff": {"num_histogram_bins": 3}, - "firing_range": {"bin_size_s": 1}, - } - - small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) - qm_extension = small_sorting_analyzer.get_extension("quality_metrics") - calculated_metrics = list(qm_extension.get_data().keys()) - - assert calculated_metrics == ["snr"] - - small_sorting_analyzer.compute( - {"quality_metrics": {"metric_names": list(qm_params.keys()), "metric_params": qm_params}} - ) - small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) - - quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") - - # Check old metrics are not deleted and the new one is added to the data and metadata - assert set(list(quality_metric_extension.get_data().keys())) == set( - [ - "amplitude_cutoff", - "firing_range", - "presence_ratio", - "snr", - ] - ) - assert set(list(quality_metric_extension.params.get("metric_names"))) == set( - [ - "amplitude_cutoff", - "firing_range", - "presence_ratio", - "snr", - ] - ) - - # check that, when parameters are changed, the data and metadata are updated - old_snr_data = deepcopy(quality_metric_extension.get_data()["snr"].values) - small_sorting_analyzer.compute( - {"quality_metrics": {"metric_names": ["snr"], "metric_params": {"snr": {"peak_mode": "peak_to_peak"}}}} - ) - new_quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") - new_snr_data = new_quality_metric_extension.get_data()["snr"].values - - assert np.all(old_snr_data != new_snr_data) - assert new_quality_metric_extension.params["metric_params"]["snr"]["peak_mode"] == "peak_to_peak" - - -def test_metric_names_in_same_order(small_sorting_analyzer): - """ - Computes sepecified quality metrics and checks order is propagated. - """ - specified_metric_names = ["firing_range", "snr", "amplitude_cutoff"] - small_sorting_analyzer.compute("quality_metrics", metric_names=specified_metric_names) - qm_keys = small_sorting_analyzer.get_extension("quality_metrics").get_data().keys() - for i in range(3): - assert specified_metric_names[i] == qm_keys[i] - - -def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder): - """ - Computes quality metrics in binary folder format. Then computes subsets of quality - metrics and checks if they are saved correctly. - """ - - # can't use _misc_metric_name_to_func as some functions compute several qms - # e.g. isi_violation and synchrony - quality_metrics = [ - "num_spikes", - "firing_rate", - "presence_ratio", - "snr", - "isi_violations_ratio", - "isi_violations_count", - "rp_contamination", - "rp_violations", - "sliding_rp_violation", - "amplitude_cutoff", - "amplitude_median", - "amplitude_cv_median", - "amplitude_cv_range", - "sync_spike_2", - "sync_spike_4", - "sync_spike_8", - "firing_range", - "drift_ptp", - "drift_std", - "drift_mad", - "sd_ratio", - "isolation_distance", - "l_ratio", - "d_prime", - "silhouette", - "nn_hit_rate", - "nn_miss_rate", - ] - - small_sorting_analyzer.compute("quality_metrics") - - cache_folder = create_cache_folder - output_folder = cache_folder / "sorting_analyzer" - - folder_analyzer = small_sorting_analyzer.save_as(format="binary_folder", folder=output_folder) - quality_metrics_filename = output_folder / "extensions" / "quality_metrics" / "metrics.csv" - - with open(quality_metrics_filename) as metrics_file: - saved_metrics = csv.reader(metrics_file) - metric_names = next(saved_metrics) - - for metric_name in quality_metrics: - assert metric_name in metric_names - - folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=False) - - with open(quality_metrics_filename) as metrics_file: - saved_metrics = csv.reader(metrics_file) - metric_names = next(saved_metrics) - - for metric_name in quality_metrics: - assert metric_name in metric_names - - folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=True) - - with open(quality_metrics_filename) as metrics_file: - saved_metrics = csv.reader(metrics_file) - metric_names = next(saved_metrics) - - for metric_name in quality_metrics: - if metric_name == "snr": - assert metric_name in metric_names - else: - assert metric_name not in metric_names - - -def test_unit_structure_in_output(small_sorting_analyzer): - - qm_params = { - "presence_ratio": {"bin_duration_s": 0.1}, - "amplitude_cutoff": {"num_histogram_bins": 3}, - "amplitude_cv": {"average_num_spikes_per_bin": 7, "min_num_bins": 3}, - "firing_range": {"bin_size_s": 1}, - "isi_violation": {"isi_threshold_ms": 10}, - "drift": {"interval_s": 1, "min_spikes_per_interval": 5, "min_fraction_valid_intervals": 0.2}, - "sliding_rp_violation": {"max_ref_period_ms": 50, "bin_size_ms": 0.15}, - "rp_violation": {"refractory_period_ms": 10.0, "censored_period_ms": 0.0}, - } - - for metric in misc_metrics_list: - metric_name = metric.metric_name - metric_fun = metric.metric_function - try: - qm_param = qm_params[metric_name] - except: - qm_param = {} - - result_all = metric_fun(sorting_analyzer=small_sorting_analyzer, **qm_param) - result_sub = metric_fun(sorting_analyzer=small_sorting_analyzer, unit_ids=["#4", "#9"], **qm_param) - - error = "Problem with metric: " + metric_name - - if isinstance(result_all, dict): - assert list(result_all.keys()) == ["#3", "#9", "#4"], error - assert list(result_sub.keys()) == ["#4", "#9"], error - assert result_sub["#9"] == result_all["#9"], error - assert result_sub["#4"] == result_all["#4"], error - - else: - for result_ind, result in enumerate(result_sub): - - assert list(result_all[result_ind].keys()) == ["#3", "#9", "#4"], error - assert result_sub[result_ind].keys() == set(["#4", "#9"]), error - - assert result_sub[result_ind]["#9"] == result_all[result_ind]["#9"], error - assert result_sub[result_ind]["#4"] == result_all[result_ind]["#4"], error - - -def test_unit_id_order_independence(small_sorting_analyzer): - """ - Takes two almost-identical sorting_analyzers, whose unit_ids are in different orders and have different labels, - and checks that their calculated quality metrics are independent of the ordering and labelling. - """ - - recording = small_sorting_analyzer.recording - sorting = small_sorting_analyzer.sorting.select_units(["#4", "#9", "#3"], [1, 7, 2]) - - small_sorting_analyzer_2 = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") - - extensions_to_compute = { - "random_spikes": {"seed": 1205}, - "noise_levels": {"seed": 1205}, - "waveforms": {}, - "templates": {}, - "spike_amplitudes": {}, - "spike_locations": {}, - "principal_components": {}, - } - - small_sorting_analyzer_2.compute(extensions_to_compute) - - # need special params to get non-nan results on a short recording - qm_params = { - "presence_ratio": {"bin_duration_s": 0.1}, - "amplitude_cutoff": {"num_histogram_bins": 3}, - "amplitude_cv": {"average_num_spikes_per_bin": 7, "min_num_bins": 3}, - "firing_range": {"bin_size_s": 1}, - "isi_violation": {"isi_threshold_ms": 10}, - "drift": {"interval_s": 1, "min_spikes_per_interval": 5}, - "sliding_rp_violation": {"max_ref_period_ms": 50, "bin_size_ms": 0.15}, - } - - quality_metrics_1 = compute_quality_metrics( - small_sorting_analyzer, metric_names=get_quality_metric_list(), metric_params=qm_params, skip_pc_metrics=True - ) - quality_metrics_2 = compute_quality_metrics( - small_sorting_analyzer_2, metric_names=get_quality_metric_list(), metric_params=qm_params, skip_pc_metrics=True - ) - - for metric, metric_2_data in quality_metrics_2.items(): - error = "Problem with the metric " + metric - assert quality_metrics_1[metric]["#3"] == metric_2_data[2], error - assert quality_metrics_1[metric]["#9"] == metric_2_data[7], error - assert quality_metrics_1[metric]["#4"] == metric_2_data[1], error +from spikeinterface.core.base import minimum_spike_dtype, unit_period_dtype +### HELPER FUNCTIONS AND FIXTURES ### def _sorting_violation(): max_time = 100.0 sampling_frequency = 30000 @@ -335,7 +78,6 @@ def _sorting_violation(): def _sorting_analyzer_violations(): - sorting = _sorting_violation() duration = (sorting.to_spike_vector()["sample_index"][-1] + 1) / sorting.sampling_frequency @@ -352,9 +94,87 @@ def _sorting_analyzer_violations(): return sorting_analyzer -@pytest.fixture(scope="module") -def sorting_analyzer_violations(): - return _sorting_analyzer_violations() +@pytest.fixture(scope="module") +def sorting_analyzer_violations(): + return _sorting_analyzer_violations() + + +def compute_periods(sorting_analyzer, num_periods, bin_size_s=None): + """ + Computes and sets periods for each unit in the sorting analyzer. + The periods span the total duration of the recording, but divide it into + smaller periods either by specifying the number of periods or the size of each bin. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer containing the units and recording information. + num_periods : int + The number of periods to divide the total duration into (used if bin_size_s is None). + bin_size_s : float, defaut: None + If given, periods will be multiple of this size in seconds. + + Returns + ------- + periods + np.ndarray of dtype unit_period_dtype containing the segment, start, end samples and unit index. + """ + all_periods = [] + for segment_index in range(sorting_analyzer.recording.get_num_segments()): + samples_per_period = sorting_analyzer.get_num_samples(segment_index) // num_periods + if bin_size_s is not None: + print(f"Original samples_per_period: {samples_per_period} - num_periods: {num_periods}") + bin_size_samples = int(bin_size_s * sorting_analyzer.sampling_frequency) + print(samples_per_period / bin_size_samples) + samples_per_period = samples_per_period // bin_size_samples * bin_size_samples + num_periods = int(np.round(sorting_analyzer.get_num_samples(segment_index) / samples_per_period)) + print(f"Adjusted samples_per_period: {samples_per_period} - num_periods: {num_periods}") + for unit_index, unit_id in enumerate(sorting_analyzer.unit_ids): + period_starts = np.arange(0, sorting_analyzer.get_num_samples(segment_index), samples_per_period) + periods_per_unit = np.zeros(len(period_starts), dtype=unit_period_dtype) + for i, period_start in enumerate(period_starts): + period_end = min(period_start + samples_per_period, sorting_analyzer.get_num_samples(segment_index)) + periods_per_unit[i]["segment_index"] = segment_index + periods_per_unit[i]["start_sample_index"] = period_start + periods_per_unit[i]["end_sample_index"] = period_end + periods_per_unit[i]["unit_index"] = unit_index + print(periods_per_unit, sorting_analyzer.get_num_samples(segment_index), samples_per_period) + all_periods.append(periods_per_unit) + return np.concatenate(all_periods) + + +@pytest.fixture +def periods_simple(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + periods = compute_periods(sorting_analyzer, num_periods=5) + return periods + + +@pytest.fixture +def periods_violations(sorting_analyzer_violations): + sorting_analyzer = sorting_analyzer_violations + periods = compute_periods(sorting_analyzer, num_periods=5) + return periods + + +# Common job kwargs +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") + + +### LOW-LEVEL TESTS ### +def test_noise_cutoff(): + """ + Generate two artifical gaussian, one truncated and one not. Check the metrics are higher for the truncated one. + """ + np.random.seed(1) + amps = np.random.normal(0, 1, 1000) + amps_trunc = amps[amps > -1] + + cutoff1, ratio1 = _noise_cutoff(amps=amps) + cutoff2, ratio2 = _noise_cutoff(amps=amps_trunc) + + assert cutoff1 <= cutoff2 + assert ratio1 <= ratio2 def test_synchrony_counts_no_sync(): @@ -489,22 +309,13 @@ def test_simplified_silhouette_score_metrics(): assert sim_sil_score1 < sim_sil_score2 -# def test_calculate_firing_rate_num_spikes(sorting_analyzer_simple): -# sorting_analyzer = sorting_analyzer_simple -# firing_rates = compute_firing_rates(sorting_analyzer) -# num_spikes = compute_num_spikes(sorting_analyzer) - -# testing method accuracy with magic number is not a good pratcice, I remove this. -# firing_rates_gt = {0: 10.01, 1: 5.03, 2: 5.09} -# num_spikes_gt = {0: 1001, 1: 503, 2: 509} -# assert np.allclose(list(firing_rates_gt.values()), list(firing_rates.values()), rtol=0.05) -# np.testing.assert_array_equal(list(num_spikes_gt.values()), list(num_spikes.values())) - - +### TEST METRICS FUNCTIONS ### def test_calculate_firing_range(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple - firing_ranges = compute_firing_ranges(sorting_analyzer) - print(firing_ranges) + firing_ranges = compute_firing_ranges(sorting_analyzer, bin_size_s=1) + periods = compute_periods(sorting_analyzer, num_periods=5, bin_size_s=1) + firing_ranges_periods = compute_firing_ranges(sorting_analyzer, periods=periods, bin_size_s=1) + assert firing_ranges == firing_ranges_periods with pytest.warns(UserWarning) as w: firing_ranges_nan = compute_firing_ranges( @@ -517,6 +328,9 @@ def test_calculate_amplitude_cutoff(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple # spike_amps = sorting_analyzer.get_extension("spike_amplitudes").get_data() amp_cuts = compute_amplitude_cutoffs(sorting_analyzer, num_histogram_bins=10) + periods = compute_periods(sorting_analyzer, num_periods=5) + amp_cuts_periods = compute_amplitude_cutoffs(sorting_analyzer, periods=periods, num_histogram_bins=10) + assert amp_cuts == amp_cuts_periods # print(amp_cuts) # testing method accuracy with magic number is not a good pratcice, I remove this. @@ -528,18 +342,26 @@ def test_calculate_amplitude_median(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple # spike_amps = sorting_analyzer.get_extension("spike_amplitudes").get_data() amp_medians = compute_amplitude_medians(sorting_analyzer) - # print(amp_medians) + periods = compute_periods(sorting_analyzer, num_periods=5) + amp_medians_periods = compute_amplitude_medians(sorting_analyzer, periods=periods) + assert amp_medians == amp_medians_periods # testing method accuracy with magic number is not a good pratcice, I remove this. # amp_medians_gt = {0: 130.77323354628675, 1: 130.7461997791725, 2: 130.7461997791725} # assert np.allclose(list(amp_medians_gt.values()), list(amp_medians.values()), rtol=0.05) -def test_calculate_amplitude_cv_metrics(sorting_analyzer_simple): +def test_calculate_amplitude_cv_metrics(sorting_analyzer_simple, periods_simple): sorting_analyzer = sorting_analyzer_simple amp_cv_median, amp_cv_range = compute_amplitude_cv_metrics(sorting_analyzer, average_num_spikes_per_bin=20) - print(amp_cv_median) - print(amp_cv_range) + periods = periods_simple + amp_cv_median_periods, amp_cv_range_periods = compute_amplitude_cv_metrics( + sorting_analyzer, + periods=periods, + average_num_spikes_per_bin=20, + ) + assert amp_cv_median == amp_cv_median_periods + assert amp_cv_range == amp_cv_range_periods # amps_scalings = compute_amplitude_scalings(sorting_analyzer) sorting_analyzer.compute("amplitude_scalings", **job_kwargs) @@ -549,34 +371,46 @@ def test_calculate_amplitude_cv_metrics(sorting_analyzer_simple): amplitude_extension="amplitude_scalings", min_num_bins=5, ) - print(amp_cv_median_scalings) - print(amp_cv_range_scalings) + amp_cv_median_scalings_periods, amp_cv_range_scalings_periods = compute_amplitude_cv_metrics( + sorting_analyzer, + periods=periods, + average_num_spikes_per_bin=20, + amplitude_extension="amplitude_scalings", + min_num_bins=5, + ) + assert amp_cv_median_scalings == amp_cv_median_scalings_periods + assert amp_cv_range_scalings == amp_cv_range_scalings_periods -def test_calculate_snrs(sorting_analyzer_simple): +def test_calculate_snrs(sorting_analyzer_simple, periods_simple): sorting_analyzer = sorting_analyzer_simple snrs = compute_snrs(sorting_analyzer) - print(snrs) + # SNR doesn't support periods # testing method accuracy with magic number is not a good pratcice, I remove this. # snrs_gt = {0: 12.92, 1: 12.99, 2: 12.99} # assert np.allclose(list(snrs_gt.values()), list(snrs.values()), rtol=0.05) -def test_calculate_presence_ratio(sorting_analyzer_simple): +def test_calculate_presence_ratio(sorting_analyzer_simple, periods_simple): sorting_analyzer = sorting_analyzer_simple ratios = compute_presence_ratios(sorting_analyzer, bin_duration_s=10) - print(ratios) - + periods = periods_simple + ratios_periods = compute_presence_ratios(sorting_analyzer, periods=periods, bin_duration_s=10) + assert ratios == ratios_periods # testing method accuracy with magic number is not a good pratcice, I remove this. # ratios_gt = {0: 1.0, 1: 1.0, 2: 1.0} # np.testing.assert_array_equal(list(ratios_gt.values()), list(ratios.values())) -def test_calculate_isi_violations(sorting_analyzer_violations): +def test_calculate_isi_violations(sorting_analyzer_violations, periods_violations): sorting_analyzer = sorting_analyzer_violations isi_viol, counts = compute_isi_violations(sorting_analyzer, isi_threshold_ms=1, min_isi_ms=0.0) - print(isi_viol) + periods = periods_violations + isi_viol_periods, counts_periods = compute_isi_violations( + sorting_analyzer, isi_threshold_ms=1, min_isi_ms=0.0, periods=periods + ) + assert isi_viol == isi_viol_periods # testing method accuracy with magic number is not a good pratcice, I remove this. # isi_viol_gt = {0: 0.0998002996004994, 1: 0.7904857139469347, 2: 1.929898371551754} @@ -585,23 +419,30 @@ def test_calculate_isi_violations(sorting_analyzer_violations): # np.testing.assert_array_equal(list(counts_gt.values()), list(counts.values())) -def test_calculate_sliding_rp_violations(sorting_analyzer_violations): +def test_calculate_sliding_rp_violations(sorting_analyzer_violations, periods_violations): sorting_analyzer = sorting_analyzer_violations contaminations = compute_sliding_rp_violations(sorting_analyzer, bin_size_ms=0.25, window_size_s=1) - print(contaminations) + periods = periods_violations + contaminations_periods = compute_sliding_rp_violations( + sorting_analyzer, periods=periods, bin_size_ms=0.25, window_size_s=1 + ) + assert contaminations == contaminations_periods # testing method accuracy with magic number is not a good pratcice, I remove this. # contaminations_gt = {0: 0.03, 1: 0.185, 2: 0.325} # assert np.allclose(list(contaminations_gt.values()), list(contaminations.values()), rtol=0.05) -def test_calculate_rp_violations(sorting_analyzer_violations): +def test_calculate_rp_violations(sorting_analyzer_violations, periods_violations): sorting_analyzer = sorting_analyzer_violations rp_contamination, counts = compute_refrac_period_violations( sorting_analyzer, refractory_period_ms=1, censored_period_ms=0.0 ) - print(rp_contamination, counts) - + periods = periods_violations + rp_contamination_periods, counts_periods = compute_refrac_period_violations( + sorting_analyzer, refractory_period_ms=1, censored_period_ms=0.0, periods=periods + ) + assert rp_contamination == rp_contamination_periods # testing method accuracy with magic number is not a good pratcice, I remove this. # counts_gt = {0: 2, 1: 4, 2: 10} # rp_contamination_gt = {0: 0.10534956502609294, 1: 1.0, 2: 1.0} @@ -620,10 +461,13 @@ def test_calculate_rp_violations(sorting_analyzer_violations): assert np.isnan(rp_contamination[1]) -def test_synchrony_metrics(sorting_analyzer_simple): +def test_synchrony_metrics(sorting_analyzer_simple, periods_simple): sorting_analyzer = sorting_analyzer_simple sorting = sorting_analyzer.sorting synchrony_metrics = compute_synchrony_metrics(sorting_analyzer) + periods = periods_simple + synchrony_metrics_periods = compute_synchrony_metrics(sorting_analyzer, periods=periods) + assert synchrony_metrics == synchrony_metrics_periods synchrony_sizes = np.array([2, 4, 8]) @@ -679,6 +523,13 @@ def test_calculate_drift_metrics(sorting_analyzer_simple): drifts_ptps, drifts_stds, drift_mads = compute_drift_metrics( sorting_analyzer, interval_s=10, min_spikes_per_interval=10 ) + periods = compute_periods(sorting_analyzer, num_periods=5, bin_size_s=10) + drifts_ptps_periods, drifts_stds_periods, drift_mads_periods = compute_drift_metrics( + sorting_analyzer, periods=periods, min_spikes_per_interval=10, interval_s=10 + ) + assert drifts_ptps == drifts_ptps_periods + assert drifts_stds == drifts_stds_periods + assert drift_mads == drift_mads_periods # print(drifts_ptps, drifts_stds, drift_mads) @@ -691,25 +542,257 @@ def test_calculate_drift_metrics(sorting_analyzer_simple): # assert np.allclose(list(drift_mads_gt.values()), list(drift_mads.values()), rtol=0.05) -def test_calculate_sd_ratio(sorting_analyzer_simple): +def test_calculate_sd_ratio(sorting_analyzer_simple, periods_simple): sd_ratio = compute_sd_ratio( sorting_analyzer_simple, ) + periods = periods_simple + sd_ratio_periods = compute_sd_ratio(sorting_analyzer_simple, periods=periods) + assert sd_ratio == sd_ratio_periods assert np.all(list(sd_ratio.keys()) == sorting_analyzer_simple.unit_ids) # @aurelien can you check this, this is not working anymore # assert np.allclose(list(sd_ratio.values()), 1, atol=0.25, rtol=0) -if __name__ == "__main__": +### MACHINERY TESTS ### +def test_compute_new_quality_metrics(small_sorting_analyzer): + """ + Computes quality metrics then computes a subset of quality metrics, and checks + that the old quality metrics are not deleted. + """ + + qm_params = { + "presence_ratio": {"bin_duration_s": 0.1}, + "amplitude_cutoff": {"num_histogram_bins": 3}, + "firing_range": {"bin_size_s": 1}, + } + + small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) + qm_extension = small_sorting_analyzer.get_extension("quality_metrics") + calculated_metrics = list(qm_extension.get_data().keys()) - sorting_analyzer = _sorting_analyzer_simple() - print(sorting_analyzer) + assert calculated_metrics == ["snr"] - test_unit_structure_in_output(_small_sorting_analyzer()) + small_sorting_analyzer.compute( + {"quality_metrics": {"metric_names": list(qm_params.keys()), "metric_params": qm_params}} + ) + small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) - # test_calculate_firing_rate_num_spikes(sorting_analyzer) + quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") + + # Check old metrics are not deleted and the new one is added to the data and metadata + assert set(list(quality_metric_extension.get_data().keys())) == set( + [ + "amplitude_cutoff", + "firing_range", + "presence_ratio", + "snr", + ] + ) + assert set(list(quality_metric_extension.params.get("metric_names"))) == set( + [ + "amplitude_cutoff", + "firing_range", + "presence_ratio", + "snr", + ] + ) + + # check that, when parameters are changed, the data and metadata are updated + old_snr_data = deepcopy(quality_metric_extension.get_data()["snr"].values) + small_sorting_analyzer.compute( + {"quality_metrics": {"metric_names": ["snr"], "metric_params": {"snr": {"peak_mode": "peak_to_peak"}}}} + ) + new_quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") + new_snr_data = new_quality_metric_extension.get_data()["snr"].values + + assert np.all(old_snr_data != new_snr_data) + assert new_quality_metric_extension.params["metric_params"]["snr"]["peak_mode"] == "peak_to_peak" + + +def test_metric_names_in_same_order(small_sorting_analyzer): + """ + Computes sepecified quality metrics and checks order is propagated. + """ + specified_metric_names = ["firing_range", "snr", "amplitude_cutoff"] + small_sorting_analyzer.compute("quality_metrics", metric_names=specified_metric_names) + qm_keys = small_sorting_analyzer.get_extension("quality_metrics").get_data().keys() + for i in range(3): + assert specified_metric_names[i] == qm_keys[i] + + +def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder): + """ + Computes quality metrics in binary folder format. Then computes subsets of quality + metrics and checks if they are saved correctly. + """ + + # can't use _misc_metric_name_to_func as some functions compute several qms + # e.g. isi_violation and synchrony + quality_metrics = [ + "num_spikes", + "firing_rate", + "presence_ratio", + "snr", + "isi_violations_ratio", + "isi_violations_count", + "rp_contamination", + "rp_violations", + "sliding_rp_violation", + "amplitude_cutoff", + "amplitude_median", + "amplitude_cv_median", + "amplitude_cv_range", + "sync_spike_2", + "sync_spike_4", + "sync_spike_8", + "firing_range", + "drift_ptp", + "drift_std", + "drift_mad", + "sd_ratio", + "isolation_distance", + "l_ratio", + "d_prime", + "silhouette", + "nn_hit_rate", + "nn_miss_rate", + ] + + small_sorting_analyzer.compute("quality_metrics") + + cache_folder = create_cache_folder + output_folder = cache_folder / "sorting_analyzer" + + folder_analyzer = small_sorting_analyzer.save_as(format="binary_folder", folder=output_folder) + quality_metrics_filename = output_folder / "extensions" / "quality_metrics" / "metrics.csv" + + with open(quality_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in quality_metrics: + assert metric_name in metric_names + + folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=False) + + with open(quality_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in quality_metrics: + assert metric_name in metric_names + + folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=True) + + with open(quality_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in quality_metrics: + if metric_name == "snr": + assert metric_name in metric_names + else: + assert metric_name not in metric_names + + +def test_unit_structure_in_output(small_sorting_analyzer): + + qm_params = { + "presence_ratio": {"bin_duration_s": 0.1}, + "amplitude_cutoff": {"num_histogram_bins": 3}, + "amplitude_cv": {"average_num_spikes_per_bin": 7, "min_num_bins": 3}, + "firing_range": {"bin_size_s": 1}, + "isi_violation": {"isi_threshold_ms": 10}, + "drift": {"interval_s": 1, "min_spikes_per_interval": 5, "min_fraction_valid_intervals": 0.2}, + "sliding_rp_violation": {"max_ref_period_ms": 50, "bin_size_ms": 0.15}, + "rp_violation": {"refractory_period_ms": 10.0, "censored_period_ms": 0.0}, + } + + for metric in misc_metrics_list: + metric_name = metric.metric_name + metric_fun = metric.metric_function + try: + qm_param = qm_params[metric_name] + except: + qm_param = {} + + result_all = metric_fun(sorting_analyzer=small_sorting_analyzer, **qm_param) + result_sub = metric_fun(sorting_analyzer=small_sorting_analyzer, unit_ids=["#4", "#9"], **qm_param) + + error = "Problem with metric: " + metric_name + + if isinstance(result_all, dict): + assert list(result_all.keys()) == ["#3", "#9", "#4"], error + assert list(result_sub.keys()) == ["#4", "#9"], error + assert result_sub["#9"] == result_all["#9"], error + assert result_sub["#4"] == result_all["#4"], error + + else: + for result_ind, result in enumerate(result_sub): + + assert list(result_all[result_ind].keys()) == ["#3", "#9", "#4"], error + assert result_sub[result_ind].keys() == set(["#4", "#9"]), error + + assert result_sub[result_ind]["#9"] == result_all[result_ind]["#9"], error + assert result_sub[result_ind]["#4"] == result_all[result_ind]["#4"], error + + +def test_unit_id_order_independence(small_sorting_analyzer): + """ + Takes two almost-identical sorting_analyzers, whose unit_ids are in different orders and have different labels, + and checks that their calculated quality metrics are independent of the ordering and labelling. + """ + + recording = small_sorting_analyzer.recording + sorting = small_sorting_analyzer.sorting.select_units(["#4", "#9", "#3"], [1, 7, 2]) + + small_sorting_analyzer_2 = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") + + extensions_to_compute = { + "random_spikes": {"seed": 1205}, + "noise_levels": {"seed": 1205}, + "waveforms": {}, + "templates": {}, + "spike_amplitudes": {}, + "spike_locations": {}, + "principal_components": {}, + } + + small_sorting_analyzer_2.compute(extensions_to_compute) + + # need special params to get non-nan results on a short recording + qm_params = { + "presence_ratio": {"bin_duration_s": 0.1}, + "amplitude_cutoff": {"num_histogram_bins": 3}, + "amplitude_cv": {"average_num_spikes_per_bin": 7, "min_num_bins": 3}, + "firing_range": {"bin_size_s": 1}, + "isi_violation": {"isi_threshold_ms": 10}, + "drift": {"interval_s": 1, "min_spikes_per_interval": 5}, + "sliding_rp_violation": {"max_ref_period_ms": 50, "bin_size_ms": 0.15}, + } + quality_metrics_1 = compute_quality_metrics( + small_sorting_analyzer, metric_names=get_quality_metric_list(), metric_params=qm_params, skip_pc_metrics=True + ) + quality_metrics_2 = compute_quality_metrics( + small_sorting_analyzer_2, metric_names=get_quality_metric_list(), metric_params=qm_params, skip_pc_metrics=True + ) + + for metric, metric_2_data in quality_metrics_2.items(): + error = "Problem with the metric " + metric + assert quality_metrics_1[metric]["#3"] == metric_2_data[2], error + assert quality_metrics_1[metric]["#9"] == metric_2_data[7], error + assert quality_metrics_1[metric]["#4"] == metric_2_data[1], error + + +if __name__ == "__main__": + pass + # sorting_analyzer = _sorting_analyzer_simple() + # print(sorting_analyzer) + # test_unit_structure_in_output(_small_sorting_analyzer()) + # test_calculate_firing_rate_num_spikes(sorting_analyzer) # test_calculate_snrs(sorting_analyzer) # test_calculate_amplitude_cutoff(sorting_analyzer) # test_calculate_presence_ratio(sorting_analyzer) From 0fe7f3e7778a826562dd2d3a9667684a49861cb1 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 20 Jan 2026 17:38:18 +0100 Subject: [PATCH 54/70] Fix periods arg in MetricExtensions --- src/spikeinterface/metrics/quality/quality_metrics.py | 2 ++ src/spikeinterface/metrics/template/template_metrics.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/spikeinterface/metrics/quality/quality_metrics.py b/src/spikeinterface/metrics/quality/quality_metrics.py index 8e96f4dcaf..e5cc2aa323 100644 --- a/src/spikeinterface/metrics/quality/quality_metrics.py +++ b/src/spikeinterface/metrics/quality/quality_metrics.py @@ -70,6 +70,7 @@ def _set_params( metric_params: dict | None = None, delete_existing_metrics: bool = False, metrics_to_compute: list[str] | None = None, + periods=None, # common extension kwargs peak_sign=None, seed=None, @@ -90,6 +91,7 @@ def _set_params( metric_params=metric_params, delete_existing_metrics=delete_existing_metrics, metrics_to_compute=metrics_to_compute, + periods=periods, peak_sign=peak_sign, seed=seed, skip_pc_metrics=skip_pc_metrics, diff --git a/src/spikeinterface/metrics/template/template_metrics.py b/src/spikeinterface/metrics/template/template_metrics.py index e27f16963d..85ef9e22cb 100644 --- a/src/spikeinterface/metrics/template/template_metrics.py +++ b/src/spikeinterface/metrics/template/template_metrics.py @@ -131,6 +131,7 @@ def _set_params( metric_params: dict | None = None, delete_existing_metrics: bool = False, metrics_to_compute: list[str] | None = None, + periods=None, # common extension kwargs peak_sign="neg", upsampling_factor=10, @@ -160,6 +161,7 @@ def _set_params( metric_params=metric_params, delete_existing_metrics=delete_existing_metrics, metrics_to_compute=metrics_to_compute, + periods=periods, # template metrics do not use periods peak_sign=peak_sign, upsampling_factor=upsampling_factor, include_multi_channel_metrics=include_multi_channel_metrics, From f087e08ec7b5df7b7fdb122d78e6a4b90b2d116c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 20 Jan 2026 17:57:57 +0100 Subject: [PATCH 55/70] Make bin edges unique --- .../metrics/quality/tests/test_metrics_functions.py | 3 --- src/spikeinterface/metrics/utils.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index 2e31c53135..f29e72d153 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -123,12 +123,10 @@ def compute_periods(sorting_analyzer, num_periods, bin_size_s=None): for segment_index in range(sorting_analyzer.recording.get_num_segments()): samples_per_period = sorting_analyzer.get_num_samples(segment_index) // num_periods if bin_size_s is not None: - print(f"Original samples_per_period: {samples_per_period} - num_periods: {num_periods}") bin_size_samples = int(bin_size_s * sorting_analyzer.sampling_frequency) print(samples_per_period / bin_size_samples) samples_per_period = samples_per_period // bin_size_samples * bin_size_samples num_periods = int(np.round(sorting_analyzer.get_num_samples(segment_index) / samples_per_period)) - print(f"Adjusted samples_per_period: {samples_per_period} - num_periods: {num_periods}") for unit_index, unit_id in enumerate(sorting_analyzer.unit_ids): period_starts = np.arange(0, sorting_analyzer.get_num_samples(segment_index), samples_per_period) periods_per_unit = np.zeros(len(period_starts), dtype=unit_period_dtype) @@ -138,7 +136,6 @@ def compute_periods(sorting_analyzer, num_periods, bin_size_s=None): periods_per_unit[i]["start_sample_index"] = period_start periods_per_unit[i]["end_sample_index"] = period_end periods_per_unit[i]["unit_index"] = unit_index - print(periods_per_unit, sorting_analyzer.get_num_samples(segment_index), samples_per_period) all_periods.append(periods_per_unit) return np.concatenate(all_periods) diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py index 83ddfcf90b..00db100c1f 100644 --- a/src/spikeinterface/metrics/utils.py +++ b/src/spikeinterface/metrics/utils.py @@ -37,7 +37,7 @@ def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, per end_sample = seg_start + period["end_sample_index"] end_sample = end_sample // bin_duration_samples * bin_duration_samples + 1 # align to bin bin_edges.extend(np.arange(start_sample, end_sample, bin_duration_samples)) - bin_edges_for_units[unit_id] = np.array(bin_edges) + bin_edges_for_units[unit_id] = np.unique(np.array(bin_edges)) else: for unit_id in sorting.unit_ids: bin_edges = [] From 173e7473034089ed27b080958511d939415b5533 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 21 Jan 2026 10:05:28 +0100 Subject: [PATCH 56/70] Add support_periods to spike train metrics and tests --- src/spikeinterface/metrics/conftest.py | 81 +++++++++++++++++- .../metrics/quality/tests/conftest.py | 85 ------------------- .../quality/tests/test_metrics_functions.py | 43 +--------- .../metrics/spiketrain/metrics.py | 4 +- .../spiketrain/tests/test_metric_functions.py | 33 +++++++ src/spikeinterface/metrics/utils.py | 42 +++++++++ 6 files changed, 158 insertions(+), 130 deletions(-) delete mode 100644 src/spikeinterface/metrics/quality/tests/conftest.py create mode 100644 src/spikeinterface/metrics/spiketrain/tests/test_metric_functions.py diff --git a/src/spikeinterface/metrics/conftest.py b/src/spikeinterface/metrics/conftest.py index 8d32c103fa..5313e763c1 100644 --- a/src/spikeinterface/metrics/conftest.py +++ b/src/spikeinterface/metrics/conftest.py @@ -1,8 +1,85 @@ import pytest -from spikeinterface.postprocessing.tests.conftest import _small_sorting_analyzer +from spikeinterface.core import ( + generate_ground_truth_recording, + create_sorting_analyzer, +) + +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") + + +def make_small_analyzer(): + recording, sorting = generate_ground_truth_recording( + durations=[10.0], + num_units=10, + seed=1205, + ) + + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + + sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"]) + + sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") + + extensions_to_compute = { + "random_spikes": {"seed": 1205}, + "noise_levels": {"seed": 1205}, + "waveforms": {}, + "templates": {"operators": ["average", "median"]}, + "spike_amplitudes": {}, + "spike_locations": {}, + "principal_components": {}, + } + + sorting_analyzer.compute(extensions_to_compute) + + return sorting_analyzer @pytest.fixture(scope="module") def small_sorting_analyzer(): - return _small_sorting_analyzer() + return make_small_analyzer() + + +@pytest.fixture(scope="module") +def sorting_analyzer_simple(): + # we need high firing rate for amplitude_cutoff + recording, sorting = generate_ground_truth_recording( + durations=[ + 120.0, + ], + sampling_frequency=30_000.0, + num_channels=6, + num_units=10, + generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), + generate_unit_locations_kwargs=dict( + margin_um=5.0, + minimum_z=5.0, + maximum_z=20.0, + ), + generate_templates_kwargs=dict( + unit_params=dict( + alpha=(200.0, 500.0), + ) + ), + noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + seed=1205, + ) + + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) + + sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205) + sorting_analyzer.compute("noise_levels") + sorting_analyzer.compute("waveforms", **job_kwargs) + sorting_analyzer.compute("templates") + sorting_analyzer.compute(["spike_amplitudes", "spike_locations"], **job_kwargs) + + return sorting_analyzer diff --git a/src/spikeinterface/metrics/quality/tests/conftest.py b/src/spikeinterface/metrics/quality/tests/conftest.py deleted file mode 100644 index 5313e763c1..0000000000 --- a/src/spikeinterface/metrics/quality/tests/conftest.py +++ /dev/null @@ -1,85 +0,0 @@ -import pytest - -from spikeinterface.core import ( - generate_ground_truth_recording, - create_sorting_analyzer, -) - -job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") - - -def make_small_analyzer(): - recording, sorting = generate_ground_truth_recording( - durations=[10.0], - num_units=10, - seed=1205, - ) - - channel_ids_as_integers = [id for id in range(recording.get_num_channels())] - unit_ids_as_integers = [id for id in range(sorting.get_num_units())] - recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) - sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) - - sorting = sorting.select_units([2, 7, 0], ["#3", "#9", "#4"]) - - sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") - - extensions_to_compute = { - "random_spikes": {"seed": 1205}, - "noise_levels": {"seed": 1205}, - "waveforms": {}, - "templates": {"operators": ["average", "median"]}, - "spike_amplitudes": {}, - "spike_locations": {}, - "principal_components": {}, - } - - sorting_analyzer.compute(extensions_to_compute) - - return sorting_analyzer - - -@pytest.fixture(scope="module") -def small_sorting_analyzer(): - return make_small_analyzer() - - -@pytest.fixture(scope="module") -def sorting_analyzer_simple(): - # we need high firing rate for amplitude_cutoff - recording, sorting = generate_ground_truth_recording( - durations=[ - 120.0, - ], - sampling_frequency=30_000.0, - num_channels=6, - num_units=10, - generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), - generate_unit_locations_kwargs=dict( - margin_um=5.0, - minimum_z=5.0, - maximum_z=20.0, - ), - generate_templates_kwargs=dict( - unit_params=dict( - alpha=(200.0, 500.0), - ) - ), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), - seed=1205, - ) - - channel_ids_as_integers = [id for id in range(recording.get_num_channels())] - unit_ids_as_integers = [id for id in range(sorting.get_num_units())] - recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) - sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) - - sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) - - sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205) - sorting_analyzer.compute("noise_levels") - sorting_analyzer.compute("waveforms", **job_kwargs) - sorting_analyzer.compute("templates") - sorting_analyzer.compute(["spike_amplitudes", "spike_locations"], **job_kwargs) - - return sorting_analyzer diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index f29e72d153..0356e24ed0 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -11,7 +11,7 @@ synthesize_random_firings, ) -from spikeinterface.metrics.utils import create_ground_truth_pc_distributions +from spikeinterface.metrics.utils import create_ground_truth_pc_distributions, compute_periods from spikeinterface.metrics.quality import ( get_quality_metric_list, @@ -99,47 +99,6 @@ def sorting_analyzer_violations(): return _sorting_analyzer_violations() -def compute_periods(sorting_analyzer, num_periods, bin_size_s=None): - """ - Computes and sets periods for each unit in the sorting analyzer. - The periods span the total duration of the recording, but divide it into - smaller periods either by specifying the number of periods or the size of each bin. - - Parameters - ---------- - sorting_analyzer : SortingAnalyzer - The sorting analyzer containing the units and recording information. - num_periods : int - The number of periods to divide the total duration into (used if bin_size_s is None). - bin_size_s : float, defaut: None - If given, periods will be multiple of this size in seconds. - - Returns - ------- - periods - np.ndarray of dtype unit_period_dtype containing the segment, start, end samples and unit index. - """ - all_periods = [] - for segment_index in range(sorting_analyzer.recording.get_num_segments()): - samples_per_period = sorting_analyzer.get_num_samples(segment_index) // num_periods - if bin_size_s is not None: - bin_size_samples = int(bin_size_s * sorting_analyzer.sampling_frequency) - print(samples_per_period / bin_size_samples) - samples_per_period = samples_per_period // bin_size_samples * bin_size_samples - num_periods = int(np.round(sorting_analyzer.get_num_samples(segment_index) / samples_per_period)) - for unit_index, unit_id in enumerate(sorting_analyzer.unit_ids): - period_starts = np.arange(0, sorting_analyzer.get_num_samples(segment_index), samples_per_period) - periods_per_unit = np.zeros(len(period_starts), dtype=unit_period_dtype) - for i, period_start in enumerate(period_starts): - period_end = min(period_start + samples_per_period, sorting_analyzer.get_num_samples(segment_index)) - periods_per_unit[i]["segment_index"] = segment_index - periods_per_unit[i]["start_sample_index"] = period_start - periods_per_unit[i]["end_sample_index"] = period_end - periods_per_unit[i]["unit_index"] = unit_index - all_periods.append(periods_per_unit) - return np.concatenate(all_periods) - - @pytest.fixture def periods_simple(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple diff --git a/src/spikeinterface/metrics/spiketrain/metrics.py b/src/spikeinterface/metrics/spiketrain/metrics.py index 669733f47a..652be32955 100644 --- a/src/spikeinterface/metrics/spiketrain/metrics.py +++ b/src/spikeinterface/metrics/spiketrain/metrics.py @@ -40,6 +40,7 @@ class NumSpikes(BaseMetric): metric_params = {} metric_descriptions = {"num_spikes": "Total number of spikes for each unit across all segments."} metric_columns = {"num_spikes": int} + supports_periods = True def compute_firing_rates(sorting_analyzer, unit_ids=None, periods=None): @@ -68,7 +69,7 @@ def compute_firing_rates(sorting_analyzer, unit_ids=None, periods=None): total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) firing_rates = {} - num_spikes = compute_num_spikes(sorting_analyzer, unit_ids=unit_ids) + num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) for unit_id in unit_ids: if num_spikes[unit_id] == 0: firing_rates[unit_id] = np.nan @@ -83,6 +84,7 @@ class FiringRate(BaseMetric): metric_params = {} metric_descriptions = {"firing_rate": "Firing rate (spikes per second) for each unit across all segments."} metric_columns = {"firing_rate": float} + supports_periods = True spiketrain_metrics = [NumSpikes, FiringRate] diff --git a/src/spikeinterface/metrics/spiketrain/tests/test_metric_functions.py b/src/spikeinterface/metrics/spiketrain/tests/test_metric_functions.py new file mode 100644 index 0000000000..86a5e9db2d --- /dev/null +++ b/src/spikeinterface/metrics/spiketrain/tests/test_metric_functions.py @@ -0,0 +1,33 @@ +import numpy as np + +from spikeinterface.core.base import unit_period_dtype +from spikeinterface.metrics.utils import compute_periods +from spikeinterface.metrics.spiketrain.metrics import compute_num_spikes, compute_firing_rates + + +def test_calculate_num_spikes(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + # spike_amps = sorting_analyzer.get_extension("spike_amplitudes").get_data() + num_spikes = compute_num_spikes(sorting_analyzer) + periods = compute_periods(sorting_analyzer, num_periods=5) + num_spikes_periods = compute_num_spikes(sorting_analyzer, periods=periods) + assert num_spikes == num_spikes_periods + + # calculate num spikes with empty periods + empty_periods = np.empty(0, dtype=unit_period_dtype) + num_spikes_empty_periods = compute_num_spikes(sorting_analyzer, periods=empty_periods) + assert num_spikes_empty_periods == {unit_id: 0 for unit_id in sorting_analyzer.sorting.unit_ids} + + +def test_calculate_firing_rates(sorting_analyzer_simple): + sorting_analyzer = sorting_analyzer_simple + # spike_amps = sorting_analyzer.get_extension("spike_amplitudes").get_data() + firing_rates = compute_firing_rates(sorting_analyzer) + periods = compute_periods(sorting_analyzer, num_periods=5) + firing_rates_periods = compute_firing_rates(sorting_analyzer, periods=periods) + assert firing_rates == firing_rates_periods + + # calculate num spikes with empty periods + empty_periods = np.empty(0, dtype=unit_period_dtype) + firing_rates_empty_periods = compute_firing_rates(sorting_analyzer, periods=empty_periods) + assert np.all(np.isnan(np.array(list(firing_rates_empty_periods.values())))) diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py index 00db100c1f..e007b19c05 100644 --- a/src/spikeinterface/metrics/utils.py +++ b/src/spikeinterface/metrics/utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np +from spikeinterface.core.base import unit_period_dtype def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, periods=None): @@ -108,6 +109,47 @@ def compute_total_durations_per_unit(sorting_analyzer, periods=None): return total_durations +def compute_periods(sorting_analyzer, num_periods, bin_size_s=None): + """ + Computes and sets periods for each unit in the sorting analyzer. + The periods span the total duration of the recording, but divide it into + smaller periods either by specifying the number of periods or the size of each bin. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The sorting analyzer containing the units and recording information. + num_periods : int + The number of periods to divide the total duration into (used if bin_size_s is None). + bin_size_s : float, defaut: None + If given, periods will be multiple of this size in seconds. + + Returns + ------- + periods + np.ndarray of dtype unit_period_dtype containing the segment, start, end samples and unit index. + """ + all_periods = [] + for segment_index in range(sorting_analyzer.recording.get_num_segments()): + samples_per_period = sorting_analyzer.get_num_samples(segment_index) // num_periods + if bin_size_s is not None: + bin_size_samples = int(bin_size_s * sorting_analyzer.sampling_frequency) + print(samples_per_period / bin_size_samples) + samples_per_period = samples_per_period // bin_size_samples * bin_size_samples + num_periods = int(np.round(sorting_analyzer.get_num_samples(segment_index) / samples_per_period)) + for unit_index, unit_id in enumerate(sorting_analyzer.unit_ids): + period_starts = np.arange(0, sorting_analyzer.get_num_samples(segment_index), samples_per_period) + periods_per_unit = np.zeros(len(period_starts), dtype=unit_period_dtype) + for i, period_start in enumerate(period_starts): + period_end = min(period_start + samples_per_period, sorting_analyzer.get_num_samples(segment_index)) + periods_per_unit[i]["segment_index"] = segment_index + periods_per_unit[i]["start_sample_index"] = period_start + periods_per_unit[i]["end_sample_index"] = period_end + periods_per_unit[i]["unit_index"] = unit_index + all_periods.append(periods_per_unit) + return np.concatenate(all_periods) + + def create_ground_truth_pc_distributions(center_locations, total_points): """ Simulate PCs as multivariate Gaussians, for testing PC-based quality metrics From 066c3787171c46ef2bd39727bcbede3e015ec86d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 21 Jan 2026 12:11:30 +0100 Subject: [PATCH 57/70] Force NaN/-1 values for float/int metrics if num_spikes is 0 --- .../metrics/quality/misc_metrics.py | 44 +++++++++-- .../quality/tests/test_metrics_functions.py | 73 ++++++++++++++++++- 2 files changed, 108 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 04d451202d..7bf7ff0f86 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -75,6 +75,7 @@ def compute_presence_ratios( if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() + num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) segment_samples = [sorting_analyzer.get_num_samples(i) for i in range(num_segs)] total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) @@ -104,6 +105,9 @@ def compute_presence_ratios( else: for unit_id in unit_ids: + if num_spikes[unit_id] == 0: + presence_ratios[unit_id] = np.nan + continue spike_train = [] bin_edges = bin_edges_per_unit[unit_id] if len(bin_edges) < 2: @@ -264,6 +268,7 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, periods=None, isi_th unit_ids = sorting_analyzer.unit_ids total_durations = compute_total_durations_per_unit(sorting_analyzer, periods=periods) + num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) fs = sorting_analyzer.sampling_frequency isi_threshold_s = isi_threshold_ms / 1000 @@ -273,15 +278,17 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, periods=None, isi_th isi_violations_ratio = {} for unit_id in unit_ids: + if num_spikes[unit_id] == 0: + isi_violations_ratio[unit_id] = np.nan + isi_violations_count[unit_id] = -1 + continue + spike_train_list = [] for segment_index in range(sorting_analyzer.get_num_segments()): spike_train = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) if len(spike_train) > 0: spike_train_list.append(spike_train / fs) - if not any([len(train) > 0 for train in spike_train_list]): - continue - total_duration = total_durations[unit_id] ratio, _, count = isi_violations(spike_train_list, total_duration, isi_threshold_s, min_isi_s) @@ -359,7 +366,7 @@ def compute_refrac_period_violations( if not HAVE_NUMBA: warnings.warn("Error: numba is not installed.") warnings.warn("compute_refrac_period_violations cannot run without numba.") - return {unit_id: np.nan for unit_id in unit_ids} + return res({unit_id: np.nan for unit_id in unit_ids}, {unit_id: 0 for unit_id in unit_ids}) num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) @@ -372,6 +379,11 @@ def compute_refrac_period_violations( nb_violations = {} rp_contamination = {} for unit_id in unit_ids: + if num_spikes[unit_id] == 0: + rp_contamination[unit_id] = np.nan + nb_violations[unit_id] = -1 + continue + nb_violations[unit_id] = 0 total_samples_unit = total_samples[unit_id] @@ -556,7 +568,7 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, periods=None, syn if unit_ids is None: unit_ids = sorting.unit_ids - spike_counts = sorting_analyzer.sorting.count_num_spikes_per_unit(unit_ids=unit_ids) + num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) spikes = sorting.to_spike_vector() all_unit_ids = sorting.unit_ids @@ -569,10 +581,10 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, periods=None, syn for i, unit_id in enumerate(all_unit_ids): if unit_id not in unit_ids: continue - if spike_counts[unit_id] != 0: - sync_id_metrics_dict[unit_id] = synchrony_counts[sync_idx][i] / spike_counts[unit_id] + if num_spikes[unit_id] != 0: + sync_id_metrics_dict[unit_id] = synchrony_counts[sync_idx][i] / num_spikes[unit_id] else: - sync_id_metrics_dict[unit_id] = 0 + sync_id_metrics_dict[unit_id] = -1 synchrony_metrics_dict[f"sync_spike_{synchrony_size}"] = sync_id_metrics_dict return res(**synchrony_metrics_dict) @@ -629,6 +641,8 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, periods=None, bin_siz if unit_ids is None: unit_ids = sorting.unit_ids + num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) + if all( [ sorting_analyzer.get_num_samples(segment_index) < bin_size_samples @@ -648,6 +662,8 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, periods=None, bin_siz ) cumulative_segment_samples = np.cumsum([0] + segment_samples[:-1]) for unit_id in unit_ids: + if num_spikes[unit_id] == 0: + continue bin_edges = bin_edges_per_unit[unit_id] # we can concatenate spike trains across segments adding the cumulative number of samples @@ -665,6 +681,9 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, periods=None, bin_siz # finally we compute the percentiles firing_ranges = {} for unit_id in unit_ids: + if num_spikes[unit_id] == 0: + firing_ranges[unit_id] = np.nan + continue firing_ranges[unit_id] = np.percentile(firing_rate_histograms[unit_id], percentiles[1]) - np.percentile( firing_rate_histograms[unit_id], percentiles[0] ) @@ -748,6 +767,10 @@ def compute_amplitude_cv_metrics( amplitude_cv_medians, amplitude_cv_ranges = {}, {} for unit_id in unit_ids: + if num_spikes[unit_id] == 0: + amplitude_cv_medians[unit_id] = np.nan + amplitude_cv_ranges[unit_id] = np.nan + continue total_duration = total_durations[unit_id] firing_rate = num_spikes[unit_id] / total_duration temporal_bin_size_samples = int( @@ -1267,6 +1290,8 @@ def compute_sd_ratio( if unit_ids is None: unit_ids = sorting_analyzer.unit_ids + num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) + if not sorting_analyzer.has_recording(): warnings.warn( "The `sd_ratio` metric cannot work with a recordless SortingAnalyzer object" @@ -1297,6 +1322,9 @@ def compute_sd_ratio( sd_ratio = {} for unit_id in unit_ids: + if num_spikes[unit_id] == 0: + sd_ratio[unit_id] = np.nan + continue spk_amp = [] for segment_index in range(sorting_analyzer.get_num_segments()): spike_train = sorting.get_unit_spike_train(unit_id, segment_index) diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index 0356e24ed0..c13f1ffbaa 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -273,6 +273,10 @@ def test_calculate_firing_range(sorting_analyzer_simple): firing_ranges_periods = compute_firing_ranges(sorting_analyzer, periods=periods, bin_size_s=1) assert firing_ranges == firing_ranges_periods + empty_periods = np.empty(0, dtype=unit_period_dtype) + firing_ranges_empty = compute_firing_ranges(sorting_analyzer, periods=empty_periods) + assert np.all(np.isnan(np.array(list(firing_ranges_empty.values())))) + with pytest.warns(UserWarning) as w: firing_ranges_nan = compute_firing_ranges( sorting_analyzer, bin_size_s=sorting_analyzer.get_total_duration() + 1 @@ -287,6 +291,10 @@ def test_calculate_amplitude_cutoff(sorting_analyzer_simple): periods = compute_periods(sorting_analyzer, num_periods=5) amp_cuts_periods = compute_amplitude_cutoffs(sorting_analyzer, periods=periods, num_histogram_bins=10) assert amp_cuts == amp_cuts_periods + + empty_periods = np.empty(0, dtype=unit_period_dtype) + amp_cuts_empty = compute_amplitude_cutoffs(sorting_analyzer, periods=empty_periods) + assert np.all(np.isnan(np.array(list(amp_cuts_empty.values())))) # print(amp_cuts) # testing method accuracy with magic number is not a good pratcice, I remove this. @@ -302,6 +310,10 @@ def test_calculate_amplitude_median(sorting_analyzer_simple): amp_medians_periods = compute_amplitude_medians(sorting_analyzer, periods=periods) assert amp_medians == amp_medians_periods + empty_periods = np.empty(0, dtype=unit_period_dtype) + amp_medians_empty = compute_amplitude_medians(sorting_analyzer, periods=empty_periods) + assert np.all(np.isnan(np.array(list(amp_medians_empty.values())))) + # testing method accuracy with magic number is not a good pratcice, I remove this. # amp_medians_gt = {0: 130.77323354628675, 1: 130.7461997791725, 2: 130.7461997791725} # assert np.allclose(list(amp_medians_gt.values()), list(amp_medians.values()), rtol=0.05) @@ -319,6 +331,15 @@ def test_calculate_amplitude_cv_metrics(sorting_analyzer_simple, periods_simple) assert amp_cv_median == amp_cv_median_periods assert amp_cv_range == amp_cv_range_periods + empty_periods = np.empty(0, dtype=unit_period_dtype) + amp_cv_median_empty, amp_cv_range_empty = compute_amplitude_cv_metrics( + sorting_analyzer, + periods=empty_periods, + average_num_spikes_per_bin=20, + ) + assert np.all(np.isnan(np.array(list(amp_cv_median_empty.values())))) + assert np.all(np.isnan(np.array(list(amp_cv_range_empty.values())))) + # amps_scalings = compute_amplitude_scalings(sorting_analyzer) sorting_analyzer.compute("amplitude_scalings", **job_kwargs) amp_cv_median_scalings, amp_cv_range_scalings = compute_amplitude_cv_metrics( @@ -354,6 +375,10 @@ def test_calculate_presence_ratio(sorting_analyzer_simple, periods_simple): periods = periods_simple ratios_periods = compute_presence_ratios(sorting_analyzer, periods=periods, bin_duration_s=10) assert ratios == ratios_periods + + empty_periods = np.empty(0, dtype=unit_period_dtype) + ratios_periods_empty = compute_presence_ratios(sorting_analyzer, periods=empty_periods) + assert np.all(np.isnan(np.array(list(ratios_periods_empty.values())))) # testing method accuracy with magic number is not a good pratcice, I remove this. # ratios_gt = {0: 1.0, 1: 1.0, 2: 1.0} # np.testing.assert_array_equal(list(ratios_gt.values()), list(ratios.values())) @@ -367,6 +392,12 @@ def test_calculate_isi_violations(sorting_analyzer_violations, periods_violation sorting_analyzer, isi_threshold_ms=1, min_isi_ms=0.0, periods=periods ) assert isi_viol == isi_viol_periods + assert counts == counts_periods + + empty_periods = np.empty(0, dtype=unit_period_dtype) + isi_viol_empty, isi_counts_empty = compute_isi_violations(sorting_analyzer, periods=empty_periods) + assert np.all(np.isnan(np.array(list(isi_viol_empty.values())))) + assert np.array_equal(np.array(list(isi_counts_empty.values())), -1 * np.ones(len(sorting_analyzer.unit_ids))) # testing method accuracy with magic number is not a good pratcice, I remove this. # isi_viol_gt = {0: 0.0998002996004994, 1: 0.7904857139469347, 2: 1.929898371551754} @@ -384,6 +415,12 @@ def test_calculate_sliding_rp_violations(sorting_analyzer_violations, periods_vi ) assert contaminations == contaminations_periods + empty_periods = np.empty(0, dtype=unit_period_dtype) + contaminations_periods_empty = compute_sliding_rp_violations( + sorting_analyzer, periods=empty_periods, bin_size_ms=0.25, window_size_s=1 + ) + assert np.all(np.isnan(np.array(list(contaminations_periods_empty.values())))) + # testing method accuracy with magic number is not a good pratcice, I remove this. # contaminations_gt = {0: 0.03, 1: 0.185, 2: 0.325} # assert np.allclose(list(contaminations_gt.values()), list(contaminations.values()), rtol=0.05) @@ -399,6 +436,15 @@ def test_calculate_rp_violations(sorting_analyzer_violations, periods_violations sorting_analyzer, refractory_period_ms=1, censored_period_ms=0.0, periods=periods ) assert rp_contamination == rp_contamination_periods + assert counts == counts_periods + + empty_periods = np.empty(0, dtype=unit_period_dtype) + rp_contamination_empty, counts_empty = compute_refrac_period_violations( + sorting_analyzer, refractory_period_ms=1, censored_period_ms=0.0, periods=empty_periods + ) + assert np.all(np.isnan(np.array(list(rp_contamination_empty.values())))) + assert np.array_equal(np.array(list(counts_empty.values())), -1 * np.ones(len(sorting_analyzer.unit_ids))) + # testing method accuracy with magic number is not a good pratcice, I remove this. # counts_gt = {0: 2, 1: 4, 2: 10} # rp_contamination_gt = {0: 0.10534956502609294, 1: 1.0, 2: 1.0} @@ -425,8 +471,19 @@ def test_synchrony_metrics(sorting_analyzer_simple, periods_simple): synchrony_metrics_periods = compute_synchrony_metrics(sorting_analyzer, periods=periods) assert synchrony_metrics == synchrony_metrics_periods - synchrony_sizes = np.array([2, 4, 8]) + empty_periods = np.empty(0, dtype=unit_period_dtype) + synchrony_metrics_empty = compute_synchrony_metrics(sorting_analyzer, periods=empty_periods) + assert np.array_equal( + np.array(list(synchrony_metrics_empty.sync_spike_2.values())), -1 * np.ones(len(sorting_analyzer.unit_ids)) + ) + assert np.array_equal( + np.array(list(synchrony_metrics_empty.sync_spike_4.values())), -1 * np.ones(len(sorting_analyzer.unit_ids)) + ) + assert np.array_equal( + np.array(list(synchrony_metrics_empty.sync_spike_8.values())), -1 * np.ones(len(sorting_analyzer.unit_ids)) + ) + synchrony_sizes = np.array([2, 4, 8]) # check returns for size in synchrony_sizes: assert f"sync_spike_{size}" in synchrony_metrics._fields @@ -487,6 +544,15 @@ def test_calculate_drift_metrics(sorting_analyzer_simple): assert drifts_stds == drifts_stds_periods assert drift_mads == drift_mads_periods + # calculate num spikes with empty periods + empty_periods = np.empty(0, dtype=unit_period_dtype) + drifts_ptps_empty, drifts_stds_empty, drift_mads_empty = compute_drift_metrics( + sorting_analyzer_simple, periods=empty_periods + ) + assert np.all(np.isnan(np.array(list(drifts_ptps_empty.values())))) + assert np.all(np.isnan(np.array(list(drifts_stds_empty.values())))) + assert np.all(np.isnan(np.array(list(drift_mads_empty.values())))) + # print(drifts_ptps, drifts_stds, drift_mads) # testing method accuracy with magic number is not a good pratcice, I remove this. @@ -507,6 +573,11 @@ def test_calculate_sd_ratio(sorting_analyzer_simple, periods_simple): assert sd_ratio == sd_ratio_periods assert np.all(list(sd_ratio.keys()) == sorting_analyzer_simple.unit_ids) + + # calculate num spikes with empty periods + empty_periods = np.empty(0, dtype=unit_period_dtype) + sd_ratios_empty_periods = compute_sd_ratio(sorting_analyzer_simple, periods=empty_periods) + assert np.all(np.isnan(np.array(list(sd_ratios_empty_periods.values())))) # @aurelien can you check this, this is not working anymore # assert np.allclose(list(sd_ratio.values()), 1, atol=0.25, rtol=0) From 65e18488860ce63d8766834bec4efc09602f6df7 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 21 Jan 2026 12:27:52 +0100 Subject: [PATCH 58/70] Fix test_empty_units: -1 is a valid value for ints --- .../metrics/quality/tests/test_quality_metric_calculator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py b/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py index ec72fdc178..2e87002018 100644 --- a/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py @@ -168,7 +168,8 @@ def test_empty_units(sorting_analyzer_simple): for col in metrics_empty.columns: all_nans = np.all(isnull(metrics_empty.loc[empty_unit_ids, col].values)) all_zeros = np.all(metrics_empty.loc[empty_unit_ids, col].values == 0) - assert all_nans or all_zeros + all_neg_ones = np.all(metrics_empty.loc[empty_unit_ids, col].values == -1) + assert all_nans or all_zeros or all_neg_ones, f"Column {col} failed the empty unit test" if __name__ == "__main__": From f1c46828a5cf40944c2d0f8800bba6fdee484197 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 21 Jan 2026 12:39:45 +0100 Subject: [PATCH 59/70] Fix firing range if unit samples < bin samples --- src/spikeinterface/metrics/quality/misc_metrics.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 7bf7ff0f86..4556fcf09a 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -642,15 +642,7 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, periods=None, bin_siz unit_ids = sorting.unit_ids num_spikes = sorting.count_num_spikes_per_unit(unit_ids=unit_ids) - - if all( - [ - sorting_analyzer.get_num_samples(segment_index) < bin_size_samples - for segment_index in range(sorting_analyzer.get_num_segments()) - ] - ): - warnings.warn(f"Bin size of {bin_size_s}s is larger than each segment duration. Firing ranges are set to NaN.") - return {unit_id: np.nan for unit_id in unit_ids} + total_samples = compute_total_samples_per_unit(sorting_analyzer, periods=periods) # for each segment, we compute the firing rate histogram and we concatenate them firing_rate_histograms = {unit_id: np.array([], dtype=float) for unit_id in unit_ids} @@ -662,7 +654,7 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, periods=None, bin_siz ) cumulative_segment_samples = np.cumsum([0] + segment_samples[:-1]) for unit_id in unit_ids: - if num_spikes[unit_id] == 0: + if num_spikes[unit_id] == 0 or total_samples[unit_id] < bin_size_samples: continue bin_edges = bin_edges_per_unit[unit_id] @@ -681,7 +673,7 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, periods=None, bin_siz # finally we compute the percentiles firing_ranges = {} for unit_id in unit_ids: - if num_spikes[unit_id] == 0: + if num_spikes[unit_id] == 0 or total_samples[unit_id] < bin_size_samples: firing_ranges[unit_id] = np.nan continue firing_ranges[unit_id] = np.percentile(firing_rate_histograms[unit_id], percentiles[1]) - np.percentile( From 32916382980290b6b75a99c0e9c0f524486af92b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 21 Jan 2026 12:51:19 +0100 Subject: [PATCH 60/70] fix noise_cutoff if empty units --- src/spikeinterface/metrics/quality/misc_metrics.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 4556fcf09a..83ac82bd73 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -1018,6 +1018,10 @@ def compute_noise_cutoffs( for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] + if len(amplitudes) == 0: + cutoff, ratio = np.nan, np.nan + continue + if invert_amplitudes: amplitudes = -amplitudes From b5bf3c3f03fde75a007133ff356dd6a278a6f34c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 21 Jan 2026 13:10:56 +0100 Subject: [PATCH 61/70] Move warnings at the end of the loop for firing range and drift --- .../metrics/quality/misc_metrics.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index 83ac82bd73..ab8dc670d2 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -672,13 +672,20 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, periods=None, bin_siz # finally we compute the percentiles firing_ranges = {} + failed_units = [] for unit_id in unit_ids: if num_spikes[unit_id] == 0 or total_samples[unit_id] < bin_size_samples: + failed_units.append(unit_id) firing_ranges[unit_id] = np.nan continue firing_ranges[unit_id] = np.percentile(firing_rate_histograms[unit_id], percentiles[1]) - np.percentile( firing_rate_histograms[unit_id], percentiles[0] ) + if len(failed_units) > 0: + warnings.warn( + f"Firing range could not be computed for units {failed_units} " + f"because they have no spikes or the total duration is less than bin size." + ) return firing_ranges @@ -1156,18 +1163,16 @@ def compute_drift_metrics( bin_duration_s=interval_s, ) + failed_units = [] median_positions_per_unit = {} for unit_id in unit_ids: bins = bin_edges_for_units[unit_id] num_bins = len(bins) - 1 if num_bins < min_num_bins: - warnings.warn( - f"Unit {unit_id} has only {num_bins} bins given the specified 'interval_s' and " - f"'min_num_bins'. Drift metrics will be set to NaN" - ) drift_ptps[unit_id] = np.nan drift_stds[unit_id] = np.nan drift_mads[unit_id] = np.nan + failed_units.append(unit_id) continue # bin_edges are global across segments, so we have to use spike_sample_indices, @@ -1191,6 +1196,7 @@ def compute_drift_metrics( if np.any(np.isnan(position_diff)): # deal with nans: if more than 50% nans --> set to nan if np.sum(np.isnan(position_diff)) > min_fraction_valid_intervals * len(position_diff): + failed_units.append(unit_id) ptp_drift = np.nan std_drift = np.nan mad_drift = np.nan @@ -1206,6 +1212,12 @@ def compute_drift_metrics( drift_stds[unit_id] = std_drift drift_mads[unit_id] = mad_drift + if len(failed_units) > 0: + warnings.warn( + f"Drift metrics could not be computed for units {failed_units} because they have less than " + f"{min_num_bins} bins given the specified 'interval_s' and 'min_num_bins' or not enough valid intervals." + ) + if return_positions: outs = res(drift_ptps, drift_stds, drift_mads), median_positions_per_unit else: From 8aeedccf3fe36808f0ec145ca39b62bbe24e7855 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 22 Jan 2026 09:55:49 +0100 Subject: [PATCH 62/70] clean up tests and add get_available_metric_names --- .../core/analyzer_extension_core.py | 11 +++++ .../quality/tests/test_metrics_functions.py | 45 +++---------------- 2 files changed, 18 insertions(+), 38 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index a21404e58f..30038bc270 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -901,6 +901,17 @@ class BaseMetricExtension(AnalyzerExtension): need_backward_compatibility_on_load = False metric_list: list[BaseMetric] = None # list of BaseMetric + @classmethod + def get_available_metric_names(cls): + """Get the available metric names. + + Returns + ------- + available_metric_names : list[str] + List of available metric names. + """ + return [m.metric_name for m in cls.metric_list] + @classmethod def get_default_metric_params(cls): """Get the default metric parameters. diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index c13f1ffbaa..61f014c289 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -13,10 +13,7 @@ from spikeinterface.metrics.utils import create_ground_truth_pc_distributions, compute_periods -from spikeinterface.metrics.quality import ( - get_quality_metric_list, - compute_quality_metrics, -) +from spikeinterface.metrics.quality import get_quality_metric_list, compute_quality_metrics, ComputeQualityMetrics from spikeinterface.metrics.quality.misc_metrics import ( misc_metrics_list, compute_amplitude_cutoffs, @@ -657,37 +654,9 @@ def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder): # can't use _misc_metric_name_to_func as some functions compute several qms # e.g. isi_violation and synchrony - quality_metrics = [ - "num_spikes", - "firing_rate", - "presence_ratio", - "snr", - "isi_violations_ratio", - "isi_violations_count", - "rp_contamination", - "rp_violations", - "sliding_rp_violation", - "amplitude_cutoff", - "amplitude_median", - "amplitude_cv_median", - "amplitude_cv_range", - "sync_spike_2", - "sync_spike_4", - "sync_spike_8", - "firing_range", - "drift_ptp", - "drift_std", - "drift_mad", - "sd_ratio", - "isolation_distance", - "l_ratio", - "d_prime", - "silhouette", - "nn_hit_rate", - "nn_miss_rate", - ] - - small_sorting_analyzer.compute("quality_metrics") + quality_metric_columns = ComputeQualityMetrics.get_metric_columns() + all_metrics = ComputeQualityMetrics.get_available_metric_names() + small_sorting_analyzer.compute("quality_metrics", metric_names=all_metrics) cache_folder = create_cache_folder output_folder = cache_folder / "sorting_analyzer" @@ -699,7 +668,7 @@ def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder): saved_metrics = csv.reader(metrics_file) metric_names = next(saved_metrics) - for metric_name in quality_metrics: + for metric_name in quality_metric_columns: assert metric_name in metric_names folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=False) @@ -708,7 +677,7 @@ def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder): saved_metrics = csv.reader(metrics_file) metric_names = next(saved_metrics) - for metric_name in quality_metrics: + for metric_name in quality_metric_columns: assert metric_name in metric_names folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=True) @@ -717,7 +686,7 @@ def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder): saved_metrics = csv.reader(metrics_file) metric_names = next(saved_metrics) - for metric_name in quality_metrics: + for metric_name in quality_metric_columns: if metric_name == "snr": assert metric_name in metric_names else: From d4db43cab085ef362c6200f35f10914ff5273019 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 22 Jan 2026 09:57:34 +0100 Subject: [PATCH 63/70] simplify total samples --- src/spikeinterface/metrics/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py index e007b19c05..dea652985f 100644 --- a/src/spikeinterface/metrics/utils.py +++ b/src/spikeinterface/metrics/utils.py @@ -82,7 +82,8 @@ def compute_total_samples_per_unit(sorting_analyzer, periods=None): num_samples_in_period += period["end_sample_index"] - period["start_sample_index"] total_samples[unit_id] = num_samples_in_period else: - total_samples = {unit_id: sorting_analyzer.get_total_samples() for unit_id in sorting_analyzer.unit_ids} + total = sorting_analyzer.get_total_samples() + total_samples = {unit_id: total for unit_id in sorting_analyzer.unit_ids} return total_samples From d0a1e66c68127e41875ecb7b9d90d4dabf95be99 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 22 Jan 2026 10:24:01 +0100 Subject: [PATCH 64/70] Go back to Pierre's implementation for drifts --- .../metrics/quality/misc_metrics.py | 78 ++++++++----------- src/spikeinterface/metrics/utils.py | 37 +++++++-- 2 files changed, 62 insertions(+), 53 deletions(-) diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index ab8dc670d2..198e98037c 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -1129,13 +1129,20 @@ def compute_drift_metrics( unit_ids = sorting.unit_ids spike_locations_ext = sorting_analyzer.get_extension("spike_locations") - spike_locations_array = spike_locations_ext.get_data(periods=periods) + spike_locations_by_unit_and_segments = spike_locations_ext.get_data( + outputs="by_unit", concatenated=False, periods=periods + ) spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit", concatenated=True, periods=periods) segment_samples = [sorting_analyzer.get_num_samples(i) for i in range(sorting_analyzer.get_num_segments())] - assert direction in spike_locations_array.dtype.names, ( - f"Direction {direction} is invalid. Available directions: " f"{spike_locations_array.dtype.names}" + data = spike_locations_by_unit[unit_ids[0]] + assert direction in data.dtype.names, ( + f"Direction {direction} is invalid. Available directions: " f"{data.dtype.names}" + ) + bin_edges_for_units = compute_bin_edges_per_unit( + sorting, segment_samples=segment_samples, periods=periods, bin_duration_s=interval_s, concatenated=False ) + failed_units = [] # we need drift_ptps = {} @@ -1144,62 +1151,43 @@ def compute_drift_metrics( # reference positions are the medians across segments reference_positions = {} + median_position_segments = {unit_id: np.array([]) for unit_id in unit_ids} + for unit_id in unit_ids: reference_positions[unit_id] = np.median(spike_locations_by_unit[unit_id][direction]) - # now compute median positions and concatenate them over segments - spike_vector = sorting.to_spike_vector() - spike_sample_indices = spike_vector["sample_index"].copy() - # we need to add the cumulative sum of segment samples to have global sample indices - cumulative_segment_samples = np.cumsum([0] + segment_samples[:-1]) for segment_index in range(sorting_analyzer.get_num_segments()): - segment_slice = sorting._get_spike_vector_segment_slices()[segment_index] - spike_sample_indices[segment_slice[0] : segment_slice[1]] += cumulative_segment_samples[segment_index] - - bin_edges_for_units = compute_bin_edges_per_unit( - sorting, - segment_samples=segment_samples, - periods=periods, - bin_duration_s=interval_s, - ) - - failed_units = [] - median_positions_per_unit = {} + for unit_id in unit_ids: + bins = bin_edges_for_units[unit_id][segment_index] + num_bin_edges = len(bins) + if (num_bin_edges - 1) < min_num_bins: + failed_units.append(unit_id) + continue + median_positions = np.nan * np.zeros((num_bin_edges - 1)) + spikes_in_segment_of_unit = sorting.get_unit_spike_train(unit_id, segment_index) + bounds = np.searchsorted(spikes_in_segment_of_unit, bins, side="left") + for bin_index, (i0, i1) in enumerate(zip(bounds[:-1], bounds[1:])): + spike_locations_in_bin = spike_locations_by_unit_and_segments[segment_index][unit_id][i0:i1][direction] + if (i1 - i0) >= min_spikes_per_interval: + median_positions[bin_index] = np.median(spike_locations_in_bin) + median_position_segments[unit_id] = np.concatenate((median_position_segments[unit_id], median_positions)) + + # finally, compute deviations and drifts for unit_id in unit_ids: - bins = bin_edges_for_units[unit_id] - num_bins = len(bins) - 1 - if num_bins < min_num_bins: + # Skip units that already failed because not enough bins in at least one segment + if unit_id in failed_units: drift_ptps[unit_id] = np.nan drift_stds[unit_id] = np.nan drift_mads[unit_id] = np.nan - failed_units.append(unit_id) continue - - # bin_edges are global across segments, so we have to use spike_sample_indices, - # since we offseted them to be global - bin_spike_indices = np.searchsorted(spike_sample_indices, bins) - median_positions = np.nan * np.zeros(num_bins) - for bin_index, (i0, i1) in enumerate(zip(bin_spike_indices[:-1], bin_spike_indices[1:])): - spikes_in_bin = spike_vector[i0:i1] - spike_locations_in_bin = spike_locations_array[i0:i1][direction] - - unit_index = sorting_analyzer.sorting.id_to_index(unit_id) - mask = spikes_in_bin["unit_index"] == unit_index - if np.sum(mask) >= min_spikes_per_interval: - median_positions[bin_index] = np.median(spike_locations_in_bin[mask]) - else: - median_positions[bin_index] = np.nan - median_positions_per_unit[unit_id] = median_positions - - # now compute deviations and drifts for this unit - position_diff = median_positions - reference_positions[unit_id] + position_diff = median_position_segments[unit_id] - reference_positions[unit_id] if np.any(np.isnan(position_diff)): # deal with nans: if more than 50% nans --> set to nan if np.sum(np.isnan(position_diff)) > min_fraction_valid_intervals * len(position_diff): - failed_units.append(unit_id) ptp_drift = np.nan std_drift = np.nan mad_drift = np.nan + failed_units.append(unit_id) else: ptp_drift = np.nanmax(position_diff) - np.nanmin(position_diff) std_drift = np.nanstd(np.abs(position_diff)) @@ -1219,7 +1207,7 @@ def compute_drift_metrics( ) if return_positions: - outs = res(drift_ptps, drift_stds, drift_mads), median_positions_per_unit + outs = res(drift_ptps, drift_stds, drift_mads), median_positions else: outs = res(drift_ptps, drift_stds, drift_mads) return outs diff --git a/src/spikeinterface/metrics/utils.py b/src/spikeinterface/metrics/utils.py index dea652985f..235ae5cd16 100644 --- a/src/spikeinterface/metrics/utils.py +++ b/src/spikeinterface/metrics/utils.py @@ -4,7 +4,7 @@ from spikeinterface.core.base import unit_period_dtype -def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, periods=None): +def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, periods=None, concatenated=True): """ Compute bin edges for units, optionally taking into account periods. @@ -18,6 +18,16 @@ def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, per Duration of each bin in seconds periods : array of unit_period_dtype, default: None Periods to consider for each unit + concatenated : bool, default: True + Wheter the bins are concatenated across segments or not. + If False, the bin edges are computed per segment and the first index of each segment is 0. + If True, the bin edges are computed on the concatenated segments, with the correct offsets. + + Returns + ------- + dict + Bin edges for each unit. If concatenated is True, the bin edges are a 1D array. + If False, the bin edges are a list of arrays, one per segment. """ bin_edges_for_units = {} num_segments = len(segment_samples) @@ -31,27 +41,38 @@ def compute_bin_edges_per_unit(sorting, segment_samples, bin_duration_s=1.0, per for seg_index in range(num_segments): seg_periods = periods_unit[periods_unit["segment_index"] == seg_index] if len(seg_periods) == 0: + if not concatenated: + bin_edges.append(np.array([])) continue - seg_start = np.sum(segment_samples[:seg_index]) + seg_start = np.sum(segment_samples[:seg_index]) if concatenated else 0 + bin_edges_segment = [] for period in seg_periods: start_sample = seg_start + period["start_sample_index"] end_sample = seg_start + period["end_sample_index"] end_sample = end_sample // bin_duration_samples * bin_duration_samples + 1 # align to bin - bin_edges.extend(np.arange(start_sample, end_sample, bin_duration_samples)) - bin_edges_for_units[unit_id] = np.unique(np.array(bin_edges)) + bin_edges_segment.extend(np.arange(start_sample, end_sample, bin_duration_samples)) + bin_edges_segment = np.unique(np.array(bin_edges_segment)) + if concatenated: + bin_edges.extend(bin_edges_segment) + else: + bin_edges.append(bin_edges_segment) + bin_edges_for_units[unit_id] = bin_edges else: for unit_id in sorting.unit_ids: bin_edges = [] for seg_index in range(num_segments): - seg_start = np.sum(segment_samples[:seg_index]) + seg_start = np.sum(segment_samples[:seg_index]) if concatenated else 0 seg_end = seg_start + segment_samples[seg_index] # for segments which are not the last, we don't need to correct the end # since the first index of the next segment will be the end of the current segment if seg_index == num_segments - 1: seg_end = seg_end // bin_duration_samples * bin_duration_samples + 1 # align to bin - bins = np.arange(seg_start, seg_end, bin_duration_samples) - bin_edges.extend(bins) - bin_edges_for_units[unit_id] = np.array(bin_edges) + bin_edges_segment = np.arange(seg_start, seg_end, bin_duration_samples) + if concatenated: + bin_edges.extend(bin_edges_segment) + else: + bin_edges.append(bin_edges_segment) + bin_edges_for_units[unit_id] = bin_edges return bin_edges_for_units From 4909bfb0ddd079add88e14cd9eed1a27465ea37a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 22 Jan 2026 12:49:40 +0100 Subject: [PATCH 65/70] Add in docs --- doc/index.rst | 2 +- doc/modules/core.rst | 35 ++++++++++++--------- doc/modules/exporters.rst | 38 ++++++++++++---------- doc/modules/postprocessing.rst | 52 ++++++++++++++++++++++++++++--- doc/modules/sortingcomponents.rst | 4 ++- 5 files changed, 95 insertions(+), 36 deletions(-) diff --git a/doc/index.rst b/doc/index.rst index ce4053ca43..ae08a48fed 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -30,7 +30,7 @@ SpikeInterface is made of several modules to deal with different aspects of the - visualize recordings and spike sorting outputs in several ways (matplotlib, sortingview, jupyter, ephyviewer) - export a report and/or export to phy - curate your sorting with several strategies (ml-based, metrics based, manual, ...) -- offer a powerful Qt-based or we-based viewer in a separate package `spikeinterface-gui `_ for manual curation that replace phy. +- offer a powerful Qt-based or web-based viewer in a separate package `spikeinterface-gui `_ for manual curation that replace phy. - have powerful sorting components to build your own sorter. - have a full motion/drift correction framework (See :ref:`motion_correction`) diff --git a/doc/modules/core.rst b/doc/modules/core.rst index 078db82201..bf9edbe9d6 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -93,10 +93,11 @@ with 16 channels: timestamps = np.arange(num_samples) / sampling_frequency + 300 recording.set_times(times=timestamps, segment_index=0) -**Note**: -Raw data formats often store data as integer values for memory efficiency. To give these integers meaningful physical units (uV), you can apply a gain and an offset. -Many devices have their own gains and offsets necessary to convert their data and these values are handled by SpikeInterface for its extractors. This -is triggered by the :code:`return_in_uV` parameter in :code:`get_traces()`, (see above example), which will return the traces in uV. Read more in our how to guide, :ref:`physical_units`. +.. note:: + + Raw data formats often store data as integer values for memory efficiency. To give these integers meaningful physical units (uV), you can apply a gain and an offset. + Many devices have their own gains and offsets necessary to convert their data and these values are handled by SpikeInterface for its extractors. This + is triggered by the :code:`return_in_uV` parameter in :code:`get_traces()`, (see above example), which will return the traces in uV. Read more in our how to guide, :ref:`physical_units`. Sorting @@ -458,9 +459,11 @@ It represents unsorted waveform cutouts. Some acquisition systems, in fact, allo threshold and only record the times at which a peak was detected and the waveform cut out around the peak. -**NOTE**: while we support this class (mainly for legacy formats), this approach is a bad practice -and is highly discouraged! Most modern spike sorters, in fact, require the raw traces to perform -template matching to recover spikes! +.. note:: + + While we support this class (mainly for legacy formats), this approach is a bad practice + and is highly discouraged! Most modern spike sorters, in fact, require the raw traces to perform + template matching to recover spikes! Here we assume :code:`snippets` is a :py:class:`~spikeinterface.core.BaseSnippets` object with 16 channels: @@ -548,9 +551,11 @@ Sparsity is defined as the subset of channels on which waveforms (and related in sparsity is not global, but it is unit-specific. Importantly, saving sparse waveforms, especially for high-density probes, dramatically reduces the size of the waveforms extension if computed. -**NOTE** As of :code:`0.101.0` all :code:`SortingAnalyzer`'s have a default of :code:`sparse=True`. This was first -introduced in :code:`0.99.0` for :code:`WaveformExtractor`'s and will be the default going forward. To obtain dense -waveforms you will need to set :code:`sparse=False` at the creation of the :code:`SortingAnalyzer`. +.. note:: + + As of :code:`0.101.0` all :code:`SortingAnalyzer`'s have a default of :code:`sparse=True`. This was first + introduced in :code:`0.99.0` for :code:`WaveformExtractor`'s and will be the default going forward. To obtain dense + waveforms you will need to set :code:`sparse=False` at the creation of the :code:`SortingAnalyzer`. Sparsity can be computed from a :py:class:`~spikeinterface.core.SortingAnalyzer` object with the @@ -854,10 +859,12 @@ The same functions are also available for :py:func:`~spikeinterface.core.select_segment_sorting`). -**Note** :py:func:`~spikeinterface.core.append_recordings` and:py:func:`~spikeinterface.core.concatenate_recordings` -have the same goal, aggregate recording pieces on the time axis but with 2 different strategies! One is keeping the -multi segments concept, the other one is breaking it! -See this example for more detail :ref:`example_segments`. +.. note:: + + :py:func:`~spikeinterface.core.append_recordings` and:py:func:`~spikeinterface.core.concatenate_recordings` + have the same goal, aggregate recording pieces on the time axis but with 2 different strategies! One is keeping the + multi segments concept, the other one is breaking it! + See this example for more detail :ref:`example_segments`. diff --git a/doc/modules/exporters.rst b/doc/modules/exporters.rst index 4eeaa23b81..ab25e8803d 100644 --- a/doc/modules/exporters.rst +++ b/doc/modules/exporters.rst @@ -12,9 +12,11 @@ and behavioral data. It can be used to decode behavior, make tuning curves, comp The :py:func:`~spikeinterface.exporters.to_pynapple_tsgroup` function allows you to convert a SortingAnalyzer to Pynapple's ``TsGroup`` object on the fly. -**Note** : When creating the ``TsGroup``, we will use the underlying time support of the SortingAnalyzer. -How this works depends on your acquisition system. You can use the ``get_times`` method on a recording -(``my_recording.get_times()``) to find the time support of your recording. +.. note:: + + When creating the ``TsGroup``, we will use the underlying time support of the SortingAnalyzer. + How this works depends on your acquisition system. You can use the ``get_times`` method on a recording + (``my_recording.get_times()``) to find the time support of your recording. When constructed, if ``attach_unit_metadata`` is set to ``True``, any relevant unit information is propagated to the ``TsGroup``. The ``to_pynapple_tsgroup`` checks if unit locations, quality @@ -54,13 +56,15 @@ The :py:func:`~spikeinterface.exporters.export_to_phy` function allows you to us `Phy template GUI `_ for visual inspection and manual curation of spike sorting results. -**Note** : :py:func:`~spikeinterface.exporters.export_to_phy` speed and the size of the folder will highly depend -on the sparsity of the :code:`SortingAnalyzer` itself or the external specified sparsity. -The Phy viewer enables one to explore PCA projections, spike amplitudes, waveforms and quality of spike sorting results. -So if these pieces of information have already been computed as extensions (see :ref:`modules/postprocessing:Extensions as AnalyzerExtensions`), -then exporting to Phy should be fast (and the user has better control of the parameters for the extensions). -If not pre-computed, then the required extensions (e.g., :code:`spike_amplitudes`, :code:`principal_components`) -can be computed directly at export time. +.. note:: + + :py:func:`~spikeinterface.exporters.export_to_phy` speed and the size of the folder will highly depend + on the sparsity of the :code:`SortingAnalyzer` itself or the external specified sparsity. + The Phy viewer enables one to explore PCA projections, spike amplitudes, waveforms and quality of spike sorting results. + So if these pieces of information have already been computed as extensions (see :ref:`modules/postprocessing:Extensions as AnalyzerExtensions`), + then exporting to Phy should be fast (and the user has better control of the parameters for the extensions). + If not pre-computed, then the required extensions (e.g., :code:`spike_amplitudes`, :code:`principal_components`) + can be computed directly at export time. The input of the :py:func:`~spikeinterface.exporters.export_to_phy` is a :code:`SortingAnalyzer` object. @@ -131,12 +135,14 @@ The report includes summary figures of the spike sorting output (e.g. amplitude depth VS amplitude) as well as unit-specific reports, that include waveforms, templates, template maps, ISI distributions, and more. -**Note** : similarly to :py:func:`~spikeinterface.exporters.export_to_phy` the -:py:func:`~spikeinterface.exporters.export_report` depends on the sparsity of the :code:`SortingAnalyzer` itself and -on which extensions have been computed. For example, :code:`spike_amplitudes` and :code:`correlograms` related plots -will be automatically included in the report if the associated extensions are computed in advance. -The function can perform these computations as well, but it is a better practice to compute everything that's needed -beforehand. +.. note:: + + Similarly to :py:func:`~spikeinterface.exporters.export_to_phy` the + :py:func:`~spikeinterface.exporters.export_report` depends on the sparsity of the :code:`SortingAnalyzer` itself and + on which extensions have been computed. For example, :code:`spike_amplitudes` and :code:`correlograms` related plots + will be automatically included in the report if the associated extensions are computed in advance. + The function can perform these computations as well, but it is a better practice to compute everything that's needed + beforehand. Note that every unit will generate a summary unit figure, so the export process can be slow for spike sorting outputs with many units! diff --git a/doc/modules/postprocessing.rst b/doc/modules/postprocessing.rst index 5442b4728c..5c4e29b359 100644 --- a/doc/modules/postprocessing.rst +++ b/doc/modules/postprocessing.rst @@ -163,8 +163,10 @@ Extensions are generally saved in two ways, suitable for two workflows: :code:`sorting_analyzer.compute('waveforms', save=False)`). -**NOTE**: We recommend choosing a workflow and sticking with it. Either keep everything on disk or keep everything in memory until -you'd like to save. A mixture can lead to unexpected behavior. For example, consider the following code +.. note:: + + We recommend choosing a workflow and sticking with it. Either keep everything on disk or keep everything in memory until + you'd like to save. A mixture can lead to unexpected behavior. For example, consider the following code .. code:: @@ -257,15 +259,35 @@ spike_amplitudes This extension computes the amplitude of each spike as the value of the traces on the extremum channel at the times of each spike. The extremum channel is computed from the templates. + **NOTE:** computing spike amplitudes is highly recommended before calculating amplitude-based quality metrics, such as :ref:`amp_cutoff` and :ref:`amp_median`. .. code-block:: python - amplitudes = sorting_analyzer.compute(input="spike_amplitudes", peak_sign="neg") + amplitudes = sorting_analyzer.compute(input="spike_amplitudes") For more information, see :py:func:`~spikeinterface.postprocessing.compute_spike_amplitudes` + +.. _postprocessing_amplitude_scalings: + +amplitude_scalings +^^^^^^^^^^^^^^^^^^ + +This extension computes the amplitude scaling of each spike as the value of the linear fit between the template and the +spike waveform. In case of spatio-temporal collisions, a multi-linear fit is performed using the templates of all units +involved in the collision. + +**NOTE:** computing amplitude scalings is highly recommended before calculating amplitude-based quality metrics, such as +:ref:`amp_cutoff` and :ref:`amp_median`. + +.. code-block:: python + + amplitude_scalings = sorting_analyzer.compute(input="amplitude_scalings") + +For more information, see :py:func:`~spikeinterface.postprocessing.compute_amplitude_scalings` + .. _postprocessing_spike_locations: spike_locations @@ -367,7 +389,29 @@ This extension computes the histograms of inter-spike-intervals. The computed ou method="auto" ) -For more information, see :py:func:`~spikeinterface.postprocessing.compute_isi_histograms` +valid_unit_periods +^^^^^^^^^^^^^^^^^^ + +This extension computes the valid unit periods for each unit based on the estimation of false positive rates +(using RP violation - see ::doc:`metrics/qualitymetrics/isi_violations`) and false negative rates +(using amplitude cutoff - see ::doc:`metrics/qualitymetrics/amplitude_cutoff`) computed over chunks of the recording. +The valid unit periods are the periods where both false positive and false negative rates are below specified +thresholds. Periods can be either absolute (in seconds), same for all units, or relative, where +chunks will be unit-specific depending on firing rate (with a target number of spikes per chunk). + +.. code-block:: python + + valid_periods = sorting_analyzer.compute( + input="valid_unit_periods", + period_mode='relative', + target_num_spikes=300, + fp_threshold=0.1, + fn_threshold=0.1, + ) + +For more information, see :py:func:`~spikeinterface.postprocessing.compute_valid_unit_periods`. + + Other postprocessing tools diff --git a/doc/modules/sortingcomponents.rst b/doc/modules/sortingcomponents.rst index e76cf3f99d..f81c63588c 100644 --- a/doc/modules/sortingcomponents.rst +++ b/doc/modules/sortingcomponents.rst @@ -81,7 +81,9 @@ Other variants are also implemented (but less tested or not so useful): * **'by_channel_torch'** (requires :code:`torch`): pytorch implementation (GPU-compatible) that uses max pooling for time deduplication * **'locally_exclusive_torch'** (requires :code:`torch`): pytorch implementation (GPU-compatible) that uses max pooling for space-time deduplication -**NOTE**: the torch implementations give slightly different results due to a different implementation. +.. note:: + + The torch implementations give slightly different results due to a different implementation. Peak detection, as many of the other sorting components, can be run in parallel. From 2739df8f186b3e6f63ad66e5ada012435aaf9eec Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 22 Jan 2026 14:43:00 +0100 Subject: [PATCH 66/70] Add use_valid_periods param to quality metrics --- doc/index.rst | 2 +- doc/modules/core.rst | 5 +- doc/modules/sortingcomponents.rst | 2 +- src/spikeinterface/core/sortinganalyzer.py | 12 ++++ .../metrics/quality/quality_metrics.py | 13 ++++ .../tests/test_quality_metric_calculator.py | 71 ++++++++++++++++++- .../tests/common_extension_tests.py | 41 +++++++---- .../tests/test_principal_component.py | 10 +-- .../tests/test_valid_unit_periods.py | 18 ++--- .../postprocessing/valid_unit_periods.py | 10 +++ .../widgets/tests/test_widgets.py | 18 ++--- 11 files changed, 155 insertions(+), 47 deletions(-) diff --git a/doc/index.rst b/doc/index.rst index ae08a48fed..659efe85a8 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -30,7 +30,7 @@ SpikeInterface is made of several modules to deal with different aspects of the - visualize recordings and spike sorting outputs in several ways (matplotlib, sortingview, jupyter, ephyviewer) - export a report and/or export to phy - curate your sorting with several strategies (ml-based, metrics based, manual, ...) -- offer a powerful Qt-based or web-based viewer in a separate package `spikeinterface-gui `_ for manual curation that replace phy. +- offer a powerful desktop or web viewer in a separate package `spikeinterface-gui `_ for manual curation that replace phy. - have powerful sorting components to build your own sorter. - have a full motion/drift correction framework (See :ref:`motion_correction`) diff --git a/doc/modules/core.rst b/doc/modules/core.rst index bf9edbe9d6..0faed14fba 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -181,8 +181,9 @@ a numpy.array with dtype `[("sample_index", "int64"), ("unit_index", "int64"), ( For computations which are done unit-by-unit, like computing isi-violations per unit, it is better that spikes from a single unit are concurrent in memory. For these other cases, we can re-order the `spike_vector` in different ways: - * order by unit, then segment, then sample - * order by segment, then unit, then sample + +* order by unit, then segment, then sample +* order by segment, then unit, then sample This is done using `sorting.to_reordered_spike_vector()`. The first time a reordering is done, the reordered spiketrain is cached in memory by default. Users should rarely have to worry about these diff --git a/doc/modules/sortingcomponents.rst b/doc/modules/sortingcomponents.rst index f81c63588c..5549fd0317 100644 --- a/doc/modules/sortingcomponents.rst +++ b/doc/modules/sortingcomponents.rst @@ -276,7 +276,7 @@ handle drift can benefit from drift estimation/correction. Especially for acute Neuropixels-like probes, this is a crucial step. The motion estimation step comes after peak detection and peak localization. Read more about -it in the :ref:`_motion_correction` modules doc, and a more practical guide in the +it in the :ref:`motion_correction` modules doc, and a more practical guide in the :ref:`handle-drift-in-your-recording` How To. Here is an example with non-rigid motion estimation: diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 7b7eea796e..9883885c15 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -2471,6 +2471,18 @@ def get_any_dependencies(cls, **params): all_dependencies = list(chain.from_iterable([dep.split("|") for dep in all_dependencies])) return all_dependencies + @classmethod + def get_default_params(cls): + """ + Get the default params for the extension. + + Returns + ------- + default_params : dict + The default parameters for the extension. + """ + return get_default_analyzer_extension_params(cls.extension_name) + def load_run_info(self): run_info = None if self.format == "binary_folder": diff --git a/src/spikeinterface/metrics/quality/quality_metrics.py b/src/spikeinterface/metrics/quality/quality_metrics.py index e5cc2aa323..5476aa405a 100644 --- a/src/spikeinterface/metrics/quality/quality_metrics.py +++ b/src/spikeinterface/metrics/quality/quality_metrics.py @@ -49,6 +49,13 @@ class ComputeQualityMetrics(BaseMetricExtension): need_backward_compatibility_on_load = True metric_list = misc_metrics_list + pca_metrics_list + @classmethod + def get_required_dependencies(cls, **params): + if params.get("use_valid_periods", False): + return ["valid_unit_periods"] + else: + return [] + def _handle_backward_compatibility_on_load(self): # For backwards compatibility - this renames qm_params as metric_params if (qm_params := self.params.get("qm_params")) is not None: @@ -70,6 +77,7 @@ def _set_params( metric_params: dict | None = None, delete_existing_metrics: bool = False, metrics_to_compute: list[str] | None = None, + use_valid_periods=False, periods=None, # common extension kwargs peak_sign=None, @@ -86,6 +94,11 @@ def _set_params( pc_metric_names = [m.metric_name for m in pca_metrics_list] metric_names = [m for m in metric_names if m not in pc_metric_names] + if use_valid_periods: + if periods is not None: + raise ValueError("If use_valid_periods is True, periods should not be provided.") + periods = self.sorting_analyzer.get_extension("valid_unit_periods").get_data(outputs="numpy") + return super()._set_params( metric_names=metric_names, metric_params=metric_params, diff --git a/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py b/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py index 2e87002018..dfd47c4df9 100644 --- a/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py @@ -1,5 +1,4 @@ import pytest -from pathlib import Path import numpy as np from spikeinterface.core import ( @@ -172,6 +171,76 @@ def test_empty_units(sorting_analyzer_simple): assert all_nans or all_zeros or all_neg_ones, f"Column {col} failed the empty unit test" +def test_quality_metrics_with_periods(): + """ + Test that quality metrics can be computed using valid unit periods. + """ + from spikeinterface.core.base import unit_period_dtype + + recording, sorting = generate_ground_truth_recording() + sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, format="memory") + + # compute dependencies + sorting_analyzer.compute(["random_spikes", "templates", "amplitude_scalings", "valid_unit_periods"], **job_kwargs) + print(sorting_analyzer) + + # compute quality metrics using valid periods + metrics = compute_quality_metrics( + sorting_analyzer, + metric_names=None, + skip_pc_metrics=True, + use_valid_periods=True, + seed=2205, + ) + print(metrics) + + # test with external periods: 1 period per segment from 10 to 90% of recording + num_segments = recording.get_num_segments() + periods = np.zeros(len(sorting.unit_ids) * num_segments, dtype=unit_period_dtype) + for i, unit_id in enumerate(sorting.unit_ids): + unit_index = sorting.id_to_index(unit_id) + for segment_index in range(num_segments): + num_samples = recording.get_num_samples(segment_index=segment_index) + idx = i * num_segments + segment_index + periods[idx]["unit_index"] = unit_index + period_start = int(num_samples * 0.1) + period_end = int(num_samples * 0.9) + periods[idx]["start_sample_index"] = period_start + periods[idx]["end_sample_index"] = period_end + periods[idx]["segment_index"] = segment_index + + metrics_ext_periods = compute_quality_metrics( + sorting_analyzer, + metric_names=None, + skip_pc_metrics=True, + use_valid_periods=False, + periods=periods, + seed=2205, + ) + + # test failure when both periods and use_valid_periods are set + with pytest.raises(ValueError): + compute_quality_metrics( + sorting_analyzer, + metric_names=None, + skip_pc_metrics=True, + use_valid_periods=True, + periods=periods, + seed=2205, + ) + + # test failure if use valid_periods is True but valid_unit_periods extension is missing + sorting_analyzer.delete_extension("valid_unit_periods") + with pytest.raises(AssertionError): + compute_quality_metrics( + sorting_analyzer, + metric_names=None, + skip_pc_metrics=True, + use_valid_periods=True, + seed=2205, + ) + + if __name__ == "__main__": sorting_analyzer = get_sorting_analyzer() diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index fe71d0eb5a..e78c0a6b47 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -4,9 +4,13 @@ import shutil import numpy as np -from spikeinterface.core import generate_ground_truth_recording -from spikeinterface.core import create_sorting_analyzer, load_sorting_analyzer -from spikeinterface.core import estimate_sparsity +from spikeinterface.core import ( + generate_ground_truth_recording, + create_sorting_analyzer, + load_sorting_analyzer, + estimate_sparsity, +) +from spikeinterface.core.sortinganalyzer import get_extension_class def get_dataset(): @@ -95,7 +99,19 @@ def get_sorting_analyzer(self, recording, sorting, format="memory", sparsity=Non return sorting_analyzer - def _prepare_sorting_analyzer(self, format, sparse, extension_class, extra_dependencies=None): + def _compute_extensions_recursively(self, sorting_analyzer, extension_class, params): + # compute dependencies of the extension class with default params + dependencies = extension_class.get_required_dependencies(**params) + for dependency_name in dependencies: + if "|" in dependency_name: + dependency_name = dependency_name.split("|")[0] + if not sorting_analyzer.has_extension(dependency_name): + # compute dependencies of the dependency + self._compute_extensions_recursively(sorting_analyzer, get_extension_class(dependency_name), {}) + # compute the dependency itself + sorting_analyzer.compute(dependency_name) + + def _prepare_sorting_analyzer(self, format, sparse, extension_class, extension_params=None): # prepare a SortingAnalyzer object with depencies already computed sparsity_ = self.sparsity if sparse else None sorting_analyzer = self.get_sorting_analyzer( @@ -103,14 +119,11 @@ def _prepare_sorting_analyzer(self, format, sparse, extension_class, extra_depen ) sorting_analyzer.compute("random_spikes", max_spikes_per_unit=20, seed=2205) - for dependency_name in extension_class.depend_on: - if "|" in dependency_name: - dependency_name = dependency_name.split("|")[0] - sorting_analyzer.compute(dependency_name) - if extra_dependencies is not None: - for dependency_name in extra_dependencies: - print("Computing extra dependency:", dependency_name) - sorting_analyzer.compute(dependency_name) + # default params for dependencies + params = sorting_analyzer.get_default_extension_params(extension_class.extension_name) + if extension_params is not None: + params.update(extension_params) + self._compute_extensions_recursively(sorting_analyzer, extension_class, params) return sorting_analyzer @@ -168,7 +181,7 @@ def _check_one(self, sorting_analyzer, extension_class, params): else: continue - def run_extension_tests(self, extension_class, params, extra_dependencies=None): + def run_extension_tests(self, extension_class, params): """ Convenience function to perform all checks on the extension of interest with the passed parameters. Will perform tests @@ -178,6 +191,6 @@ def run_extension_tests(self, extension_class, params, extra_dependencies=None): for format in ("memory", "binary_folder", "zarr"): print("sparse", sparse, format) sorting_analyzer = self._prepare_sorting_analyzer( - format, sparse, extension_class, extra_dependencies=extra_dependencies + format, sparse, extension_class, extension_params=params ) self._check_one(sorting_analyzer, extension_class, params) diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index fc9d3643bc..77bff7a3d8 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -99,9 +99,9 @@ def test_get_projections(self, sparse): some_unit_ids = sorting_analyzer.unit_ids[::2] some_channel_ids = sorting_analyzer.channel_ids[::2] - random_spikes_indices = sorting_analyzer.get_extension("random_spikes").get_data() - all_num_spikes = sorting_analyzer.sorting.count_num_spikes_per_unit() - unit_ids_num_spikes = np.sum(all_num_spikes[unit_id] for unit_id in some_unit_ids) + random_spikes_ext = sorting_analyzer.get_extension("random_spikes") + random_spikes_indices = random_spikes_ext.get_data() + unit_ids_num_random_spikes = np.sum(random_spikes_ext.params["max_spikes_per_unit"] for _ in some_unit_ids) # this should be all spikes all channels some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=None) @@ -113,7 +113,7 @@ def test_get_projections(self, sparse): # this should be some spikes all channels some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=some_unit_ids) assert some_projections.shape[0] == spike_unit_index.shape[0] - assert spike_unit_index.shape[0] == unit_ids_num_spikes + assert spike_unit_index.shape[0] == unit_ids_num_random_spikes assert some_projections.shape[1] == n_components assert some_projections.shape[2] == num_chans assert 1 not in spike_unit_index @@ -123,7 +123,7 @@ def test_get_projections(self, sparse): channel_ids=some_channel_ids, unit_ids=some_unit_ids ) assert some_projections.shape[0] == spike_unit_index.shape[0] - assert spike_unit_index.shape[0] == unit_ids_num_spikes + assert spike_unit_index.shape[0] == unit_ids_num_random_spikes assert some_projections.shape[1] == n_components assert some_projections.shape[2] == some_channel_ids.size assert 1 not in spike_unit_index diff --git a/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py b/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py index 6f4c25f415..d198f67295 100644 --- a/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py +++ b/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py @@ -16,9 +16,7 @@ class TestComputeValidUnitPeriods(AnalyzerExtensionCommonTestSuite): ], ) def test_extension(self, params): - self.run_extension_tests( - ComputeValidUnitPeriods, params, extra_dependencies=["templates", "amplitude_scalings"] - ) + self.run_extension_tests(ComputeValidUnitPeriods, params) def test_user_defined_periods(self): unit_ids = self.sorting.unit_ids @@ -116,20 +114,18 @@ def test_combined_periods(self): periods[idx]["end_sample_index"] = period_start + period_duration periods[idx]["segment_index"] = segment_index - sorting_analyzer = self._prepare_sorting_analyzer( - "memory", - sparse=False, - extension_class=ComputeValidUnitPeriods, - extra_dependencies=["templates", "amplitude_scalings"], - ) - ext = sorting_analyzer.compute( - ComputeValidUnitPeriods.extension_name, + unit_valid_periods_params = dict( method="combined", user_defined_periods=periods, period_mode="absolute", period_duration_s_absolute=1.0, minimum_valid_period_duration=1, ) + + sorting_analyzer = self._prepare_sorting_analyzer( + "memory", sparse=False, extension_class=ComputeValidUnitPeriods, extension_params=unit_valid_periods_params + ) + ext = sorting_analyzer.compute(ComputeValidUnitPeriods.extension_name, **unit_valid_periods_params) # check that valid periods correspond to intersection of auto-computed and user defined periods ext_periods = ext.get_data(outputs="numpy") assert len(ext_periods) <= len(periods) # should be less or equal than user defined ones diff --git a/src/spikeinterface/postprocessing/valid_unit_periods.py b/src/spikeinterface/postprocessing/valid_unit_periods.py index 539d587b5b..41c3aa4303 100644 --- a/src/spikeinterface/postprocessing/valid_unit_periods.py +++ b/src/spikeinterface/postprocessing/valid_unit_periods.py @@ -90,6 +90,16 @@ class ComputeValidUnitPeriods(AnalyzerExtension): use_nodepipeline = False need_job_kwargs = False + @classmethod + def get_required_dependencies(cls, **params): + ext_params = cls.get_default_params() + ext_params.update(params) + method = ext_params.get("method", None) + if method is not None and method in ("false_positives_and_negatives", "combined"): + return ["amplitude_scalings"] + else: + return [] + def _set_params( self, method: str = "false_positives_and_negatives", diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 4295b80c7c..154d6e4ed3 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -75,6 +75,12 @@ def setUpClass(cls): template_metrics=dict(), correlograms=dict(), template_similarity=dict(), + valid_unit_periods=dict( + period_mode="relative", + period_target_num_spikes=200, + relative_margin_size=0.5, + min_num_periods_relative=5, + ), ) job_kwargs = dict(n_jobs=-1) @@ -82,18 +88,6 @@ def setUpClass(cls): cls.sorting_analyzer_dense = create_sorting_analyzer(cls.sorting, cls.recording, format="memory", sparse=False) cls.sorting_analyzer_dense.compute("random_spikes") cls.sorting_analyzer_dense.compute(extensions_to_compute, **job_kwargs) - # compute valid periods later, since it depends on amplitude_scalings - cls.sorting_analyzer_dense.compute( - dict( - valid_unit_periods=dict( - period_mode="relative", - period_target_num_spikes=200, - relative_margin_size=0.5, - min_num_periods_relative=5, - ) - ), - **job_kwargs, - ) sw.set_default_plotter_backend("matplotlib") From 2b27c40e64a0c8b7a9787008ff146f8d335ddc0e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 22 Jan 2026 15:10:54 +0100 Subject: [PATCH 67/70] Force int64 in tests --- .../postprocessing/tests/test_valid_unit_periods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py b/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py index d198f67295..6d34264eac 100644 --- a/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py +++ b/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py @@ -56,7 +56,7 @@ def test_user_defined_periods_as_arrays(self): num_segments = self.sorting.get_num_segments() # unit periods of unit_period_dtypes - periods_array = np.zeros((len(unit_ids) * num_segments, 4), dtype=int) + periods_array = np.zeros((len(unit_ids) * num_segments, 4), dtype="int64") # for each unit we 1 valid period per segment for i, unit_id in enumerate(unit_ids): From a0aad71c5ce4a655e2a42cf8ce78aa146d7ad5cc Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 22 Jan 2026 15:32:18 +0100 Subject: [PATCH 68/70] Add clip_amplitude_scalings arg in plot --- src/spikeinterface/widgets/unit_valid_periods.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/unit_valid_periods.py b/src/spikeinterface/widgets/unit_valid_periods.py index ffe9bc555f..9ea4bbb899 100644 --- a/src/spikeinterface/widgets/unit_valid_periods.py +++ b/src/spikeinterface/widgets/unit_valid_periods.py @@ -29,6 +29,7 @@ def __init__( segment_index: int | None = None, unit_ids: list | None = None, show_only_units_with_valid_periods: bool = True, + clip_amplitude_scalings: float | None = 5.0, backend: str | None = None, **backend_kwargs, ): @@ -57,6 +58,7 @@ def __init__( sorting_analyzer=sorting_analyzer, segment_index=segment_index, unit_ids=valid_unit_ids, + clip_amplitude_scalings=clip_amplitude_scalings, ) BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) @@ -112,7 +114,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): axs[0].axhline(fp_threshold, color="gray", ls="--") axs[1].plot(center_bins_s, fn, ls="", marker="o") axs[1].axhline(fn_threshold, color="gray", ls="--") - axs[2].plot(spiketrain, amp_scalings_by_unit[unit_id], ls="", marker="o", color="gray", alpha=0.5) + amp_scalings_data = amp_scalings_by_unit[unit_id] + if dp.clip_amplitude_scalings is not None: + amp_scalings_data = np.clip(amp_scalings_data, -dp.clip_amplitude_scalings, dp.clip_amplitude_scalings) + axs[2].plot(spiketrain, amp_scalings_data, ls="", marker="o", color="gray", alpha=0.5) + axs[2].axhline(1.0, color="k", ls="--") # plot valid periods valid_period_for_units = good_periods[good_periods["unit_index"] == unit_index] for valid_period in valid_period_for_units: From 4c203d5b0cd46013a3cfea0ecf75a4e997beadaf Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 23 Jan 2026 16:48:48 +0100 Subject: [PATCH 69/70] Fix serialization/deserialization to Zarr --- src/spikeinterface/core/sortinganalyzer.py | 7 +++++-- .../postprocessing/valid_unit_periods.py | 13 ++----------- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 9883885c15..2f6c797ff2 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -2710,9 +2710,12 @@ def _save_data(self): for ext_data_name, ext_data in self.data.items(): if ext_data_name in extension_group: del extension_group[ext_data_name] - if isinstance(ext_data, dict): + if isinstance(ext_data, (dict, list)): + # These could be dicts or lists of dicts. The check_json makes sure + # that everything is json serializable + ext_data_ = check_json(ext_data) extension_group.create_dataset( - name=ext_data_name, data=np.array([ext_data], dtype=object), object_codec=numcodecs.JSON() + name=ext_data_name, data=np.array([ext_data_], dtype=object), object_codec=numcodecs.JSON() ) extension_group[ext_data_name].attrs["dict"] = True elif isinstance(ext_data, np.ndarray): diff --git a/src/spikeinterface/postprocessing/valid_unit_periods.py b/src/spikeinterface/postprocessing/valid_unit_periods.py index 41c3aa4303..78d2dad837 100644 --- a/src/spikeinterface/postprocessing/valid_unit_periods.py +++ b/src/spikeinterface/postprocessing/valid_unit_periods.py @@ -132,6 +132,8 @@ def _set_params( if method in ["false_positives_and_negatives", "combined"]: if not self.sorting_analyzer.has_extension("amplitude_scalings"): raise ValueError("Requires 'amplitude_scalings' extension; please compute it first.") + if not HAVE_NUMBA: + raise ImportError("Numba is required to compute RP violations (false positives).") # subperiods assert period_mode in ("absolute", "relative"), f"Invalid subperiod_size_mode: {period_mode}" @@ -470,17 +472,6 @@ def _compute_valid_periods(self, sorting_analyzer, unit_ids=None, **job_kwargs): valid_mask = duration_samples >= min_valid_period_samples valid_unit_periods = valid_unit_periods[valid_mask] - # # Convert subperiods per unit in period_centers_s - # period_centers = [] - # for segment_index in range(sorting_analyzer.sorting.get_num_segments()): - # periods_segment = all_periods[all_periods["segment_index"] == segment_index] - # period_centers_dict = {} - # for unit_index, unit_id in enumerate(sorting_analyzer.unit_ids): - # periods_unit = periods_segment[periods_segment["unit_index"] == unit_index] - # centers = list(0.5 * (periods_unit["start_sample_index"] + periods_unit["end_sample_index"])) - # period_centers_dict[unit_id] = centers - # period_centers.append(period_centers_dict) - # Store data: here we have to make sure every dict is JSON serializable, so everything is lists return valid_unit_periods, period_centers, fps, fns From d42b6a25f9541025a1cc0acabeb2206feec1e231 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 23 Jan 2026 17:07:48 +0100 Subject: [PATCH 70/70] Fix reloading extension data in zarr --- src/spikeinterface/postprocessing/valid_unit_periods.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/postprocessing/valid_unit_periods.py b/src/spikeinterface/postprocessing/valid_unit_periods.py index 78d2dad837..aa0bc0f771 100644 --- a/src/spikeinterface/postprocessing/valid_unit_periods.py +++ b/src/spikeinterface/postprocessing/valid_unit_periods.py @@ -522,14 +522,16 @@ def _sort_periods(self, periods): def set_data(self, ext_data_name, ext_data): # cast back lists of dicts (required for dumping) back to arrays if ext_data_name in ("period_centers", "periods_fp_per_unit", "periods_fn_per_unit"): - ext_data = [] + ext_data_ = [] # lists of dicts to lists of dicts with arrays for segment_dict in ext_data: segment_dict_arrays = {} for unit_id, values in segment_dict.items(): segment_dict_arrays[unit_id] = np.array(values) - ext_data.append(segment_dict_arrays) - self.data[ext_data_name] = ext_data + ext_data_.append(segment_dict_arrays) + else: + ext_data_ = ext_data + self.data[ext_data_name] = ext_data_ # TODO: deal with margin when returning periods