Skip to content

Commit

Permalink
Add a function to discretise spiketimes (#454)
Browse files Browse the repository at this point in the history
added `discretise_spiketimes(spiketrains, sampling_rate)` to conversion
  • Loading branch information
Kleinjohann authored Mar 8, 2022
1 parent b263ffe commit 0074a14
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 0 deletions.
66 changes: 66 additions & 0 deletions elephant/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,3 +1246,69 @@ def _check_neo_spiketrain(matrix):
if isinstance(matrix, (list, tuple)):
return all(map(_check_neo_spiketrain, matrix))
return False


def discretise_spiketimes(spiketrains, sampling_rate):
"""
Rounds down all spike times in the input spike train(s)
to multiples of the sampling_rate
Parameters
----------
spiketrains : neo.SpikeTrain or list of neo.SpikeTrain
The spiketrain(s) to discretise
sampling_rate : pq.Quantity
The desired sampling rate
Returns
-------
neo.SpikeTrain or list of neo.SpikeTrain
The discretised spiketrain(s)
"""
# spiketrains type check
was_single_spiketrain = False
if isinstance(spiketrains, neo.SpikeTrain):
spiketrains = [spiketrains]
was_single_spiketrain = True
elif isinstance(spiketrains, list):
for st in spiketrains:
if not isinstance(st, (np.ndarray, neo.SpikeTrain)):
raise TypeError(
"spiketrains must be a SpikeTrain, a numpy ndarray, or a "
"list of one of those, not %s." % type(spiketrains))
else:
raise TypeError(
"spiketrains must be a SpikeTrain or a list of SpikeTrain objects,"
" not %s." % type(spiketrains))

if not isinstance(sampling_rate, pq.Quantity):
raise TypeError(
"The 'sampling_rate' must be pq.Quantity.\n"
"Found: %s." % type(sampling_rate))

units = spiketrains[0].times.units
mag_sampling_rate = sampling_rate.rescale(1/units).magnitude.flatten()

new_spiketrains = []
for spiketrain in spiketrains:
mag_t_start = spiketrain.t_start.rescale(units).magnitude.flatten()
mag_times = spiketrain.times.magnitude.flatten()
discrete_times = (mag_times // (1 / mag_sampling_rate)
/ mag_sampling_rate)
mask = discrete_times < mag_t_start

if np.any(mask):
warnings.warn(f'{mask.sum()} spike(s) would be before t_start '
'and are set to t_start instead.')
discrete_times[mask] = mag_t_start

discrete_times *= units
new_spiketrain = spiketrain.duplicate_with_new_data(discrete_times)
new_spiketrain.annotations = spiketrain.annotations
new_spiketrain.sampling_rate = sampling_rate
new_spiketrains.append(new_spiketrain)

if was_single_spiketrain:
new_spiketrains = new_spiketrains[0]

return new_spiketrains
38 changes: 38 additions & 0 deletions elephant/test/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,5 +702,43 @@ def test_binned_spiketrain_rounding(self):
np.arange(120000))


class DiscretiseSpiketrainsTestCase(unittest.TestCase):
def setUp(self):
times = (np.arange(10) + np.random.uniform(size=10)) * pq.ms
self.spiketrains = [neo.SpikeTrain(times, t_stop=10*pq.ms)] * 5

def test_list_of_spiketrains(self):
discretised_spiketrains = cv.discretise_spiketimes(self.spiketrains,
1 / pq.ms)
for idx in range(len(self.spiketrains)):
np.testing.assert_array_equal(discretised_spiketrains[idx].times,
np.arange(10) * pq.ms)

def test_single_spiketrain(self):
discretised_spiketrain = cv.discretise_spiketimes(self.spiketrains[0],
1 / pq.ms)
np.testing.assert_array_equal(discretised_spiketrain.times,
np.arange(10) * pq.ms)

def test_preserve_t_start(self):
spiketrain = neo.SpikeTrain([0.7, 5.1]*pq.ms,
t_start=0.5*pq.ms, t_stop=10*pq.ms)
with self.assertWarns(UserWarning):
discretised_spiketrain = cv.discretise_spiketimes(spiketrain,
1 / pq.ms)
np.testing.assert_array_equal(discretised_spiketrain.times,
[0.5, 5] * pq.ms)

def test_binning_consistency(self):
discretised_spiketrains = cv.discretise_spiketimes(self.spiketrains,
1 / pq.ms)
bsts = cv.BinnedSpikeTrain(self.spiketrains,
bin_size=1 * pq.ms)
bsts_discretised = cv.BinnedSpikeTrain(discretised_spiketrains,
bin_size=1 * pq.ms)
np.testing.assert_array_equal(bsts.to_array(),
bsts_discretised.to_array())


if __name__ == '__main__':
unittest.main()

0 comments on commit 0074a14

Please sign in to comment.