Skip to content

Commit

Permalink
Support SpikeTrainList object and deprecate RecordingChannel in Eleph…
Browse files Browse the repository at this point in the history
…ant (#447)

* Handle SpikeTrainList as input for BinnedSpikeTrain

* Add corresponding regression test

* PEP 8; avoid universal try-except clause

* Fixed SpikeTrainList issue in spike train synchrony class

* Fixed SpikeTrainList issue in SPADE

* Pep8 issue

* This removes support of RrcordingChannelGroup in CSD methods.

* Changed requirements to Neo >=0.10.0

* Fix spiketrain lists in instantaneous rate.
  • Loading branch information
mdenker authored Mar 14, 2022
1 parent 0074a14 commit 5be9d7f
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 50 deletions.
17 changes: 10 additions & 7 deletions elephant/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,14 +1223,14 @@ def __init__(self, t_start, t_stop, bin_size, units, sparse_matrix,
self.tolerance = tolerance


def _check_neo_spiketrain(matrix):
def _check_neo_spiketrain(query):
"""
Checks if given input contains neo.SpikeTrain objects
Parameters
----------
matrix
Object to test for `neo.SpikeTrain`s
query
Object to test for `neo.SpikeTrain` objects
Returns
-------
Expand All @@ -1240,11 +1240,14 @@ def _check_neo_spiketrain(matrix):
"""
# Check for single spike train
if isinstance(matrix, neo.SpikeTrain):
if isinstance(query, neo.SpikeTrain):
return True
# Check for list or tuple
if isinstance(matrix, (list, tuple)):
return all(map(_check_neo_spiketrain, matrix))
# Check for list, tuple, or SpikeTrainList
try:
return all(map(_check_neo_spiketrain, query))
except TypeError:
pass

return False


Expand Down
51 changes: 28 additions & 23 deletions elephant/current_source_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@


@deprecated_alias(coords='coordinates')
def estimate_csd(lfp, coordinates=None, method=None,
def estimate_csd(lfp, coordinates='coordinates', method=None,
process_estimate=True, **kwargs):
"""
Function call to compute the current source density (CSD) from
Expand All @@ -72,12 +72,16 @@ def estimate_csd(lfp, coordinates=None, method=None,
Parameters
----------
lfp : neo.AnalogSignal
positions of electrodes can be added as neo.RecordingChannel
coordinate or sent externally as a func argument (See coords)
coordinates : [Optional] corresponding spatial coordinates of the
electrodes.
Defaults to None
Otherwise looks for ChannelIndex coordinate
Positions of electrodes can be added as an array annotation
coordinates : array-like Quantity or string
Specifies the corresponding spatial coordinates of the electrodes.
Coordinates can be directly supplied by a NxM array-like Quantity
with dimension of space, where M is the number of signals in 'lfp',
and N is equal to the dimensionality of the method.
Alternatively, if coordinates is a string, the function will fetch the
coordinates, supplied in the same format, as annotation of 'lfp' by that
name.
Default: 'coordinates'
method : string
Pick a method corresponding to the setup, in this implementation
For Laminar probe style (1D), use 'KCSD1D' or 'StandardCSD',
Expand Down Expand Up @@ -114,17 +118,19 @@ def estimate_csd(lfp, coordinates=None, method=None,
"""
if not isinstance(lfp, neo.AnalogSignal):
raise TypeError('Parameter `lfp` must be a neo.AnalogSignal object')
if coordinates is None:
coordinates = lfp.channel_index.coordinates
else:
scaled_coords = []
for coord in coordinates:
try:
scaled_coords.append(coord.rescale(pq.mm))
except AttributeError:
raise AttributeError('No units given for electrode spatial \
coordinates')
coordinates = scaled_coords
if isinstance(coordinates, str):
coordinates = lfp.annotations[coordinates]

# Scale all coordinates to mm as common basis
scaled_coords = []
for coord in coordinates:
try:
scaled_coords.append(coord.rescale(pq.mm))
except AttributeError:
raise AttributeError('No units given for electrode spatial \
coordinates')
coordinates = scaled_coords

if method is None:
raise ValueError('Must specify a method of CSD implementation')
if len(coordinates) != lfp.shape[1]:
Expand Down Expand Up @@ -249,7 +255,8 @@ def generate_lfp(csd_profile, x_positions, y_positions=None, z_positions=None,
-------
LFP : neo.AnalogSignal
The potentials created by the csd profile at the electrode positions.
The electrode positions are attached as RecordingChannel's coordinate.
The electrode positions are attached as an annotation named
'coordinates'.
"""

def integrate_1D(x0, csd_x, csd, h):
Expand Down Expand Up @@ -327,10 +334,8 @@ def integrate_3D(x, y, z, csd, xlin, ylin, zlin, X, Y, Z):
pots /= 4 * np.pi * sigma
ele_pos = np.vstack((x_positions, y_positions, z_positions)).T
ele_pos = ele_pos * pq.mm
ch = neo.ChannelIndex(index=range(len(pots)))

asig = neo.AnalogSignal(np.expand_dims(pots, axis=0),
sampling_rate=pq.kHz, units='mV')
ch.coordinates = ele_pos
ch.analogsignals.append(asig)
ch.create_relationship()
asig.annotate(coordinates=ele_pos)
return asig
3 changes: 2 additions & 1 deletion elephant/spade.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
from itertools import chain, combinations

import neo
from neo.core.spiketrainlist import SpikeTrainList
import numpy as np
import quantities as pq
from scipy import sparse
Expand Down Expand Up @@ -614,7 +615,7 @@ def concepts_mining(spiketrains, bin_size, winlen, min_spikes=2, min_occ=2,
"report has to assume of the following values:" +
" 'a', '#' and '3d#,' got {} instead".format(report))
# if spiketrains is list of neo.SpikeTrain convert to conv.BinnedSpikeTrain
if isinstance(spiketrains, list) and \
if isinstance(spiketrains, (list, SpikeTrainList)) and \
isinstance(spiketrains[0], neo.SpikeTrain):
spiketrains = conv.BinnedSpikeTrain(
spiketrains, bin_size=bin_size, tolerance=None)
Expand Down
6 changes: 5 additions & 1 deletion elephant/spike_train_synchrony.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,11 @@ def delete_synchrofacts(self, threshold, in_place=False, mode='delete'):
# replace link to spiketrain in segment
new_index = self._get_spiketrain_index(
segment.spiketrains, st)
segment.spiketrains[new_index] = new_st
# Todo: Simplify following lines once Neo SpikeTrainList
# implments indexed assignment of entries (i.e., stl[i]=st)
spiketrainlist = list(segment.spiketrains)
spiketrainlist[new_index] = new_st
segment.spiketrains = spiketrainlist
except ValueError:
# st is not in this segment even though it points to it
warnings.warn(f"The SpikeTrain at index {idx} of the "
Expand Down
3 changes: 2 additions & 1 deletion elephant/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
import warnings

import neo
from neo.core.spiketrainlist import SpikeTrainList
import numpy as np
import quantities as pq
import scipy.stats
Expand Down Expand Up @@ -769,7 +770,7 @@ def optimal_kernel(st):
if kernel == 'auto':
kernel = optimal_kernel(spiketrains)
spiketrains = [spiketrains]
elif not isinstance(spiketrains, (list, tuple)):
elif not isinstance(spiketrains, (list, tuple, SpikeTrainList)):
raise TypeError(
"'spiketrains' must be a list of neo.SpikeTrain's or a single "
"neo.SpikeTrain. Found: '{}'".format(type(spiketrains)))
Expand Down
23 changes: 23 additions & 0 deletions elephant/test/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import unittest

import neo
from neo.core.spiketrainlist import SpikeTrainList
import numpy as np
import quantities as pq
from numpy.testing import (assert_array_almost_equal, assert_array_equal)
Expand Down Expand Up @@ -185,6 +186,28 @@ def test_bin_edges_empty_binned_spiketrain(self):
assert_array_equal(bst.spike_indices, [[]]) # no binned spikes
self.assertEqual(bst.get_num_of_spikes(), 0)

def test_regression_431(self):
"""
Addresses issue 431
This unittest addresses an issue where a SpikeTrainList obejct was not
correctly handled by the constructor
"""
st1 = neo.SpikeTrain(
times=np.array([1, 2, 3]) * pq.ms,
t_start=0 * pq.ms, t_stop=10 * pq.ms)
st2 = neo.SpikeTrain(
times=np.array([4, 5, 6]) * pq.ms,
t_start=0 * pq.ms, t_stop=10 * pq.ms)
real_list = [st1, st2]
spiketrainlist = SpikeTrainList([st1, st2])

real_list_binary = cv.BinnedSpikeTrain(real_list, bin_size=1*pq.ms)
spiketrainlist_binary = cv.BinnedSpikeTrain(
spiketrainlist, bin_size=1 * pq.ms)

assert_array_equal(
real_list_binary.to_array(), spiketrainlist_binary.to_array())


class BinnedSpikeTrainTestCase(unittest.TestCase):
def setUp(self):
Expand Down
18 changes: 3 additions & 15 deletions elephant/test/test_kcsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,7 @@ def setUp(self):
temp_signals.append(self.pots[ii])
self.an_sigs = neo.AnalogSignal(np.array(temp_signals).T * pq.mV,
sampling_rate=1000 * pq.Hz)
chidx = neo.ChannelIndex(range(len(self.pots)))
chidx.analogsignals.append(self.an_sigs)
chidx.coordinates = self.ele_pos * pq.mm

chidx.create_relationship()
self.an_sigs.annotate(coordinates=self.ele_pos * pq.mm)

def test_kcsd1d_estimate(self, cv_params={}):
self.test_params.update(cv_params)
Expand Down Expand Up @@ -86,11 +82,7 @@ def setUp(self):
temp_signals.append(self.pots[ii])
self.an_sigs = neo.AnalogSignal(np.array(temp_signals).T * pq.mV,
sampling_rate=1000 * pq.Hz)
chidx = neo.ChannelIndex(range(len(self.pots)))
chidx.analogsignals.append(self.an_sigs)
chidx.coordinates = self.ele_pos * pq.mm

chidx.create_relationship()
self.an_sigs.annotate(coordinates=self.ele_pos * pq.mm)

def test_kcsd2d_estimate(self, cv_params={}):
self.test_params.update(cv_params)
Expand Down Expand Up @@ -149,11 +141,7 @@ def setUp(self):
temp_signals.append(self.pots[ii])
self.an_sigs = neo.AnalogSignal(np.array(temp_signals).T * pq.mV,
sampling_rate=1000 * pq.Hz)
chidx = neo.ChannelIndex(range(len(self.pots)))
chidx.analogsignals.append(self.an_sigs)
chidx.coordinates = self.ele_pos * pq.mm

chidx.create_relationship()
self.an_sigs.annotate(coordinates=self.ele_pos * pq.mm)

def test_kcsd3d_estimate(self, cv_params={}):
self.test_params.update(cv_params)
Expand Down
30 changes: 30 additions & 0 deletions elephant/test/test_spade.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import random

import neo
from neo.core.spiketrainlist import SpikeTrainList
import numpy as np
import quantities as pq
from numpy.testing.utils import assert_array_equal
Expand Down Expand Up @@ -180,6 +181,35 @@ def test_spade_msip(self):
# check the lags
assert_array_equal(lags_msip, self.lags_msip)

# Testing with multiple patterns input
def test_spade_msip_spiketrainlist(self):
output_msip = spade.spade(
SpikeTrainList(self.msip), self.bin_size, self.winlen,
approx_stab_pars=dict(
n_subsets=self.n_subset,
stability_thresh=self.stability_thresh),
n_surr=self.n_surr, alpha=self.alpha,
psr_param=self.psr_param,
stat_corr='no',
output_format='patterns')['patterns']
elements_msip = []
occ_msip = []
lags_msip = []
# collecting spade output
for out in output_msip:
elements_msip.append(out['neurons'])
occ_msip.append(list(out['times'].magnitude))
lags_msip.append(list(out['lags'].magnitude))
elements_msip = sorted(elements_msip, key=len)
occ_msip = sorted(occ_msip, key=len)
lags_msip = sorted(lags_msip, key=len)
# check neurons in the patterns
assert_array_equal(elements_msip, self.elements_msip)
# check the occurrences time of the patters
assert_array_equal(occ_msip, self.occ_msip)
# check the lags
assert_array_equal(lags_msip, self.lags_msip)

def test_parameters(self):
"""
Test under different configuration of parameters than the default one
Expand Down
3 changes: 2 additions & 1 deletion elephant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from functools import wraps

import neo
from neo.core.spiketrainlist import SpikeTrainList
import numpy as np
import quantities as pq

Expand Down Expand Up @@ -188,7 +189,7 @@ def check_neo_consistency(neo_objects, object_type, t_start=None,
ValueError
If input object units, t_start, or t_stop do not match across trials.
"""
if not isinstance(neo_objects, (list, tuple)):
if not isinstance(neo_objects, (list, tuple, SpikeTrainList)):
neo_objects = [neo_objects]
try:
units = neo_objects[0].units
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
neo>=0.9.0,<0.10.0
neo>=0.10.0
numpy>=1.18.1
quantities>=0.12.1
scipy<1.7.0
Expand Down

0 comments on commit 5be9d7f

Please sign in to comment.