Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/spikeinterface/preprocessing/motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
method="locally_exclusive",
peak_sign="neg",
detect_threshold=8.0,
exclude_sweep_ms=0.1,
exclude_sweep_ms=0.8,
radius_um=75.0,
),
"select_kwargs": dict(),
Expand All @@ -139,7 +139,7 @@
method="locally_exclusive",
peak_sign="neg",
detect_threshold=8.0,
exclude_sweep_ms=0.1,
exclude_sweep_ms=0.8,
radius_um=50,
),
"select_kwargs": dict(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def run_peaks(recording, job_kwargs):
method_kwargs=dict(
peak_sign="neg",
detect_threshold=5,
exclude_sweep_ms=0.1,
exclude_sweep_ms=0.8,
noise_levels=noise_levels,
),
job_kwargs=job_kwargs,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sortingcomponents/matching/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def __init__(
return_output=True,
templates=None,
peak_sign="neg",
exclude_sweep_ms=0.1,
exclude_sweep_ms=0.8,
jitter_ms=0.1,
detect_threshold=5,
noise_levels=None,
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/sortingcomponents/matching/nearest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
templates,
return_output=True,
peak_sign="neg",
exclude_sweep_ms=0.1,
exclude_sweep_ms=0.8,
detect_threshold=5,
noise_levels=None,
detection_radius_um=100.0,
Expand Down Expand Up @@ -158,7 +158,7 @@ def __init__(
svd_model,
return_output=True,
peak_sign="neg",
exclude_sweep_ms=0.1,
exclude_sweep_ms=0.8,
detect_threshold=5,
noise_levels=None,
detection_radius_um=100.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
templates,
return_output=True,
peak_sign="neg",
exclude_sweep_ms=0.5,
exclude_sweep_ms=0.8,
peak_shift_ms=0.2,
detect_threshold=5,
noise_levels=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def setup_dataset_and_peaks(cache_folder):
noise_levels=get_noise_levels(recording, return_in_uV=False),
peak_sign="neg",
detect_threshold=5,
exclude_sweep_ms=0.1,
exclude_sweep_ms=1.0,
return_output=True,
)
extract_dense_waveforms = ExtractDenseWaveforms(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ class ByChannelPeakDetector(PeakDetector):
Sign of the peak
detect_threshold: float, default: 5
Threshold, in median absolute deviations (MAD), to use to detect peaks
exclude_sweep_ms: float, default: 0.1
exclude_sweep_ms: float, default: 1.0
Time, in ms, during which the peak is isolated. Exclusive param with exclude_sweep_size
For example, if `exclude_sweep_ms` is 0.1, a peak is detected if a sample crosses the threshold,
and no larger peaks are located during the 0.1ms preceding and following the peak
For example, if `exclude_sweep_ms` is 1.0, a peak is detected if a sample crosses the threshold,
and no larger peaks are located during the 1.0ms preceding and following the peak
noise_levels: array or None, default: None
Estimated noise levels to use, if already computed
If not provide then it is estimated from a random snippet of the data
Expand All @@ -42,7 +42,7 @@ def __init__(
recording,
peak_sign="neg",
detect_threshold=5,
exclude_sweep_ms=0.1,
exclude_sweep_ms=1.0,
noise_levels=None,
return_output=True,
):
Expand Down Expand Up @@ -116,10 +116,10 @@ class ByChannelTorchPeakDetector(ByChannelPeakDetector):
Sign of the peak
detect_threshold: float, default: 5
Threshold, in median absolute deviations (MAD), to use to detect peaks
exclude_sweep_ms: float, default: 0.1
exclude_sweep_ms: float, default: 1.0
Time, in ms, during which the peak is isolated. Exclusive param with exclude_sweep_size
For example, if `exclude_sweep_ms` is 0.1, a peak is detected if a sample crosses the threshold,
and no larger peaks are located during the 0.1ms preceding and following the peak
For example, if `exclude_sweep_ms` is 1.0, a peak is detected if a sample crosses the threshold,
and no larger peaks are located during the 1.0ms preceding and following the peak
noise_levels: array or None, default: None
Estimated noise levels to use, if already computed.
If not provide then it is estimated from a random snippet of the data
Expand All @@ -134,7 +134,7 @@ def __init__(
recording,
peak_sign="neg",
detect_threshold=5,
exclude_sweep_ms=0.1,
exclude_sweep_ms=1.0,
noise_levels=None,
device=None,
return_tensor=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
recording,
peak_sign="neg",
detect_threshold=5,
exclude_sweep_ms=0.1,
exclude_sweep_ms=1.0,
radius_um=50,
noise_levels=None,
return_output=True,
Expand Down Expand Up @@ -81,7 +81,8 @@ def __init__(
self.neighbours_mask = self.channel_distance <= radius_um

def get_trace_margin(self):
return self.exclude_sweep_size
# the +1 in the border is important because we need peak in the border
return self.exclude_sweep_size + 1

def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
assert HAVE_NUMBA, "You need to install numba"
Expand All @@ -104,88 +105,71 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
if HAVE_NUMBA:
import numba

@numba.jit(nopython=True, parallel=False, nogil=True)
def detect_peaks_numba_locally_exclusive_on_chunk(
traces, peak_sign, abs_thresholds, exclude_sweep_size, neighbours_mask
traces, peak_sign, abs_thresholds, exclude_sweep_size, neighbours_mask,
):
num_chans = traces.shape[1]
num_samples = traces.shape[0]


do_pos = peak_sign in ("pos", "both")
do_neg = peak_sign in ("neg", "both")

# first find peaks
peak_mask = np.zeros(traces.shape, dtype="bool")
for s in range(1, num_samples - 1):
for chan_ind in range(num_chans):
if do_neg:
if (traces[s, chan_ind] <= -abs_thresholds[chan_ind]) and \
(traces[s, chan_ind] < traces[s-1, chan_ind]) and \
(traces[s, chan_ind] <= traces[s+1, chan_ind]):
peak_mask[s, chan_ind] = True

if do_pos :
if (traces[s, chan_ind] >= abs_thresholds[chan_ind]) and \
(traces[s, chan_ind] > traces[s-1, chan_ind]) and \
(traces[s, chan_ind] >= traces[s+1, chan_ind]):
peak_mask[s, chan_ind] = True

samples_inds, chan_inds = np.nonzero(peak_mask)

npeaks = samples_inds.size
keep_peak = np.ones(npeaks, dtype="bool")
next_start = 0
for i in range(npeaks):

if (samples_inds[i] < exclude_sweep_size + 1) or (samples_inds[i]>= (num_samples - exclude_sweep_size - 1)):
keep_peak[i] = False
continue

for j in range(next_start, npeaks):
if i == j:
continue

# if medians is not None:
# traces = traces - medians

traces_center = traces[exclude_sweep_size:-exclude_sweep_size, :]

if peak_sign in ("pos", "both"):
peak_mask = traces_center > abs_thresholds[None, :]
peak_mask = _numba_detect_peak_pos(
traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask
)

if peak_sign in ("neg", "both"):
if peak_sign == "both":
peak_mask_pos = peak_mask.copy()
if samples_inds[i] + exclude_sweep_size < samples_inds[j]:
break

peak_mask = traces_center < -abs_thresholds[None, :]
peak_mask = _numba_detect_peak_neg(
traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask
)
if samples_inds[i] - exclude_sweep_size > samples_inds[j]:
next_start = j
continue

#search for neighbors
if neighbours_mask[chan_inds[i], chan_inds[j]]:
# if inside spatial zone
if abs(samples_inds[i] - samples_inds[j]) <= exclude_sweep_size:
value_i = abs(traces[samples_inds[i], chan_inds[i]]) / abs_thresholds[chan_inds[i]]
value_j = abs(traces[samples_inds[j], chan_inds[j]]) / abs_thresholds[chan_inds[j]]

if ((value_j >= value_i) & (samples_inds[i] > samples_inds[j])) | ((value_j > value_i) & (samples_inds[i] <= samples_inds[j])):
keep_peak[i] = False
break

if peak_sign == "both":
peak_mask = peak_mask | peak_mask_pos
samples_inds, chan_inds = samples_inds[keep_peak], chan_inds[keep_peak]

# Find peaks and correct for time shift
peak_sample_ind, peak_chan_ind = np.nonzero(peak_mask)
peak_sample_ind += exclude_sweep_size
return samples_inds, chan_inds

return peak_sample_ind, peak_chan_ind

@numba.jit(nopython=True, parallel=False)
def _numba_detect_peak_pos(
traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask
):
num_chans = traces_center.shape[1]
for chan_ind in range(num_chans):
for s in range(peak_mask.shape[0]):
if not peak_mask[s, chan_ind]:
continue
for neighbour in range(num_chans):
if not neighbours_mask[chan_ind, neighbour]:
continue
for i in range(exclude_sweep_size):
if chan_ind != neighbour:
peak_mask[s, chan_ind] &= traces_center[s, chan_ind] >= traces_center[s, neighbour]
peak_mask[s, chan_ind] &= traces_center[s, chan_ind] > traces[s + i, neighbour]
peak_mask[s, chan_ind] &= (
traces_center[s, chan_ind] >= traces[exclude_sweep_size + s + i + 1, neighbour]
)
if not peak_mask[s, chan_ind]:
break
if not peak_mask[s, chan_ind]:
break
return peak_mask

@numba.jit(nopython=True, parallel=False)
def _numba_detect_peak_neg(
traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask
):
num_chans = traces_center.shape[1]
for chan_ind in range(num_chans):
for s in range(peak_mask.shape[0]):
if not peak_mask[s, chan_ind]:
continue
for neighbour in range(num_chans):
if not neighbours_mask[chan_ind, neighbour]:
continue
for i in range(exclude_sweep_size):
if chan_ind != neighbour:
peak_mask[s, chan_ind] &= traces_center[s, chan_ind] <= traces_center[s, neighbour]
peak_mask[s, chan_ind] &= traces_center[s, chan_ind] < traces[s + i, neighbour]
peak_mask[s, chan_ind] &= (
traces_center[s, chan_ind] <= traces[exclude_sweep_size + s + i + 1, neighbour]
)
if not peak_mask[s, chan_ind]:
break
if not peak_mask[s, chan_ind]:
break
return peak_mask


class LocallyExclusiveTorchPeakDetector(ByChannelTorchPeakDetector):
Expand All @@ -205,7 +189,7 @@ def __init__(
recording,
peak_sign="neg",
detect_threshold=5,
exclude_sweep_ms=0.1,
exclude_sweep_ms=1.0,
noise_levels=None,
device=None,
radius_um=50,
Expand Down Expand Up @@ -275,7 +259,7 @@ def __init__(
recording,
peak_sign="neg",
detect_threshold=5,
exclude_sweep_ms=0.1,
exclude_sweep_ms=1.0,
radius_um=50,
noise_levels=None,
opencl_context_kwargs={},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
ms_before,
peak_sign="neg",
detect_threshold=5,
exclude_sweep_ms=0.1,
exclude_sweep_ms=1.0,
radius_um=50,
random_chunk_kwargs={"num_chunks_per_segment": 5},
weight_method={},
Expand Down
Loading
Loading