diff --git a/src/spikeinterface/metrics/template/metrics.py b/src/spikeinterface/metrics/template/metrics.py index a1af1de348..db1037c925 100644 --- a/src/spikeinterface/metrics/template/metrics.py +++ b/src/spikeinterface/metrics/template/metrics.py @@ -6,7 +6,7 @@ from spikeinterface.core.analyzer_extension_core import BaseMetric -def get_trough_and_peak_idx(template): +def get_trough_and_peak_idx(template, peak_sign="neg"): """ Return the indices into the input template of the detected trough (minimum of template) and peak (maximum of template, after trough). @@ -25,6 +25,15 @@ def get_trough_and_peak_idx(template): The index of the peak """ assert template.ndim == 1 + + # If peak_sign is 'pos', invert the template + if peak_sign == "pos": + template = -template + elif peak_sign == "both": + max_idx = np.abs(template).argmax() + if template[max_idx] > 0: + template = -template + trough_idx = np.argmin(template) peak_idx = trough_idx + np.argmax(template[trough_idx:]) return trough_idx, peak_idx @@ -107,6 +116,7 @@ def get_half_width(template_single, sampling_frequency, trough_idx=None, peak_id if trough_idx is None or peak_idx is None: trough_idx, peak_idx = get_trough_and_peak_idx(template_single) + # Edge case: template is flat if peak_idx == 0: return np.nan @@ -114,19 +124,19 @@ def get_half_width(template_single, sampling_frequency, trough_idx=None, peak_id # threshold is half of peak height (assuming baseline is 0) threshold = 0.5 * trough_val - (cpre_idx,) = np.where(template_single[:trough_idx] < threshold) - (cpost_idx,) = np.where(template_single[trough_idx:] < threshold) + # Find where the template crosses the threshold before and after the trough + threshold_crossings = np.where(np.diff(template_single >= threshold))[0] + crossings_before_trough = threshold_crossings[threshold_crossings < trough_idx] + crossings_after_trough = threshold_crossings[threshold_crossings >= trough_idx] - if len(cpre_idx) == 0 or len(cpost_idx) == 0: + if len(crossings_before_trough) == 0 or len(crossings_after_trough) == 0: hw = np.nan - else: - # last occurence of template lower than thr, before peak - cross_pre_pk = cpre_idx[0] - 1 - # first occurence of template lower than peak, after peak - cross_post_pk = cpost_idx[-1] + 1 + trough_idx + last_crossing_before_trough = crossings_before_trough[-1] + first_crossing_after_trough = crossings_after_trough[0] + + hw = (first_crossing_after_trough - last_crossing_before_trough) / sampling_frequency - hw = (cross_post_pk - cross_pre_pk) / sampling_frequency return hw @@ -163,11 +173,12 @@ def get_repolarization_slope(template_single, sampling_frequency, trough_idx=Non if trough_idx == 0: return np.nan - (rtrn_idx,) = np.nonzero(template_single[trough_idx:] >= 0) - if len(rtrn_idx) == 0: + # Find where the template crosses the baseline (0) after the trough + baseline_crossings = np.where(np.diff(template_single[trough_idx:] >= 0))[0] + if len(baseline_crossings) == 0: return np.nan # first time after trough, where template is at baseline - return_to_base_idx = rtrn_idx[0] + trough_idx + return_to_base_idx = baseline_crossings[0] + trough_idx + 1 if return_to_base_idx - trough_idx < 3: return np.nan @@ -218,6 +229,9 @@ def get_recovery_slope(template_single, sampling_frequency, peak_idx=None, **kwa max_idx = int(peak_idx + ((recovery_window_ms / 1000) * sampling_frequency)) max_idx = np.min([max_idx, template_single.shape[0]]) + if max_idx - peak_idx < 3: + return np.nan + res = scipy.stats.linregress(times[peak_idx:max_idx], template_single[peak_idx:max_idx]) return res.slope @@ -315,6 +329,7 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs) The sampling frequency of the template **kwargs: Required kwargs: - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") + - peak_sign: whether expected peaks are negative, positive, or both ("neg", "pos", "both") - min_channels: the minimum number of channels above or below to compute velocity - min_r2: the minimum r2 to accept the velocity fit - column_range: the range in um in the x-direction to consider channels for velocity @@ -327,11 +342,13 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs) The velocity below the max channel """ assert "depth_direction" in kwargs, "depth_direction must be given as kwarg" + assert "peak_sign" in kwargs, "peak_sign must be given as kwarg" assert "min_channels" in kwargs, "min_channels must be given as kwarg" assert "min_r2" in kwargs, "min_r2 must be given as kwarg" assert "column_range" in kwargs, "column_range must be given as kwarg" depth_direction = kwargs["depth_direction"] + peak_sign = kwargs["peak_sign"] min_channels_for_velocity = kwargs["min_channels"] min_r2 = kwargs["min_r2"] column_range = kwargs["column_range"] @@ -340,6 +357,14 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs) template, channel_locations = transform_column_range(template, channel_locations, column_range, depth_direction) template, channel_locations = sort_template_and_locations(template, channel_locations, depth_direction) + # If peak_sign is 'pos', invert the template + if peak_sign == "pos": + template = -template + elif peak_sign == "both": + peak_value = template.flat[np.abs(template).argmax()] + if peak_value > 0: + template = -template + # find location of max channel max_sample_idx, max_channel_idx = np.unravel_index(np.argmin(template), template.shape) max_peak_time = max_sample_idx / sampling_frequency * 1000 @@ -454,9 +479,10 @@ def get_spread(template, channel_locations, sampling_frequency, **kwargs) -> flo sampling_frequency : float The sampling frequency of the template **kwargs: Required kwargs: - - depth_direction: the direction to compute velocity above and below ("x", "y", or "z") - - spread_threshold: the threshold to compute the spread - - column_range: the range in um in the x-direction to consider channels for velocity + - depth_direction: the direction to compute spread ("x", "y", or "z") + - spread_threshold: the threshold (0-1) to compute spread + - spread_smooth_um: the smoothing in um to apply to the amplitude profile before computing spread + - column_range: the range in um in the x-direction to consider channels for spread Returns ------- @@ -666,6 +692,7 @@ def _get_velocity_fits_metric_function(sorting_analyzer, unit_ids, tmp_data, **m channel_locations_multi = tmp_data["channel_locations_multi"] sampling_frequency = tmp_data["sampling_frequency"] metric_params["depth_direction"] = tmp_data["depth_direction"] + metric_params["peak_sign"] = tmp_data["peak_sign"] for unit_index, unit_id in enumerate(unit_ids): channel_locations = channel_locations_multi[unit_index] template = templates_multi[unit_index] @@ -678,11 +705,7 @@ def _get_velocity_fits_metric_function(sorting_analyzer, unit_ids, tmp_data, **m class VelocityFits(BaseMetric): metric_name = "velocity_fits" metric_function = _get_velocity_fits_metric_function - metric_params = { - "min_channels": 3, - "min_r2": 0.2, - "column_range": None, - } + metric_params = {"min_channels": 3, "min_r2": 0.2, "column_range": None} metric_columns = {"velocity_above": float, "velocity_below": float} metric_descriptions = { "velocity_above": "Velocity of the spike propagation above the max channel in um/ms", diff --git a/src/spikeinterface/metrics/template/template_metrics.py b/src/spikeinterface/metrics/template/template_metrics.py index 85ef9e22cb..fcdf4ec2ce 100644 --- a/src/spikeinterface/metrics/template/template_metrics.py +++ b/src/spikeinterface/metrics/template/template_metrics.py @@ -70,7 +70,7 @@ class ComputeTemplateMetrics(BaseMetricExtension): metric_params : dict of dicts or None, default: None Dictionary with parameters for template metrics calculation. Default parameters can be obtained with: `si.metrics.template_metrics.get_default_template_metrics_params()` - peak_sign : {"neg", "pos"}, default: "neg" + peak_sign : {"neg", "pos", "both"}, default: "neg" Whether to use the positive ("pos") or negative ("neg") peaks to estimate extremum channels. upsampling_factor : int, default: 10 The upsampling factor to upsample the templates @@ -209,8 +209,7 @@ def _prepare_data(self, sorting_analyzer, unit_ids): template_upsampled = resample_poly(template_single, up=upsampling_factor, down=1) else: template_upsampled = template_single - sampling_frequency_up = sampling_frequency - trough_idx, peak_idx = get_trough_and_peak_idx(template_upsampled) + trough_idx, peak_idx = get_trough_and_peak_idx(template_upsampled, peak_sign=peak_sign) templates_single.append(template_upsampled) troughs[unit_id] = trough_idx @@ -246,6 +245,7 @@ def _prepare_data(self, sorting_analyzer, unit_ids): tmp_data["templates_multi"] = templates_multi tmp_data["channel_locations_multi"] = channel_locations_multi tmp_data["depth_direction"] = self.params["depth_direction"] + tmp_data["peak_sign"] = self.params["peak_sign"] return tmp_data