diff --git a/doc/modules/postprocessing.rst b/doc/modules/postprocessing.rst index 5442b4728c..b45ff75934 100644 --- a/doc/modules/postprocessing.rst +++ b/doc/modules/postprocessing.rst @@ -285,11 +285,7 @@ with center of mass (:code:`method="center_of_mass"` - fast, but less accurate), input="spike_locations", ms_before=0.5, ms_after=0.5, - spike_retriver_kwargs=dict( - channel_from_template=True, - radius_um=50, - peak_sign="neg" - ), + peak_sign="neg", method="center_of_mass" ) diff --git a/src/spikeinterface/benchmark/benchmark_peak_localization.py b/src/spikeinterface/benchmark/benchmark_peak_localization.py index 546beff6bb..024a88592f 100644 --- a/src/spikeinterface/benchmark/benchmark_peak_localization.py +++ b/src/spikeinterface/benchmark/benchmark_peak_localization.py @@ -25,7 +25,7 @@ def __init__(self, recording, gt_sorting, params, gt_positions, channel_from_tem self.templates_params[key] = self.params[key] if not self.channel_from_template: - self.params["spike_retriver_kwargs"] = {"channel_from_template": False} + self.params["spike_retriever_kwargs"] = {"channel_from_template": False} else: ## TODO pass diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 1609f11d17..d197ccc5b2 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -155,7 +155,7 @@ class SpikeRetriever(PeakSource): radius_um : float, default: 50 The radius to find the real max channel. Used only when channel_from_template=False - peak_sign : "neg" | "pos", default: "neg" + peak_sign : "neg" | "pos" | "both", default: "neg" Peak sign to find the max channel. Used only when channel_from_template=False include_spikes_in_margin : bool, default False diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 3f34a106e2..49a65e0bf1 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -46,9 +46,9 @@ def compute_monopolar_triangulation( ---------- sorting_analyzer_or_templates : SortingAnalyzer | Templates A SortingAnalyzer or Templates object - unit_ids: str | int | None + unit_ids : str | int | None A list of unit_id to restrci the computation - method : "least_square" | "minimize_with_log_penality", default: "least_square" + optimizer : "least_square" | "minimize_with_log_penality", default: "least_square" The optimizer to use radius_um : float, default: 75 For channel sparsity diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index d4e226aa99..8accf4de3b 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -1,8 +1,5 @@ from __future__ import annotations -import numpy as np - -from spikeinterface.core.job_tools import _shared_job_kwargs_doc from spikeinterface.core.sortinganalyzer import register_result_extension from spikeinterface.core.template_tools import get_template_extremum_channel from spikeinterface.core.node_pipeline import SpikeRetriever @@ -19,18 +16,11 @@ class ComputeSpikeLocations(BaseSpikeVectorExtension): The left window, before a peak, in milliseconds ms_after : float, default: 0.5 The right window, after a peak, in milliseconds - spike_retriver_kwargs : dict - A dictionary to control the behavior for getting the maximum channel for each spike - This dictionary contains: - - * channel_from_template: bool, default: True - For each spike is the maximum channel computed from template or re estimated at every spikes - channel_from_template = True is old behavior but less acurate - channel_from_template = False is slower but more accurate - * radius_um: float, default: 50 - In case channel_from_template=False, this is the radius to get the true peak - * peak_sign, default: "neg" - In case channel_from_template=False, this is the peak sign. + peak_sign : "neg" | "pos" | "both", default: "neg" + The peak sign to use when looking for the template extremum channel. + spike_retriever_kwargs : dict + Arguments to control the spike retriever behavior. See + `spikeinterface.sortingcomponents.peak_localization.SpikeRetriever`. method : "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass" The localization method to use method_kwargs : dict, default: dict() @@ -45,26 +35,30 @@ class ComputeSpikeLocations(BaseSpikeVectorExtension): extension_name = "spike_locations" depend_on = ["templates"] nodepipeline_variables = ["spike_locations"] + need_backward_compatibility_on_load = True + + def _handle_backward_compatibility_on_load(self): + # For backwards compatibility - this renames spike_retriver_kwargs to spike_retriever_kwargs + if "spike_retriver_kwargs" in self.params: + self.params["peak_sign"] = self.params["spike_retriver_kwargs"].get("peak_sign", "neg") + self.params["spike_retriever_kwargs"] = self.params.pop("spike_retriver_kwargs") def _set_params( self, ms_before=0.5, ms_after=0.5, - spike_retriver_kwargs=None, + peak_sign="neg", + spike_retriever_kwargs=None, method="center_of_mass", method_kwargs={}, ): - spike_retriver_kwargs_ = dict( - channel_from_template=True, - radius_um=50, - peak_sign="neg", - ) - if spike_retriver_kwargs is not None: - spike_retriver_kwargs_.update(spike_retriver_kwargs) + if spike_retriever_kwargs is None: + spike_retriever_kwargs = {} return super()._set_params( ms_before=ms_before, ms_after=ms_after, - spike_retriver_kwargs=spike_retriver_kwargs_, + peak_sign=peak_sign, + spike_retriever_kwargs=spike_retriever_kwargs, method=method, method_kwargs=method_kwargs, ) @@ -74,17 +68,17 @@ def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording sorting = self.sorting_analyzer.sorting - peak_sign = self.params["spike_retriver_kwargs"]["peak_sign"] + peak_sign = self.params["peak_sign"] extremum_channels_indices = get_template_extremum_channel( self.sorting_analyzer, peak_sign=peak_sign, outputs="index" ) - retriever = SpikeRetriever( - sorting, - recording, - channel_from_template=True, - extremum_channel_inds=extremum_channels_indices, - ) + retriever_kwargs = { + "channel_from_template": True, + "extremum_channel_inds": extremum_channels_indices, + **self.params["spike_retriever_kwargs"], + } + retriever = SpikeRetriever(sorting, recording, **retriever_kwargs) nodes = get_localization_pipeline_nodes( recording, retriever, diff --git a/src/spikeinterface/postprocessing/tests/test_spike_locations.py b/src/spikeinterface/postprocessing/tests/test_spike_locations.py index 46a39d23ea..f457bd9250 100644 --- a/src/spikeinterface/postprocessing/tests/test_spike_locations.py +++ b/src/spikeinterface/postprocessing/tests/test_spike_locations.py @@ -8,8 +8,9 @@ class TestSpikeLocationsExtension(AnalyzerExtensionCommonTestSuite): @pytest.mark.parametrize( "params", [ - dict(method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=True)), - dict(method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=False)), + dict(method="center_of_mass", peak_sign="both"), + dict(method="center_of_mass", spike_retriever_kwargs=dict(channel_from_template=True)), + dict(method="center_of_mass", spike_retriever_kwargs=dict(channel_from_template=False)), dict(method="center_of_mass"), dict(method="monopolar_triangulation"), dict(method="grid_convolution"), diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 8a2b64b6d2..431818c501 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -556,7 +556,7 @@ def create_sorting_analyzer_with_existing_templates( sa.extensions["spike_locations"].params = dict( ms_before=0.5, ms_after=0.5, - spike_retriver_kwargs=None, + spike_retriever_kwargs=None, method="center_of_mass", method_kwargs={}, )