From 0506f28005c562842ec46f0169fb4206d36c0ea2 Mon Sep 17 00:00:00 2001 From: Erick Cobos Date: Mon, 19 Jan 2026 13:56:09 +0100 Subject: [PATCH 1/8] minor: docstring fixes --- src/spikeinterface/core/node_pipeline.py | 2 +- src/spikeinterface/postprocessing/localization_tools.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 1609f11d17..d197ccc5b2 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -155,7 +155,7 @@ class SpikeRetriever(PeakSource): radius_um : float, default: 50 The radius to find the real max channel. Used only when channel_from_template=False - peak_sign : "neg" | "pos", default: "neg" + peak_sign : "neg" | "pos" | "both", default: "neg" Peak sign to find the max channel. Used only when channel_from_template=False include_spikes_in_margin : bool, default False diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 3f34a106e2..49a65e0bf1 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -46,9 +46,9 @@ def compute_monopolar_triangulation( ---------- sorting_analyzer_or_templates : SortingAnalyzer | Templates A SortingAnalyzer or Templates object - unit_ids: str | int | None + unit_ids : str | int | None A list of unit_id to restrci the computation - method : "least_square" | "minimize_with_log_penality", default: "least_square" + optimizer : "least_square" | "minimize_with_log_penality", default: "least_square" The optimizer to use radius_um : float, default: 75 For channel sparsity From c1b8a6aa5c37bd49aa6aeaa85a64510d172c5146 Mon Sep 17 00:00:00 2001 From: Erick Cobos Date: Mon, 19 Jan 2026 14:11:50 +0100 Subject: [PATCH 2/8] simplify ComputeSpikeLocations to only receive peak_sign --- .../postprocessing/spike_locations.py | 30 ++++--------------- 1 file changed, 5 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index d4e226aa99..54e7d1bf75 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -1,8 +1,5 @@ from __future__ import annotations -import numpy as np - -from spikeinterface.core.job_tools import _shared_job_kwargs_doc from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.node_pipeline import SpikeRetriever @@ -19,18 +16,8 @@ class ComputeSpikeLocations(BaseSpikeVectorExtension): The left window, before a peak, in milliseconds ms_after : float, default: 0.5 The right window, after a peak, in milliseconds - spike_retriver_kwargs : dict - A dictionary to control the behavior for getting the maximum channel for each spike - This dictionary contains: - - * channel_from_template: bool, default: True - For each spike is the maximum channel computed from template or re estimated at every spikes - channel_from_template = True is old behavior but less acurate - channel_from_template = False is slower but more accurate - * radius_um: float, default: 50 - In case channel_from_template=False, this is the radius to get the true peak - * peak_sign, default: "neg" - In case channel_from_template=False, this is the peak sign. + peak_sign: "neg" | "pos" | "both", default: "neg" + The peak sign to consider when looking for the template extremum channel for each spike. method : "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass" The localization method to use method_kwargs : dict, default: dict() @@ -50,21 +37,14 @@ def _set_params( self, ms_before=0.5, ms_after=0.5, - spike_retriver_kwargs=None, + peak_sign="neg", method="center_of_mass", method_kwargs={}, ): - spike_retriver_kwargs_ = dict( - channel_from_template=True, - radius_um=50, - peak_sign="neg", - ) - if spike_retriver_kwargs is not None: - spike_retriver_kwargs_.update(spike_retriver_kwargs) return super()._set_params( ms_before=ms_before, ms_after=ms_after, - spike_retriver_kwargs=spike_retriver_kwargs_, + peak_sign=peak_sign, method=method, method_kwargs=method_kwargs, ) @@ -74,7 +54,7 @@ def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting - peak_sign = self.params["spike_retriver_kwargs"]["peak_sign"] + peak_sign = self.params["peak_sign"] extremum_channels_indices = get_template_extremum_channel( self.sorting_analyzer, peak_sign=peak_sign, outputs="index" ) From 18417673b0f526459e73011d85259413da678aef Mon Sep 17 00:00:00 2001 From: Erick Cobos Date: Mon, 19 Jan 2026 15:26:55 +0100 Subject: [PATCH 3/8] change references to spike_retriver_kwargs (to peak_sign) --- doc/modules/postprocessing.rst | 6 +----- .../benchmark/benchmark_peak_localization.py | 10 +++++----- .../postprocessing/tests/test_spike_locations.py | 3 +-- src/spikeinterface/sortingcomponents/tools.py | 2 +- 4 files changed, 8 insertions(+), 13 deletions(-) diff --git a/doc/modules/postprocessing.rst b/doc/modules/postprocessing.rst index 5442b4728c..b45ff75934 100644 --- a/doc/modules/postprocessing.rst +++ b/doc/modules/postprocessing.rst @@ -285,11 +285,7 @@ with center of mass (:code:`method="center_of_mass"` - fast, but less accurate), input="spike_locations", ms_before=0.5, ms_after=0.5, - spike_retriver_kwargs=dict( - channel_from_template=True, - radius_um=50, - peak_sign="neg" - ), + peak_sign="neg", method="center_of_mass" ) diff --git a/src/spikeinterface/benchmark/benchmark_peak_localization.py b/src/spikeinterface/benchmark/benchmark_peak_localization.py index 546beff6bb..0bcdf96d8c 100644 --- a/src/spikeinterface/benchmark/benchmark_peak_localization.py +++ b/src/spikeinterface/benchmark/benchmark_peak_localization.py @@ -24,11 +24,11 @@ def __init__(self, recording, gt_sorting, params, gt_positions, channel_from_tem self.params[key] = self.params.get(key, 2) self.templates_params[key] = self.params[key] - if not self.channel_from_template: - self.params["spike_retriver_kwargs"] = {"channel_from_template": False} - else: - ## TODO - pass + # if not self.channel_from_template: + # self.params["spike_retriver_kwargs"] = {"channel_from_template": False} + # else: + # ## TODO + # pass def run(self, **job_kwargs): sorting_analyzer = create_sorting_analyzer( diff --git a/src/spikeinterface/postprocessing/tests/test_spike_locations.py b/src/spikeinterface/postprocessing/tests/test_spike_locations.py index 46a39d23ea..26aa221a5a 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_locations.py @@ -8,8 +8,7 @@ class TestSpikeLocationsExtension(AnalyzerExtensionCommonTestSuite): @pytest.mark.parametrize( "params", [ - dict(method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=True)), - dict(method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=False)), + dict(method="center_of_mass", peak_sign='both'), dict(method="center_of_mass"), dict(method="monopolar_triangulation"), dict(method="grid_convolution"), diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 8a2b64b6d2..1d0c54610a 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -556,7 +556,7 @@ def create_sorting_analyzer_with_existing_templates( sa.extensions["spike_locations"].params = dict( ms_before=0.5, ms_after=0.5, - spike_retriver_kwargs=None, + peak_sign="neg", method="center_of_mass", method_kwargs={}, ) From 20a02c7033d6559dbd641abcb4f58aba9943a772 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Jan 2026 14:27:31 +0000 Subject: [PATCH 4/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/postprocessing/tests/test_spike_locations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/tests/test_spike_locations.py b/src/spikeinterface/postprocessing/tests/test_spike_locations.py index 26aa221a5a..5576157213 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_locations.py @@ -8,7 +8,7 @@ class TestSpikeLocationsExtension(AnalyzerExtensionCommonTestSuite): @pytest.mark.parametrize( "params", [ - dict(method="center_of_mass", peak_sign='both'), + dict(method="center_of_mass", peak_sign="both"), dict(method="center_of_mass"), dict(method="monopolar_triangulation"), dict(method="grid_convolution"), From f22fb5d0d7d4b03a5a0aca7b8f2b7ddca9db71c9 Mon Sep 17 00:00:00 2001 From: Erick Cobos Date: Fri, 23 Jan 2026 00:59:27 +0100 Subject: [PATCH 5/8] Revert "change references to spike_retriver_kwargs (to peak_sign)" This reverts commit 18417673b0f526459e73011d85259413da678aef. --- doc/modules/postprocessing.rst | 6 +++++- .../benchmark/benchmark_peak_localization.py | 10 +++++----- .../postprocessing/tests/test_spike_locations.py | 3 ++- src/spikeinterface/sortingcomponents/tools.py | 2 +- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/doc/modules/postprocessing.rst b/doc/modules/postprocessing.rst index b45ff75934..5442b4728c 100644 --- a/doc/modules/postprocessing.rst +++ b/doc/modules/postprocessing.rst @@ -285,7 +285,11 @@ with center of mass (:code:`method="center_of_mass"` - fast, but less accurate), input="spike_locations", ms_before=0.5, ms_after=0.5, - peak_sign="neg", + spike_retriver_kwargs=dict( + channel_from_template=True, + radius_um=50, + peak_sign="neg" + ), method="center_of_mass" ) diff --git a/src/spikeinterface/benchmark/benchmark_peak_localization.py b/src/spikeinterface/benchmark/benchmark_peak_localization.py index 0bcdf96d8c..546beff6bb 100644 --- a/src/spikeinterface/benchmark/benchmark_peak_localization.py +++ b/src/spikeinterface/benchmark/benchmark_peak_localization.py @@ -24,11 +24,11 @@ def __init__(self, recording, gt_sorting, params, gt_positions, channel_from_tem self.params[key] = self.params.get(key, 2) self.templates_params[key] = self.params[key] - # if not self.channel_from_template: - # self.params["spike_retriver_kwargs"] = {"channel_from_template": False} - # else: - # ## TODO - # pass + if not self.channel_from_template: + self.params["spike_retriver_kwargs"] = {"channel_from_template": False} + else: + ## TODO + pass def run(self, **job_kwargs): sorting_analyzer = create_sorting_analyzer( diff --git a/src/spikeinterface/postprocessing/tests/test_spike_locations.py b/src/spikeinterface/postprocessing/tests/test_spike_locations.py index 5576157213..46a39d23ea 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_locations.py @@ -8,7 +8,8 @@ class TestSpikeLocationsExtension(AnalyzerExtensionCommonTestSuite): @pytest.mark.parametrize( "params", [ - dict(method="center_of_mass", peak_sign="both"), + dict(method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=True)), + dict(method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=False)), dict(method="center_of_mass"), dict(method="monopolar_triangulation"), dict(method="grid_convolution"), diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 1d0c54610a..8a2b64b6d2 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -556,7 +556,7 @@ def create_sorting_analyzer_with_existing_templates( sa.extensions["spike_locations"].params = dict( ms_before=0.5, ms_after=0.5, - peak_sign="neg", + spike_retriver_kwargs=None, method="center_of_mass", method_kwargs={}, ) From 3f8fadce6db5ee27998795d1c59f92c6e1218baf Mon Sep 17 00:00:00 2001 From: Erick Cobos Date: Fri, 23 Jan 2026 01:13:14 +0100 Subject: [PATCH 6/8] send spike_retriever_kwargs to the SpikeRetriever --- .../postprocessing/spike_locations.py | 29 +++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 54e7d1bf75..3a14475c7c 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -16,8 +16,18 @@ class ComputeSpikeLocations(BaseSpikeVectorExtension): The left window, before a peak, in milliseconds ms_after : float, default: 0.5 The right window, after a peak, in milliseconds - peak_sign: "neg" | "pos" | "both", default: "neg" - The peak sign to consider when looking for the template extremum channel for each spike. + spike_retriver_kwargs : dict + A dictionary to control the behavior for getting the maximum channel for each spike + This dictionary contains: + + * channel_from_template: bool, default: True + For each spike is the maximum channel computed from template or re estimated at every spikes + channel_from_template = True is old behavior but less acurate + channel_from_template = False is slower but more accurate + * radius_um: float, default: 50 + In case channel_from_template=False, this is the radius to get the true peak + * peak_sign, default: "neg" + In case channel_from_template=False, this is the peak sign. method : "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass" The localization method to use method_kwargs : dict, default: dict() @@ -37,14 +47,21 @@ def _set_params( self, ms_before=0.5, ms_after=0.5, - peak_sign="neg", + spike_retriver_kwargs=None, method="center_of_mass", method_kwargs={}, ): + spike_retriver_kwargs_ = dict( + channel_from_template=True, + radius_um=50, + peak_sign="neg", + ) + if spike_retriver_kwargs is not None: + spike_retriver_kwargs_.update(spike_retriver_kwargs) return super()._set_params( ms_before=ms_before, ms_after=ms_after, - peak_sign=peak_sign, + spike_retriver_kwargs=spike_retriver_kwargs_, method=method, method_kwargs=method_kwargs, ) @@ -54,7 +71,7 @@ def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting - peak_sign = self.params["peak_sign"] + peak_sign = self.params["spike_retriver_kwargs"]["peak_sign"] extremum_channels_indices = get_template_extremum_channel( self.sorting_analyzer, peak_sign=peak_sign, outputs="index" ) @@ -62,8 +79,8 @@ def _get_pipeline_nodes(self): retriever = SpikeRetriever( sorting, recording, - channel_from_template=True, extremum_channel_inds=extremum_channels_indices, + **self.params["spike_retriver_kwargs"], ) nodes = get_localization_pipeline_nodes( recording, From be02fbd3b9400ca5227d1ce9b0110e9b53b512f1 Mon Sep 17 00:00:00 2001 From: Erick Cobos Date: Fri, 23 Jan 2026 13:50:21 +0100 Subject: [PATCH 7/8] add peak_sign and change retriver to retriever in spike_locations --- doc/modules/postprocessing.rst | 6 +-- .../benchmark/benchmark_peak_localization.py | 2 +- .../postprocessing/spike_locations.py | 51 +++++++++---------- .../tests/test_spike_locations.py | 5 +- src/spikeinterface/sortingcomponents/tools.py | 2 +- 5 files changed, 29 insertions(+), 37 deletions(-) diff --git a/doc/modules/postprocessing.rst b/doc/modules/postprocessing.rst index 5442b4728c..b45ff75934 100644 --- a/doc/modules/postprocessing.rst +++ b/doc/modules/postprocessing.rst @@ -285,11 +285,7 @@ with center of mass (:code:`method="center_of_mass"` - fast, but less accurate), input="spike_locations", ms_before=0.5, ms_after=0.5, - spike_retriver_kwargs=dict( - channel_from_template=True, - radius_um=50, - peak_sign="neg" - ), + peak_sign="neg", method="center_of_mass" ) diff --git a/src/spikeinterface/benchmark/benchmark_peak_localization.py b/src/spikeinterface/benchmark/benchmark_peak_localization.py index 546beff6bb..024a88592f 100644 --- a/src/spikeinterface/benchmark/benchmark_peak_localization.py +++ b/src/spikeinterface/benchmark/benchmark_peak_localization.py @@ -25,7 +25,7 @@ def __init__(self, recording, gt_sorting, params, gt_positions, channel_from_tem self.templates_params[key] = self.params[key] if not self.channel_from_template: - self.params["spike_retriver_kwargs"] = {"channel_from_template": False} + self.params["spike_retriever_kwargs"] = {"channel_from_template": False} else: ## TODO pass diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 3a14475c7c..991cc899ac 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -16,18 +16,11 @@ class ComputeSpikeLocations(BaseSpikeVectorExtension): The left window, before a peak, in milliseconds ms_after : float, default: 0.5 The right window, after a peak, in milliseconds - spike_retriver_kwargs : dict - A dictionary to control the behavior for getting the maximum channel for each spike - This dictionary contains: - - * channel_from_template: bool, default: True - For each spike is the maximum channel computed from template or re estimated at every spikes - channel_from_template = True is old behavior but less acurate - channel_from_template = False is slower but more accurate - * radius_um: float, default: 50 - In case channel_from_template=False, this is the radius to get the true peak - * peak_sign, default: "neg" - In case channel_from_template=False, this is the peak sign. + peak_sign : "neg" | "pos" | "both", default: "neg" + The peak sign to use when looking for the template extremum channel. + spike_retriever_kwargs : dict + Arguments to control the spike retriever behavior. See + `spikeinterface.sortingcomponents.peak_localization.SpikeRetriever`. method : "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass" The localization method to use method_kwargs : dict, default: dict() @@ -42,26 +35,30 @@ class ComputeSpikeLocations(BaseSpikeVectorExtension): extension_name = "spike_locations" depend_on = ["templates"] nodepipeline_variables = ["spike_locations"] + need_backward_compatibility_on_load = True + + def _handle_backward_compatibility_on_load(self): + # For backwards compatibility - this renames spike_retriver_kwargs to spike_retriever_kwargs + if "spike_retriver_kwargs" in self.params: + self.params['peak_sign'] = self.params['spike_retriver_kwargs'].get('peak_sign', 'neg') + self.params["spike_retriever_kwargs"] = self.params.pop("spike_retriver_kwargs") def _set_params( self, ms_before=0.5, ms_after=0.5, - spike_retriver_kwargs=None, + peak_sign="neg", + spike_retriever_kwargs=None, method="center_of_mass", method_kwargs={}, ): - spike_retriver_kwargs_ = dict( - channel_from_template=True, - radius_um=50, - peak_sign="neg", - ) - if spike_retriver_kwargs is not None: - spike_retriver_kwargs_.update(spike_retriver_kwargs) + if spike_retriever_kwargs is None: + spike_retriever_kwargs = {} return super()._set_params( ms_before=ms_before, ms_after=ms_after, - spike_retriver_kwargs=spike_retriver_kwargs_, + peak_sign=peak_sign, + spike_retriever_kwargs=spike_retriever_kwargs, method=method, method_kwargs=method_kwargs, ) @@ -71,17 +68,15 @@ def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting - peak_sign = self.params["spike_retriver_kwargs"]["peak_sign"] + peak_sign = self.params["peak_sign"] extremum_channels_indices = get_template_extremum_channel( self.sorting_analyzer, peak_sign=peak_sign, outputs="index" ) - retriever = SpikeRetriever( - sorting, - recording, - extremum_channel_inds=extremum_channels_indices, - **self.params["spike_retriver_kwargs"], - ) + retriever_kwargs = {"channel_from_template": True, + "extremum_channel_inds": extremum_channels_indices, + **self.params["spike_retriever_kwargs"]} + retriever = SpikeRetriever(sorting, recording, **retriever_kwargs) nodes = get_localization_pipeline_nodes( recording, retriever, diff --git a/src/spikeinterface/postprocessing/tests/test_spike_locations.py b/src/spikeinterface/postprocessing/tests/test_spike_locations.py index 46a39d23ea..f457bd9250 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_locations.py @@ -8,8 +8,9 @@ class TestSpikeLocationsExtension(AnalyzerExtensionCommonTestSuite): @pytest.mark.parametrize( "params", [ - dict(method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=True)), - dict(method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=False)), + dict(method="center_of_mass", peak_sign="both"), + dict(method="center_of_mass", spike_retriever_kwargs=dict(channel_from_template=True)), + dict(method="center_of_mass", spike_retriever_kwargs=dict(channel_from_template=False)), dict(method="center_of_mass"), dict(method="monopolar_triangulation"), dict(method="grid_convolution"), diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 8a2b64b6d2..431818c501 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -556,7 +556,7 @@ def create_sorting_analyzer_with_existing_templates( sa.extensions["spike_locations"].params = dict( ms_before=0.5, ms_after=0.5, - spike_retriver_kwargs=None, + spike_retriever_kwargs=None, method="center_of_mass", method_kwargs={}, ) From e4dc370acc95cd9d932b76b2068372581f094bc7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 23 Jan 2026 12:52:49 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../postprocessing/spike_locations.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 991cc899ac..8accf4de3b 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -16,10 +16,10 @@ class ComputeSpikeLocations(BaseSpikeVectorExtension): The left window, before a peak, in milliseconds ms_after : float, default: 0.5 The right window, after a peak, in milliseconds - peak_sign : "neg" | "pos" | "both", default: "neg" + peak_sign : "neg" | "pos" | "both", default: "neg" The peak sign to use when looking for the template extremum channel. spike_retriever_kwargs : dict - Arguments to control the spike retriever behavior. See + Arguments to control the spike retriever behavior. See `spikeinterface.sortingcomponents.peak_localization.SpikeRetriever`. method : "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass" The localization method to use @@ -40,7 +40,7 @@ class ComputeSpikeLocations(BaseSpikeVectorExtension): def _handle_backward_compatibility_on_load(self): # For backwards compatibility - this renames spike_retriver_kwargs to spike_retriever_kwargs if "spike_retriver_kwargs" in self.params: - self.params['peak_sign'] = self.params['spike_retriver_kwargs'].get('peak_sign', 'neg') + self.params["peak_sign"] = self.params["spike_retriver_kwargs"].get("peak_sign", "neg") self.params["spike_retriever_kwargs"] = self.params.pop("spike_retriver_kwargs") def _set_params( @@ -73,9 +73,11 @@ def _get_pipeline_nodes(self): self.sorting_analyzer, peak_sign=peak_sign, outputs="index" ) - retriever_kwargs = {"channel_from_template": True, - "extremum_channel_inds": extremum_channels_indices, - **self.params["spike_retriever_kwargs"]} + retriever_kwargs = { + "channel_from_template": True, + "extremum_channel_inds": extremum_channels_indices, + **self.params["spike_retriever_kwargs"], + } retriever = SpikeRetriever(sorting, recording, **retriever_kwargs) nodes = get_localization_pipeline_nodes( recording,