From 5be9d7fbfedda46e13baf332ef569d3cd072b10c Mon Sep 17 00:00:00 2001 From: Michael Denker Date: Mon, 14 Mar 2022 08:08:54 +0100 Subject: [PATCH] Support SpikeTrainList object and deprecate RecordingChannel in Elephant (#447) * Handle SpikeTrainList as input for BinnedSpikeTrain * Add corresponding regression test * PEP 8; avoid universal try-except clause * Fixed SpikeTrainList issue in spike train synchrony class * Fixed SpikeTrainList issue in SPADE * Pep8 issue * This removes support of RrcordingChannelGroup in CSD methods. * Changed requirements to Neo >=0.10.0 * Fix spiketrain lists in instantaneous rate. --- elephant/conversion.py | 17 ++++++---- elephant/current_source_density.py | 51 ++++++++++++++++-------------- elephant/spade.py | 3 +- elephant/spike_train_synchrony.py | 6 +++- elephant/statistics.py | 3 +- elephant/test/test_conversion.py | 23 ++++++++++++++ elephant/test/test_kcsd.py | 18 ++--------- elephant/test/test_spade.py | 30 ++++++++++++++++++ elephant/utils.py | 3 +- requirements/requirements.txt | 2 +- 10 files changed, 106 insertions(+), 50 deletions(-) diff --git a/elephant/conversion.py b/elephant/conversion.py index a0aae2a28..adb070865 100644 --- a/elephant/conversion.py +++ b/elephant/conversion.py @@ -1223,14 +1223,14 @@ def __init__(self, t_start, t_stop, bin_size, units, sparse_matrix, self.tolerance = tolerance -def _check_neo_spiketrain(matrix): +def _check_neo_spiketrain(query): """ Checks if given input contains neo.SpikeTrain objects Parameters ---------- - matrix - Object to test for `neo.SpikeTrain`s + query + Object to test for `neo.SpikeTrain` objects Returns ------- @@ -1240,11 +1240,14 @@ def _check_neo_spiketrain(matrix): """ # Check for single spike train - if isinstance(matrix, neo.SpikeTrain): + if isinstance(query, neo.SpikeTrain): return True - # Check for list or tuple - if isinstance(matrix, (list, tuple)): - return all(map(_check_neo_spiketrain, matrix)) + # Check for list, tuple, or SpikeTrainList + try: + return all(map(_check_neo_spiketrain, query)) + except TypeError: + pass + return False diff --git a/elephant/current_source_density.py b/elephant/current_source_density.py index 686c1fde9..6638d9e8e 100644 --- a/elephant/current_source_density.py +++ b/elephant/current_source_density.py @@ -62,7 +62,7 @@ @deprecated_alias(coords='coordinates') -def estimate_csd(lfp, coordinates=None, method=None, +def estimate_csd(lfp, coordinates='coordinates', method=None, process_estimate=True, **kwargs): """ Function call to compute the current source density (CSD) from @@ -72,12 +72,16 @@ def estimate_csd(lfp, coordinates=None, method=None, Parameters ---------- lfp : neo.AnalogSignal - positions of electrodes can be added as neo.RecordingChannel - coordinate or sent externally as a func argument (See coords) - coordinates : [Optional] corresponding spatial coordinates of the - electrodes. - Defaults to None - Otherwise looks for ChannelIndex coordinate + Positions of electrodes can be added as an array annotation + coordinates : array-like Quantity or string + Specifies the corresponding spatial coordinates of the electrodes. + Coordinates can be directly supplied by a NxM array-like Quantity + with dimension of space, where M is the number of signals in 'lfp', + and N is equal to the dimensionality of the method. + Alternatively, if coordinates is a string, the function will fetch the + coordinates, supplied in the same format, as annotation of 'lfp' by that + name. + Default: 'coordinates' method : string Pick a method corresponding to the setup, in this implementation For Laminar probe style (1D), use 'KCSD1D' or 'StandardCSD', @@ -114,17 +118,19 @@ def estimate_csd(lfp, coordinates=None, method=None, """ if not isinstance(lfp, neo.AnalogSignal): raise TypeError('Parameter `lfp` must be a neo.AnalogSignal object') - if coordinates is None: - coordinates = lfp.channel_index.coordinates - else: - scaled_coords = [] - for coord in coordinates: - try: - scaled_coords.append(coord.rescale(pq.mm)) - except AttributeError: - raise AttributeError('No units given for electrode spatial \ - coordinates') - coordinates = scaled_coords + if isinstance(coordinates, str): + coordinates = lfp.annotations[coordinates] + + # Scale all coordinates to mm as common basis + scaled_coords = [] + for coord in coordinates: + try: + scaled_coords.append(coord.rescale(pq.mm)) + except AttributeError: + raise AttributeError('No units given for electrode spatial \ + coordinates') + coordinates = scaled_coords + if method is None: raise ValueError('Must specify a method of CSD implementation') if len(coordinates) != lfp.shape[1]: @@ -249,7 +255,8 @@ def generate_lfp(csd_profile, x_positions, y_positions=None, z_positions=None, ------- LFP : neo.AnalogSignal The potentials created by the csd profile at the electrode positions. - The electrode positions are attached as RecordingChannel's coordinate. + The electrode positions are attached as an annotation named + 'coordinates'. """ def integrate_1D(x0, csd_x, csd, h): @@ -327,10 +334,8 @@ def integrate_3D(x, y, z, csd, xlin, ylin, zlin, X, Y, Z): pots /= 4 * np.pi * sigma ele_pos = np.vstack((x_positions, y_positions, z_positions)).T ele_pos = ele_pos * pq.mm - ch = neo.ChannelIndex(index=range(len(pots))) + asig = neo.AnalogSignal(np.expand_dims(pots, axis=0), sampling_rate=pq.kHz, units='mV') - ch.coordinates = ele_pos - ch.analogsignals.append(asig) - ch.create_relationship() + asig.annotate(coordinates=ele_pos) return asig diff --git a/elephant/spade.py b/elephant/spade.py index 34ecf51de..6f95e55cc 100644 --- a/elephant/spade.py +++ b/elephant/spade.py @@ -97,6 +97,7 @@ from itertools import chain, combinations import neo +from neo.core.spiketrainlist import SpikeTrainList import numpy as np import quantities as pq from scipy import sparse @@ -614,7 +615,7 @@ def concepts_mining(spiketrains, bin_size, winlen, min_spikes=2, min_occ=2, "report has to assume of the following values:" + " 'a', '#' and '3d#,' got {} instead".format(report)) # if spiketrains is list of neo.SpikeTrain convert to conv.BinnedSpikeTrain - if isinstance(spiketrains, list) and \ + if isinstance(spiketrains, (list, SpikeTrainList)) and \ isinstance(spiketrains[0], neo.SpikeTrain): spiketrains = conv.BinnedSpikeTrain( spiketrains, bin_size=bin_size, tolerance=None) diff --git a/elephant/spike_train_synchrony.py b/elephant/spike_train_synchrony.py index c4f0ed505..8cde2d758 100644 --- a/elephant/spike_train_synchrony.py +++ b/elephant/spike_train_synchrony.py @@ -349,7 +349,11 @@ def delete_synchrofacts(self, threshold, in_place=False, mode='delete'): # replace link to spiketrain in segment new_index = self._get_spiketrain_index( segment.spiketrains, st) - segment.spiketrains[new_index] = new_st + # Todo: Simplify following lines once Neo SpikeTrainList + # implments indexed assignment of entries (i.e., stl[i]=st) + spiketrainlist = list(segment.spiketrains) + spiketrainlist[new_index] = new_st + segment.spiketrains = spiketrainlist except ValueError: # st is not in this segment even though it points to it warnings.warn(f"The SpikeTrain at index {idx} of the " diff --git a/elephant/statistics.py b/elephant/statistics.py index 170aad6c5..1fa076f5e 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -70,6 +70,7 @@ import warnings import neo +from neo.core.spiketrainlist import SpikeTrainList import numpy as np import quantities as pq import scipy.stats @@ -769,7 +770,7 @@ def optimal_kernel(st): if kernel == 'auto': kernel = optimal_kernel(spiketrains) spiketrains = [spiketrains] - elif not isinstance(spiketrains, (list, tuple)): + elif not isinstance(spiketrains, (list, tuple, SpikeTrainList)): raise TypeError( "'spiketrains' must be a list of neo.SpikeTrain's or a single " "neo.SpikeTrain. Found: '{}'".format(type(spiketrains))) diff --git a/elephant/test/test_conversion.py b/elephant/test/test_conversion.py index d0fde16b7..e63d6690d 100644 --- a/elephant/test/test_conversion.py +++ b/elephant/test/test_conversion.py @@ -9,6 +9,7 @@ import unittest import neo +from neo.core.spiketrainlist import SpikeTrainList import numpy as np import quantities as pq from numpy.testing import (assert_array_almost_equal, assert_array_equal) @@ -185,6 +186,28 @@ def test_bin_edges_empty_binned_spiketrain(self): assert_array_equal(bst.spike_indices, [[]]) # no binned spikes self.assertEqual(bst.get_num_of_spikes(), 0) + def test_regression_431(self): + """ + Addresses issue 431 + This unittest addresses an issue where a SpikeTrainList obejct was not + correctly handled by the constructor + """ + st1 = neo.SpikeTrain( + times=np.array([1, 2, 3]) * pq.ms, + t_start=0 * pq.ms, t_stop=10 * pq.ms) + st2 = neo.SpikeTrain( + times=np.array([4, 5, 6]) * pq.ms, + t_start=0 * pq.ms, t_stop=10 * pq.ms) + real_list = [st1, st2] + spiketrainlist = SpikeTrainList([st1, st2]) + + real_list_binary = cv.BinnedSpikeTrain(real_list, bin_size=1*pq.ms) + spiketrainlist_binary = cv.BinnedSpikeTrain( + spiketrainlist, bin_size=1 * pq.ms) + + assert_array_equal( + real_list_binary.to_array(), spiketrainlist_binary.to_array()) + class BinnedSpikeTrainTestCase(unittest.TestCase): def setUp(self): diff --git a/elephant/test/test_kcsd.py b/elephant/test/test_kcsd.py index 6a8527f30..ee98b96c1 100644 --- a/elephant/test/test_kcsd.py +++ b/elephant/test/test_kcsd.py @@ -32,11 +32,7 @@ def setUp(self): temp_signals.append(self.pots[ii]) self.an_sigs = neo.AnalogSignal(np.array(temp_signals).T * pq.mV, sampling_rate=1000 * pq.Hz) - chidx = neo.ChannelIndex(range(len(self.pots))) - chidx.analogsignals.append(self.an_sigs) - chidx.coordinates = self.ele_pos * pq.mm - - chidx.create_relationship() + self.an_sigs.annotate(coordinates=self.ele_pos * pq.mm) def test_kcsd1d_estimate(self, cv_params={}): self.test_params.update(cv_params) @@ -86,11 +82,7 @@ def setUp(self): temp_signals.append(self.pots[ii]) self.an_sigs = neo.AnalogSignal(np.array(temp_signals).T * pq.mV, sampling_rate=1000 * pq.Hz) - chidx = neo.ChannelIndex(range(len(self.pots))) - chidx.analogsignals.append(self.an_sigs) - chidx.coordinates = self.ele_pos * pq.mm - - chidx.create_relationship() + self.an_sigs.annotate(coordinates=self.ele_pos * pq.mm) def test_kcsd2d_estimate(self, cv_params={}): self.test_params.update(cv_params) @@ -149,11 +141,7 @@ def setUp(self): temp_signals.append(self.pots[ii]) self.an_sigs = neo.AnalogSignal(np.array(temp_signals).T * pq.mV, sampling_rate=1000 * pq.Hz) - chidx = neo.ChannelIndex(range(len(self.pots))) - chidx.analogsignals.append(self.an_sigs) - chidx.coordinates = self.ele_pos * pq.mm - - chidx.create_relationship() + self.an_sigs.annotate(coordinates=self.ele_pos * pq.mm) def test_kcsd3d_estimate(self, cv_params={}): self.test_params.update(cv_params) diff --git a/elephant/test/test_spade.py b/elephant/test/test_spade.py index 299bc62b7..a85c1ed56 100644 --- a/elephant/test/test_spade.py +++ b/elephant/test/test_spade.py @@ -10,6 +10,7 @@ import random import neo +from neo.core.spiketrainlist import SpikeTrainList import numpy as np import quantities as pq from numpy.testing.utils import assert_array_equal @@ -180,6 +181,35 @@ def test_spade_msip(self): # check the lags assert_array_equal(lags_msip, self.lags_msip) + # Testing with multiple patterns input + def test_spade_msip_spiketrainlist(self): + output_msip = spade.spade( + SpikeTrainList(self.msip), self.bin_size, self.winlen, + approx_stab_pars=dict( + n_subsets=self.n_subset, + stability_thresh=self.stability_thresh), + n_surr=self.n_surr, alpha=self.alpha, + psr_param=self.psr_param, + stat_corr='no', + output_format='patterns')['patterns'] + elements_msip = [] + occ_msip = [] + lags_msip = [] + # collecting spade output + for out in output_msip: + elements_msip.append(out['neurons']) + occ_msip.append(list(out['times'].magnitude)) + lags_msip.append(list(out['lags'].magnitude)) + elements_msip = sorted(elements_msip, key=len) + occ_msip = sorted(occ_msip, key=len) + lags_msip = sorted(lags_msip, key=len) + # check neurons in the patterns + assert_array_equal(elements_msip, self.elements_msip) + # check the occurrences time of the patters + assert_array_equal(occ_msip, self.occ_msip) + # check the lags + assert_array_equal(lags_msip, self.lags_msip) + def test_parameters(self): """ Test under different configuration of parameters than the default one diff --git a/elephant/utils.py b/elephant/utils.py index 445b74bc3..3ce1e9206 100644 --- a/elephant/utils.py +++ b/elephant/utils.py @@ -16,6 +16,7 @@ from functools import wraps import neo +from neo.core.spiketrainlist import SpikeTrainList import numpy as np import quantities as pq @@ -188,7 +189,7 @@ def check_neo_consistency(neo_objects, object_type, t_start=None, ValueError If input object units, t_start, or t_stop do not match across trials. """ - if not isinstance(neo_objects, (list, tuple)): + if not isinstance(neo_objects, (list, tuple, SpikeTrainList)): neo_objects = [neo_objects] try: units = neo_objects[0].units diff --git a/requirements/requirements.txt b/requirements/requirements.txt index e89638ae3..25f8e040b 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,4 +1,4 @@ -neo>=0.9.0,<0.10.0 +neo>=0.10.0 numpy>=1.18.1 quantities>=0.12.1 scipy<1.7.0