Skip to content

Commit

Permalink
Feature/multitaper psd estimate (#458)
Browse files Browse the repository at this point in the history
New Feature: multitaper psd estimate. Tests include a comparison against nitime multitaper.

Co-authored-by: ackurth <[email protected]>
Co-authored-by: Regimantas Jurkus <[email protected]>
Co-authored-by: Aitor Morales-Gregorio <[email protected]>
Co-authored-by: kleinjohann <[email protected]>
Co-authored-by: pbouss <[email protected]>
Co-authored-by: Danylo Ulianych <[email protected]>
Co-authored-by: Alexander Kleinjohann <[email protected]>
Co-authored-by: Andrew Davison <[email protected]>
Co-authored-by: Anno Christopher Kurth <[email protected]>
  • Loading branch information
10 people authored Feb 25, 2022
1 parent 8a284ad commit b263ffe
Show file tree
Hide file tree
Showing 2 changed files with 394 additions and 0 deletions.
246 changes: 246 additions & 0 deletions elephant/spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,252 @@ def welch_psd(signal, n_segments=8, len_segment=None,
return freqs, psd


def multitaper_psd(signal, n_segments=1, len_segment=None,
frequency_resolution=None, overlap=0.5, fs=1,
nw=4, num_tapers=None, peak_resolution=None, axis=-1):
"""
Estimates power spectrum density (PSD) of a given 'neo.AnalogSignal'
using Multitaper method
The PSD is obtained through the following steps:
1. Cut the given data into several overlapping segments. The degree of
overlap can be specified by parameter `overlap` (default is 0.5,
i.e. segments are overlapped by the half of their length).
The number and the length of the segments are determined according
to the parameters `n_segments`, `len_segment` or `frequency_resolution`.
By default, the data is cut into 8 segments;
2. Calculate 'num_tapers' approximately independent estimates of the
spectrum by multiplying the signal with the discrete prolate spheroidal
functions (also known as Slepian function) and calculate the PSD of each
tapered segment
3. Average the approximately independent estimates of each segment to
decrease overall variance of the estimates
4. Average the obtained estimates for each segment
Parameters
----------
signal : neo.AnalogSignal
Time series data of which PSD is estimated. When `signal` is np.ndarray
sampling frequency should be given through keyword argument `fs`.
Signal should be passed as (n_channels, n_samples)
fs : float, optional
Specifies the sampling frequency of the input time series
Default: 1.0.
n_segments : int, optional
Number of segments. The length of segments is adjusted so that
overlapping segments cover the entire stretch of the given data. This
parameter is ignored if `len_segment` or `frequency_resolution` is
given.
Default: 8.
len_segment : int, optional
Length of segments. This parameter is ignored if `frequency_resolution`
is given. If None, it will be determined from other parameters.
Default: None.
frequency_resolution : pq.Quantity or float, optional
Desired frequency resolution of the obtained PSD estimate in terms of
the interval between adjacent frequency bins. When given as a `float`,
it is taken as frequency in Hz.
If None, it will be determined from other parameters.
Default: None.
overlap : float, optional
Overlap between segments represented as a float number between 0 (no
overlap) and 1 (complete overlap).
Default: 0.5 (half-overlapped).
nw : float, optional
Time bandwidth product
Default: 4.0.
num_tapers : int, optional
Number of tapers used in 1. to obtain estimate of PSD. By default
[2*nw] - 1 is chosen.
Default: None.
peak_resolution : pq.Quantity float, optional
Quantity in Hz determining the number of tapers used for analysis.
Fine peak resolution --> low numerical value --> low number of tapers
High peak resolution --> high numerical value --> high number of tapers
When given as a `float`, it is taken as frequency in Hz.
Default: None.
axis : int, optional
Axis along which the periodogram is computed.
See Notes [2].
Default: last axis (-1).
Notes
-----
1. There is a parameter hierarchy regarding n_segments and len_segment. The
former parameter is ignored if the latter one is passed.
2. There is a parameter hierarchy regarding nw, num_tapers and
peak_resolution. If peak_resolution is provided, it determines both nw
and the num_tapers. Specifying num_tapers has an effect only if
peak_resolution is not provided.
Returns
-------
freqs : np.ndarray
Frequencies associated with power estimate in `psd`
psd : np.ndarray
PSD estimate of the time series in `signal`
Raises
------
ValueError
If `peak_resolution` is None and `num_tapers` is not a positive number.
If `frequency_resolution` is too high for the given data size.
If `frequency_resolution` is None and `len_segment` is not a positive
number.
If `frequency_resolution` is None and `len_segment` is greater than the
length of data at `axis`.
If both `frequency_resolution` and `len_segment` are None and
`n_segments` is not a positive number.
If both `frequency_resolution` and `len_segment` are None and
`n_segments` is greater than the length of data at `axis`.
TypeError
If `peak_resolution` is None and `num_tapers` is not an int.
"""

# When the input is AnalogSignal, the data is added after rolling the axis
# for time index to the last
data = np.asarray(signal)
if isinstance(signal, neo.AnalogSignal):
data = np.rollaxis(data, 0, len(data.shape))

# Number of data points in time series
if data.ndim == 1:
length_signal = np.shape(data)[0]
else:
length_signal = np.shape(data)[1]

# If the data is given as AnalogSignal, use its attribute to specify the
# sampling frequency
if hasattr(signal, 'sampling_rate'):
fs = signal.sampling_rate.rescale('Hz').magnitude

# If fs and peak resolution is pq.Quantity, get magnitude
if isinstance(fs, pq.quantity.Quantity):
fs = fs.rescale('Hz').magnitude

# Determine length per segment - n_per_seg
if frequency_resolution is not None:
if frequency_resolution <= 0:
raise ValueError("frequency_resolution must be positive")
if isinstance(frequency_resolution, pq.quantity.Quantity):
dF = frequency_resolution.rescale('Hz').magnitude
else:
dF = frequency_resolution
n_per_seg = int(fs / dF)
if n_per_seg > data.shape[axis]:
raise ValueError("frequency_resolution is too high for the given "
"data size")
elif len_segment is not None:
if len_segment <= 0:
raise ValueError("len_seg must be a positive number")
elif data.shape[axis] < len_segment:
raise ValueError("len_seg must be shorter than the data length")
n_per_seg = len_segment
else:
if n_segments <= 0:
raise ValueError("n_segments must be a positive number")
elif data.shape[axis] < n_segments:
raise ValueError("n_segments must be smaller than the data length")
# when only *n_segments* is given, *n_per_seg* is determined by solving
# the following equation:
# n_segments * n_per_seg - (n_segments-1) * overlap * n_per_seg =
# data.shape[-1]
# -------------------- =============================== ^^^^^^^^^^^
# summed segment lengths total overlap data length
n_per_seg = int(data.shape[axis] /
(n_segments - overlap * (n_segments - 1)))

n_overlap = int(n_per_seg * overlap)
n_segments = int((length_signal - n_overlap) / (n_per_seg - n_overlap))

if isinstance(peak_resolution, pq.quantity.Quantity):
peak_resolution = peak_resolution.rescale('Hz').magnitude

# Determine time-halfbandwidth product from given parameters
if peak_resolution is not None:
if peak_resolution <= 0:
raise ValueError("peak_resolution must be positive")
nw = n_per_seg / fs * peak_resolution / 2
num_tapers = int(np.floor(2*nw) - 1)

if num_tapers is None:
num_tapers = int(np.floor(2*nw) - 1)
else:
if not isinstance(num_tapers, int):
raise TypeError("num_tapers must be integer")
if num_tapers <= 0:
raise ValueError("num_tapers must be positive")

# Generate frequencies of PSD estimate
freqs = np.fft.rfftfreq(n_per_seg, d=1/fs)

# Zero-pad signal to fit segment length
remainder = length_signal % n_per_seg

if data.ndim == 1:
data = np.pad(data, pad_width=(0, remainder),
mode='constant', constant_values=0)
# Generate array for storing PSD estimates of segments
psd_estimates = np.zeros((n_segments, len(freqs)))
else:
data = np.pad(data, [(0, 0), (0, remainder)],
mode='constant', constant_values=0)
# Generate array for storing PSD estimates of segments
psd_estimates = np.zeros((n_segments, data.shape[0], len(freqs)))

# Determine the number of samples given overlap
n_overlap_step = n_per_seg - n_overlap

for i in range(n_segments):
# Get slepian functions (sym=False used for spectral analysis)
slepian_fcts = scipy.signal.windows.dpss(M=n_per_seg,
NW=nw,
Kmax=num_tapers,
sym=False)

# Calculate approximately independent spectrum estimates
if data.ndim == 1:
tapered_signal = (data[i * n_overlap_step:
i * n_overlap_step + n_per_seg]
* slepian_fcts)
else:
# Use broadcasting to match dim for point-wise multiplication
tapered_signal = (data[:,
np.newaxis,
i * n_overlap_step:
i * n_overlap_step + n_per_seg]
* slepian_fcts)

# Determine Fourier transform of tapered signal
spectrum_estimates = np.abs(np.fft.rfft(tapered_signal, axis=-1))**2
spectrum_estimates[..., 1:] *= 2

# Average Fourier transform windowed signal
psd_segment = np.mean(spectrum_estimates, axis=-2) / fs

psd_estimates[i] = psd_segment

psd = np.mean(np.asarray(psd_estimates), axis=0)

# Attach proper units to return values
if isinstance(signal, pq.quantity.Quantity):
psd = psd * signal.units * signal.units / pq.Hz
freqs = freqs * pq.Hz

return freqs, psd


@deprecated_alias(x='signal_i', y='signal_j', num_seg='n_segments',
len_seg='len_segment', freq_res='frequency_resolution')
def welch_coherence(signal_i, signal_j, n_segments=8, len_segment=None,
Expand Down
Loading

0 comments on commit b263ffe

Please sign in to comment.