diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index ae49877ce..ec2547b70 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -105,8 +105,7 @@ jobs: - name: Test with pytest run: | - coverage run --source=elephant -m pytest - coveralls --service=github + coverage run --source=elephant -m pytest && coveralls --service=github || echo "Coveralls submission failed" env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} @@ -294,8 +293,7 @@ jobs: - name: Test with pytest run: | - mpiexec -n 1 python -m mpi4py -m coverage run --source=elephant -m pytest - coveralls --service=github + mpiexec -n 1 python -m mpi4py -m coverage run --source=elephant -m pytest && coveralls --service=github || echo "Coveralls submission failed" env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/elephant/gpfa/gpfa.py b/elephant/gpfa/gpfa.py index 4d671867e..025d03d35 100644 --- a/elephant/gpfa/gpfa.py +++ b/elephant/gpfa/gpfa.py @@ -69,19 +69,18 @@ :license: Modified BSD, see LICENSE.txt for details. """ -from __future__ import division, print_function, unicode_literals - +from typing import List, Union import neo import numpy as np import quantities as pq import sklearn from elephant.gpfa import gpfa_core, gpfa_util +from elephant.trials import Trials +from elephant.utils import trials_to_list_of_spiketrainlist -__all__ = [ - "GPFA" -] +__all__ = ["GPFA"] class GPFA(sklearn.base.BaseEstimator): @@ -228,9 +227,18 @@ class GPFA(sklearn.base.BaseEstimator): ... 'latent_variable']) """ - def __init__(self, bin_size=20 * pq.ms, x_dim=3, min_var_frac=0.01, - tau_init=100.0 * pq.ms, eps_init=1.0E-3, em_tol=1.0E-8, - em_max_iters=500, freq_ll=5, verbose=False): + def __init__( + self, + bin_size=20 * pq.ms, + x_dim=3, + min_var_frac=0.01, + tau_init=100.0 * pq.ms, + eps_init=1.0e-3, + em_tol=1.0e-8, + em_max_iters=500, + freq_ll=5, + verbose=False, + ): # Initialize object self.bin_size = bin_size self.x_dim = x_dim @@ -241,11 +249,12 @@ def __init__(self, bin_size=20 * pq.ms, x_dim=3, min_var_frac=0.01, self.em_max_iters = em_max_iters self.freq_ll = freq_ll self.valid_data_names = ( - 'latent_variable_orth', - 'latent_variable', - 'Vsm', - 'VsmGP', - 'y') + "latent_variable_orth", + "latent_variable", + "Vsm", + "VsmGP", + "y", + ) self.verbose = verbose if not isinstance(self.bin_size, pq.Quantity): @@ -258,17 +267,53 @@ def __init__(self, bin_size=20 * pq.ms, x_dim=3, min_var_frac=0.01, self.fit_info = dict() self.transform_info = dict() - def fit(self, spiketrains): + @staticmethod + def _check_training_data( + spiketrains: List[List[neo.core.SpikeTrain]], + ) -> None: + if len(spiketrains) == 0: + raise ValueError("Input spiketrains can not be empty") + if not all( + isinstance(item, neo.SpikeTrain) + for sublist in spiketrains + for item in sublist + ): + raise ValueError( + "structure of the spiketrains is not " + "correct: 0-axis should be trials, 1-axis " + "neo.SpikeTrain and 2-axis spike times." + ) + + def _format_training_data( + self, spiketrains: List[List[neo.core.SpikeTrain]] + ) -> np.recarray: + seqs = gpfa_util.get_seqs(spiketrains, self.bin_size) + # Remove inactive units based on training set + self.has_spikes_bool = np.hstack(seqs["y"]).any(axis=1) + for seq in seqs: + seq["y"] = seq["y"][self.has_spikes_bool, :] + return seqs + + @trials_to_list_of_spiketrainlist + def fit( + self, + spiketrains: Union[ + List[List[neo.core.SpikeTrain]], + "Trials", + List[neo.core.spiketrainlist.SpikeTrainList], + ], + ) -> "GPFA": """ Fit the model with the given training data. Parameters - ---------- - spiketrains : list of list of neo.SpikeTrain + ---------- # noqa + spiketrains : :class:`elephant.trials.Trials`, list of :class:`neo.core.spiketrainlist.SpikeTrainList` or list of list of :class:`neo.core.SpikeTrain` Spike train data to be fit to latent variables. - The outer list corresponds to trials and the inner list corresponds - to the neurons recorded in that trial, such that - `spiketrains[l][n]` is the spike train of neuron `n` in trial `l`. + For list of lists, the outer list corresponds to trials and the + inner list corresponds to the neurons recorded in that trial, such + that `spiketrains[l][n]` is the spike train of neuron `n` in trial + `l`. Note that the number and order of `neo.SpikeTrain` objects per trial must be fixed such that `spiketrains[l][n]` and `spiketrains[k][n]` refer to spike trains of the same neuron @@ -288,69 +333,74 @@ def fit(self, spiketrains): If covariance matrix of input spike data is rank deficient. """ - self._check_training_data(spiketrains) - seqs_train = self._format_training_data(spiketrains) - # Check if training data covariance is full rank - y_all = np.hstack(seqs_train['y']) - y_dim = y_all.shape[0] - - if np.linalg.matrix_rank(np.cov(y_all)) < y_dim: - errmesg = 'Observation covariance matrix is rank deficient.\n' \ - 'Possible causes: ' \ - 'repeated units, not enough observations.' - raise ValueError(errmesg) - - if self.verbose: - print('Number of training trials: {}'.format(len(seqs_train))) - print('Latent space dimensionality: {}'.format(self.x_dim)) - print('Observation dimensionality: {}'.format( - self.has_spikes_bool.sum())) - - # The following does the heavy lifting. - self.params_estimated, self.fit_info = gpfa_core.fit( - seqs_train=seqs_train, - x_dim=self.x_dim, - bin_width=self.bin_size.rescale('ms').magnitude, - min_var_frac=self.min_var_frac, - em_max_iters=self.em_max_iters, - em_tol=self.em_tol, - tau_init=self.tau_init.rescale('ms').magnitude, - eps_init=self.eps_init, - freq_ll=self.freq_ll, - verbose=self.verbose) - - return self - - @staticmethod - def _check_training_data(spiketrains): - if len(spiketrains) == 0: - raise ValueError("Input spiketrains cannot be empty") - if not isinstance(spiketrains[0][0], neo.SpikeTrain): - raise ValueError("structure of the spiketrains is not correct: " - "0-axis should be trials, 1-axis neo.SpikeTrain" - "and 2-axis spike times") - - def _format_training_data(self, spiketrains): - seqs = gpfa_util.get_seqs(spiketrains, self.bin_size) - # Remove inactive units based on training set - self.has_spikes_bool = np.hstack(seqs['y']).any(axis=1) - for seq in seqs: - seq['y'] = seq['y'][self.has_spikes_bool, :] - return seqs - - def transform(self, spiketrains, returned_data=['latent_variable_orth']): + if all( + isinstance(item, neo.SpikeTrain) + for sublist in spiketrains + for item in sublist + ): + self._check_training_data(spiketrains) + seqs_train = self._format_training_data(spiketrains) + # Check if training data covariance is full rank + y_all = np.hstack(seqs_train["y"]) + y_dim = y_all.shape[0] + + if np.linalg.matrix_rank(np.cov(y_all)) < y_dim: + errmesg = ( + "Observation covariance matrix is rank deficient.\n" + "Possible causes: " + "repeated units, not enough observations." + ) + raise ValueError(errmesg) + + if self.verbose: + print("Number of training trials: {}".format(len(seqs_train))) + print("Latent space dimensionality: {}".format(self.x_dim)) + print( + "Observation dimensionality: {}".format( + self.has_spikes_bool.sum() + ) + ) + + # The following does the heavy lifting. + self.params_estimated, self.fit_info = gpfa_core.fit( + seqs_train=seqs_train, + x_dim=self.x_dim, + bin_width=self.bin_size.rescale("ms").magnitude, + min_var_frac=self.min_var_frac, + em_max_iters=self.em_max_iters, + em_tol=self.em_tol, + tau_init=self.tau_init.rescale("ms").magnitude, + eps_init=self.eps_init, + freq_ll=self.freq_ll, + verbose=self.verbose, + ) + return self + else: # TODO: implement case for continuous data + raise ValueError + + @trials_to_list_of_spiketrainlist + def transform( + self, + spiketrains: Union[ + List[List[neo.core.SpikeTrain]], + "Trials", + List[neo.core.spiketrainlist.SpikeTrainList], + ], + returned_data: str = ["latent_variable_orth"], + ) -> "GPFA": """ Obtain trajectories of neural activity in a low-dimensional latent variable space by inferring the posterior mean of the obtained GPFA model and applying an orthonormalization on the latent variable space. Parameters - ---------- - spiketrains : list of list of neo.SpikeTrain + ---------- # noqa + spiketrains : list of list of :class:`neo.core.SpikeTrain`, list of :class:`neo.core.spiketrainlist.SpikeTrainList` or :class:`elephant.trials.Trials` Spike train data to be transformed to latent variables. - The outer list corresponds to trials and the inner list corresponds - to the neurons recorded in that trial, such that - `spiketrains[l][n]` is the spike train of neuron `n` in trial `l`. + For list of lists, the outer list corresponds to trials and the + inner list corresponds to the neurons recorded in that trial, such + that `spiketrains[l][n]` is the spike train of neuron `n` in trial + `l`. Note that the number and order of `neo.SpikeTrain` objects per trial must be fixed such that `spiketrains[l][n]` and `spiketrains[k][n]` refer to spike trains of the same neuron @@ -378,7 +428,7 @@ def transform(self, spiketrains, returned_data=['latent_variable_orth']): Returns ------- - np.ndarray or dict + :class:`np.ndarray` or dict When the length of `returned_data` is one, a single np.ndarray, containing the requested data (the first entry in `returned_data` keys list), is returned. Otherwise, a dict of multiple np.ndarrays @@ -411,36 +461,55 @@ def transform(self, spiketrains, returned_data=['latent_variable_orth']): If `returned_data` contains keys different from the ones in `self.valid_data_names`. """ - if len(spiketrains[0]) != len(self.has_spikes_bool): - raise ValueError("'spiketrains' must contain the same number of " - "neurons as the training spiketrain data") - invalid_keys = set(returned_data).difference(self.valid_data_names) - if len(invalid_keys) > 0: - raise ValueError("'returned_data' can only have the following " - "entries: {}".format(self.valid_data_names)) - seqs = gpfa_util.get_seqs(spiketrains, self.bin_size) - for seq in seqs: - seq['y'] = seq['y'][self.has_spikes_bool, :] - seqs, ll = gpfa_core.exact_inference_with_ll(seqs, - self.params_estimated, - get_ll=True) - self.transform_info['log_likelihood'] = ll - self.transform_info['num_bins'] = seqs['T'] - Corth, seqs = gpfa_core.orthonormalize(self.params_estimated, seqs) - self.transform_info['Corth'] = Corth - if len(returned_data) == 1: - return seqs[returned_data[0]] - return {x: seqs[x] for x in returned_data} - - def fit_transform(self, spiketrains, returned_data=[ - 'latent_variable_orth']): + if all( + isinstance(item, neo.SpikeTrain) + for sublist in spiketrains + for item in sublist + ): + if len(spiketrains[0]) != len(self.has_spikes_bool): + raise ValueError( + "'spiketrains' must contain the same number of " + "neurons as the training spiketrain data" + ) + invalid_keys = set(returned_data).difference(self.valid_data_names) + if len(invalid_keys) > 0: + raise ValueError( + "'returned_data' can only have the following " + f"entries: {self.valid_data_names}" + ) + seqs = gpfa_util.get_seqs(spiketrains, self.bin_size) + for seq in seqs: + seq["y"] = seq["y"][self.has_spikes_bool, :] + seqs, ll = gpfa_core.exact_inference_with_ll( + seqs, self.params_estimated, get_ll=True + ) + self.transform_info["log_likelihood"] = ll + self.transform_info["num_bins"] = seqs["T"] + Corth, seqs = gpfa_core.orthonormalize(self.params_estimated, seqs) + self.transform_info["Corth"] = Corth + if len(returned_data) == 1: + return seqs[returned_data[0]] + return {x: seqs[x] for x in returned_data} + else: # TODO: implement case for continuous data + raise ValueError + + @trials_to_list_of_spiketrainlist + def fit_transform( + self, + spiketrains: Union[ + List[List[neo.core.SpikeTrain]], + "Trials", + List[neo.core.spiketrainlist.SpikeTrainList], + ], + returned_data: str = ["latent_variable_orth"], + ) -> "GPFA": """ Fit the model with `spiketrains` data and apply the dimensionality reduction on `spiketrains`. Parameters - ---------- - spiketrains : list of list of neo.SpikeTrain + ---------- # noqa + spiketrains : list of list of :class:`neo.core.SpikeTrain`, list of :class:`neo.core.spiketrainlist.SpikeTrainList` or :class:`elephant.trials.Trials` Refer to the :func:`GPFA.fit` docstring. returned_data : list of str @@ -465,13 +534,21 @@ def fit_transform(self, spiketrains, returned_data=[ self.fit(spiketrains) return self.transform(spiketrains, returned_data=returned_data) - def score(self, spiketrains): + @trials_to_list_of_spiketrainlist + def score( + self, + spiketrains: Union[ + List[List[neo.core.SpikeTrain]], + "Trials", + List[neo.core.spiketrainlist.SpikeTrainList], + ], + ) -> "GPFA": """ Returns the log-likelihood of the given data under the fitted model Parameters - ---------- - spiketrains : list of list of neo.SpikeTrain + ---------- # noqa + spiketrains : list of list of :class:`neo.core.SpikeTrain`, list of :class:`neo.core.spiketrainlist.SpikeTrainList` or :class:`elephant.trials.Trials` Spike train data to be scored. The outer list corresponds to trials and the inner list corresponds to the neurons recorded in that trial, such that @@ -487,4 +564,4 @@ def score(self, spiketrains): Log-likelihood of the given spiketrains under the fitted model. """ self.transform(spiketrains) - return self.transform_info['log_likelihood'] + return self.transform_info["log_likelihood"] diff --git a/elephant/test/test_gpfa.py b/elephant/test/test_gpfa.py index 3a7c7a096..32d468ef0 100644 --- a/elephant/test/test_gpfa.py +++ b/elephant/test/test_gpfa.py @@ -14,21 +14,23 @@ from numpy.testing import assert_array_equal, assert_array_almost_equal from elephant.spike_train_generation import StationaryPoissonProcess - +from elephant.trials import TrialsFromLists try: import sklearn + HAVE_SKLEARN = True +except ModuleNotFoundError: + HAVE_SKLEARN = False + +if HAVE_SKLEARN: from elephant.gpfa import gpfa_util from elephant.gpfa import GPFA from sklearn.model_selection import cross_val_score - HAVE_SKLEARN = True -except ImportError: - HAVE_SKLEARN = False - @unittest.skipUnless(HAVE_SKLEARN, 'requires sklearn') class GPFATestCase(unittest.TestCase): - def setUp(self): + @classmethod + def setUpClass(cls): def gen_gamma_spike_train(k, theta, t_max): x = [] for i in range(int(3 * t_max / (k * theta))): @@ -43,8 +45,8 @@ def gen_test_data(rates, durs, shapes=(1, 1, 1, 1)): s = np.concatenate([s, s_i + np.sum(durs[:i])]) return s - self.n_iters = 10 - self.bin_size = 20 * pq.ms + cls.n_iters = 10 + cls.bin_size = 20 * pq.ms # generate data1 rates_a = (2, 10, 2, 2) @@ -52,7 +54,7 @@ def gen_test_data(rates, durs, shapes=(1, 1, 1, 1)): durs = (2.5, 2.5, 2.5, 2.5) np.random.seed(0) n_trials = 100 - self.data0 = [] + cls.data0 = [] for trial in range(n_trials): n1 = neo.SpikeTrain(gen_test_data(rates_a, durs), units=1 * pq.s, t_start=0 * pq.s, t_stop=10 * pq.s) @@ -70,14 +72,14 @@ def gen_test_data(rates, durs, shapes=(1, 1, 1, 1)): t_start=0 * pq.s, t_stop=10 * pq.s) n8 = neo.SpikeTrain(gen_test_data(rates_b, durs), units=1 * pq.s, t_start=0 * pq.s, t_stop=10 * pq.s) - self.data0.append([n1, n2, n3, n4, n5, n6, n7, n8]) - self.x_dim = 4 + cls.data0.append([n1, n2, n3, n4, n5, n6, n7, n8]) + cls.x_dim = 4 - self.data1 = self.data0[:20] + cls.data1 = cls.data0[:20] # generate data2 np.random.seed(27) - self.data2 = [] + cls.data2 = [] n_trials = 10 n_channels = 20 for trial in range(n_trials): @@ -85,7 +87,7 @@ def gen_test_data(rates, durs, shapes=(1, 1, 1, 1)): spike_times = [StationaryPoissonProcess(rate=rate * pq.Hz, t_stop=1000.0 * pq.ms).generate_spiketrain() for rate in rates] - self.data2.append(spike_times) + cls.data2.append(spike_times) def test_data1(self): gpfa = GPFA(x_dim=self.x_dim, em_max_iters=self.n_iters) @@ -217,6 +219,42 @@ def test_logdet(self): logdet_ground_truth = np.log(np.linalg.det(matrix)) assert_array_almost_equal(logdet_fast, logdet_ground_truth) + def test_trial_object_gpfa_fit(self): + gpfa_trial_object = GPFA(bin_size=self.bin_size, x_dim=self.x_dim, + em_max_iters=self.n_iters) + gpfa_list_of_lists = GPFA(bin_size=self.bin_size, x_dim=self.x_dim, + em_max_iters=self.n_iters) + + trials = TrialsFromLists(self.data1) + gpfa_trial_object.fit(trials) + gpfa_list_of_lists.fit(self.data1) + + assert_array_almost_equal(gpfa_trial_object.params_estimated['gamma'], + gpfa_list_of_lists.params_estimated['gamma']) + + def test_trial_object_gpfa_transform(self): + gpfa_trial_object = GPFA(bin_size=self.bin_size, x_dim=self.x_dim, + em_max_iters=self.n_iters) + gpfa_list_of_lists = GPFA(bin_size=self.bin_size, x_dim=self.x_dim, + em_max_iters=self.n_iters) + + trials = TrialsFromLists(self.data1) + gpfa_trial_object.fit(trials) + gpfa_trial_object.transform(trials) + gpfa_list_of_lists.fit(self.data1) + gpfa_list_of_lists.transform(self.data1) + + assert_array_almost_equal(gpfa_trial_object.transform_info['Corth'], + gpfa_list_of_lists.transform_info['Corth']) + + def test_trial_object_zero_trials(self): + gpfa_trial_object = GPFA(bin_size=self.bin_size, x_dim=self.x_dim, + em_max_iters=self.n_iters) + + trials = TrialsFromLists([]) + with self.assertRaises(ValueError): + gpfa_trial_object.fit(trials) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/elephant/test/test_utils.py b/elephant/test/test_utils.py index fcfb92263..cc927d53e 100644 --- a/elephant/test/test_utils.py +++ b/elephant/test/test_utils.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Unit tests for the synchrofact detection app +Unit tests for elephant.utils """ import unittest @@ -12,6 +12,11 @@ from elephant import utils from numpy.testing import assert_array_equal +from elephant.spike_train_generation import StationaryPoissonProcess +from elephant.trials import TrialsFromLists +from neo.core.spiketrainlist import SpikeTrainList +from neo.core import SpikeTrain + class TestUtils(unittest.TestCase): @@ -56,5 +61,80 @@ def test_round_binning_errors(self): [0, 0, 0, 0]) +class DecoratorTest: + """ + This class is used as a mock for testing the decorator. + """ + @utils.trials_to_list_of_spiketrainlist + def method_to_decorate(self, trials=None, trials_obj=None): + # This is just a mock implementation for testing purposes + if trials_obj: + return trials_obj + else: + return trials + + +class TestTrialsToListOfSpiketrainlist(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.n_channels = 10 + cls.n_trials = 5 + cls.list_of_list_of_spiketrains = [ + StationaryPoissonProcess(rate=5 * pq.Hz, t_stop=1000.0 * pq.ms + ).generate_n_spiketrains(cls.n_channels) + for _ in range(cls.n_trials)] + cls.trial_object = TrialsFromLists(cls.list_of_list_of_spiketrains) + + def test_decorator_applied(self): + # Test that the decorator is applied correctly + self.assertTrue(hasattr( + DecoratorTest.method_to_decorate, '__wrapped__' + )) + + def test_decorator_return_with_trials_input_as_arg(self): + # Test if decorator takes in trial-object and returns + # list of spiketrainlists + new_class = DecoratorTest() + list_of_spiketrainlists = new_class.method_to_decorate( + self.trial_object) + self.assertEqual(len(list_of_spiketrainlists), self.n_trials) + for spiketrainlist in list_of_spiketrainlists: + self.assertIsInstance(spiketrainlist, SpikeTrainList) + + def test_decorator_return_with_list_of_lists_input_as_arg(self): + # Test if decorator takes in list of lists of spiketrains + # and does not change input + new_class = DecoratorTest() + list_of_list_of_spiketrains = new_class.method_to_decorate( + self.list_of_list_of_spiketrains) + self.assertEqual(len(list_of_list_of_spiketrains), self.n_trials) + for list_of_spiketrains in list_of_list_of_spiketrains: + self.assertIsInstance(list_of_spiketrains, list) + for spiketrain in list_of_spiketrains: + self.assertIsInstance(spiketrain, SpikeTrain) + + def test_decorator_return_with_trials_input_as_kwarg(self): + # Test if decorator takes in trial-object and returns + # list of spiketrainlists + new_class = DecoratorTest() + list_of_spiketrainlists = new_class.method_to_decorate( + trials_obj=self.trial_object) + self.assertEqual(len(list_of_spiketrainlists), self.n_trials) + for spiketrainlist in list_of_spiketrainlists: + self.assertIsInstance(spiketrainlist, SpikeTrainList) + + def test_decorator_return_with_list_of_lists_input_as_kwarg(self): + # Test if decorator takes in list of lists of spiketrains + # and does not change input + new_class = DecoratorTest() + list_of_list_of_spiketrains = new_class.method_to_decorate( + trials_obj=self.list_of_list_of_spiketrains) + self.assertEqual(len(list_of_list_of_spiketrains), self.n_trials) + for list_of_spiketrains in list_of_list_of_spiketrains: + self.assertIsInstance(list_of_spiketrains, list) + for spiketrain in list_of_spiketrains: + self.assertIsInstance(spiketrain, SpikeTrain) + + if __name__ == '__main__': unittest.main() diff --git a/elephant/utils.py b/elephant/utils.py index 361a1c103..b4ddfee22 100644 --- a/elephant/utils.py +++ b/elephant/utils.py @@ -16,11 +16,12 @@ import warnings from functools import wraps -import neo from neo.core.spiketrainlist import SpikeTrainList import numpy as np import quantities as pq +from elephant.trials import Trials + __all__ = [ "deprecated_alias", @@ -29,7 +30,7 @@ "get_common_start_stop_times", "check_neo_consistency", "check_same_units", - "round_binning_errors" + "round_binning_errors", ] @@ -37,11 +38,15 @@ logger = logging.getLogger(__file__) log_handler = logging.StreamHandler() log_handler.setFormatter( - logging.Formatter(f"[%(asctime)s] {__name__[__name__.rfind('.')+1::]} -" - " %(levelname)s: %(message)s")) + logging.Formatter( + f"[%(asctime)s] {__name__[__name__.rfind('.')+1::]} -" + " %(levelname)s: %(message)s" + ) +) logger.addHandler(log_handler) logger.propagate = False + def is_binary(array): """ Parameters @@ -83,6 +88,7 @@ def deprecated_alias(**aliases): ... pass """ + def deco(func): @wraps(func) def wrapper(*args, **kwargs): @@ -98,10 +104,12 @@ def _rename_kwargs(func_name, kwargs, aliases): for old, new in aliases.items(): if old in kwargs: if new in kwargs: - raise TypeError(f"{func_name} received both '{old}' and " - f"'{new}'") - warnings.warn(f"'{old}' is deprecated; use '{new}'", - DeprecationWarning) + raise TypeError( + f"{func_name} received both '{old}' and " f"'{new}'" + ) + warnings.warn( + f"'{old}' is deprecated; use '{new}'", DeprecationWarning + ) kwargs[new] = kwargs.pop(old) @@ -159,22 +167,25 @@ def get_common_start_stop_times(neo_objects): If there is no shared interval ``[t_start, t_stop]`` across the input neo objects. """ - if hasattr(neo_objects, 't_start') and hasattr(neo_objects, 't_stop'): + if hasattr(neo_objects, "t_start") and hasattr(neo_objects, "t_stop"): return neo_objects.t_start, neo_objects.t_stop try: t_start = max(elem.t_start for elem in neo_objects) t_stop = min(elem.t_stop for elem in neo_objects) except AttributeError: - raise AttributeError("Input neo objects must have 't_start' and " - "'t_stop' attributes") + raise AttributeError( + "Input neo objects must have 't_start' and " "'t_stop' attributes" + ) if t_stop < t_start: - raise ValueError(f"t_stop ({t_stop}) is smaller than t_start " - f"({t_start})") + raise ValueError( + f"t_stop ({t_stop}) is smaller than t_start " f"({t_start})" + ) return t_start, t_stop -def check_neo_consistency(neo_objects, object_type, t_start=None, - t_stop=None, tolerance=1e-8): +def check_neo_consistency( + neo_objects, object_type, t_start=None, t_stop=None, tolerance=1e-8 +): """ Checks that all input neo objects share the same units, t_start, and t_stop. @@ -213,9 +224,11 @@ def check_neo_consistency(neo_objects, object_type, t_start=None, tolerance = 0 for neo_obj in neo_objects: if not isinstance(neo_obj, object_type): - raise TypeError("The input must be a list of " - f"{object_type.__name__}. Got " - f"{type(neo_obj).__name__}") + raise TypeError( + "The input must be a list of " + f"{object_type.__name__}. Got " + f"{type(neo_obj).__name__}" + ) if neo_obj.units != units: raise ValueError("The input must have the same units.") if t_start is None and abs(neo_obj.t_start.item() - start) > tolerance: @@ -252,13 +265,17 @@ def check_same_units(quantities, object_type=pq.Quantity): raise TypeError(f"The input must be a list of {object_type.__name__}") for quantity in quantities: if not isinstance(quantity, object_type): - raise TypeError("The input must be a list of " - f"{object_type.__name__}. Got " - f"{type(quantity).__name__}") + raise TypeError( + "The input must be a list of " + f"{object_type.__name__}. Got " + f"{type(quantity).__name__}" + ) if quantity.units != units: - raise ValueError("The input quantities must have the same units, " - "which is achieved with object.rescale('ms') " - "operation.") + raise ValueError( + "The input quantities must have the same units, " + "which is achieved with object.rescale('ms') " + "operation." + ) def round_binning_errors(values, tolerance=1e-8): @@ -298,18 +315,22 @@ def round_binning_errors(values, tolerance=1e-8): if isinstance(values, np.ndarray): num_corrections = correction_mask.sum() if num_corrections > 0: - logger.warning(f'Correcting {num_corrections} rounding errors by ' - 'shifting the affected spikes into the following ' - 'bin. You can set tolerance=None to disable this ' - 'behaviour.') + logger.warning( + f"Correcting {num_corrections} rounding errors by " + "shifting the affected spikes into the following " + "bin. You can set tolerance=None to disable this " + "behaviour." + ) values[correction_mask] += 0.5 return values.astype(np.int32) if correction_mask: - logger.warning('Correcting a rounding error in the calculation ' - 'of the number of bins by incrementing the value by 1. ' - 'You can set tolerance=None to disable this ' - 'behaviour.') + logger.warning( + "Correcting a rounding error in the calculation " + "of the number of bins by incrementing the value by 1. " + "You can set tolerance=None to disable this " + "behaviour." + ) values += 0.5 return int(values) @@ -325,7 +346,7 @@ def get_cuda_capability_major(): CUDA capability major version. """ cuda_success = 0 - for libname in ('libcuda.so', 'libcuda.dylib', 'cuda.dll'): + for libname in ("libcuda.so", "libcuda.dylib", "cuda.dll"): try: cuda = ctypes.CDLL(libname) except OSError: @@ -346,9 +367,9 @@ def get_cuda_capability_major(): cc_major = ctypes.c_int() cc_minor = ctypes.c_int() - cuda.cuDeviceComputeCapability(ctypes.byref(cc_major), - ctypes.byref(cc_minor), - device) + cuda.cuDeviceComputeCapability( + ctypes.byref(cc_major), ctypes.byref(cc_minor), device + ) return cc_major.value @@ -364,6 +385,7 @@ def get_opencl_capability(): """ try: import pyopencl + platforms = pyopencl.get_platforms() if len(platforms) == 0: @@ -372,3 +394,55 @@ def get_opencl_capability(): return True except ImportError: return False + + +def trials_to_list_of_spiketrainlist(method): + """ + Decorator to convert `Trials` object to a list of `SpikeTrainList` before + calling the wrapped method. + + Parameters + ---------- + method: callable + The method to be decorated. + + Returns + ------- + callable: + The decorated method. + + Examples + -------- + The decorator can be used as follows: + + >>> @trials_to_list_of_spiketrainlist + ... def process_data(self, spiketrains): + ... return None + """ + + @wraps(method) + def wrapper(*args, **kwargs): + new_args = tuple( + [ + arg.get_spiketrains_from_trial_as_list(idx) + for idx in range(arg.n_trials) + ] + if isinstance(arg, Trials) + else arg + for arg in args + ) + new_kwargs = { + key: ( + [ + value.get_spiketrains_from_trial_as_list(idx) + for idx in range(value.n_trials) + ] + if isinstance(value, Trials) + else value + ) + for key, value in kwargs.items() + } + + return method(*new_args, **new_kwargs) + + return wrapper