Skip to content
6 changes: 1 addition & 5 deletions doc/modules/postprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,7 @@ with center of mass (:code:`method="center_of_mass"` - fast, but less accurate),
input="spike_locations",
ms_before=0.5,
ms_after=0.5,
spike_retriver_kwargs=dict(
channel_from_template=True,
radius_um=50,
peak_sign="neg"
),
peak_sign="neg",
method="center_of_mass"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, recording, gt_sorting, params, gt_positions, channel_from_tem
self.templates_params[key] = self.params[key]

if not self.channel_from_template:
self.params["spike_retriver_kwargs"] = {"channel_from_template": False}
self.params["spike_retriever_kwargs"] = {"channel_from_template": False}
else:
## TODO
pass
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class SpikeRetriever(PeakSource):
radius_um : float, default: 50
The radius to find the real max channel.
Used only when channel_from_template=False
peak_sign : "neg" | "pos", default: "neg"
peak_sign : "neg" | "pos" | "both", default: "neg"
Peak sign to find the max channel.
Used only when channel_from_template=False
include_spikes_in_margin : bool, default False
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/postprocessing/localization_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def compute_monopolar_triangulation(
----------
sorting_analyzer_or_templates : SortingAnalyzer | Templates
A SortingAnalyzer or Templates object
unit_ids: str | int | None
unit_ids : str | int | None
A list of unit_id to restrci the computation
method : "least_square" | "minimize_with_log_penality", default: "least_square"
optimizer : "least_square" | "minimize_with_log_penality", default: "least_square"
The optimizer to use
radius_um : float, default: 75
For channel sparsity
Expand Down
56 changes: 25 additions & 31 deletions src/spikeinterface/postprocessing/spike_locations.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from __future__ import annotations

import numpy as np

from spikeinterface.core.job_tools import _shared_job_kwargs_doc
from spikeinterface.core.sortinganalyzer import register_result_extension
from spikeinterface.core.template_tools import get_template_extremum_channel
from spikeinterface.core.node_pipeline import SpikeRetriever
Expand All @@ -19,18 +16,11 @@ class ComputeSpikeLocations(BaseSpikeVectorExtension):
The left window, before a peak, in milliseconds
ms_after : float, default: 0.5
The right window, after a peak, in milliseconds
spike_retriver_kwargs : dict
A dictionary to control the behavior for getting the maximum channel for each spike
This dictionary contains:

* channel_from_template: bool, default: True
For each spike is the maximum channel computed from template or re estimated at every spikes
channel_from_template = True is old behavior but less acurate
channel_from_template = False is slower but more accurate
* radius_um: float, default: 50
In case channel_from_template=False, this is the radius to get the true peak
* peak_sign, default: "neg"
In case channel_from_template=False, this is the peak sign.
peak_sign : "neg" | "pos" | "both", default: "neg"
The peak sign to use when looking for the template extremum channel.
spike_retriever_kwargs : dict
Arguments to control the spike retriever behavior. See
`spikeinterface.sortingcomponents.peak_localization.SpikeRetriever`.
method : "center_of_mass" | "monopolar_triangulation" | "grid_convolution", default: "center_of_mass"
The localization method to use
method_kwargs : dict, default: dict()
Expand All @@ -45,26 +35,30 @@ class ComputeSpikeLocations(BaseSpikeVectorExtension):
extension_name = "spike_locations"
depend_on = ["templates"]
nodepipeline_variables = ["spike_locations"]
need_backward_compatibility_on_load = True

def _handle_backward_compatibility_on_load(self):
# For backwards compatibility - this renames spike_retriver_kwargs to spike_retriever_kwargs
if "spike_retriver_kwargs" in self.params:
self.params["peak_sign"] = self.params["spike_retriver_kwargs"].get("peak_sign", "neg")
self.params["spike_retriever_kwargs"] = self.params.pop("spike_retriver_kwargs")

def _set_params(
self,
ms_before=0.5,
ms_after=0.5,
spike_retriver_kwargs=None,
peak_sign="neg",
spike_retriever_kwargs=None,
method="center_of_mass",
method_kwargs={},
):
spike_retriver_kwargs_ = dict(
channel_from_template=True,
radius_um=50,
peak_sign="neg",
)
if spike_retriver_kwargs is not None:
spike_retriver_kwargs_.update(spike_retriver_kwargs)
if spike_retriever_kwargs is None:
spike_retriever_kwargs = {}
return super()._set_params(
ms_before=ms_before,
ms_after=ms_after,
spike_retriver_kwargs=spike_retriver_kwargs_,
peak_sign=peak_sign,
spike_retriever_kwargs=spike_retriever_kwargs,
method=method,
method_kwargs=method_kwargs,
)
Expand All @@ -74,17 +68,17 @@ def _get_pipeline_nodes(self):

recording = self.sorting_analyzer.recording
sorting = self.sorting_analyzer.sorting
peak_sign = self.params["spike_retriver_kwargs"]["peak_sign"]
peak_sign = self.params["peak_sign"]
extremum_channels_indices = get_template_extremum_channel(
self.sorting_analyzer, peak_sign=peak_sign, outputs="index"
)

retriever = SpikeRetriever(
sorting,
recording,
channel_from_template=True,
extremum_channel_inds=extremum_channels_indices,
)
retriever_kwargs = {
"channel_from_template": True,
"extremum_channel_inds": extremum_channels_indices,
**self.params["spike_retriever_kwargs"],
}
retriever = SpikeRetriever(sorting, recording, **retriever_kwargs)
nodes = get_localization_pipeline_nodes(
recording,
retriever,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ class TestSpikeLocationsExtension(AnalyzerExtensionCommonTestSuite):
@pytest.mark.parametrize(
"params",
[
dict(method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=True)),
dict(method="center_of_mass", spike_retriver_kwargs=dict(channel_from_template=False)),
dict(method="center_of_mass", peak_sign="both"),
dict(method="center_of_mass", spike_retriever_kwargs=dict(channel_from_template=True)),
dict(method="center_of_mass", spike_retriever_kwargs=dict(channel_from_template=False)),
dict(method="center_of_mass"),
dict(method="monopolar_triangulation"),
dict(method="grid_convolution"),
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sortingcomponents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def create_sorting_analyzer_with_existing_templates(
sa.extensions["spike_locations"].params = dict(
ms_before=0.5,
ms_after=0.5,
spike_retriver_kwargs=None,
spike_retriever_kwargs=None,
method="center_of_mass",
method_kwargs={},
)
Expand Down