diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index f59c849649..53882b05b8 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -121,7 +121,7 @@ method="locally_exclusive", peak_sign="neg", detect_threshold=8.0, - exclude_sweep_ms=0.1, + exclude_sweep_ms=0.8, radius_um=75.0, ), "select_kwargs": dict(), @@ -139,7 +139,7 @@ method="locally_exclusive", peak_sign="neg", detect_threshold=8.0, - exclude_sweep_ms=0.1, + exclude_sweep_ms=0.8, radius_um=50, ), "select_kwargs": dict(), diff --git a/src/spikeinterface/sortingcomponents/clustering/tests/test_clustering.py b/src/spikeinterface/sortingcomponents/clustering/tests/test_clustering.py index 290ba9f244..12ceb95e58 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tests/test_clustering.py +++ b/src/spikeinterface/sortingcomponents/clustering/tests/test_clustering.py @@ -42,7 +42,7 @@ def run_peaks(recording, job_kwargs): method_kwargs=dict( peak_sign="neg", detect_threshold=5, - exclude_sweep_ms=0.1, + exclude_sweep_ms=0.8, noise_levels=noise_levels, ), job_kwargs=job_kwargs, diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 71f0bd02d6..ea6ce11f3e 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -626,7 +626,7 @@ def __init__( return_output=True, templates=None, peak_sign="neg", - exclude_sweep_ms=0.1, + exclude_sweep_ms=0.8, jitter_ms=0.1, detect_threshold=5, noise_levels=None, diff --git a/src/spikeinterface/sortingcomponents/matching/nearest.py b/src/spikeinterface/sortingcomponents/matching/nearest.py index 7942e3db4e..44389cc503 100644 --- a/src/spikeinterface/sortingcomponents/matching/nearest.py +++ b/src/spikeinterface/sortingcomponents/matching/nearest.py @@ -33,7 +33,7 @@ def __init__( templates, return_output=True, peak_sign="neg", - exclude_sweep_ms=0.1, + exclude_sweep_ms=0.8, detect_threshold=5, noise_levels=None, detection_radius_um=100.0, @@ -158,7 +158,7 @@ def __init__( svd_model, return_output=True, peak_sign="neg", - exclude_sweep_ms=0.1, + exclude_sweep_ms=0.8, detect_threshold=5, noise_levels=None, detection_radius_um=100.0, diff --git a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py index 1171ebdf81..d3ae787a4b 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc_peeler.py @@ -95,7 +95,7 @@ def __init__( templates, return_output=True, peak_sign="neg", - exclude_sweep_ms=0.5, + exclude_sweep_ms=0.8, peak_shift_ms=0.2, detect_threshold=5, noise_levels=None, diff --git a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_estimation.py b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_estimation.py index 7422a0e91c..89f97c9a72 100644 --- a/src/spikeinterface/sortingcomponents/motion/tests/test_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion/tests/test_motion_estimation.py @@ -30,7 +30,7 @@ def setup_dataset_and_peaks(cache_folder): noise_levels=get_noise_levels(recording, return_in_uV=False), peak_sign="neg", detect_threshold=5, - exclude_sweep_ms=0.1, + exclude_sweep_ms=1.0, return_output=True, ) extract_dense_waveforms = ExtractDenseWaveforms( diff --git a/src/spikeinterface/sortingcomponents/peak_detection/by_channel.py b/src/spikeinterface/sortingcomponents/peak_detection/by_channel.py index 258d548449..732ada21bc 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection/by_channel.py +++ b/src/spikeinterface/sortingcomponents/peak_detection/by_channel.py @@ -28,10 +28,10 @@ class ByChannelPeakDetector(PeakDetector): Sign of the peak detect_threshold: float, default: 5 Threshold, in median absolute deviations (MAD), to use to detect peaks - exclude_sweep_ms: float, default: 0.1 + exclude_sweep_ms: float, default: 1.0 Time, in ms, during which the peak is isolated. Exclusive param with exclude_sweep_size - For example, if `exclude_sweep_ms` is 0.1, a peak is detected if a sample crosses the threshold, - and no larger peaks are located during the 0.1ms preceding and following the peak + For example, if `exclude_sweep_ms` is 1.0, a peak is detected if a sample crosses the threshold, + and no larger peaks are located during the 1.0ms preceding and following the peak noise_levels: array or None, default: None Estimated noise levels to use, if already computed If not provide then it is estimated from a random snippet of the data @@ -42,7 +42,7 @@ def __init__( recording, peak_sign="neg", detect_threshold=5, - exclude_sweep_ms=0.1, + exclude_sweep_ms=1.0, noise_levels=None, return_output=True, ): @@ -116,10 +116,10 @@ class ByChannelTorchPeakDetector(ByChannelPeakDetector): Sign of the peak detect_threshold: float, default: 5 Threshold, in median absolute deviations (MAD), to use to detect peaks - exclude_sweep_ms: float, default: 0.1 + exclude_sweep_ms: float, default: 1.0 Time, in ms, during which the peak is isolated. Exclusive param with exclude_sweep_size - For example, if `exclude_sweep_ms` is 0.1, a peak is detected if a sample crosses the threshold, - and no larger peaks are located during the 0.1ms preceding and following the peak + For example, if `exclude_sweep_ms` is 1.0, a peak is detected if a sample crosses the threshold, + and no larger peaks are located during the 1.0ms preceding and following the peak noise_levels: array or None, default: None Estimated noise levels to use, if already computed. If not provide then it is estimated from a random snippet of the data @@ -134,7 +134,7 @@ def __init__( recording, peak_sign="neg", detect_threshold=5, - exclude_sweep_ms=0.1, + exclude_sweep_ms=1.0, noise_levels=None, device=None, return_tensor=False, diff --git a/src/spikeinterface/sortingcomponents/peak_detection/locally_exclusive.py b/src/spikeinterface/sortingcomponents/peak_detection/locally_exclusive.py index 0cf5e56d0c..151722bb94 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection/locally_exclusive.py +++ b/src/spikeinterface/sortingcomponents/peak_detection/locally_exclusive.py @@ -49,7 +49,7 @@ def __init__( recording, peak_sign="neg", detect_threshold=5, - exclude_sweep_ms=0.1, + exclude_sweep_ms=1.0, radius_um=50, noise_levels=None, return_output=True, @@ -81,7 +81,8 @@ def __init__( self.neighbours_mask = self.channel_distance <= radius_um def get_trace_margin(self): - return self.exclude_sweep_size + # the +1 in the border is important because we need peak in the border + return self.exclude_sweep_size + 1 def compute(self, traces, start_frame, end_frame, segment_index, max_margin): assert HAVE_NUMBA, "You need to install numba" @@ -104,88 +105,71 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): if HAVE_NUMBA: import numba + @numba.jit(nopython=True, parallel=False, nogil=True) def detect_peaks_numba_locally_exclusive_on_chunk( - traces, peak_sign, abs_thresholds, exclude_sweep_size, neighbours_mask + traces, peak_sign, abs_thresholds, exclude_sweep_size, neighbours_mask, ): + num_chans = traces.shape[1] + num_samples = traces.shape[0] + + + do_pos = peak_sign in ("pos", "both") + do_neg = peak_sign in ("neg", "both") + + # first find peaks + peak_mask = np.zeros(traces.shape, dtype="bool") + for s in range(1, num_samples - 1): + for chan_ind in range(num_chans): + if do_neg: + if (traces[s, chan_ind] <= -abs_thresholds[chan_ind]) and \ + (traces[s, chan_ind] < traces[s-1, chan_ind]) and \ + (traces[s, chan_ind] <= traces[s+1, chan_ind]): + peak_mask[s, chan_ind] = True + + if do_pos : + if (traces[s, chan_ind] >= abs_thresholds[chan_ind]) and \ + (traces[s, chan_ind] > traces[s-1, chan_ind]) and \ + (traces[s, chan_ind] >= traces[s+1, chan_ind]): + peak_mask[s, chan_ind] = True + + samples_inds, chan_inds = np.nonzero(peak_mask) + + npeaks = samples_inds.size + keep_peak = np.ones(npeaks, dtype="bool") + next_start = 0 + for i in range(npeaks): + + if (samples_inds[i] < exclude_sweep_size + 1) or (samples_inds[i]>= (num_samples - exclude_sweep_size - 1)): + keep_peak[i] = False + continue + + for j in range(next_start, npeaks): + if i == j: + continue - # if medians is not None: - # traces = traces - medians - - traces_center = traces[exclude_sweep_size:-exclude_sweep_size, :] - - if peak_sign in ("pos", "both"): - peak_mask = traces_center > abs_thresholds[None, :] - peak_mask = _numba_detect_peak_pos( - traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask - ) - - if peak_sign in ("neg", "both"): - if peak_sign == "both": - peak_mask_pos = peak_mask.copy() + if samples_inds[i] + exclude_sweep_size < samples_inds[j]: + break - peak_mask = traces_center < -abs_thresholds[None, :] - peak_mask = _numba_detect_peak_neg( - traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask - ) + if samples_inds[i] - exclude_sweep_size > samples_inds[j]: + next_start = j + continue + + #search for neighbors + if neighbours_mask[chan_inds[i], chan_inds[j]]: + # if inside spatial zone + if abs(samples_inds[i] - samples_inds[j]) <= exclude_sweep_size: + value_i = abs(traces[samples_inds[i], chan_inds[i]]) / abs_thresholds[chan_inds[i]] + value_j = abs(traces[samples_inds[j], chan_inds[j]]) / abs_thresholds[chan_inds[j]] + + if ((value_j >= value_i) & (samples_inds[i] > samples_inds[j])) | ((value_j > value_i) & (samples_inds[i] <= samples_inds[j])): + keep_peak[i] = False + break - if peak_sign == "both": - peak_mask = peak_mask | peak_mask_pos + samples_inds, chan_inds = samples_inds[keep_peak], chan_inds[keep_peak] - # Find peaks and correct for time shift - peak_sample_ind, peak_chan_ind = np.nonzero(peak_mask) - peak_sample_ind += exclude_sweep_size + return samples_inds, chan_inds - return peak_sample_ind, peak_chan_ind - @numba.jit(nopython=True, parallel=False) - def _numba_detect_peak_pos( - traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask - ): - num_chans = traces_center.shape[1] - for chan_ind in range(num_chans): - for s in range(peak_mask.shape[0]): - if not peak_mask[s, chan_ind]: - continue - for neighbour in range(num_chans): - if not neighbours_mask[chan_ind, neighbour]: - continue - for i in range(exclude_sweep_size): - if chan_ind != neighbour: - peak_mask[s, chan_ind] &= traces_center[s, chan_ind] >= traces_center[s, neighbour] - peak_mask[s, chan_ind] &= traces_center[s, chan_ind] > traces[s + i, neighbour] - peak_mask[s, chan_ind] &= ( - traces_center[s, chan_ind] >= traces[exclude_sweep_size + s + i + 1, neighbour] - ) - if not peak_mask[s, chan_ind]: - break - if not peak_mask[s, chan_ind]: - break - return peak_mask - - @numba.jit(nopython=True, parallel=False) - def _numba_detect_peak_neg( - traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask - ): - num_chans = traces_center.shape[1] - for chan_ind in range(num_chans): - for s in range(peak_mask.shape[0]): - if not peak_mask[s, chan_ind]: - continue - for neighbour in range(num_chans): - if not neighbours_mask[chan_ind, neighbour]: - continue - for i in range(exclude_sweep_size): - if chan_ind != neighbour: - peak_mask[s, chan_ind] &= traces_center[s, chan_ind] <= traces_center[s, neighbour] - peak_mask[s, chan_ind] &= traces_center[s, chan_ind] < traces[s + i, neighbour] - peak_mask[s, chan_ind] &= ( - traces_center[s, chan_ind] <= traces[exclude_sweep_size + s + i + 1, neighbour] - ) - if not peak_mask[s, chan_ind]: - break - if not peak_mask[s, chan_ind]: - break - return peak_mask class LocallyExclusiveTorchPeakDetector(ByChannelTorchPeakDetector): @@ -205,7 +189,7 @@ def __init__( recording, peak_sign="neg", detect_threshold=5, - exclude_sweep_ms=0.1, + exclude_sweep_ms=1.0, noise_levels=None, device=None, radius_um=50, @@ -275,7 +259,7 @@ def __init__( recording, peak_sign="neg", detect_threshold=5, - exclude_sweep_ms=0.1, + exclude_sweep_ms=1.0, radius_um=50, noise_levels=None, opencl_context_kwargs={}, diff --git a/src/spikeinterface/sortingcomponents/peak_detection/matched_filtering.py b/src/spikeinterface/sortingcomponents/peak_detection/matched_filtering.py index 78c436fc0d..9118839fd6 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection/matched_filtering.py +++ b/src/spikeinterface/sortingcomponents/peak_detection/matched_filtering.py @@ -46,7 +46,7 @@ def __init__( ms_before, peak_sign="neg", detect_threshold=5, - exclude_sweep_ms=0.1, + exclude_sweep_ms=1.0, radius_um=50, random_chunk_kwargs={"num_chunks_per_segment": 5}, weight_method={}, diff --git a/src/spikeinterface/sortingcomponents/peak_detection/tests/test_peak_detection.py b/src/spikeinterface/sortingcomponents/peak_detection/tests/test_peak_detection.py index 230518a8c4..af7e0e15df 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection/tests/test_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/peak_detection/tests/test_peak_detection.py @@ -90,7 +90,7 @@ def pca_model_folder_path(recording, job_kwargs, tmp_path): n_components = 3 n_peaks = 100 # Heuristic for extracting around 1k waveforms per channel peak_selection_params = dict(method="uniform", select_per_channel=True, n_peaks=n_peaks) - detect_peaks_params = dict(method="by_channel", peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1) + detect_peaks_params = dict(method="by_channel", peak_sign="neg", detect_threshold=5, exclude_sweep_ms=1.0) TemporalPCADenoising.fit( recording=recording, model_folder_path=model_folder_path, @@ -261,7 +261,7 @@ def test_detect_peaks_by_channel(recording, job_kwargs, torch_job_kwargs): peaks_by_channel_np = detect_peaks( recording, method="by_channel", - method_kwargs=dict(peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1), + method_kwargs=dict(peak_sign="neg", detect_threshold=5, exclude_sweep_ms=1.0), job_kwargs=job_kwargs, ) @@ -272,7 +272,7 @@ def test_detect_peaks_by_channel(recording, job_kwargs, torch_job_kwargs): method_kwargs=dict( peak_sign="neg", detect_threshold=5, - exclude_sweep_ms=0.1, + exclude_sweep_ms=1.0, ), job_kwargs=torch_job_kwargs, ) @@ -285,18 +285,39 @@ def test_detect_peaks_locally_exclusive(recording, job_kwargs, torch_job_kwargs) peaks_by_channel_np = detect_peaks( recording, method="by_channel", - method_kwargs=dict(peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1), + method_kwargs=dict(peak_sign="neg", detect_threshold=5, exclude_sweep_ms=1.0), job_kwargs=job_kwargs, ) peaks_local_numba = detect_peaks( recording, method="locally_exclusive", - method_kwargs=dict(peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1), + method_kwargs=dict(peak_sign="neg", detect_threshold=5, exclude_sweep_ms=1.0), job_kwargs=job_kwargs, ) assert len(peaks_by_channel_np) > len(peaks_local_numba) + DEBUG = True + if DEBUG: + import matplotlib.pyplot as plt + + peaks = peaks_local_numba + labels = ["locally_exclusive numba", ] + + fig, ax = plt.subplots() + chan_offset = 500 + traces = recording.get_traces().copy() + traces += np.arange(traces.shape[1])[None, :] * chan_offset + ax.plot(traces, color="k") + + for count, peaks in enumerate([peaks_local_numba, ]): + sample_inds, chan_inds, amplitudes = peaks["sample_index"], peaks["channel_index"], peaks["amplitude"] + ax.scatter(sample_inds, chan_inds * chan_offset + amplitudes, label=labels[count]) + + ax.legend() + plt.show() + + if HAVE_TORCH: peaks_local_torch = detect_peaks( recording, @@ -304,7 +325,7 @@ def test_detect_peaks_locally_exclusive(recording, job_kwargs, torch_job_kwargs) method_kwargs=dict( peak_sign="neg", detect_threshold=5, - exclude_sweep_ms=0.1, + exclude_sweep_ms=1.0, ), job_kwargs=torch_job_kwargs, ) @@ -317,7 +338,7 @@ def test_detect_peaks_locally_exclusive(recording, job_kwargs, torch_job_kwargs) method_kwargs=dict( peak_sign="neg", detect_threshold=5, - exclude_sweep_ms=0.1, + exclude_sweep_ms=1.0, ), job_kwargs=job_kwargs, ) @@ -328,7 +349,7 @@ def test_detect_peaks_locally_exclusive_matched_filtering(recording, job_kwargs) peaks_by_channel_np = detect_peaks( recording, method="locally_exclusive", - method_kwargs=dict(peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1), + method_kwargs=dict(peak_sign="neg", detect_threshold=5, exclude_sweep_ms=1.0), job_kwargs=job_kwargs, ) @@ -344,7 +365,7 @@ def test_detect_peaks_locally_exclusive_matched_filtering(recording, job_kwargs) method_kwargs=dict( peak_sign="neg", detect_threshold=5, - exclude_sweep_ms=0.1, + exclude_sweep_ms=1.0, prototype=prototype, ms_before=1.0, ), @@ -359,7 +380,7 @@ def test_detect_peaks_locally_exclusive_matched_filtering(recording, job_kwargs) method_kwargs=dict( peak_sign="both", detect_threshold=5, - exclude_sweep_ms=0.1, + exclude_sweep_ms=1.0, prototype=prototype, ms_before=1.0, ), @@ -429,133 +450,6 @@ def test_peak_sign_consistency(recording, job_kwargs, detection_class): assert all_peaks.size > 0 -# def test_peak_detection_with_pipeline(recording, job_kwargs, torch_job_kwargs, tmp_path): -# extract_dense_waveforms = ExtractDenseWaveforms(recording, ms_before=1.0, ms_after=1.0, return_output=False) - -# pipeline_nodes = [ -# extract_dense_waveforms, -# PeakToPeakFeature(recording, all_channels=False, parents=[extract_dense_waveforms]), -# LocalizeCenterOfMass(recording, radius_um=50.0, parents=[extract_dense_waveforms]), -# ] -# peaks, ptp, peak_locations = detect_peaks( -# recording, -# method="locally_exclusive", -# method_kwargs=dict( -# peak_sign="neg", -# detect_threshold=5, -# exclude_sweep_ms=0.1, -# ), -# pipeline_nodes=pipeline_nodes, -# job_kwargs=job_kwargs, -# ) -# assert peaks.shape[0] == ptp.shape[0] -# assert peaks.shape[0] == peak_locations.shape[0] -# assert "x" in peak_locations.dtype.fields - -# # same pipeline but saved to npy -# folder = tmp_path / "peak_detection_folder" -# if folder.is_dir(): -# shutil.rmtree(folder) -# peaks2, ptp2, peak_locations2 = detect_peaks( -# recording, -# method="locally_exclusive", -# method_kwargs=dict( -# peak_sign="neg", -# detect_threshold=5, -# exclude_sweep_ms=0.1, -# ), -# pipeline_nodes=pipeline_nodes, -# gather_mode="npy", -# folder=folder, -# names=["peaks", "ptps", "peak_locations"], -# job_kwargs=job_kwargs, -# ) -# peak_file = folder / "peaks.npy" -# assert peak_file.is_file() -# peaks3 = np.load(peak_file) -# assert np.array_equal(peaks, peaks2) -# assert np.array_equal(peaks2, peaks3) - -# ptp_file = folder / "ptps.npy" -# assert ptp_file.is_file() -# ptp3 = np.load(ptp_file) -# assert np.array_equal(ptp, ptp2) -# assert np.array_equal(ptp2, ptp3) - -# peak_location_file = folder / "peak_locations.npy" -# assert peak_location_file.is_file() -# peak_locations3 = np.load(peak_location_file) -# assert np.array_equal(peak_locations, peak_locations2) -# assert np.array_equal(peak_locations2, peak_locations3) - -# if HAVE_TORCH: -# peaks_torch, ptp_torch, peak_locations_torch = detect_peaks( -# recording, -# method="locally_exclusive_torch", -# method_kwargs=dict( -# peak_sign="neg", -# detect_threshold=5, -# exclude_sweep_ms=0.1, -# ), -# pipeline_nodes=pipeline_nodes, -# job_kwargs=torch_job_kwargs, -# ) -# assert peaks_torch.shape[0] == ptp_torch.shape[0] -# assert peaks_torch.shape[0] == peak_locations_torch.shape[0] -# assert "x" in peak_locations_torch.dtype.fields - -# if HAVE_PYOPENCL: -# peaks_cl, ptp_cl, peak_locations_cl = detect_peaks( -# recording, -# method="locally_exclusive_cl", -# method_kwargs=dict( -# peak_sign="neg", -# detect_threshold=5, -# exclude_sweep_ms=0.1, -# ), -# pipeline_nodes=pipeline_nodes, -# job_kwargs=job_kwargs, -# ) -# assert peaks_cl.shape[0] == ptp_cl.shape[0] -# assert peaks_cl.shape[0] == peak_locations_cl.shape[0] -# assert "x" in peak_locations_cl.dtype.fields - -# # DEBUG -# DEBUG = False -# if DEBUG: -# import matplotlib.pyplot as plt -# import spikeinterface.widgets as sw -# from probeinterface.plotting import plot_probe - -# sample_inds, chan_inds, amplitudes = peaks["sample_index"], peaks["channel_index"], peaks["amplitude"] -# chan_offset = 500 -# traces = recording.get_traces() -# traces += np.arange(traces.shape[1])[None, :] * chan_offset -# fig, ax = plt.subplots() -# ax.plot(traces, color="k") -# ax.scatter(sample_inds, chan_inds * chan_offset + amplitudes, color="r") -# plt.show() - -# fig, ax = plt.subplots() -# probe = recording.get_probe() -# plot_probe(probe, ax=ax) -# ax.scatter(peak_locations["x"], peak_locations["y"], color="k", s=1, alpha=0.5) -# # MEArec is "yz" in 2D -# # import MEArec - -# # recgen = MEArec.load_recordings( -# # recordings=local_path, -# # return_h5_objects=True, -# # check_suffix=False, -# # load=["recordings", "spiketrains", "channel_positions"], -# # load_waveforms=False, -# # ) -# # soma_positions = np.zeros((len(recgen.spiketrains), 3), dtype="float32") -# # for i, st in enumerate(recgen.spiketrains): -# # soma_positions[i, :] = st.annotations["soma_position"] -# # ax.scatter(soma_positions[:, 1], soma_positions[:, 2], color="g", s=20, marker="*") -# plt.show() - if __name__ == "__main__": recording, sorting = make_dataset() @@ -568,12 +462,14 @@ def test_peak_sign_consistency(recording, job_kwargs, detection_class): pca_model_folder_path_main = pca_model_folder_path(recording, job_kwargs_main, tmp_dir_main) peak_detector_kwargs_main = peak_detector_kwargs(recording) - test_iterative_peak_detection(recording, job_kwargs_main, pca_model_folder_path_main, peak_detector_kwargs_main) + # test_iterative_peak_detection(recording, job_kwargs_main, pca_model_folder_path_main, peak_detector_kwargs_main) - test_peak_sign_consistency(recording, torch_job_kwargs_main, LocallyExclusiveTorchPeakDetector) + # test_peak_sign_consistency(recording, torch_job_kwargs_main, LocallyExclusiveTorchPeakDetector) # test_peak_detection_with_pipeline(recording, job_kwargs_main, torch_job_kwargs_main, tmp_path) # test_detect_peaks_locally_exclusive_matched_filtering( # recording, # job_kwargs_main, # ) + + test_detect_peaks_locally_exclusive(recording, job_kwargs_main, torch_job_kwargs_main) diff --git a/src/spikeinterface/sortingcomponents/peak_localization/tests/test_peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization/tests/test_peak_localization.py index a04b21f4aa..9db0588f6d 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization/tests/test_peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization/tests/test_peak_localization.py @@ -16,7 +16,7 @@ def test_localize_peaks(): peaks = detect_peaks( recording, method="locally_exclusive", - method_kwargs=dict(peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1), + method_kwargs=dict(peak_sign="neg", detect_threshold=5, exclude_sweep_ms=1.0), job_kwargs=job_kwargs, ) diff --git a/src/spikeinterface/sortingcomponents/tests/test_peak_selection.py b/src/spikeinterface/sortingcomponents/tests/test_peak_selection.py index ecf8773f76..bb0da231d9 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_peak_selection.py +++ b/src/spikeinterface/sortingcomponents/tests/test_peak_selection.py @@ -22,7 +22,7 @@ def test_select_peaks(): method_kwargs=dict( peak_sign="neg", detect_threshold=5, - exclude_sweep_ms=0.1, + exclude_sweep_ms=0.8, noise_levels=noise_levels, ), job_kwargs=dict( diff --git a/src/spikeinterface/sortingcomponents/waveforms/tests/test_temporal_pca.py b/src/spikeinterface/sortingcomponents/waveforms/tests/test_temporal_pca.py index e52ace9e26..286c103d22 100644 --- a/src/spikeinterface/sortingcomponents/waveforms/tests/test_temporal_pca.py +++ b/src/spikeinterface/sortingcomponents/waveforms/tests/test_temporal_pca.py @@ -38,7 +38,7 @@ def model_path_of_trained_pca(folder_to_save_pca_model, generated_recording, chu n_components = 3 n_peaks = 100 # Heuristic for extracting around 1k waveforms per channel peak_selection_params = dict(method="uniform", select_per_channel=True, n_peaks=n_peaks) - detect_peaks_params = dict(method="by_channel", peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1) + detect_peaks_params = dict(method="by_channel", peak_sign="neg", detect_threshold=5) TemporalPCAProjection.fit( recording=recording, model_folder_path=model_folder_path,