diff --git a/elephant/spectral.py b/elephant/spectral.py index cd41ea471..661cac76e 100644 --- a/elephant/spectral.py +++ b/elephant/spectral.py @@ -270,6 +270,252 @@ def welch_psd(signal, n_segments=8, len_segment=None, return freqs, psd +def multitaper_psd(signal, n_segments=1, len_segment=None, + frequency_resolution=None, overlap=0.5, fs=1, + nw=4, num_tapers=None, peak_resolution=None, axis=-1): + """ + Estimates power spectrum density (PSD) of a given 'neo.AnalogSignal' + using Multitaper method + + The PSD is obtained through the following steps: + + 1. Cut the given data into several overlapping segments. The degree of + overlap can be specified by parameter `overlap` (default is 0.5, + i.e. segments are overlapped by the half of their length). + The number and the length of the segments are determined according + to the parameters `n_segments`, `len_segment` or `frequency_resolution`. + By default, the data is cut into 8 segments; + + 2. Calculate 'num_tapers' approximately independent estimates of the + spectrum by multiplying the signal with the discrete prolate spheroidal + functions (also known as Slepian function) and calculate the PSD of each + tapered segment + + 3. Average the approximately independent estimates of each segment to + decrease overall variance of the estimates + + 4. Average the obtained estimates for each segment + + Parameters + ---------- + signal : neo.AnalogSignal + Time series data of which PSD is estimated. When `signal` is np.ndarray + sampling frequency should be given through keyword argument `fs`. + Signal should be passed as (n_channels, n_samples) + fs : float, optional + Specifies the sampling frequency of the input time series + Default: 1.0. + n_segments : int, optional + Number of segments. The length of segments is adjusted so that + overlapping segments cover the entire stretch of the given data. This + parameter is ignored if `len_segment` or `frequency_resolution` is + given. + Default: 8. + len_segment : int, optional + Length of segments. This parameter is ignored if `frequency_resolution` + is given. If None, it will be determined from other parameters. + Default: None. + frequency_resolution : pq.Quantity or float, optional + Desired frequency resolution of the obtained PSD estimate in terms of + the interval between adjacent frequency bins. When given as a `float`, + it is taken as frequency in Hz. + If None, it will be determined from other parameters. + Default: None. + overlap : float, optional + Overlap between segments represented as a float number between 0 (no + overlap) and 1 (complete overlap). + Default: 0.5 (half-overlapped). + nw : float, optional + Time bandwidth product + Default: 4.0. + num_tapers : int, optional + Number of tapers used in 1. to obtain estimate of PSD. By default + [2*nw] - 1 is chosen. + Default: None. + peak_resolution : pq.Quantity float, optional + Quantity in Hz determining the number of tapers used for analysis. + Fine peak resolution --> low numerical value --> low number of tapers + High peak resolution --> high numerical value --> high number of tapers + When given as a `float`, it is taken as frequency in Hz. + Default: None. + axis : int, optional + Axis along which the periodogram is computed. + See Notes [2]. + Default: last axis (-1). + + Notes + ----- + 1. There is a parameter hierarchy regarding n_segments and len_segment. The + former parameter is ignored if the latter one is passed. + + 2. There is a parameter hierarchy regarding nw, num_tapers and + peak_resolution. If peak_resolution is provided, it determines both nw + and the num_tapers. Specifying num_tapers has an effect only if + peak_resolution is not provided. + + Returns + ------- + freqs : np.ndarray + Frequencies associated with power estimate in `psd` + psd : np.ndarray + PSD estimate of the time series in `signal` + + Raises + ------ + ValueError + If `peak_resolution` is None and `num_tapers` is not a positive number. + + If `frequency_resolution` is too high for the given data size. + + If `frequency_resolution` is None and `len_segment` is not a positive + number. + + If `frequency_resolution` is None and `len_segment` is greater than the + length of data at `axis`. + + If both `frequency_resolution` and `len_segment` are None and + `n_segments` is not a positive number. + + If both `frequency_resolution` and `len_segment` are None and + `n_segments` is greater than the length of data at `axis`. + + TypeError + If `peak_resolution` is None and `num_tapers` is not an int. + """ + + # When the input is AnalogSignal, the data is added after rolling the axis + # for time index to the last + data = np.asarray(signal) + if isinstance(signal, neo.AnalogSignal): + data = np.rollaxis(data, 0, len(data.shape)) + + # Number of data points in time series + if data.ndim == 1: + length_signal = np.shape(data)[0] + else: + length_signal = np.shape(data)[1] + + # If the data is given as AnalogSignal, use its attribute to specify the + # sampling frequency + if hasattr(signal, 'sampling_rate'): + fs = signal.sampling_rate.rescale('Hz').magnitude + + # If fs and peak resolution is pq.Quantity, get magnitude + if isinstance(fs, pq.quantity.Quantity): + fs = fs.rescale('Hz').magnitude + + # Determine length per segment - n_per_seg + if frequency_resolution is not None: + if frequency_resolution <= 0: + raise ValueError("frequency_resolution must be positive") + if isinstance(frequency_resolution, pq.quantity.Quantity): + dF = frequency_resolution.rescale('Hz').magnitude + else: + dF = frequency_resolution + n_per_seg = int(fs / dF) + if n_per_seg > data.shape[axis]: + raise ValueError("frequency_resolution is too high for the given " + "data size") + elif len_segment is not None: + if len_segment <= 0: + raise ValueError("len_seg must be a positive number") + elif data.shape[axis] < len_segment: + raise ValueError("len_seg must be shorter than the data length") + n_per_seg = len_segment + else: + if n_segments <= 0: + raise ValueError("n_segments must be a positive number") + elif data.shape[axis] < n_segments: + raise ValueError("n_segments must be smaller than the data length") + # when only *n_segments* is given, *n_per_seg* is determined by solving + # the following equation: + # n_segments * n_per_seg - (n_segments-1) * overlap * n_per_seg = + # data.shape[-1] + # -------------------- =============================== ^^^^^^^^^^^ + # summed segment lengths total overlap data length + n_per_seg = int(data.shape[axis] / + (n_segments - overlap * (n_segments - 1))) + + n_overlap = int(n_per_seg * overlap) + n_segments = int((length_signal - n_overlap) / (n_per_seg - n_overlap)) + + if isinstance(peak_resolution, pq.quantity.Quantity): + peak_resolution = peak_resolution.rescale('Hz').magnitude + + # Determine time-halfbandwidth product from given parameters + if peak_resolution is not None: + if peak_resolution <= 0: + raise ValueError("peak_resolution must be positive") + nw = n_per_seg / fs * peak_resolution / 2 + num_tapers = int(np.floor(2*nw) - 1) + + if num_tapers is None: + num_tapers = int(np.floor(2*nw) - 1) + else: + if not isinstance(num_tapers, int): + raise TypeError("num_tapers must be integer") + if num_tapers <= 0: + raise ValueError("num_tapers must be positive") + + # Generate frequencies of PSD estimate + freqs = np.fft.rfftfreq(n_per_seg, d=1/fs) + + # Zero-pad signal to fit segment length + remainder = length_signal % n_per_seg + + if data.ndim == 1: + data = np.pad(data, pad_width=(0, remainder), + mode='constant', constant_values=0) + # Generate array for storing PSD estimates of segments + psd_estimates = np.zeros((n_segments, len(freqs))) + else: + data = np.pad(data, [(0, 0), (0, remainder)], + mode='constant', constant_values=0) + # Generate array for storing PSD estimates of segments + psd_estimates = np.zeros((n_segments, data.shape[0], len(freqs))) + + # Determine the number of samples given overlap + n_overlap_step = n_per_seg - n_overlap + + for i in range(n_segments): + # Get slepian functions (sym=False used for spectral analysis) + slepian_fcts = scipy.signal.windows.dpss(M=n_per_seg, + NW=nw, + Kmax=num_tapers, + sym=False) + + # Calculate approximately independent spectrum estimates + if data.ndim == 1: + tapered_signal = (data[i * n_overlap_step: + i * n_overlap_step + n_per_seg] + * slepian_fcts) + else: + # Use broadcasting to match dim for point-wise multiplication + tapered_signal = (data[:, + np.newaxis, + i * n_overlap_step: + i * n_overlap_step + n_per_seg] + * slepian_fcts) + + # Determine Fourier transform of tapered signal + spectrum_estimates = np.abs(np.fft.rfft(tapered_signal, axis=-1))**2 + spectrum_estimates[..., 1:] *= 2 + + # Average Fourier transform windowed signal + psd_segment = np.mean(spectrum_estimates, axis=-2) / fs + + psd_estimates[i] = psd_segment + + psd = np.mean(np.asarray(psd_estimates), axis=0) + + # Attach proper units to return values + if isinstance(signal, pq.quantity.Quantity): + psd = psd * signal.units * signal.units / pq.Hz + freqs = freqs * pq.Hz + + return freqs, psd + + @deprecated_alias(x='signal_i', y='signal_j', num_seg='n_segments', len_seg='len_segment', freq_res='frequency_resolution') def welch_coherence(signal_i, signal_j, n_segments=8, len_segment=None, diff --git a/elephant/test/test_spectral.py b/elephant/test/test_spectral.py index d4962834d..314f4baab 100644 --- a/elephant/test/test_spectral.py +++ b/elephant/test/test_spectral.py @@ -15,6 +15,7 @@ from numpy.testing import assert_array_almost_equal, assert_array_equal import elephant.spectral +from elephant.test.download import download, ELEPHANT_TMP_DIR class WelchPSDTestCase(unittest.TestCase): @@ -158,6 +159,153 @@ def test_welch_psd_multidim_input(self): self.assertTrue(np.all(psd_neo_1dim == psd_neo[0])) +class MultitaperPSDTestCase(unittest.TestCase): + def test_multitaper_psd_errors(self): + # generate dummy data + signal = n.AnalogSignal(np.zeros(5000), sampling_period=0.001 * pq.s, + units='mV') + fs = 1000 * pq.Hz + nw = 3 + + # check for invalid parameter values + # - number of tapers + self.assertRaises(ValueError, elephant.spectral.multitaper_psd, signal, + fs, nw, num_tapers=-5) + self.assertRaises(TypeError, elephant.spectral.multitaper_psd, signal, + fs, nw, num_tapers=-5.0) + # - frequency resolution + self.assertRaises(ValueError, elephant.spectral.multitaper_psd, signal, + fs, nw, peak_resolution=-1) + + def test_multitaper_psd_behavior(self): + # generate data by adding white noise and a sinusoid + data_length = 5000 + sampling_period = 0.001 + signal_freq = 100.0 + noise = np.random.normal(size=data_length) + signal = [np.sin(2 * np.pi * signal_freq * t) + for t in np.arange(0, data_length * sampling_period, + sampling_period)] + data = n.AnalogSignal(np.array(signal + noise), + sampling_period=sampling_period * pq.s, + units='mV') + + # consistency between different ways of specifying number of tapers + freqs1, psd1 = elephant.spectral.multitaper_psd(data, + fs=data.sampling_rate, + nw=3.5) + freqs2, psd2 = elephant.spectral.multitaper_psd(data, + fs=data.sampling_rate, + nw=3.5, + num_tapers=6) + self.assertTrue((psd1 == psd2).all() and (freqs1 == freqs2).all()) + + # frequency resolution and consistency with data + freq_res = 1.0 * pq.Hz + freqs, psd = elephant.spectral.multitaper_psd( + data, peak_resolution=freq_res) + self.assertEqual(freqs[psd.argmax()], signal_freq) + freqs_np, psd_np = elephant.spectral.multitaper_psd( + data.magnitude.flatten(), fs=1 / sampling_period, + peak_resolution=freq_res) + self.assertTrue((freqs == freqs_np).all() and (psd == psd_np).all()) + + def test_multitaper_psd_parameter_hierarchy(self): + # generate data by adding white noise and a sinusoid + data_length = 5000 + sampling_period = 0.001 + signal_freq = 100.0 + noise = np.random.normal(size=data_length) + signal = [np.sin(2 * np.pi * signal_freq * t) + for t in np.arange(0, data_length * sampling_period, + sampling_period)] + data = n.AnalogSignal(np.array(signal + noise), + sampling_period=sampling_period * pq.s, + units='mV') + + # Test num_tapers vs nw + freqs1, psd1 = elephant.spectral.multitaper_psd(data, + fs=data.sampling_rate, + nw=3, + num_tapers=9) + freqs2, psd2 = elephant.spectral.multitaper_psd(data, + fs=data.sampling_rate, + nw=3) + self.assertTrue((freqs1 == freqs2).all() and (psd1 != psd2).all()) + + # Test peak_resolution vs nw + freqs1, psd1 = elephant.spectral.multitaper_psd(data, + fs=data.sampling_rate, + nw=3, + num_tapers=9, + peak_resolution=1) + freqs2, psd2 = elephant.spectral.multitaper_psd(data, + fs=data.sampling_rate, + nw=3, + num_tapers=9) + self.assertTrue((freqs1 == freqs2).all() and (psd1 != psd2).all()) + + def test_multitaper_psd_against_nitime(self): + """ + This test assesses the match between this implementation of + multitaper against nitime (0.8) using a predefined time series + generated by an autoregressive model. + + Please follow the link below for more details: + https://gin.g-node.org/INM-6/elephant-data/src/master/unittest/spectral/multitaper_psd + """ + data_url = r"https://web.gin.g-node.org/INM-6/elephant-data/raw/master/unittest/spectral/multitaper_psd/data" # noqa + + files_to_download = [ + ("time_series.npy", "ff43797e2ac94613f510b20a31e2e80e"), + ("psd_nitime.npy", "89d1f53957e66c786049ea425b53c0e8") + ] + + for filename, checksum in files_to_download: + download(url=f"{data_url}/{filename}", checksum=checksum) + + time_series = np.load(ELEPHANT_TMP_DIR / 'time_series.npy') + psd_nitime = np.load(ELEPHANT_TMP_DIR / 'psd_nitime.npy') + + freqs, psd_multitaper = elephant.spectral.multitaper_psd( + signal=time_series, fs=0.1, nw=4, num_tapers=8) + + np.testing.assert_allclose(psd_multitaper, psd_nitime, rtol=0.3, + atol=0.1) + + def test_multitaper_psd_input_types(self): + # generate a test data + sampling_period = 0.001 + data = n.AnalogSignal(np.array(np.random.normal(size=5000)), + sampling_period=sampling_period * pq.s, + units='mV') + + # outputs from AnalogSignal input are of Quantity type (standard usage) + freqs_neo, psd_neo = elephant.spectral.multitaper_psd(data) + self.assertTrue(isinstance(freqs_neo, pq.quantity.Quantity)) + self.assertTrue(isinstance(psd_neo, pq.quantity.Quantity)) + + # outputs from Quantity array input are of Quantity type + freqs_pq, psd_pq = elephant.spectral.multitaper_psd( + data.magnitude.flatten() * data.units, fs=1 / sampling_period) + self.assertTrue(isinstance(freqs_pq, pq.quantity.Quantity)) + self.assertTrue(isinstance(psd_pq, pq.quantity.Quantity)) + + # outputs from Numpy ndarray input are NOT of Quantity type + freqs_np, psd_np = elephant.spectral.multitaper_psd( + data.magnitude.flatten(), fs=1 / sampling_period) + self.assertFalse(isinstance(freqs_np, pq.quantity.Quantity)) + self.assertFalse(isinstance(psd_np, pq.quantity.Quantity)) + + # check if the results from different input types are identical + self.assertTrue( + (freqs_neo == freqs_pq).all() and ( + psd_neo == psd_pq).all()) + self.assertTrue( + (freqs_neo == freqs_np).all() and ( + psd_neo == psd_np).all()) + + class WelchCohereTestCase(unittest.TestCase): def test_welch_cohere_errors(self): # generate a dummy data