diff --git a/doc/index.rst b/doc/index.rst index ce4053ca43..659efe85a8 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -30,7 +30,7 @@ SpikeInterface is made of several modules to deal with different aspects of the - visualize recordings and spike sorting outputs in several ways (matplotlib, sortingview, jupyter, ephyviewer) - export a report and/or export to phy - curate your sorting with several strategies (ml-based, metrics based, manual, ...) -- offer a powerful Qt-based or we-based viewer in a separate package `spikeinterface-gui `_ for manual curation that replace phy. +- offer a powerful desktop or web viewer in a separate package `spikeinterface-gui `_ for manual curation that replace phy. - have powerful sorting components to build your own sorter. - have a full motion/drift correction framework (See :ref:`motion_correction`) diff --git a/doc/modules/core.rst b/doc/modules/core.rst index 078db82201..0faed14fba 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -93,10 +93,11 @@ with 16 channels: timestamps = np.arange(num_samples) / sampling_frequency + 300 recording.set_times(times=timestamps, segment_index=0) -**Note**: -Raw data formats often store data as integer values for memory efficiency. To give these integers meaningful physical units (uV), you can apply a gain and an offset. -Many devices have their own gains and offsets necessary to convert their data and these values are handled by SpikeInterface for its extractors. This -is triggered by the :code:`return_in_uV` parameter in :code:`get_traces()`, (see above example), which will return the traces in uV. Read more in our how to guide, :ref:`physical_units`. +.. note:: + + Raw data formats often store data as integer values for memory efficiency. To give these integers meaningful physical units (uV), you can apply a gain and an offset. + Many devices have their own gains and offsets necessary to convert their data and these values are handled by SpikeInterface for its extractors. This + is triggered by the :code:`return_in_uV` parameter in :code:`get_traces()`, (see above example), which will return the traces in uV. Read more in our how to guide, :ref:`physical_units`. Sorting @@ -180,8 +181,9 @@ a numpy.array with dtype `[("sample_index", "int64"), ("unit_index", "int64"), ( For computations which are done unit-by-unit, like computing isi-violations per unit, it is better that spikes from a single unit are concurrent in memory. For these other cases, we can re-order the `spike_vector` in different ways: - * order by unit, then segment, then sample - * order by segment, then unit, then sample + +* order by unit, then segment, then sample +* order by segment, then unit, then sample This is done using `sorting.to_reordered_spike_vector()`. The first time a reordering is done, the reordered spiketrain is cached in memory by default. Users should rarely have to worry about these @@ -458,9 +460,11 @@ It represents unsorted waveform cutouts. Some acquisition systems, in fact, allo threshold and only record the times at which a peak was detected and the waveform cut out around the peak. -**NOTE**: while we support this class (mainly for legacy formats), this approach is a bad practice -and is highly discouraged! Most modern spike sorters, in fact, require the raw traces to perform -template matching to recover spikes! +.. note:: + + While we support this class (mainly for legacy formats), this approach is a bad practice + and is highly discouraged! Most modern spike sorters, in fact, require the raw traces to perform + template matching to recover spikes! Here we assume :code:`snippets` is a :py:class:`~spikeinterface.core.BaseSnippets` object with 16 channels: @@ -548,9 +552,11 @@ Sparsity is defined as the subset of channels on which waveforms (and related in sparsity is not global, but it is unit-specific. Importantly, saving sparse waveforms, especially for high-density probes, dramatically reduces the size of the waveforms extension if computed. -**NOTE** As of :code:`0.101.0` all :code:`SortingAnalyzer`'s have a default of :code:`sparse=True`. This was first -introduced in :code:`0.99.0` for :code:`WaveformExtractor`'s and will be the default going forward. To obtain dense -waveforms you will need to set :code:`sparse=False` at the creation of the :code:`SortingAnalyzer`. +.. note:: + + As of :code:`0.101.0` all :code:`SortingAnalyzer`'s have a default of :code:`sparse=True`. This was first + introduced in :code:`0.99.0` for :code:`WaveformExtractor`'s and will be the default going forward. To obtain dense + waveforms you will need to set :code:`sparse=False` at the creation of the :code:`SortingAnalyzer`. Sparsity can be computed from a :py:class:`~spikeinterface.core.SortingAnalyzer` object with the @@ -854,10 +860,12 @@ The same functions are also available for :py:func:`~spikeinterface.core.select_segment_sorting`). -**Note** :py:func:`~spikeinterface.core.append_recordings` and:py:func:`~spikeinterface.core.concatenate_recordings` -have the same goal, aggregate recording pieces on the time axis but with 2 different strategies! One is keeping the -multi segments concept, the other one is breaking it! -See this example for more detail :ref:`example_segments`. +.. note:: + + :py:func:`~spikeinterface.core.append_recordings` and:py:func:`~spikeinterface.core.concatenate_recordings` + have the same goal, aggregate recording pieces on the time axis but with 2 different strategies! One is keeping the + multi segments concept, the other one is breaking it! + See this example for more detail :ref:`example_segments`. diff --git a/doc/modules/exporters.rst b/doc/modules/exporters.rst index 4eeaa23b81..ab25e8803d 100644 --- a/doc/modules/exporters.rst +++ b/doc/modules/exporters.rst @@ -12,9 +12,11 @@ and behavioral data. It can be used to decode behavior, make tuning curves, comp The :py:func:`~spikeinterface.exporters.to_pynapple_tsgroup` function allows you to convert a SortingAnalyzer to Pynapple's ``TsGroup`` object on the fly. -**Note** : When creating the ``TsGroup``, we will use the underlying time support of the SortingAnalyzer. -How this works depends on your acquisition system. You can use the ``get_times`` method on a recording -(``my_recording.get_times()``) to find the time support of your recording. +.. note:: + + When creating the ``TsGroup``, we will use the underlying time support of the SortingAnalyzer. + How this works depends on your acquisition system. You can use the ``get_times`` method on a recording + (``my_recording.get_times()``) to find the time support of your recording. When constructed, if ``attach_unit_metadata`` is set to ``True``, any relevant unit information is propagated to the ``TsGroup``. The ``to_pynapple_tsgroup`` checks if unit locations, quality @@ -54,13 +56,15 @@ The :py:func:`~spikeinterface.exporters.export_to_phy` function allows you to us `Phy template GUI `_ for visual inspection and manual curation of spike sorting results. -**Note** : :py:func:`~spikeinterface.exporters.export_to_phy` speed and the size of the folder will highly depend -on the sparsity of the :code:`SortingAnalyzer` itself or the external specified sparsity. -The Phy viewer enables one to explore PCA projections, spike amplitudes, waveforms and quality of spike sorting results. -So if these pieces of information have already been computed as extensions (see :ref:`modules/postprocessing:Extensions as AnalyzerExtensions`), -then exporting to Phy should be fast (and the user has better control of the parameters for the extensions). -If not pre-computed, then the required extensions (e.g., :code:`spike_amplitudes`, :code:`principal_components`) -can be computed directly at export time. +.. note:: + + :py:func:`~spikeinterface.exporters.export_to_phy` speed and the size of the folder will highly depend + on the sparsity of the :code:`SortingAnalyzer` itself or the external specified sparsity. + The Phy viewer enables one to explore PCA projections, spike amplitudes, waveforms and quality of spike sorting results. + So if these pieces of information have already been computed as extensions (see :ref:`modules/postprocessing:Extensions as AnalyzerExtensions`), + then exporting to Phy should be fast (and the user has better control of the parameters for the extensions). + If not pre-computed, then the required extensions (e.g., :code:`spike_amplitudes`, :code:`principal_components`) + can be computed directly at export time. The input of the :py:func:`~spikeinterface.exporters.export_to_phy` is a :code:`SortingAnalyzer` object. @@ -131,12 +135,14 @@ The report includes summary figures of the spike sorting output (e.g. amplitude depth VS amplitude) as well as unit-specific reports, that include waveforms, templates, template maps, ISI distributions, and more. -**Note** : similarly to :py:func:`~spikeinterface.exporters.export_to_phy` the -:py:func:`~spikeinterface.exporters.export_report` depends on the sparsity of the :code:`SortingAnalyzer` itself and -on which extensions have been computed. For example, :code:`spike_amplitudes` and :code:`correlograms` related plots -will be automatically included in the report if the associated extensions are computed in advance. -The function can perform these computations as well, but it is a better practice to compute everything that's needed -beforehand. +.. note:: + + Similarly to :py:func:`~spikeinterface.exporters.export_to_phy` the + :py:func:`~spikeinterface.exporters.export_report` depends on the sparsity of the :code:`SortingAnalyzer` itself and + on which extensions have been computed. For example, :code:`spike_amplitudes` and :code:`correlograms` related plots + will be automatically included in the report if the associated extensions are computed in advance. + The function can perform these computations as well, but it is a better practice to compute everything that's needed + beforehand. Note that every unit will generate a summary unit figure, so the export process can be slow for spike sorting outputs with many units! diff --git a/doc/modules/postprocessing.rst b/doc/modules/postprocessing.rst index 5442b4728c..5c4e29b359 100644 --- a/doc/modules/postprocessing.rst +++ b/doc/modules/postprocessing.rst @@ -163,8 +163,10 @@ Extensions are generally saved in two ways, suitable for two workflows: :code:`sorting_analyzer.compute('waveforms', save=False)`). -**NOTE**: We recommend choosing a workflow and sticking with it. Either keep everything on disk or keep everything in memory until -you'd like to save. A mixture can lead to unexpected behavior. For example, consider the following code +.. note:: + + We recommend choosing a workflow and sticking with it. Either keep everything on disk or keep everything in memory until + you'd like to save. A mixture can lead to unexpected behavior. For example, consider the following code .. code:: @@ -257,15 +259,35 @@ spike_amplitudes This extension computes the amplitude of each spike as the value of the traces on the extremum channel at the times of each spike. The extremum channel is computed from the templates. + **NOTE:** computing spike amplitudes is highly recommended before calculating amplitude-based quality metrics, such as :ref:`amp_cutoff` and :ref:`amp_median`. .. code-block:: python - amplitudes = sorting_analyzer.compute(input="spike_amplitudes", peak_sign="neg") + amplitudes = sorting_analyzer.compute(input="spike_amplitudes") For more information, see :py:func:`~spikeinterface.postprocessing.compute_spike_amplitudes` + +.. _postprocessing_amplitude_scalings: + +amplitude_scalings +^^^^^^^^^^^^^^^^^^ + +This extension computes the amplitude scaling of each spike as the value of the linear fit between the template and the +spike waveform. In case of spatio-temporal collisions, a multi-linear fit is performed using the templates of all units +involved in the collision. + +**NOTE:** computing amplitude scalings is highly recommended before calculating amplitude-based quality metrics, such as +:ref:`amp_cutoff` and :ref:`amp_median`. + +.. code-block:: python + + amplitude_scalings = sorting_analyzer.compute(input="amplitude_scalings") + +For more information, see :py:func:`~spikeinterface.postprocessing.compute_amplitude_scalings` + .. _postprocessing_spike_locations: spike_locations @@ -367,7 +389,29 @@ This extension computes the histograms of inter-spike-intervals. The computed ou method="auto" ) -For more information, see :py:func:`~spikeinterface.postprocessing.compute_isi_histograms` +valid_unit_periods +^^^^^^^^^^^^^^^^^^ + +This extension computes the valid unit periods for each unit based on the estimation of false positive rates +(using RP violation - see ::doc:`metrics/qualitymetrics/isi_violations`) and false negative rates +(using amplitude cutoff - see ::doc:`metrics/qualitymetrics/amplitude_cutoff`) computed over chunks of the recording. +The valid unit periods are the periods where both false positive and false negative rates are below specified +thresholds. Periods can be either absolute (in seconds), same for all units, or relative, where +chunks will be unit-specific depending on firing rate (with a target number of spikes per chunk). + +.. code-block:: python + + valid_periods = sorting_analyzer.compute( + input="valid_unit_periods", + period_mode='relative', + target_num_spikes=300, + fp_threshold=0.1, + fn_threshold=0.1, + ) + +For more information, see :py:func:`~spikeinterface.postprocessing.compute_valid_unit_periods`. + + Other postprocessing tools diff --git a/doc/modules/sortingcomponents.rst b/doc/modules/sortingcomponents.rst index e76cf3f99d..5549fd0317 100644 --- a/doc/modules/sortingcomponents.rst +++ b/doc/modules/sortingcomponents.rst @@ -81,7 +81,9 @@ Other variants are also implemented (but less tested or not so useful): * **'by_channel_torch'** (requires :code:`torch`): pytorch implementation (GPU-compatible) that uses max pooling for time deduplication * **'locally_exclusive_torch'** (requires :code:`torch`): pytorch implementation (GPU-compatible) that uses max pooling for space-time deduplication -**NOTE**: the torch implementations give slightly different results due to a different implementation. +.. note:: + + The torch implementations give slightly different results due to a different implementation. Peak detection, as many of the other sorting components, can be run in parallel. @@ -274,7 +276,7 @@ handle drift can benefit from drift estimation/correction. Especially for acute Neuropixels-like probes, this is a crucial step. The motion estimation step comes after peak detection and peak localization. Read more about -it in the :ref:`_motion_correction` modules doc, and a more practical guide in the +it in the :ref:`motion_correction` modules doc, and a more practical guide in the :ref:`handle-drift-in-your-recording` How To. Here is an example with non-rigid motion estimation: diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index e078f71ed4..2f6c797ff2 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -2471,6 +2471,18 @@ def get_any_dependencies(cls, **params): all_dependencies = list(chain.from_iterable([dep.split("|") for dep in all_dependencies])) return all_dependencies + @classmethod + def get_default_params(cls): + """ + Get the default params for the extension. + + Returns + ------- + default_params : dict + The default parameters for the extension. + """ + return get_default_analyzer_extension_params(cls.extension_name) + def load_run_info(self): run_info = None if self.format == "binary_folder": @@ -2698,10 +2710,14 @@ def _save_data(self): for ext_data_name, ext_data in self.data.items(): if ext_data_name in extension_group: del extension_group[ext_data_name] - if isinstance(ext_data, dict): + if isinstance(ext_data, (dict, list)): + # These could be dicts or lists of dicts. The check_json makes sure + # that everything is json serializable + ext_data_ = check_json(ext_data) extension_group.create_dataset( - name=ext_data_name, data=np.array([ext_data], dtype=object), object_codec=numcodecs.JSON() + name=ext_data_name, data=np.array([ext_data_], dtype=object), object_codec=numcodecs.JSON() ) + extension_group[ext_data_name].attrs["dict"] = True elif isinstance(ext_data, np.ndarray): extension_group.create_dataset(name=ext_data_name, data=ext_data, **saving_options) elif HAS_PANDAS and isinstance(ext_data, pd.DataFrame): @@ -2884,6 +2900,7 @@ def set_data(self, ext_data_name, ext_data): "spike_locations": "spikeinterface.postprocessing", "template_similarity": "spikeinterface.postprocessing", "unit_locations": "spikeinterface.postprocessing", + "valid_unit_periods": "spikeinterface.postprocessing", # from metrics "quality_metrics": "spikeinterface.metrics", "template_metrics": "spikeinterface.metrics", diff --git a/src/spikeinterface/metrics/quality/__init__.py b/src/spikeinterface/metrics/quality/__init__.py index 1edcd9221f..f91ed6eefc 100644 --- a/src/spikeinterface/metrics/quality/__init__.py +++ b/src/spikeinterface/metrics/quality/__init__.py @@ -20,4 +20,5 @@ compute_sliding_rp_violations, compute_sd_ratio, compute_synchrony_metrics, + compute_refrac_period_violations, ) diff --git a/src/spikeinterface/metrics/quality/quality_metrics.py b/src/spikeinterface/metrics/quality/quality_metrics.py index e5cc2aa323..5476aa405a 100644 --- a/src/spikeinterface/metrics/quality/quality_metrics.py +++ b/src/spikeinterface/metrics/quality/quality_metrics.py @@ -49,6 +49,13 @@ class ComputeQualityMetrics(BaseMetricExtension): need_backward_compatibility_on_load = True metric_list = misc_metrics_list + pca_metrics_list + @classmethod + def get_required_dependencies(cls, **params): + if params.get("use_valid_periods", False): + return ["valid_unit_periods"] + else: + return [] + def _handle_backward_compatibility_on_load(self): # For backwards compatibility - this renames qm_params as metric_params if (qm_params := self.params.get("qm_params")) is not None: @@ -70,6 +77,7 @@ def _set_params( metric_params: dict | None = None, delete_existing_metrics: bool = False, metrics_to_compute: list[str] | None = None, + use_valid_periods=False, periods=None, # common extension kwargs peak_sign=None, @@ -86,6 +94,11 @@ def _set_params( pc_metric_names = [m.metric_name for m in pca_metrics_list] metric_names = [m for m in metric_names if m not in pc_metric_names] + if use_valid_periods: + if periods is not None: + raise ValueError("If use_valid_periods is True, periods should not be provided.") + periods = self.sorting_analyzer.get_extension("valid_unit_periods").get_data(outputs="numpy") + return super()._set_params( metric_names=metric_names, metric_params=metric_params, diff --git a/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py b/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py index 2e87002018..dfd47c4df9 100644 --- a/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/metrics/quality/tests/test_quality_metric_calculator.py @@ -1,5 +1,4 @@ import pytest -from pathlib import Path import numpy as np from spikeinterface.core import ( @@ -172,6 +171,76 @@ def test_empty_units(sorting_analyzer_simple): assert all_nans or all_zeros or all_neg_ones, f"Column {col} failed the empty unit test" +def test_quality_metrics_with_periods(): + """ + Test that quality metrics can be computed using valid unit periods. + """ + from spikeinterface.core.base import unit_period_dtype + + recording, sorting = generate_ground_truth_recording() + sorting_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, format="memory") + + # compute dependencies + sorting_analyzer.compute(["random_spikes", "templates", "amplitude_scalings", "valid_unit_periods"], **job_kwargs) + print(sorting_analyzer) + + # compute quality metrics using valid periods + metrics = compute_quality_metrics( + sorting_analyzer, + metric_names=None, + skip_pc_metrics=True, + use_valid_periods=True, + seed=2205, + ) + print(metrics) + + # test with external periods: 1 period per segment from 10 to 90% of recording + num_segments = recording.get_num_segments() + periods = np.zeros(len(sorting.unit_ids) * num_segments, dtype=unit_period_dtype) + for i, unit_id in enumerate(sorting.unit_ids): + unit_index = sorting.id_to_index(unit_id) + for segment_index in range(num_segments): + num_samples = recording.get_num_samples(segment_index=segment_index) + idx = i * num_segments + segment_index + periods[idx]["unit_index"] = unit_index + period_start = int(num_samples * 0.1) + period_end = int(num_samples * 0.9) + periods[idx]["start_sample_index"] = period_start + periods[idx]["end_sample_index"] = period_end + periods[idx]["segment_index"] = segment_index + + metrics_ext_periods = compute_quality_metrics( + sorting_analyzer, + metric_names=None, + skip_pc_metrics=True, + use_valid_periods=False, + periods=periods, + seed=2205, + ) + + # test failure when both periods and use_valid_periods are set + with pytest.raises(ValueError): + compute_quality_metrics( + sorting_analyzer, + metric_names=None, + skip_pc_metrics=True, + use_valid_periods=True, + periods=periods, + seed=2205, + ) + + # test failure if use valid_periods is True but valid_unit_periods extension is missing + sorting_analyzer.delete_extension("valid_unit_periods") + with pytest.raises(AssertionError): + compute_quality_metrics( + sorting_analyzer, + metric_names=None, + skip_pc_metrics=True, + use_valid_periods=True, + seed=2205, + ) + + if __name__ == "__main__": sorting_analyzer = get_sorting_analyzer() diff --git a/src/spikeinterface/postprocessing/__init__.py b/src/spikeinterface/postprocessing/__init__.py index dca9711ccd..555c9a5d3b 100644 --- a/src/spikeinterface/postprocessing/__init__.py +++ b/src/spikeinterface/postprocessing/__init__.py @@ -44,3 +44,8 @@ ComputeTemplateMetrics, compute_template_metrics, ) + +from .valid_unit_periods import ( + ComputeValidUnitPeriods, + compute_valid_unit_periods, +) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 8f3ffe0617..473798fe7c 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -127,7 +127,11 @@ def _get_pipeline_nodes(self): sparsity = self.params["sparsity"] else: if self.params["max_dense_channels"] is not None: - assert recording.get_num_channels() <= self.params["max_dense_channels"], "" + assert recording.get_num_channels() <= self.params["max_dense_channels"], ( + "Sparsity must be provided when the number of channels is " + f"greater than {self.params['max_dense_channels']}. Alternatively, set max_dense_channels to None " + "to compute amplitude scalings using dense waveforms." + ) sparsity = ChannelSparsity.create_dense(self.sorting_analyzer) sparsity_mask = sparsity.mask diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 2207b98da6..e78c0a6b47 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -4,9 +4,13 @@ import shutil import numpy as np -from spikeinterface.core import generate_ground_truth_recording -from spikeinterface.core import create_sorting_analyzer, load_sorting_analyzer -from spikeinterface.core import estimate_sparsity +from spikeinterface.core import ( + generate_ground_truth_recording, + create_sorting_analyzer, + load_sorting_analyzer, + estimate_sparsity, +) +from spikeinterface.core.sortinganalyzer import get_extension_class def get_dataset(): @@ -95,7 +99,19 @@ def get_sorting_analyzer(self, recording, sorting, format="memory", sparsity=Non return sorting_analyzer - def _prepare_sorting_analyzer(self, format, sparse, extension_class): + def _compute_extensions_recursively(self, sorting_analyzer, extension_class, params): + # compute dependencies of the extension class with default params + dependencies = extension_class.get_required_dependencies(**params) + for dependency_name in dependencies: + if "|" in dependency_name: + dependency_name = dependency_name.split("|")[0] + if not sorting_analyzer.has_extension(dependency_name): + # compute dependencies of the dependency + self._compute_extensions_recursively(sorting_analyzer, get_extension_class(dependency_name), {}) + # compute the dependency itself + sorting_analyzer.compute(dependency_name) + + def _prepare_sorting_analyzer(self, format, sparse, extension_class, extension_params=None): # prepare a SortingAnalyzer object with depencies already computed sparsity_ = self.sparsity if sparse else None sorting_analyzer = self.get_sorting_analyzer( @@ -103,10 +119,11 @@ def _prepare_sorting_analyzer(self, format, sparse, extension_class): ) sorting_analyzer.compute("random_spikes", max_spikes_per_unit=20, seed=2205) - for dependency_name in extension_class.depend_on: - if "|" in dependency_name: - dependency_name = dependency_name.split("|")[0] - sorting_analyzer.compute(dependency_name) + # default params for dependencies + params = sorting_analyzer.get_default_extension_params(extension_class.extension_name) + if extension_params is not None: + params.update(extension_params) + self._compute_extensions_recursively(sorting_analyzer, extension_class, params) return sorting_analyzer @@ -126,7 +143,7 @@ def _check_one(self, sorting_analyzer, extension_class, params): ext = sorting_analyzer.compute(extension_class.extension_name, **params, **job_kwargs) assert len(ext.data) > 0 main_data = ext.get_data() - assert len(main_data) > 0 + assert main_data is not None ext = sorting_analyzer.get_extension(extension_class.extension_name) assert ext is not None @@ -146,7 +163,11 @@ def _check_one(self, sorting_analyzer, extension_class, params): ext_loaded = sorting_analyzer_loaded.get_extension(extension_class.extension_name) for ext_data_name, ext_data_loaded in ext_loaded.data.items(): if isinstance(ext_data_loaded, np.ndarray): - assert np.array_equal(ext.data[ext_data_name], ext_data_loaded) + if len(ext_data_loaded) > 0 and isinstance(ext_data_loaded[0], dict): + for i in range(len(ext_data_loaded)): + assert np.array_equal(np.array(ext.data[ext_data_name][i]), np.array(ext_data_loaded[i])) + else: + assert np.array_equal(ext.data[ext_data_name], ext_data_loaded) elif isinstance(ext_data_loaded, pd.DataFrame): # skip nan values for col in ext_data_loaded.columns: @@ -169,5 +190,7 @@ def run_extension_tests(self, extension_class, params): for sparse in (True, False): for format in ("memory", "binary_folder", "zarr"): print("sparse", sparse, format) - sorting_analyzer = self._prepare_sorting_analyzer(format, sparse, extension_class) + sorting_analyzer = self._prepare_sorting_analyzer( + format, sparse, extension_class, extension_params=params + ) self._check_one(sorting_analyzer, extension_class, params) diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index fc9d3643bc..77bff7a3d8 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -99,9 +99,9 @@ def test_get_projections(self, sparse): some_unit_ids = sorting_analyzer.unit_ids[::2] some_channel_ids = sorting_analyzer.channel_ids[::2] - random_spikes_indices = sorting_analyzer.get_extension("random_spikes").get_data() - all_num_spikes = sorting_analyzer.sorting.count_num_spikes_per_unit() - unit_ids_num_spikes = np.sum(all_num_spikes[unit_id] for unit_id in some_unit_ids) + random_spikes_ext = sorting_analyzer.get_extension("random_spikes") + random_spikes_indices = random_spikes_ext.get_data() + unit_ids_num_random_spikes = np.sum(random_spikes_ext.params["max_spikes_per_unit"] for _ in some_unit_ids) # this should be all spikes all channels some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=None) @@ -113,7 +113,7 @@ def test_get_projections(self, sparse): # this should be some spikes all channels some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=some_unit_ids) assert some_projections.shape[0] == spike_unit_index.shape[0] - assert spike_unit_index.shape[0] == unit_ids_num_spikes + assert spike_unit_index.shape[0] == unit_ids_num_random_spikes assert some_projections.shape[1] == n_components assert some_projections.shape[2] == num_chans assert 1 not in spike_unit_index @@ -123,7 +123,7 @@ def test_get_projections(self, sparse): channel_ids=some_channel_ids, unit_ids=some_unit_ids ) assert some_projections.shape[0] == spike_unit_index.shape[0] - assert spike_unit_index.shape[0] == unit_ids_num_spikes + assert spike_unit_index.shape[0] == unit_ids_num_random_spikes assert some_projections.shape[1] == n_components assert some_projections.shape[2] == some_channel_ids.size assert 1 not in spike_unit_index diff --git a/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py b/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py new file mode 100644 index 0000000000..6d34264eac --- /dev/null +++ b/src/spikeinterface/postprocessing/tests/test_valid_unit_periods.py @@ -0,0 +1,131 @@ +import pytest +import numpy as np + +from spikeinterface.core.base import unit_period_dtype +from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite +from spikeinterface.postprocessing import ComputeValidUnitPeriods + + +class TestComputeValidUnitPeriods(AnalyzerExtensionCommonTestSuite): + + @pytest.mark.parametrize( + "params", + [ + dict(period_mode="absolute", period_duration_s_absolute=1.1, minimum_valid_period_duration=1.0), + dict(period_mode="relative", period_target_num_spikes=30, minimum_valid_period_duration=1.0), + ], + ) + def test_extension(self, params): + self.run_extension_tests(ComputeValidUnitPeriods, params) + + def test_user_defined_periods(self): + unit_ids = self.sorting.unit_ids + num_segments = self.sorting.get_num_segments() + + # unit periods of unit_period_dtypes + periods = np.zeros(len(unit_ids) * num_segments, dtype=unit_period_dtype) + + # for each unit we 1 valid period per segment + for i, unit_id in enumerate(unit_ids): + unit_index = self.sorting.id_to_index(unit_id) + for segment_index in range(num_segments): + num_samples = self.recording.get_num_samples(segment_index=segment_index) + idx = i * num_segments + segment_index + periods[idx]["unit_index"] = unit_index + period_start = num_samples // 4 + period_duration = num_samples // 2 + periods[idx]["start_sample_index"] = period_start + periods[idx]["end_sample_index"] = period_start + period_duration + periods[idx]["segment_index"] = segment_index + + sorting_analyzer = self._prepare_sorting_analyzer( + "memory", sparse=False, extension_class=ComputeValidUnitPeriods + ) + ext = sorting_analyzer.compute( + ComputeValidUnitPeriods.extension_name, + method="user_defined", + user_defined_periods=periods, + minimum_valid_period_duration=1, + ) + # check that valid periods correspond to user defined periods + ext_periods = ext.get_data(outputs="numpy") + np.testing.assert_array_equal(ext_periods, periods) + + def test_user_defined_periods_as_arrays(self): + unit_ids = self.sorting.unit_ids + num_segments = self.sorting.get_num_segments() + + # unit periods of unit_period_dtypes + periods_array = np.zeros((len(unit_ids) * num_segments, 4), dtype="int64") + + # for each unit we 1 valid period per segment + for i, unit_id in enumerate(unit_ids): + unit_index = self.sorting.id_to_index(unit_id) + for segment_index in range(num_segments): + num_samples = self.recording.get_num_samples(segment_index=segment_index) + idx = i * num_segments + segment_index + period_start = num_samples // 4 + period_duration = num_samples // 2 + periods_array[idx, 0] = segment_index + periods_array[idx, 1] = period_start + periods_array[idx, 2] = period_start + period_duration + periods_array[idx, 3] = unit_index + + sorting_analyzer = self._prepare_sorting_analyzer( + "memory", sparse=False, extension_class=ComputeValidUnitPeriods + ) + ext = sorting_analyzer.compute( + ComputeValidUnitPeriods.extension_name, + method="user_defined", + user_defined_periods=periods_array, + minimum_valid_period_duration=1, + ) + # check that valid periods correspond to user defined periods + ext_periods = ext.get_data(outputs="numpy") + ext_periods = np.column_stack([ext_periods[field] for field in ext_periods.dtype.names]) + np.testing.assert_array_equal(ext_periods, periods_array) + + # test that dropping segment_index raises because multi-segment + with pytest.raises(ValueError): + ext = sorting_analyzer.compute( + ComputeValidUnitPeriods.extension_name, + method="user_defined", + user_defined_periods=periods_array[:, 1:4], # drop segment_index + minimum_valid_period_duration=1, + ) + + def test_combined_periods(self): + unit_ids = self.sorting.unit_ids + num_segments = self.sorting.get_num_segments() + + # unit periods of unit_period_dtypes + periods = np.zeros(len(unit_ids) * num_segments, dtype=unit_period_dtype) + + # for each unit we 1 valid period per segment + for i, unit_id in enumerate(unit_ids): + unit_index = self.sorting.id_to_index(unit_id) + for segment_index in range(num_segments): + num_samples = self.recording.get_num_samples(segment_index=segment_index) + idx = i * num_segments + segment_index + periods[idx]["unit_index"] = unit_index + period_start = num_samples // 4 + period_duration = num_samples // 2 + periods[idx]["start_sample_index"] = period_start + periods[idx]["end_sample_index"] = period_start + period_duration + periods[idx]["segment_index"] = segment_index + + unit_valid_periods_params = dict( + method="combined", + user_defined_periods=periods, + period_mode="absolute", + period_duration_s_absolute=1.0, + minimum_valid_period_duration=1, + ) + + sorting_analyzer = self._prepare_sorting_analyzer( + "memory", sparse=False, extension_class=ComputeValidUnitPeriods, extension_params=unit_valid_periods_params + ) + ext = sorting_analyzer.compute(ComputeValidUnitPeriods.extension_name, **unit_valid_periods_params) + # check that valid periods correspond to intersection of auto-computed and user defined periods + ext_periods = ext.get_data(outputs="numpy") + assert len(ext_periods) <= len(periods) # should be less or equal than user defined ones diff --git a/src/spikeinterface/postprocessing/valid_unit_periods.py b/src/spikeinterface/postprocessing/valid_unit_periods.py new file mode 100644 index 0000000000..aa0bc0f771 --- /dev/null +++ b/src/spikeinterface/postprocessing/valid_unit_periods.py @@ -0,0 +1,749 @@ +from __future__ import annotations + +import importlib.util +import warnings + +import numpy as np +from typing import Optional +from copy import deepcopy + +from concurrent.futures import ProcessPoolExecutor +import multiprocessing as mp +from threadpoolctl import threadpool_limits +from tqdm.auto import tqdm + +from spikeinterface.core.base import unit_period_dtype +from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension +from spikeinterface.metrics.spiketrain import compute_firing_rates + +numba_spec = importlib.util.find_spec("numba") +if numba_spec is not None: + HAVE_NUMBA = True +else: + HAVE_NUMBA = False + + +class ComputeValidUnitPeriods(AnalyzerExtension): + """Compute good time periods per unit based on quality metrics. + + Paraneters + ---------- + method : {"false_positives_and_negatives", "user_defined", "combined"} + Strategy for identifying good periods for each unit. If "false_positives_and_negatives", uses + amplitude cutoff (false negative spike rate) and refractory period violations (false positive spike rate) + to estimate good periods (as periods with fn_rate 0 or period_target_num_spikes > 0 + ), "Either period_duration_s_absolute or period_target_num_spikes must be positive." + assert isinstance(period_target_num_spikes, (int)), "period_target_num_spikes must be an integer." + + # user_defined_periods formatting + self.user_defined_periods = None + if user_defined_periods is not None: + try: + user_defined_periods = np.asarray(user_defined_periods) + except Exception as e: + raise ValueError( + ( + "user_defined_periods must be some (n_periods, 3) [unit, good_period_start, good_period_end] " + "or (n_periods, 4) [unit, segment_index, good_period_start, good_period_end] structure convertible to a numpy array" + ) + ) + + if user_defined_periods.dtype != np.dtype(unit_period_dtype): + if user_defined_periods.ndim != 2 or user_defined_periods.shape[1] not in (3, 4): + raise ValueError( + "user_defined_periods must be of shape (n_periods, 3) [unit_index, good_period_start, good_period_end] or (n_periods, 4) [unit_index, segment_index, good_period_start, good_period_end]" + ) + + if not np.issubdtype(user_defined_periods.dtype, np.integer): + # Try converting to check if they're integer-valued floats + if not np.allclose(user_defined_periods, user_defined_periods.astype(int)): + raise ValueError("All values in user_defined_periods must be integers, in samples.") + user_defined_periods = user_defined_periods.astype(int) + + if user_defined_periods.shape[1] == 3: + if self.sorting_analyzer.get_num_segments() > 1: + raise ValueError( + "For multi-segment recordings, user_defined_periods must include segment_index as column 1." + ) + # add segment index 0 as column 0 if missing + user_defined_periods = np.hstack( + ( + np.zeros((user_defined_periods.shape[0], 1), dtype=int), + user_defined_periods, + ) + ) + # Cast user defined periods to unit_period_dtype + user_defined_periods = np.frombuffer(user_defined_periods, dtype=unit_period_dtype) + + # assert that user-defined periods are not too short + fs = self.sorting_analyzer.sampling_frequency + durations = user_defined_periods["end_sample_index"] - user_defined_periods["start_sample_index"] + min_duration_samples = int(minimum_valid_period_duration * fs) + if np.any(durations < min_duration_samples): + raise ValueError( + f"All user-defined periods must be at least {minimum_valid_period_duration} seconds long." + ) + self.user_defined_periods = user_defined_periods + + params = dict( + method=method, + period_duration_s_absolute=period_duration_s_absolute, + period_target_num_spikes=period_target_num_spikes, + period_mode=period_mode, + relative_margin_size=relative_margin_size, + min_num_periods_relative=min_num_periods_relative, + fp_threshold=fp_threshold, + fn_threshold=fn_threshold, + minimum_n_spikes=minimum_n_spikes, + minimum_valid_period_duration=minimum_valid_period_duration, + refractory_period_ms=refractory_period_ms, + censored_period_ms=censored_period_ms, + num_histogram_bins=num_histogram_bins, + histogram_smoothing_value=histogram_smoothing_value, + amplitudes_bins_min_ratio=amplitudes_bins_min_ratio, + ) + + return params + + def _select_extension_data(self, unit_ids): + new_extension_data = {} + good_periods = self.data["valid_unit_periods"] + unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids) + mask = np.isin(good_periods["unit_index"], unit_indices) + new_extension_data["valid_unit_periods"] = good_periods[mask] + return new_extension_data + + def _merge_extension_data( + self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, censor_ms=None, verbose=False, **job_kwargs + ): + new_extension_data = {} + good_periods = self.data["valid_unit_periods"] + if self.params["method"] in ("false_positives_and_negatives", "combined"): + # need to recompute for merged units + recompute = True + else: + # in case of user-defined periods, just merge periods + recompute = False + + if recompute: + new_periods_centers = deepcopy(self.data.get("period_centers")) + new_periods_fp_per_unit = deepcopy(self.data.get("periods_fp_per_unit")) + new_periods_fn_per_unit = deepcopy(self.data.get("periods_fn_per_unit")) + # remove data of merged units + merged_unit_indices = [] + for unit_ids in merge_unit_groups: + unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids) + merged_unit_indices.append(unit_indices) + for unit_id in unit_ids: + for segment_index in range(self.sorting_analyzer.get_num_segments()): + new_periods_centers[segment_index].pop(unit_id, None) + new_periods_fp_per_unit[segment_index].pop(unit_id, None) + new_periods_fn_per_unit[segment_index].pop(unit_id, None) + + # remove periods of merged units + good_periods_valid = good_periods[~np.isin(good_periods["unit_index"], np.array(merged_unit_indices))] + # recompute for merged units + good_periods_merged, period_centers, fps, fns = self._compute_valid_periods( + new_sorting_analyzer, + unit_ids=new_unit_ids, + ) + new_good_periods = np.concatenate((good_periods_valid, good_periods_merged), axis=0) + + # update period centers, fps, fns + for segment_index in range(new_sorting_analyzer.get_num_segments()): + new_periods_centers[segment_index].update(period_centers[segment_index]) + new_periods_fp_per_unit[segment_index].update(fps[segment_index]) + new_periods_fn_per_unit[segment_index].update(fns[segment_index]) + + new_extension_data["valid_unit_periods"] = self._sort_periods(new_good_periods) + new_extension_data["period_centers"] = period_centers + new_extension_data["periods_fp_per_unit"] = fps + new_extension_data["periods_fn_per_unit"] = fns + else: + # just merge periods + merged_periods = np.array([], dtype=unit_period_dtype) + merged_unit_indices = [] + for unit_ids in merge_unit_groups: + unit_indices = self.sorting_analyzer.sorting.ids_to_indices(unit_ids) + merged_unit_indices.append(unit_indices) + # get periods of all units to be merged + masked_periods = good_periods[np.isin(good_periods["unit_index"], unit_indices)] + if len(masked_periods) == 0: + continue + # merge periods + _merged_periods = merge_overlapping_periods_across_units_and_segments(masked_periods) + merged_periods = np.concatenate((merged_periods, _merged_periods)) + + # get periods of unmerged units + unmerged_mask = ~np.isin(good_periods["unit_index"], np.concatenate(merged_unit_indices)) + unmerged_periods = good_periods[unmerged_mask] + + new_good_periods = np.concatenate((unmerged_periods, merged_periods)) + new_extension_data["valid_unit_periods"] = self._sort_periods(new_good_periods) + + return new_extension_data + + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + new_extension_data = {} + good_periods = self.data["valid_unit_periods"] + if self.params["method"] in ("false_positives_and_negatives", "combined"): + # need to recompute for split units + recompute = True + else: + # in case of user-defined periods, we can only duplicate valid periods for the split + recompute = False + + if recompute: + new_periods_centers = deepcopy(self.data.get("period_centers")) + new_periods_fp_per_unit = deepcopy(self.data.get("periods_fp_per_unit")) + new_periods_fn_per_unit = deepcopy(self.data.get("periods_fn_per_unit")) + # remove data of split units + split_unit_indices = self.sorting_analyzer.sorting.ids_to_indices(split_units) + for unit_id in split_units: + for segment_index in range(self.sorting_analyzer.get_num_segments()): + new_periods_centers[segment_index].pop(unit_id, None) + new_periods_fp_per_unit[segment_index].pop(unit_id, None) + new_periods_fn_per_unit[segment_index].pop(unit_id, None) + + # remove periods of split units + good_periods_valid = good_periods[~np.isin(good_periods["unit_index"], split_unit_indices)] + # recompute for split units + good_periods_split, period_centers, fps, fns = self._compute_valid_periods( + new_sorting_analyzer, + unit_ids=new_unit_ids, + ) + new_good_periods = np.concatenate((good_periods_valid, good_periods_split)) + # update period centers, fps, fns + for segment_index in range(new_sorting_analyzer.get_num_segments()): + new_periods_centers[segment_index].update(period_centers[segment_index]) + new_periods_fp_per_unit[segment_index].update(fps[segment_index]) + new_periods_fn_per_unit[segment_index].update(fns[segment_index]) + + new_extension_data["valid_unit_periods"] = self._sort_periods(new_good_periods) + new_extension_data["period_centers"] = period_centers + new_extension_data["periods_fp_per_unit"] = fps + new_extension_data["periods_fn_per_unit"] = fns + else: + # just duplicate periods to the split units + split_periods = [] + split_unit_indices = self.sorting_analyzer.sorting.ids_to_indices(split_units) + for split_unit_id, new_unit_ids in zip(split_units, new_unit_ids): + unit_index = self.sorting_analyzer.sorting.id_to_index(split_unit_id) + new_unit_indices = new_sorting_analyzer.sorting.ids_to_indices(new_unit_ids) + split_unit_indices.append(unit_index) + # get periods of all units to be merged + masked_periods = good_periods[good_periods["unit_index"] == unit_index] + for new_unit_index in new_unit_indices: + _split_periods = masked_periods.copy() + _split_periods["unit_index"] = new_unit_index + split_periods = np.concatenate((split_periods, _split_periods), axis=0) + if len(masked_periods) == 0: + continue + # merge periods + _split_periods = merge_overlapping_periods_across_units_and_segments(masked_periods) + split_periods.append(_split_periods) + split_periods = np.concatenate(split_periods, axis=0) + # get periods of unmerged units + unsplit_mask = ~np.isin(good_periods["unit_index"], np.array(split_unit_indices)) + unsplit_periods = good_periods[unsplit_mask] + + new_good_periods = np.concatenate((unsplit_periods, split_periods), axis=0) + new_extension_data["valid_unit_periods"] = self._sort_periods(new_good_periods) + + return new_extension_data + + def _run(self, unit_ids=None, verbose=False, **job_kwargs): + valid_unit_periods, period_centers, fps, fns = self._compute_valid_periods( + self.sorting_analyzer, + unit_ids=unit_ids, + **job_kwargs, + ) + self.data["valid_unit_periods"] = valid_unit_periods + if period_centers is not None: + self.data["period_centers"] = period_centers + if fps is not None: + self.data["periods_fp_per_unit"] = fps + if fns is not None: + self.data["periods_fn_per_unit"] = fns + + def _compute_valid_periods(self, sorting_analyzer, unit_ids=None, **job_kwargs): + from spikeinterface import get_global_job_kwargs + + if self.params["method"] == "user_defined": + + # directly use user defined periods + return self.user_defined_periods, None, None, None + + elif self.params["method"] in ["false_positives_and_negatives", "combined"]: + + # dict: unit_id -> list of subperiod, each subperiod is an array of dtype unit_period_dtype with 4 fields + all_periods, all_periods_w_margins, period_centers = compute_subperiods( + sorting_analyzer, + self.params["period_duration_s_absolute"], + self.params["period_target_num_spikes"], + self.params["period_mode"], + self.params["relative_margin_size"], + self.params["min_num_periods_relative"], + unit_ids=unit_ids, + ) + + job_kwargs = fix_job_kwargs(job_kwargs) + n_jobs = job_kwargs["n_jobs"] + progress_bar = job_kwargs["progress_bar"] + max_threads_per_worker = job_kwargs["max_threads_per_worker"] + mp_context = job_kwargs["mp_context"] + + # Compute fp and fn for all periods + # Process units in parallel + amp_scalings = sorting_analyzer.get_extension("amplitude_scalings") + all_amplitudes_by_unit = amp_scalings.get_data(outputs="by_unit", concatenated=False) + + init_args = (sorting_analyzer.sorting, all_amplitudes_by_unit, self.params, max_threads_per_worker) + + # Each item is one computation of fp and fn for one period and one unit + items = [(period,) for period in all_periods_w_margins] + job_name = f"computing false positives and negatives" + + # parallel + with ProcessPoolExecutor( + max_workers=n_jobs, + initializer=fp_fn_worker_init, + mp_context=mp.get_context(mp_context), + initargs=init_args, + ) as executor: + results = executor.map(fp_fn_worker_func_wrapper, items) + + if progress_bar: + results = tqdm(results, desc=f"{job_name} (workers: {n_jobs} processes)", total=len(items)) + + all_fps = np.zeros(len(all_periods)) + all_fns = np.zeros(len(all_periods)) + for i, (fp, fn) in enumerate(results): + all_fps[i] = fp + all_fns[i] = fn + + # set NaNs to 1 (they will be exluded anyways) + all_fps[np.isnan(all_fps)] = 1.0 + all_fns[np.isnan(all_fns)] = 1.0 + + # split values by segment and units + # fps and fns are lists of segments with dicts unit_id -> array of shape (n_subperiods) + fps = [] + fns = [] + for segment_index in range(sorting_analyzer.sorting.get_num_segments()): + fp_in_segment = {} + fn_in_segment = {} + segment_mask = all_periods["segment_index"] == segment_index + periods_segment = all_periods[segment_mask] + fps_segment = all_fps[segment_mask] + fns_segment = all_fns[segment_mask] + for unit_index, unit_id in enumerate(sorting_analyzer.unit_ids): + unit_mask = periods_segment["unit_index"] == unit_index + fp_in_segment[unit_id] = fps_segment[unit_mask] + fn_in_segment[unit_id] = fns_segment[unit_mask] + fps.append(fp_in_segment) + fns.append(fn_in_segment) + + good_period_mask = (all_fps < self.params["fp_threshold"]) & (all_fns < self.params["fn_threshold"]) + good_periods = all_periods[good_period_mask] + + # Sort good periods on segment_index, unit_index, start_sample_index + valid_unit_periods = self._sort_periods(good_periods) + + # Combine with user-defined periods if provided + if self.params["method"] == "combined": + user_defined_periods = self.user_defined_periods + valid_unit_periods = self._sort_periods( + np.concatenate((valid_unit_periods, user_defined_periods), axis=0) + ) + valid_unit_periods = merge_overlapping_periods_across_units_and_segments(valid_unit_periods) + + # Remove good periods that are too short + minimum_valid_period_duration = self.params["minimum_valid_period_duration"] + min_valid_period_samples = int(minimum_valid_period_duration * sorting_analyzer.sampling_frequency) + duration_samples = valid_unit_periods["end_sample_index"] - valid_unit_periods["start_sample_index"] + valid_mask = duration_samples >= min_valid_period_samples + valid_unit_periods = valid_unit_periods[valid_mask] + + # Store data: here we have to make sure every dict is JSON serializable, so everything is lists + return valid_unit_periods, period_centers, fps, fns + + def _get_data(self, outputs: str = "by_unit"): + """ + Return extension data. If the extension computes more than one `nodepipeline_variables`, + the `return_data_name` is used to specify which one to return. + + Parameters + ---------- + outputs : "numpy" | "by_unit", default: "by_unit" + How to return the data. + + Returns + ------- + numpy.ndarray | list + The periods in numpy or dictionary by unit format, depending on `outputs`. + If "numpy", returns an array of dtype unit_period_dtype with columns: + unit_index, segment_index, start_sample_index, end_sample_index. + If "by_unit", returns a list (per segment) of dictionaries mapping unit IDs to lists of + (start_sample_index, end_sample_index) tuples. + """ + if outputs == "numpy": + good_periods = self.data["valid_unit_periods"].copy() + else: + # by_unit + unit_ids = self.sorting_analyzer.unit_ids + good_periods = [] + good_periods_array = self.data["valid_unit_periods"] + for segment_index in range(self.sorting_analyzer.get_num_segments()): + segment_mask = good_periods_array["segment_index"] == segment_index + periods_dict = {} + for unit_index in unit_ids: + periods_dict[unit_index] = [] + unit_mask = good_periods_array["unit_index"] == unit_index + good_periods_unit_segment = good_periods_array[segment_mask & unit_mask] + for start, end in good_periods_unit_segment[["start_sample_index", "end_sample_index"]]: + periods_dict[unit_index].append((start, end)) + good_periods.append(periods_dict) + + return good_periods + + def _sort_periods(self, periods): + sort_idx = np.lexsort((periods["start_sample_index"], periods["unit_index"], periods["segment_index"])) + sorted_periods = periods[sort_idx] + return sorted_periods + + def set_data(self, ext_data_name, ext_data): + # cast back lists of dicts (required for dumping) back to arrays + if ext_data_name in ("period_centers", "periods_fp_per_unit", "periods_fn_per_unit"): + ext_data_ = [] + # lists of dicts to lists of dicts with arrays + for segment_dict in ext_data: + segment_dict_arrays = {} + for unit_id, values in segment_dict.items(): + segment_dict_arrays[unit_id] = np.array(values) + ext_data_.append(segment_dict_arrays) + else: + ext_data_ = ext_data + self.data[ext_data_name] = ext_data_ + + +# TODO: deal with margin when returning periods +def compute_subperiods( + sorting_analyzer, + period_duration_s_absolute: float = 10, + period_target_num_spikes: int = 1000, + period_mode: str = "absolute", + relative_margin_size: float = 1.0, + min_num_periods_relative: int = 5, + unit_ids: Optional[list] = None, +) -> dict: + """ + Computes subperiods per unit based on specified size mode. + + Returns + ------- + all_subperiods : dict + Dictionary mapping unit IDs to lists of subperiods (arrays of dtype unit_period_dtype). + """ + sorting = sorting_analyzer.sorting + fs = sorting.sampling_frequency + if unit_ids is None: + unit_ids = sorting.unit_ids + + if period_mode == "absolute": + period_sizes_samples = {u: np.round(period_duration_s_absolute * fs).astype(int) for u in unit_ids} + else: # relative + mean_firing_rates = compute_firing_rates(sorting_analyzer, unit_ids) + period_sizes_samples = { + u: np.round((period_target_num_spikes / mean_firing_rates[u]) * fs).astype(int) for u in unit_ids + } + margin_sizes_samples = {u: np.round(relative_margin_size * period_sizes_samples[u]).astype(int) for u in unit_ids} + + all_subperiods = [] + all_subperiods_w_margins = [] + all_period_centers = [] + for segment_index in range(sorting.get_num_segments()): + n_samples = sorting_analyzer.get_num_samples(segment_index) # int: samples + period_centers = {} + for unit_index, unit_id in enumerate(unit_ids): + period_centers[unit_id] = [] + period_size_samples = period_sizes_samples[unit_id] + margin_size_samples = margin_sizes_samples[unit_id] + # We round the number of subperiods to ensure coverage of the entire recording + # the end of the last period is then clipped or extended to the end of the recording + n_subperiods = round(n_samples / period_size_samples) + if period_mode == "relative" and n_subperiods < min_num_periods_relative: + n_subperiods = min_num_periods_relative # at least min_num_periods_relative subperiods + period_size_samples = n_samples // n_subperiods + margin_size_samples = int(relative_margin_size * period_size_samples) + + # we generate periods starting from 0 up to n_samples, with and without margins, and period centers + starts = np.arange(0, n_samples, period_size_samples) + periods_for_unit = np.zeros(len(starts), dtype=unit_period_dtype) + periods_for_unit_w_margins = np.zeros(len(starts), dtype=unit_period_dtype) + for i, start in enumerate(starts): + end = start + period_size_samples + ext_start = max(0, start - margin_size_samples) + ext_end = min(n_samples, end + margin_size_samples) + center = start + period_size_samples // 2 + period_centers[unit_id].append(center) + periods_for_unit[i]["segment_index"] = segment_index + periods_for_unit[i]["start_sample_index"] = start + periods_for_unit[i]["end_sample_index"] = end + periods_for_unit[i]["unit_index"] = unit_index + periods_for_unit_w_margins[i]["segment_index"] = segment_index + periods_for_unit_w_margins[i]["start_sample_index"] = ext_start + periods_for_unit_w_margins[i]["end_sample_index"] = ext_end + periods_for_unit_w_margins[i]["unit_index"] = unit_index + + all_subperiods.append(periods_for_unit) + all_subperiods_w_margins.append(periods_for_unit_w_margins) + all_period_centers.append(period_centers) + return np.concatenate(all_subperiods), np.concatenate(all_subperiods_w_margins), all_period_centers + + +def merge_overlapping_periods(subperiods): + + segment_indices = np.unique(subperiods["segment_index"]) + assert len(segment_indices) == 1, "Subperiods must belong to the same segment to be merged." + segment_index = segment_indices[0] + unit_indices = np.unique(subperiods["unit_index"]) + assert len(unit_indices) == 1, "Subperiods must belong to the same unit to be merged." + unit_index = unit_indices[0] + + # Sort subperiods by start time for interval merging + sort_idx = np.argsort(subperiods["start_sample_index"]) + sorted_subperiods = subperiods[sort_idx] + + # Merge overlapping/adjacent intervals + merged_starts = [sorted_subperiods[0]["start_sample_index"]] + merged_ends = [sorted_subperiods[0]["end_sample_index"]] + + for i in range(1, len(sorted_subperiods)): + current_start = sorted_subperiods[i]["start_sample_index"] + current_end = sorted_subperiods[i]["end_sample_index"] + + # Merge if overlapping or contiguous (end >= start) + if current_start <= merged_ends[-1]: + merged_ends[-1] = max(merged_ends[-1], current_end) + else: + merged_starts.append(current_start) + merged_ends.append(current_end) + + # Construct output array + n_periods = len(merged_starts) + merged_periods = np.zeros(n_periods, dtype=unit_period_dtype) + merged_periods["segment_index"] = segment_index + merged_periods["start_sample_index"] = merged_starts + merged_periods["end_sample_index"] = merged_ends + merged_periods["unit_index"] = unit_index + + return merged_periods + + +def merge_overlapping_periods_across_units_and_segments(periods): + segments = np.unique(periods["segment_index"]) + units = np.unique(periods["unit_index"]) + merged_periods = [] + for segment_index in segments: + periods_per_segment = periods[periods["segment_index"] == segment_index] + for unit_index in units: + masked_periods = periods_per_segment[(periods_per_segment["unit_index"] == unit_index)] + if len(masked_periods) == 0: + continue + _merged_periods = merge_overlapping_periods(masked_periods) + merged_periods.append(_merged_periods) + if len(merged_periods) == 0: + merged_periods = np.array([], dtype=unit_period_dtype) + else: + merged_periods = np.concatenate(merged_periods, axis=0) + + return merged_periods + + +register_result_extension(ComputeValidUnitPeriods) +compute_valid_unit_periods = ComputeValidUnitPeriods.function_factory() + + +global worker_ctx + + +def fp_fn_worker_init(sorting, all_amplitudes_by_unit, params, max_threads_per_worker): + global worker_ctx + worker_ctx = {} + + # cache spike vector and spiketrains + sorting.precompute_spike_trains() + + worker_ctx["sorting"] = sorting + worker_ctx["all_amplitudes_by_unit"] = all_amplitudes_by_unit + worker_ctx["params"] = params + worker_ctx["max_threads_per_worker"] = max_threads_per_worker + + +def fp_fn_worker_func(period, sorting, all_amplitudes_by_unit, params): + """ + Low level computation of false positives and false negatives for one period and one unit. + """ + from spikeinterface.metrics.quality.misc_metrics import ( + amplitude_cutoff, + _compute_nb_violations_numba, + _compute_rp_contamination_one_unit, + ) + + # period is of dtype unit_period_dtype: 0: segment_index, 1: start_sample_index, 2: end_sample_index, 3: unit_index + period_sample = period[0] + segment_index = period_sample["segment_index"] + start_sample_index = period_sample["start_sample_index"] + end_sample_index = period_sample["end_sample_index"] + unit_index = period_sample["unit_index"] + unit_id = sorting.unit_ids[unit_index] + + amplitudes_unit = all_amplitudes_by_unit[segment_index][unit_id] + spiketrain = sorting.get_unit_spike_train(unit_id, segment_index=segment_index) + + mask = (spiketrain >= start_sample_index) & (spiketrain < end_sample_index) + total_samples_in_period = end_sample_index - start_sample_index + spiketrain_period = spiketrain[mask] + amplitudes_period = amplitudes_unit[mask] + + # compute fp (rp_violations). See _compute_refrac_period_violations in quality metrics + fs = sorting.sampling_frequency + t_c = int(round(params["censored_period_ms"] * fs * 1e-3)) + t_r = int(round(params["refractory_period_ms"] * fs * 1e-3)) + n_v = _compute_nb_violations_numba(spiketrain_period, t_r) + fp = _compute_rp_contamination_one_unit( + n_v, + len(spiketrain_period), + total_samples_in_period, + t_c, + t_r, + ) + + # compute fn (amplitude_cutoffs) + fn = amplitude_cutoff( + amplitudes_period, + params["num_histogram_bins"], + params["histogram_smoothing_value"], + params["amplitudes_bins_min_ratio"], + ) + return fp, fn + + +def fp_fn_worker_func_wrapper(period): + global worker_ctx + with threadpool_limits(limits=worker_ctx["max_threads_per_worker"]): + fp, fn = fp_fn_worker_func( + period, + worker_ctx["sorting"], + worker_ctx["all_amplitudes_by_unit"], + worker_ctx["params"], + ) + return fp, fn diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index 757401d77c..d59193ed8b 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -327,7 +327,6 @@ def _full_update_plot(self, change=None): backend_kwargs = dict(figure=self.figure, axes=None, ax=None) self.plot_matplotlib(data_plot, **backend_kwargs) - self._update_plot() def _update_plot(self, change=None): for ax in self.axes.flatten(): @@ -346,9 +345,6 @@ def _update_plot(self, change=None): self.figure.canvas.flush_events() -import numpy as np - - class RasterWidget(BaseRasterWidget): """ Plots spike train rasters. diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 42f8b93d74..154d6e4ed3 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -66,6 +66,7 @@ def setUpClass(cls): templates=dict(), noise_levels=dict(), spike_amplitudes=dict(), + amplitude_scalings=dict(max_dense_channels=None), # required by valid unit periods unit_locations=dict(), spike_locations=dict(), quality_metrics=dict( @@ -74,6 +75,12 @@ def setUpClass(cls): template_metrics=dict(), correlograms=dict(), template_similarity=dict(), + valid_unit_periods=dict( + period_mode="relative", + period_target_num_spikes=200, + relative_margin_size=0.5, + min_num_periods_relative=5, + ), ) job_kwargs = dict(n_jobs=-1) @@ -687,6 +694,14 @@ def test_plot_motion_info(self): if backend not in self.skip_backends: sw.plot_motion_info(motion_info, recording=self.recording, backend=backend) + def test_plot_valid_unit_periods(self): + possible_backends = list(sw.ValidUnitPeriodsWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_valid_unit_periods( + self.sorting_analyzer_dense, backend=backend, show_only_units_with_valid_periods=False + ) + if __name__ == "__main__": # unittest.main() diff --git a/src/spikeinterface/widgets/unit_valid_periods.py b/src/spikeinterface/widgets/unit_valid_periods.py new file mode 100644 index 0000000000..9ea4bbb899 --- /dev/null +++ b/src/spikeinterface/widgets/unit_valid_periods.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +import numpy as np +from warnings import warn + +from spikeinterface.core import SortingAnalyzer +from .base import BaseWidget, to_attr + + +class ValidUnitPeriodsWidget(BaseWidget): + """ + Plots the valid periods for units based on valid periods extension. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer | None, default: None + The sorting analyzer + segment_index : None or int, default: None + The segment index. If None, uses first segment. + unit_ids : list | None, default: None + List of unit ids to plot. If None, all units are plotted. + show_only_units_with_valid_periods : bool, default: True + If True, only units with valid periods are shown. + """ + + def __init__( + self, + sorting_analyzer: SortingAnalyzer | None = None, + segment_index: int | None = None, + unit_ids: list | None = None, + show_only_units_with_valid_periods: bool = True, + clip_amplitude_scalings: float | None = 5.0, + backend: str | None = None, + **backend_kwargs, + ): + sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) + self.check_extensions(sorting_analyzer, "valid_unit_periods") + valid_periods_ext = sorting_analyzer.get_extension("valid_unit_periods") + if valid_periods_ext.params["method"] == "user_defined": + raise ValueError("UnitValidPeriodsWidget cannot be used with 'user_defined' valid periods.") + + if segment_index is None: + nseg = sorting_analyzer.get_num_segments() + if nseg != 1: + raise ValueError("You must provide segment_index=...") + else: + segment_index = 0 + + valid_periods = valid_periods_ext.get_data(outputs="numpy") + if show_only_units_with_valid_periods: + valid_unit_ids = sorting_analyzer.unit_ids[np.unique(valid_periods["unit_index"])] + else: + valid_unit_ids = sorting_analyzer.unit_ids + if unit_ids is not None: + valid_unit_ids = [u for u in unit_ids if u in valid_unit_ids] + + data_plot = dict( + sorting_analyzer=sorting_analyzer, + segment_index=segment_index, + unit_ids=valid_unit_ids, + clip_amplitude_scalings=clip_amplitude_scalings, + ) + + BaseWidget.__init__(self, data_plot, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + num_units = len(dp.unit_ids) + + if backend_kwargs["axes"] is not None: + axes = backend_kwargs["axes"] + if axes.ndim == 1: + axes = axes[:, None] + assert np.asarray(axes).shape == (3, num_units), "Axes shape does not match number of units" + else: + if "figsize" not in backend_kwargs: + backend_kwargs["figsize"] = (2 * num_units, 2 * 3) + backend_kwargs["num_axes"] = num_units * 3 + backend_kwargs["ncols"] = num_units + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + sorting_analyzer = dp.sorting_analyzer + sampling_frequency = sorting_analyzer.sampling_frequency + segment_index = dp.segment_index + good_periods_ext = sorting_analyzer.get_extension("valid_unit_periods") + fp_threshold = good_periods_ext.params["fp_threshold"] + fn_threshold = good_periods_ext.params["fn_threshold"] + fp_per_unit = good_periods_ext.data["periods_fp_per_unit"][segment_index] + fn_per_unit = good_periods_ext.data["periods_fn_per_unit"][segment_index] + period_centers = good_periods_ext.data["period_centers"][segment_index] + good_periods = good_periods_ext.get_data(outputs="numpy") + good_periods = good_periods[good_periods["segment_index"] == segment_index] + + amp_scalings_ext = sorting_analyzer.get_extension("amplitude_scalings") + amp_scalings_by_unit = amp_scalings_ext.get_data(outputs="by_unit")[segment_index] + + for ui, unit_id in enumerate(dp.unit_ids): + fp = fp_per_unit[unit_id] + fn = fn_per_unit[unit_id] + unit_index = list(sorting_analyzer.unit_ids).index(unit_id) + + axs = self.axes[:, ui] + # for simplicity we don't use timestamps here + spiketrain = ( + sorting_analyzer.sorting.get_unit_spike_train(unit_id, segment_index=segment_index) / sampling_frequency + ) + center_bins_s = np.array(period_centers[unit_id]) / sampling_frequency + + axs[0].plot(center_bins_s, fp, ls="", marker="o", color="r") + axs[0].axhline(fp_threshold, color="gray", ls="--") + axs[1].plot(center_bins_s, fn, ls="", marker="o") + axs[1].axhline(fn_threshold, color="gray", ls="--") + amp_scalings_data = amp_scalings_by_unit[unit_id] + if dp.clip_amplitude_scalings is not None: + amp_scalings_data = np.clip(amp_scalings_data, -dp.clip_amplitude_scalings, dp.clip_amplitude_scalings) + axs[2].plot(spiketrain, amp_scalings_data, ls="", marker="o", color="gray", alpha=0.5) + axs[2].axhline(1.0, color="k", ls="--") + # plot valid periods + valid_period_for_units = good_periods[good_periods["unit_index"] == unit_index] + for valid_period in valid_period_for_units: + start_time = valid_period["start_sample_index"] / sorting_analyzer.sampling_frequency + end_time = valid_period["end_sample_index"] / sorting_analyzer.sampling_frequency + axs[2].axvspan(start_time, end_time, alpha=0.3, color="g") + + axs[0].set_xlabel("") + axs[1].set_xlabel("") + axs[2].set_xlabel("Time (s)") + axs[0].set_ylabel("FP Rate") + axs[1].set_ylabel("FN Rate") + axs[2].set_ylabel("Amplitude Scaling") + axs[0].set_title(f"Unit {unit_id}") + + axs[1].sharex(axs[0]) + axs[2].sharex(axs[0]) + + for ax in self.axes.flatten(): + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + self.figure.subplots_adjust(hspace=0.4) + + def plot_ipywidgets(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + import ipywidgets.widgets as widgets + from IPython.display import display + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector + + check_ipywidget_backend() + + self.next_data_plot = data_plot.copy() + + cm = 1 / 2.54 + + width_cm = backend_kwargs["width_cm"] + height_cm = backend_kwargs["height_cm"] + + ratios = [0.15, 0.85] + + with plt.ioff(): + output = widgets.Output() + with output: + # Create figure without axes - let plot_matplotlib create them + self.figure = plt.figure(figsize=((ratios[1] * width_cm) * cm, height_cm * cm)) + plt.show() + + self.unit_selector = UnitSelector(data_plot["unit_ids"]) + self.unit_selector.value = list(data_plot["unit_ids"])[:1] + + self.widget = widgets.AppLayout( + center=self.figure.canvas, + left_sidebar=self.unit_selector, + pane_widths=ratios + [0], + ) + + # a first update + self._full_update_plot() + + self.unit_selector.observe(self._update_plot, names=["value"], type="change") + + if backend_kwargs["display"]: + display(self.widget) + + def _full_update_plot(self, change=None): + self.figure.clear() + data_plot = self.next_data_plot + data_plot["unit_ids"] = self.unit_selector.value + backend_kwargs = dict(figure=self.figure, axes=None, ax=None) + self.plot_matplotlib(data_plot, **backend_kwargs) + + def _update_plot(self, change=None): + print(f"_update_plot called! change={change}", flush=True) + + for ax in self.axes.flatten(): + ax.clear() + + data_plot = self.next_data_plot + data_plot["unit_ids"] = self.unit_selector.value + + backend_kwargs = dict(figure=None, axes=self.axes, ax=None) + self.plot_matplotlib(data_plot, **backend_kwargs) + + self.figure.canvas.draw() + self.figure.canvas.flush_events() diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 6edba67c96..e74ad38053 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -37,6 +37,7 @@ from .comparison import AgreementMatrixWidget, ConfusionMatrixWidget from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, StudyAgreementMatrix, StudySummary from .collision import ComparisonCollisionBySimilarityWidget, StudyComparisonCollisionBySimilarityWidget +from .unit_valid_periods import ValidUnitPeriodsWidget widget_list = [ AgreementMatrixWidget, @@ -48,6 +49,7 @@ CrossCorrelogramsWidget, DriftingTemplatesWidget, DriftRasterMapWidget, + ValidUnitPeriodsWidget, ISIDistributionWidget, LocationsWidget, MotionWidget, @@ -128,6 +130,7 @@ plot_crosscorrelograms = CrossCorrelogramsWidget plot_drifting_templates = DriftingTemplatesWidget plot_drift_raster_map = DriftRasterMapWidget +plot_valid_unit_periods = ValidUnitPeriodsWidget plot_isi_distribution = ISIDistributionWidget plot_locations = LocationsWidget plot_motion = MotionWidget