Under the Hood¶
We strongly recommend everyone to follow the procedure as instructed here when using WaLSAtools — a user-friendly tool — which gives you all information you need to do your analysis. However, for experts who want to make themselves familiar with the techniques and codes under the hood, inspect them and modify/develop/improve them, some of the main codes are also provided below. Please note that all codes and their dependencies are available in the GitHub repository.
Analysis Modules¶
WaLSAtools is built upon a collection of analysis modules, each designed for a specific aspect of wave analysis. These modules are combined and accessed through the main WaLSAtools
interface, providing a streamlined and user-friendly experience.
Here's a brief overview of the core analysis modules:
WaLSA_speclizer.py
This module provides a collection of spectral analysis techniques, including FFT, Lomb-Scargle, Wavelet, Welch, and EMD/HHT.
WaLSA_speclizer.py
# -----------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
# -----------------------------------------------------------------------------------------------------
import numpy as np # type: ignore
from astropy.timeseries import LombScargle # type: ignore
from tqdm import tqdm # type: ignore
from .WaLSA_detrend_apod import WaLSA_detrend_apod # type: ignore
from .WaLSA_confidence import WaLSA_confidence # type: ignore
from .WaLSA_wavelet import cwt, significance # type: ignore
from scipy.signal import welch # type: ignore
# -------------------------------------- Main Function ----------------------------------------
def WaLSA_speclizer(signal=None, time=None, method=None,
dominantfreq=False, averagedpower=False, **kwargs):
"""
Main function to prepare data and call the specific spectral analysis method.
Parameters:
signal (array): The input signal (1D or 3D).
time (array): The time array of the signal.
method (str): The method to use for spectral analysis ('fft', 'lombscargle', etc.)
**kwargs: Additional parameters for data preparation and analysis methods.
Returns:
Power spectrum, frequencies, and significance (if applicable).
"""
if not dominantfreq or not averagedpower:
if method == 'fft':
return getpowerFFT(signal, time=time, **kwargs)
elif method == 'lombscargle':
return getpowerLS(signal, time=time, **kwargs)
elif method == 'wavelet':
return getpowerWavelet(signal, time=time, **kwargs)
elif method == 'welch':
return welch_psd(signal, time=time, **kwargs)
elif method == 'emd':
return getEMD_HHT(signal, time=time, **kwargs)
else:
raise ValueError(f"Unknown method: {method}")
else:
return get_dominant_averaged(signal, time=time, method=method, **kwargs)
# -------------------------------------- FFT ----------------------------------------
def getpowerFFT(signal, time, **kwargs):
"""
Perform FFT analysis on the given signal.
Parameters:
signal (array): The input signal (1D).
time (array): The time array corresponding to the signal.
siglevel (float): Significance level for the confidence intervals. Default: 0.95.
nperm (int): Number of permutations for significance testing. Default: 1000.
nosignificance (bool): If True, skip significance calculation. Default: False.
apod (float): Extent of apodization edges (of a Tukey window). Default: 0.1.
pxdetrend (int): Subtract linear trend with time per pixel. Options: 1 (simple) or 2 (advanced). Default: 2.
polyfit (int): Degree of polynomial fit for detrending the data. If set, a polynomial fit (instead of linear) is applied. Default: None.
meantemporal (bool): If True, apply simple temporal detrending by subtracting the mean signal from the data, skipping fitting procedures. Default: False.
meandetrend (bool): If True, subtract the linear trend with time for the image means (spatial detrending). Default: False.
recon (bool): If True, perform Fourier reconstruction of the input time series. This does not preserve amplitudes but is useful for examining frequencies far from the low-frequency range. Default: False.
resample_original (bool): If True, and if recon set True, approximate values close to the original are returned for comparison. Default: False.
nodetrendapod (bool): If True, neither detrending nor apodization is performed. Default: False.
amplitude (bool): If True, return the amplitudes of the Fourier transform. Default: False.
silent (bool): If True, suppress print statements. Default: False.
**kwargs: Additional parameters for the analysis method.
Returns:
Power spectrum, frequencies, significance, amplitudes
"""
# Define default values for the optional parameters
defaults = {
'siglevel': 0.95,
'significance': None,
'nperm': 1000,
'nosignificance': False,
'apod': 0.1,
'pxdetrend': 2,
'meandetrend': False,
'polyfit': None,
'meantemporal': False,
'recon': False,
'resample_original': False,
'nodetrendapod': False,
'amplitude': False,
'silent': False
}
# Update defaults with any user-provided keyword arguments
params = {**defaults, **kwargs}
params['siglevel'] = 1 - params['siglevel'] # different convention
tdiff = np.diff(time)
cadence = np.median(tdiff)
# Perform detrending and apodization
if not params['nodetrendapod']:
apocube = WaLSA_detrend_apod(
signal,
apod=params['apod'],
meandetrend=params['meandetrend'],
pxdetrend=params['pxdetrend'],
polyfit=params['polyfit'],
meantemporal=params['meantemporal'],
recon=params['recon'],
cadence=cadence,
resample_original=params['resample_original'],
silent=params['silent']
)
else:
apocube = signal
nt = len(apocube) # Length of the time series (1D)
# Calculate the frequencies
frequencies = 1. / (cadence * 2) * np.arange(nt // 2 + 1) / (nt // 2)
frequencies = frequencies[1:]
powermap = np.zeros(len(frequencies))
signal = apocube
spec = np.fft.fft(signal)
power = 2 * np.abs(spec[1:len(frequencies) + 1]) ** 2
powermap[:] = power / frequencies[0]
if params['amplitude']:
amplitudes = np.zeros((len(signal), len(frequencies)), dtype=np.complex_)
amplitudes = spec[1:len(frequencies) + 1]
else:
amplitudes = None
# Calculate significance if requested
if not params['nosignificance']:
ps_perm = np.zeros((len(frequencies), params['nperm']))
for ip in range(params['nperm']):
perm_signal = np.random.permutation(signal) # Permuting the original signal
apocube = WaLSA_detrend_apod(
perm_signal,
apod=params['apod'],
meandetrend=params['meandetrend'],
pxdetrend=params['pxdetrend'],
polyfit=params['polyfit'],
meantemporal=params['meantemporal'],
recon=params['recon'],
cadence=cadence,
resample_original=params['resample_original'],
silent=True
)
perm_spec = np.fft.fft(perm_signal)
perm_power = 2 * np.abs(perm_spec[1:len(frequencies) + 1]) ** 2
ps_perm[:, ip] = perm_power
# Correct call to WaLSA_confidence
significance = WaLSA_confidence(ps_perm, siglevel=params['siglevel'], nf=len(frequencies))
significance = significance / frequencies[0]
else:
significance = None
if not params['silent']:
print("FFT processed.")
return powermap, frequencies, significance, amplitudes
# -------------------------------------- Lomb-Scargle ----------------------------------------
def getpowerLS(signal, time, **kwargs):
"""
Perform Lomb-Scargle analysis on the given signal.
Parameters:
signal (array): The input signal (1D).
time (array): The time array corresponding to the signal.
siglevel (float): Significance level for the confidence intervals. Default: 0.95.
nperm (int): Number of permutations for significance testing. Default: 1000.
dy (array): Errors or observational uncertainties associated with the time series.
fit_mean (bool): If True, include a constant offset as part of the model at each frequency. This improves accuracy, especially for incomplete phase coverage.
center_data (bool): If True, pre-center the data by subtracting the weighted mean of the input data. This is especially important if fit_mean=False.
nterms (int): Number of terms to use in the Fourier fit. Default: 1.
normalization (str): The normalization method for the periodogram. Options: 'standard', 'model', 'log', 'psd'. Default: 'standard'.
nosignificance (bool): If True, skip significance calculation. Default: False.
apod (float): Extent of apodization edges (of a Tukey window). Default: 0.1.
pxdetrend (int): Subtract linear trend with time per pixel. Options: 1 (simple) or 2 (advanced). Default: 2.
polyfit (int): Degree of polynomial fit for detrending the data. If set, a polynomial fit (instead of linear) is applied. Default: None.
meantemporal (bool): If True, apply simple temporal detrending by subtracting the mean signal from the data, skipping fitting procedures. Default: False.
meandetrend (bool): If True, subtract the linear trend with time for the image means (spatial detrending). Default: False.
recon (bool): If True, perform Fourier reconstruction of the input time series. This does not preserve amplitudes but is useful for examining frequencies far from the low-frequency range. Default: False.
resample_original (bool): If True, and if recon set True, approximate values close to the original are returned for comparison. Default: False.
nodetrendapod (bool): If True, neither detrending nor apodization is performed. Default: False..
silent (bool): If True, suppress print statements. Default: False.
**kwargs: Additional parameters for the analysis method.
Returns:
Power spectrum, frequencies, and significance (if applicable).
"""
# Define default values for the optional parameters
defaults = {
'siglevel': 0.05,
'nperm': 1000,
'nosignificance': False,
'apod': 0.1,
'pxdetrend': 2,
'meandetrend': False,
'polyfit': None,
'meantemporal': False,
'recon': False,
'resample_original': False,
'silent': False, # Ensure 'silent' is included in defaults
'nodetrendapod': False,
'dy': None, # Error or sequence of observational errors associated with times t
'fit_mean': False, # If True include a constant offset as part of the model at each frequency. This can lead to more accurate results, especially in the case of incomplete phase coverage
'center_data': True, # If True pre-center the data by subtracting the weighted mean of the input data. This is especially important if fit_mean = False
'nterms': 1, # Number of terms to use in the Fourier fit
'normalization': 'standard' # The normalization to use for the periodogram. Options are 'standard', 'model', 'log', 'psd'
}
# Update defaults with any user-provided keyword arguments
params = {**defaults, **kwargs}
params['siglevel'] = 1 - params['siglevel'] # different convention
tdiff = np.diff(time)
cadence = np.median(tdiff)
# Perform detrending and apodization
if not params['nodetrendapod']:
apocube = WaLSA_detrend_apod(
signal,
apod=params['apod'],
meandetrend=params['meandetrend'],
pxdetrend=params['pxdetrend'],
polyfit=params['polyfit'],
meantemporal=params['meantemporal'],
recon=params['recon'],
cadence=cadence,
resample_original=params['resample_original'],
silent=params['silent']
)
else:
apocube = signal
frequencies, power = LombScargle(time, apocube, dy=params['dy'], fit_mean=params['fit_mean'],
center_data=params['center_data'], nterms=params['nterms'],
normalization=params['normalization']).autopower()
# Calculate significance if needed
if not params['nosignificance']:
ps_perm = np.zeros((len(frequencies), params['nperm']))
for ip in range(params['nperm']):
perm_signal = np.random.permutation(signal) # Permuting the original signal
apocubep = WaLSA_detrend_apod(
perm_signal,
apod=params['apod'],
meandetrend=params['meandetrend'],
pxdetrend=params['pxdetrend'],
polyfit=params['polyfit'],
meantemporal=params['meantemporal'],
recon=params['recon'],
cadence=cadence,
resample_original=params['resample_original'],
silent=True
)
frequencies, perm_power = LombScargle(time, apocubep, dy=params['dy'], fit_mean=params['fit_mean'],
center_data=params['center_data'], nterms=params['nterms'],
normalization=params['normalization']).autopower()
ps_perm[:, ip] = perm_power
significance = WaLSA_confidence(ps_perm, siglevel=params['siglevel'], nf=len(frequencies))
else:
significance = None
if not params['silent']:
print("Lomb-Scargle processed.")
return power, frequencies, significance
# -------------------------------------- Wavelet ----------------------------------------
def getpowerWavelet(signal, time, **kwargs):
"""
Perform wavelet analysis using the pycwt package.
Parameters:
signal (array): The input signal (1D).
time (array): The time array corresponding to the signal.
siglevel (float): Significance level for the confidence intervals. Default: 0.95.
nperm (int): Number of permutations for significance testing. Default: 1000.
mother (str): The mother wavelet function to use. Default: 'morlet'.
GWS (bool): If True, calculate the Global Wavelet Spectrum. Default: False.
RGWS (bool): If True, calculate the Refined Global Wavelet Spectrum (time-integrated power, excluding COI and insignificant areas). Default: False.
dj (float): Scale spacing. Smaller values result in better scale resolution but slower calculations. Default: 0.025.
s0 (float): Initial (smallest) scale of the wavelet. Default: 2 * dt.
J (int): Number of scales minus one. Scales range from s0 up to s0 * 2**(J * dj), giving a total of (J + 1) scales. Default: (log2(N * dt / s0)) / dj.
lag1 (float): Lag-1 autocorrelation. Default: 0.0.
apod (float): Extent of apodization edges (of a Tukey window). Default: 0.1.
pxdetrend (int): Subtract linear trend with time per pixel. Options: 1 (simple) or 2 (advanced). Default: 2.
polyfit (int): Degree of polynomial fit for detrending the data. If set, a polynomial fit (instead of linear) is applied. Default: None.
meantemporal (bool): If True, apply simple temporal detrending by subtracting the mean signal from the data, skipping fitting procedures. Default: False.
meandetrend (bool): If True, subtract the linear trend with time for the image means (spatial detrending). Default: False.
recon (bool): If True, perform Fourier reconstruction of the input time series. This does not preserve amplitudes but is useful for examining frequencies far from the low-frequency range. Default: False.
resample_original (bool): If True, and if recon set True, approximate values close to the original are returned for comparison. Default: False.
nodetrendapod (bool): If True, neither detrending nor apodization is performed. Default: False.
silent (bool): If True, suppress print statements. Default: False.
**kwargs: Additional parameters for the analysis method.
Returns:
power: The wavelet power spectrum.
periods: Corresponding periods.
sig_slevel: The significance levels.
coi: The cone of influence.
Optionally, if global_power=True:
global_power: Global wavelet power spectrum.
global_conf: Confidence levels for the global wavelet spectrum.
Optionally, if RGWS=True:
rgws_periods: Periods for the refined global wavelet spectrum.
rgws_power: Refined global wavelet power spectrum.
"""
# Define default values for the optional parameters similar to IDL
defaults = {
'siglevel': 0.95,
'mother': 'morlet', # Morlet wavelet as the mother function
'dj': 0.025, # Scale spacing
's0': -1, # Initial scale
'J': -1, # Number of scales
'lag1': 0.0, # Lag-1 autocorrelation
'apod': 0.1, # Tukey window apodization
'silent': False,
'pxdetrend': 2,
'meandetrend': False,
'polyfit': None,
'meantemporal': False,
'recon': False,
'resample_original': False,
'GWS': False, # If True, calculate global wavelet spectrum
'RGWS': False, # If True, calculate refined global wavelet spectrum (excluding COI)
'nperm': 1000, # Number of permutations for significance calculation
'nodetrendapod': False
}
# Update defaults with any user-provided keyword arguments
params = {**defaults, **kwargs}
tdiff = np.diff(time)
cadence = np.median(tdiff)
# Perform detrending and apodization
if not params['nodetrendapod']:
apocube = WaLSA_detrend_apod(
signal,
apod=params['apod'],
meandetrend=params['meandetrend'],
pxdetrend=params['pxdetrend'],
polyfit=params['polyfit'],
meantemporal=params['meantemporal'],
recon=params['recon'],
cadence=cadence,
resample_original=params['resample_original'],
silent=params['silent']
)
else:
apocube = signal
n = len(apocube)
dt = cadence
# Standardize the signal before the wavelet transform
std_signal = apocube.std()
norm_signal = apocube / std_signal
# Determine the initial scale s0 if not provided
if params['s0'] == -1:
params['s0'] = 2 * dt
# Determine the number of scales J if not provided
if params['J'] == -1:
params['J'] = int((np.log(float(n) * dt / params['s0']) / np.log(2)) / params['dj'])
# Perform wavelet transform
W, scales, frequencies, coi, _, _ = cwt(
norm_signal,
dt,
dj=params['dj'],
s0=params['s0'],
J=params['J'],
wavelet=params['mother']
)
power = np.abs(W) ** 2 # Wavelet power spectrum
periods = 1 / frequencies # Convert frequencies to periods
# Calculate the significance levels
signif, _ = significance(
1.0, # Normalized variance
dt,
scales,
0, # Ignored for now, used for background noise (e.g., red noise)
params['lag1'],
significance_level=params['siglevel'],
wavelet=params['mother']
)
# Calculate the significance level for each power value
sig_matrix = np.ones([len(scales), n]) * signif[:, None]
sig_slevel = power / sig_matrix
# Calculate global power if requested
if params['GWS']:
global_power = np.mean(power, axis=1) # Average along the time axis
dof = n - scales # the -scale corrects for padding at edges
global_conf, _ = significance(
norm_signal, # time series data
dt, # time step between values
scales, # scale vector
1, # sigtest = 1 for "time-average" test
0.0,
significance_level=params['siglevel'],
dof=dof,
wavelet=params['mother']
)
else:
global_power = None
global_conf = None
# Calculate refined global wavelet spectrum (RGWS) if requested
if params['RGWS']:
isig = sig_slevel < 1.0
power[isig] = np.nan
ipower = np.full_like(power, np.nan)
for i in range(len(coi)):
pcol = power[:, i]
valid_idx = periods < coi[i]
ipower[valid_idx, i] = pcol[valid_idx]
rgws_power = np.nansum(ipower, axis=1)
else:
rgws_power = None
if not params['silent']:
print("Wavelet (" + params['mother'] + ") processed.")
return power, periods, sig_slevel, coi, global_power, global_conf, rgws_power
# -------------------------------------- Welch ----------------------------------------
def welch_psd(signal, time, **kwargs):
"""
Calculate Welch Power Spectral Density (PSD) and significance levels.
Parameters:
signal (array): The 1D time series signal.
time (array): The time array corresponding to the signal.
nperseg (int, optional): Length of each segment for analysis. Default: 256.
noverlap (int, optional): Number of points to overlap between segments. Default: 128.
window (str, optional): Type of window function used in the Welch method. Default: 'hann'.
siglevel (float, optional): Significance level for confidence intervals. Default: 0.95.
nperm (int, optional): Number of permutations for significance testing. Default: 1000.
silent (bool, optional): If True, suppress print statements. Default: False.
**kwargs: Additional parameters for the analysis method.
Returns:
frequencies: Frequencies at which the PSD is estimated.
psd: Power Spectral Density values.
significance: Significance levels for the PSD.
"""
# Define default values for the optional parameters similar to FFT
defaults = {
'siglevel': 0.95,
'nperm': 1000, # Number of permutations for significance calculation
'window': 'hann', # Window type for Welch method
'nperseg': 256, # Length of each segment
'silent': False,
'noverlap': 128 # Number of points to overlap between segments
}
# Update defaults with any user-provided keyword arguments
params = {**defaults, **kwargs}
tdiff = np.diff(time)
cadence = np.median(tdiff)
# Calculate Welch PSD
frequencies, psd = welch(signal, fs=1.0 / cadence, window=params['window'], nperseg=params['nperseg'], noverlap=params['noverlap'])
# Calculate significance levels using permutation
ps_perm = np.zeros((len(frequencies), params['nperm']))
for ip in range(params['nperm']):
perm_signal = np.random.permutation(signal) # Permuting the original signal
_, perm_psd = welch(perm_signal, fs=1.0 / cadence, window=params['window'], nperseg=params['nperseg'], noverlap=params['noverlap'])
ps_perm[:, ip] = perm_psd
# Calculate significance level for Welch PSD
significance = np.percentile(ps_perm, params['siglevel'] * 100, axis=1)
if not params['silent']:
print("Welch processed.")
return psd, frequencies, significance
# -------------------------------------- EMD / EEMD ----------------------------------------
import numpy as np # type: ignore
from .PyEMD import EMD, EEMD # type: ignore
from scipy.stats import norm # type: ignore
from scipy.signal import hilbert, welch # type: ignore
from scipy.signal import find_peaks # type: ignore
from scipy.fft import fft, fftfreq # type: ignore
# Function to apply EMD and return Intrinsic Mode Functions (IMFs)
def apply_emd(signal, time):
emd = EMD()
emd.FIXE_H = 5
imfs = emd.emd(signal, time) # max_imfs=7, emd.FIXE_H = 10
return imfs
# Function to apply EEMD and return Intrinsic Mode Functions (IMFs)
def apply_eemd(signal, time, noise_std=0.2, num_realizations=1000):
eemd = EEMD()
eemd.FIXE_H = 5
eemd.noise_seed(12345)
eemd.noise_width = noise_std
eemd.ensemble_size = num_realizations
imfs = eemd.eemd(signal, time)
return imfs
# Function to generate white noise IMFs for significance testing
def generate_white_noise_imfs(signal_length, time, num_realizations, use_eemd=None, noise_std=0.2):
white_noise_imfs = []
if use_eemd:
eemd = EEMD()
eemd.FIXE_H = 5
eemd.noise_seed(12345)
eemd.noise_width = noise_std
eemd.ensemble_size = num_realizations
for _ in range(num_realizations):
white_noise = np.random.normal(size=signal_length)
imfs = eemd.eemd(white_noise, time)
white_noise_imfs.append(imfs)
else:
emd = EMD()
emd.FIXE_H = 5
for _ in range(num_realizations):
white_noise = np.random.normal(size=signal_length)
imfs = emd.emd(white_noise, time)
white_noise_imfs.append(imfs)
return white_noise_imfs
# Function to calculate the energy of each IMF
def calculate_imf_energy(imfs):
return [np.sum(imf**2) for imf in imfs]
# Function to determine the significance of the IMFs
def test_imf_significance(imfs, white_noise_imfs):
imf_energies = calculate_imf_energy(imfs)
num_imfs = len(imfs)
white_noise_energies = [calculate_imf_energy(imf_set) for imf_set in white_noise_imfs]
significance_levels = []
for i in range(num_imfs):
# Collect the i-th IMF energy from each white noise realization
white_noise_energy_dist = [imf_energies[i] for imf_energies in white_noise_energies if i < len(imf_energies)]
mean_energy = np.mean(white_noise_energy_dist)
std_energy = np.std(white_noise_energy_dist)
z_score = (imf_energies[i] - mean_energy) / std_energy
significance_level = 1 - norm.cdf(z_score)
significance_levels.append(significance_level)
return significance_levels
# Function to compute instantaneous frequency of IMFs using Hilbert transform
def compute_instantaneous_frequency(imfs, time):
instantaneous_frequencies = []
for imf in imfs:
analytic_signal = hilbert(imf)
instantaneous_phase = np.unwrap(np.angle(analytic_signal))
instantaneous_frequency = np.diff(instantaneous_phase) / (2.0 * np.pi * np.diff(time))
instantaneous_frequency = np.abs(instantaneous_frequency) # Ensure frequencies are positive
instantaneous_frequencies.append(instantaneous_frequency)
return instantaneous_frequencies
# Function to compute HHT power spectrum
def compute_hht_power_spectrum(imfs, instantaneous_frequencies, freq_bins):
power_spectrum = np.zeros_like(freq_bins)
for i in range(len(imfs)):
amplitude = np.abs(hilbert(imfs[i]))
power = amplitude**2
for j in range(len(instantaneous_frequencies[i])):
freq_idx = np.searchsorted(freq_bins, instantaneous_frequencies[i][j])
if freq_idx < len(freq_bins):
power_spectrum[freq_idx] += power[j]
return power_spectrum
# Function to compute significance level for HHT power spectrum
def compute_significance_level(white_noise_imfs, freq_bins, time):
all_power_spectra = []
for imfs in white_noise_imfs:
instantaneous_frequencies = compute_instantaneous_frequency(imfs, time)
power_spectrum = compute_hht_power_spectrum(imfs, instantaneous_frequencies, freq_bins)
all_power_spectra.append(power_spectrum)
# Compute the 95th percentile power spectrum as the significance level
significance_level = np.percentile(all_power_spectra, 95, axis=0)
return significance_level
## Custom rounding function
def custom_round(freq):
if freq < 1:
return round(freq, 1)
else:
return round(freq)
def smooth_power_spectrum(power_spectrum, window_size=5):
return np.convolve(power_spectrum, np.ones(window_size)/window_size, mode='same')
# Function to identify and print significant peaks
def identify_significant_peaks(freq_bins, power_spectrum, significance_level):
peaks, _ = find_peaks(power_spectrum, height=significance_level)
significant_frequencies = freq_bins[peaks]
rounded_frequencies = [custom_round(freq) for freq in significant_frequencies]
rounded_frequencies_two = [round(freq, 1) for freq in significant_frequencies]
print("Significant Frequencies (Hz):", rounded_frequencies_two)
print("Significant Frequencies (Hz):", rounded_frequencies)
return significant_frequencies
# Function to generate randomized signals
def generate_randomized_signals(signal, num_realizations):
randomized_signals = []
for _ in range(num_realizations):
randomized_signal = np.random.permutation(signal)
randomized_signals.append(randomized_signal)
return randomized_signals
# Function to calculate the FFT power spectrum for each IMF
def calculate_fft_psd_spectra(imfs, time):
psd_spectra = []
for imf in imfs:
N = len(imf)
T = time[1] - time[0] # Assuming uniform sampling
yf = fft(imf)
xf = fftfreq(N, T)[:N // 2]
psd = (2.0 / N) * (np.abs(yf[:N // 2]) ** 2) / (N * T)
psd_spectra.append((xf, psd))
return psd_spectra
# Function to calculate the FFT power spectrum for randomized signals
def calculate_fft_psd_spectra_randomized(imfs, time, num_realizations):
all_psd_spectra = []
for imf in imfs:
randomized_signals = generate_randomized_signals(imf, num_realizations)
for signal in randomized_signals:
N = len(signal)
T = time[1] - time[0] # Assuming uniform sampling
yf = fft(signal)
psd = (2.0 / N) * (np.abs(yf[:N // 2]) ** 2) / (N * T)
all_psd_spectra.append(psd)
return np.array(all_psd_spectra)
# Function to calculate the 95th percentile confidence level
def calculate_confidence_level(imfs, time, num_realizations):
confidence_levels = []
for imf in imfs:
all_psd_spectra = calculate_fft_psd_spectra_randomized([imf], time, num_realizations)
confidence_level = np.percentile(all_psd_spectra, 95, axis=0)
confidence_levels.append(confidence_level)
return confidence_levels
# Function to calculate the Welch PSD for each IMF
def calculate_welch_psd_spectra(imfs, fs):
psd_spectra = []
for imf in imfs:
f, psd = welch(imf, fs=fs)
psd_spectra.append((f, psd))
return psd_spectra
# Function to calculate the Welch PSD for randomized signals
def calculate_welch_psd_spectra_randomized(imfs, fs, num_realizations):
all_psd_spectra = []
for imf in imfs:
randomized_signals = generate_randomized_signals(imf, num_realizations)
for signal in randomized_signals:
f, psd = welch(signal, fs=fs)
all_psd_spectra.append(psd)
return np.array(all_psd_spectra)
# Function to calculate the 95th percentile confidence level
def calculate_confidence_level_welch(imfs, fs, num_realizations):
confidence_levels = []
for imf in imfs:
all_psd_spectra = calculate_welch_psd_spectra_randomized([imf], fs, num_realizations)
confidence_level = np.percentile(all_psd_spectra, 95, axis=0)
confidence_levels.append(confidence_level)
return confidence_levels
# ************** Main EMD/EEMD routine **************
def getEMD_HHT(signal, time, **kwargs):
"""
Calculate EMD/EEMD and HHT
Parameters:
signal (array): The input signal (1D).
time (array): The time array of the signal.
siglevel (float): Significance level for the confidence intervals. Default: 0.95.
nperm (int): Number of permutations for significance testing. Default: 1000.
EEMD (bool): If True, use Ensemble Empirical Mode Decomposition (EEMD) instead of Empirical Mode Decomposition (EMD). Default: False.
Welch_psd (bool): If True, calculate Welch PSD spectra instead of FFT PSD spectra (for the psd_spectra and psd_confidence_levels). Default: False.
apod (float): Extent of apodization edges (of a Tukey window). Default: 0.1.
pxdetrend (int): Subtract linear trend with time per pixel. Options: 1 (simple) or 2 (advanced). Default: 2.
polyfit (int): Degree of polynomial fit for detrending the data. If set, a polynomial fit (instead of linear) is applied. Default: None.
meantemporal (bool): If True, apply simple temporal detrending by subtracting the mean signal from the data, skipping fitting procedures. Default: False.
meandetrend (bool): If True, subtract the linear trend with time for the image means (spatial detrending). Default: False.
recon (bool): If True, perform Fourier reconstruction of the input time series. This does not preserve amplitudes but is useful for examining frequencies far from the low-frequency range. Default: False.
resample_original (bool): If True, and if recon is set to True, approximate values close to the original are returned for comparison. Default: False.
nodetrendapod (bool): If True, neither detrending nor apodization is performed. Default: False.
silent (bool): If True, suppress print statements. Default: False.
**kwargs: Additional parameters for the analysis method.
Returns:
frequencies: Frequencies at which the PSD is estimated.
psd: Power Spectral Density values.
significance: Significance levels for the PSD.
‘marginal’ spectrum of Hilbert-Huang Transform (HHT)
"""
# Define default values for the optional parameters similar to FFT
defaults = {
'siglevel': 0.95,
'nperm': 1000, # Number of permutations for significance calculation
'apod': 0.1, # Tukey window apodization
'silent': False,
'pxdetrend': 2,
'meandetrend': False,
'polyfit': None,
'meantemporal': False,
'recon': False,
'resample_original': False,
'EEMD': False, # If True, calculate Ensemble Empirical Mode Decomposition (EEMD)
'nodetrendapod': False,
'Welch_psd': False, # If True, calculate Welch PSD spectra
'significant_imfs': False # If True, return only significant IMFs (and for associated calculations)
}
# Update defaults with any user-provided keyword arguments
params = {**defaults, **kwargs}
tdiff = np.diff(time)
cadence = np.median(tdiff)
# Perform detrending and apodization
if not params['nodetrendapod']:
apocube = WaLSA_detrend_apod(
signal,
apod=params['apod'],
meandetrend=params['meandetrend'],
pxdetrend=params['pxdetrend'],
polyfit=params['polyfit'],
meantemporal=params['meantemporal'],
recon=params['recon'],
cadence=cadence,
resample_original=params['resample_original'],
silent=params['silent']
)
n = len(apocube)
dt = cadence
fs = 1 / dt # Sampling frequency
if params['EEMD']:
imfs = apply_eemd(signal, time)
use_eemd = True
else:
imfs = apply_emd(signal, time)
use_eemd = False
white_noise_imfs = generate_white_noise_imfs(len(signal), time, params['nperm'], use_eemd=use_eemd)
IMF_significance_levels = test_imf_significance(imfs, white_noise_imfs)
# Determine significant IMFs
significant_imfs_val = [imf for imf, sig_level in zip(imfs, IMF_significance_levels) if sig_level < params['siglevel']]
if params['significant_imfs']:
imfs = significant_imfs_val
# Compute instantaneous frequencies of IMFs
instantaneous_frequencies = compute_instantaneous_frequency(imfs, time)
if params['Welch_psd']:
# Calculate and plot Welch PSD spectra of (significant) IMFs
psd_spectra = calculate_welch_psd_spectra(imfs, fs)
# Calculate 95th percentile confidence levels for Welch PSD spectra
psd_confidence_levels = calculate_confidence_level_welch(imfs, fs, params['nperm'])
else:
# Calculate and plot FFT PSD spectra of (significant) IMFs
psd_spectra = calculate_fft_psd_spectra(imfs, time)
# Calculate 95th percentile confidence levels for FFT PSD spectra
psd_confidence_levels = calculate_confidence_level(imfs, time, params['nperm'])
# Define frequency bins for HHT power spectrum
max_freq = max([freq.max() for freq in instantaneous_frequencies])
HHT_freq_bins = np.linspace(0, max_freq, 100)
# Compute HHT power spectrum of (significant) IMFs
HHT_power_spectrum = compute_hht_power_spectrum(imfs, instantaneous_frequencies, HHT_freq_bins)
# Compute significance level for HHT power spectrum
HHT_significance_level = compute_significance_level(white_noise_imfs, HHT_freq_bins, time)
if not params['silent']:
if params['EEMD']:
print("EEMD processed.")
else:
print("EMD processed.")
return HHT_power_spectrum, HHT_significance_level, HHT_freq_bins, psd_spectra, psd_confidence_levels, imfs, IMF_significance_levels, instantaneous_frequencies
# ----------------------- Dominant Frequency & Averaged Power -------------------------------
def get_dominant_averaged(cube, time, **kwargs):
"""
Analyze a 3D data cube to compute the dominant frequency and averaged power.
Parameters:
cube (3D array): Input data cube, expected in either 'txy' or 'xyt' format.
time (array): Time array of the data cube.
method (str): Analysis method ('fft' or 'wavelet').
format (str): Format of the data cube ('txy' or 'xyt'). Default is 'txy'.
**kwargs: Additional parameters specific to the analysis method.
Returns:
dominantfreq (float): Dominant frequency of the data cube.
averagedpower (float): Averaged power of the data cube.
"""
defaults = {
'method': 'fft',
'format': 'txy',
'silent': False,
'GWS': False, # If True, calculate global wavelet spectrum
'RGWS': True, # If True, calculate refined global wavelet spectrum
}
# Update defaults with any user-provided keyword arguments
params = {**defaults, **kwargs}
# Check and adjust format if necessary
if params['format'] == 'xyt':
cube = np.transpose(cube, (2, 0, 1)) # Convert 'xyt' to 'txy'
elif params['format'] != 'txy':
raise ValueError("Unsupported format. Choose 'txy' or 'xyt'.")
# Initialize arrays for storing results across spatial coordinates
nt, nx, ny = cube.shape
dominantfreq = np.zeros((nx, ny))
if params['method'] == 'fft':
method_name = 'FFT'
elif params['method'] == 'lombscargle':
method_name = 'Lomb-Scargle'
elif params['method'] == 'wavelet':
method_name = 'Wavelet'
elif params['method'] == 'welch':
method_name = 'Welch'
if not params['silent']:
print(f"Processing {method_name} for a 3D cube with format '{params['format']}' and shape {cube.shape}.")
print(f"Calculating Dominant frequencies and/or averaged power spectrum ({method_name}) ....")
# Iterate over spatial coordinates and apply the chosen analysis method
if params['method'] == 'fft':
for x in tqdm(range(nx), desc="Processing x", leave=True):
for y in range(ny):
signal = cube[:, x, y]
fft_power, fft_freqs, _, _ = getpowerFFT(signal, time, silent=True, nosignificance=True, **kwargs)
if x == 0 and y == 0:
powermap = np.zeros((nx, ny, len(fft_freqs)))
powermap[x, y, :] = fft_power
# Determine the dominant frequency for this pixel
dominantfreq[x, y] = fft_freqs[np.argmax(fft_power)]
# Calculate the averaged power over all pixels
averagedpower = np.mean(powermap, axis=(0, 1))
frequencies = fft_freqs
print("\nAnalysis completed.")
elif params['method'] == 'lombscargle':
for x in tqdm(range(nx), desc="Processing x", leave=True):
for y in range(ny):
signal = cube[:, x, y]
ls_power, ls_freqs, _ = getpowerLS(signal, time, silent=True, **kwargs)
if x == 0 and y == 0:
powermap = np.zeros((nx, ny, len(ls_freqs)))
powermap[x, y, :] = ls_power
# Determine the dominant frequency for this pixel
dominantfreq[x, y] = ls_freqs[np.argmax(ls_power)]
# Calculate the averaged power over all pixels
averagedpower = np.mean(powermap, axis=(0, 1))
frequencies = ls_freqs
print("\nAnalysis completed.")
elif params['method'] == 'wavelet':
for x in tqdm(range(nx), desc="Processing x", leave=True):
for y in range(ny):
signal = cube[:, x, y]
if params['GWS']:
_, wavelet_periods, _, _, wavelet_power, _, _ = getpowerWavelet(signal, time, silent=True, **kwargs)
elif params['RGWS']:
_, wavelet_periods, _, _, _, _, wavelet_power = getpowerWavelet(signal, time, silent=True, **kwargs)
if x == 0 and y == 0:
powermap = np.zeros((nx, ny, len(wavelet_periods)))
wavelet_freq = 1. / wavelet_periods
powermap[x, y, :] = wavelet_power
# Determine the dominant frequency for this pixel
dominantfreq[x, y] = wavelet_freq[np.argmax(wavelet_power)]
# Calculate the averaged power over all pixels
averagedpower = np.mean(powermap, axis=(0, 1))
frequencies = wavelet_freq
print("\nAnalysis completed.")
elif params['method'] == 'welch':
for x in tqdm(range(nx), desc="Processing x", leave=True):
for y in range(ny):
signal = cube[:, x, y]
welch_power, welch_freqs, _ = welch_psd(signal, time, silent=True, **kwargs)
if x == 0 and y == 0:
powermap = np.zeros((nx, ny, len(welch_freqs)))
powermap[x, y, :] = welch_power
# Determine the dominant frequency for this pixel
dominantfreq[x, y] = welch_freqs[np.argmax(welch_power)]
# Calculate the averaged power over all pixels
averagedpower = np.mean(powermap, axis=(0, 1))
frequencies = welch_freqs
print("\nAnalysis completed.")
else:
raise ValueError("Unsupported method. Choose 'fft', 'lombscargle', 'wavelet', or 'welch'.")
return dominantfreq, averagedpower, frequencies, powermap
WaLSA_wavelet.py
This module implements the Wavelet Transform and related functionalities.
WaLSA_wavelet.py
# -----------------------------------------------------------------------------------------------------
# WaLSAtools: - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
# -----------------------------------------------------------------------------------------------------
# The following code is a modified version of the PyCWT package.
# The original code and licence:
# PyCWT is released under the open source 3-Clause BSD license:
#
# Copyright (c) 2023 Sebastian Krieger, Nabil Freij, and contributors. All rights
# reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#----------------------------------------------------------------------------------
# Modifications by Shahin Jafarzadeh, 2024
#----------------------------------------------------------------------------------
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import numpy as np # type: ignore
from tqdm import tqdm # type: ignore
from scipy.stats import chi2 # type: ignore
# Try to import the Python wrapper for FFTW.
try:
import pyfftw.interfaces.scipy_fftpack as fft # type: ignore
from multiprocessing import cpu_count
# Fast planning, use all available threads.
_FFTW_KWARGS_DEFAULT = {'planner_effort': 'FFTW_ESTIMATE',
'threads': cpu_count()}
def fft_kwargs(signal, **kwargs):
"""Return optimized keyword arguments for FFTW"""
kwargs.update(_FFTW_KWARGS_DEFAULT)
kwargs['n'] = len(signal) # do not pad
return kwargs
# Otherwise, fall back to 2 ** n padded scipy FFTPACK
except ImportError:
import scipy.fftpack as fft # type: ignore
# Can be turned off, e.g. for MKL optimizations
_FFT_NEXT_POW2 = True
def fft_kwargs(signal, **kwargs):
"""Return next higher power of 2 for given signal to speed up FFT"""
if _FFT_NEXT_POW2:
return {'n': int(2 ** np.ceil(np.log2(len(signal))))}
from scipy.signal import lfilter # type: ignore
from os import makedirs
from os.path import exists, expanduser
def find(condition):
"""Returns the indices where ravel(condition) is true."""
res, = np.nonzero(np.ravel(condition))
return res
def ar1(x):
"""
Allen and Smith autoregressive lag-1 autocorrelation coefficient.
In an AR(1) model
x(t) - <x> = gamma(x(t-1) - <x>) + \alpha z(t) ,
where <x> is the process mean, gamma and \alpha are process
parameters and z(t) is a Gaussian unit-variance white noise.
Parameters
----------
x : numpy.ndarray, list
Univariate time series
Returns
-------
g : float
Estimate of the lag-one autocorrelation.
a : float
Estimate of the noise variance [var(x) ~= a**2/(1-g**2)]
mu2 : float
Estimated square on the mean of a finite segment of AR(1)
noise, mormalized by the process variance.
References
----------
[1] Allen, M. R. and Smith, L. A. Monte Carlo SSA: detecting
irregular oscillations in the presence of colored noise.
*Journal of Climate*, **1996**, 9(12), 3373-3404.
<http://dx.doi.org/10.1175/1520-0442(1996)009<3373:MCSDIO>2.0.CO;2>
[2] http://www.madsci.org/posts/archives/may97/864012045.Eg.r.html
"""
x = np.asarray(x)
N = x.size
xm = x.mean()
x = x - xm
# Estimates the lag zero and one covariance
c0 = x.transpose().dot(x) / N
c1 = x[0:N-1].transpose().dot(x[1:N]) / (N - 1)
# According to A. Grinsteds' substitutions
B = -c1 * N - c0 * N**2 - 2 * c0 + 2 * c1 - c1 * N**2 + c0 * N
A = c0 * N**2
C = N * (c0 + c1 * N - c1)
D = B**2 - 4 * A * C
if D > 0:
g = (-B - D**0.5) / (2 * A)
else:
raise Warning('Cannot place an upperbound on the unbiased AR(1). '
'Series is too short or trend is to large.')
# According to Allen & Smith (1996), footnote 4
mu2 = -1 / N + (2 / N**2) * ((N - g**N) / (1 - g) -
g * (1 - g**(N - 1)) / (1 - g)**2)
c0t = c0 / (1 - mu2)
a = ((1 - g**2) * c0t) ** 0.5
return g, a, mu2
def ar1_spectrum(freqs, ar1=0.):
"""
Lag-1 autoregressive theoretical power spectrum.
Parameters
----------
freqs : numpy.ndarray, list
Frequencies at which to calculate the theoretical power
spectrum.
ar1 : float
Autoregressive lag-1 correlation coefficient.
Returns
-------
Pk : numpy.ndarray
Theoretical discrete Fourier power spectrum of noise signal.
References
----------
[1] http://www.madsci.org/posts/archives/may97/864012045.Eg.r.html
"""
# According to a post from the MadSci Network available at
# http://www.madsci.org/posts/archives/may97/864012045.Eg.r.html,
# the time-series spectrum for an auto-regressive model can be
# represented as
#
# P_k = \frac{E}{\left|1- \sum\limits_{k=1}^{K} a_k \, e^{2 i \pi
# \frac{k f}{f_s} } \right|^2}
#
# which for an AR1 model reduces to
#
freqs = np.asarray(freqs)
Pk = (1 - ar1 ** 2) / np.abs(1 - ar1 * np.exp(-2 * np.pi * 1j * freqs)) \
** 2
return Pk
def rednoise(N, g, a=1.):
"""
Red noise generator using filter.
Parameters
----------
N : int
Length of the desired time series.
g : float
Lag-1 autocorrelation coefficient.
a : float, optional
Noise innovation variance parameter.
Returns
-------
y : numpy.ndarray
Red noise time series.
"""
if g == 0:
yr = np.randn(N, 1) * a
else:
# Twice the decorrelation time.
tau = int(np.ceil(-2 / np.log(np.abs(g))))
yr = lfilter([1, 0], [1, -g], np.random.randn(N + tau, 1) * a)
yr = yr[tau:]
return yr.flatten()
def rect(x, normalize=False):
"""TODO: describe what I do."""
if type(x) in [int, float]:
shape = [x, ]
elif type(x) in [list, dict]:
shape = x
elif type(x) in [np.ndarray, np.ma.core.MaskedArray]:
shape = x.shape
X = np.zeros(shape)
X[0] = X[-1] = 0.5
X[1:-1] = 1
if normalize:
X /= X.sum()
return X
def boxpdf(x):
"""
Forces the probability density function of the input data to have
a boxed distribution.
Parameters
----------
x (array like) :
Input data
Returns
-------
X (array like) :
Boxed data varying between zero and one.
Bx, By (array like) :
Data lookup table.
"""
import numpy as np
x = np.asarray(x)
n = x.size
# Kind of 'unique'
i = np.argsort(x)
d = (np.diff(x[i]) != 0)
j = find(np.concatenate([d, [True]]))
X = x[i][j]
j = np.concatenate([[0], j + 1])
Y = 0.5 * (j[0:-1] + j[1:]) / n
bX = np.interp(x, X, Y)
return bX, X, Y
def get_cache_dir():
"""Returns the location of the cache directory."""
# Sets cache directory according to user home path.
cache_dir = '{}/.cache/pycwt/'.format(expanduser('~'))
# Creates cache directory if not existant.
if not exists(cache_dir):
makedirs(cache_dir)
# Returns cache directory.
return cache_dir
import numpy as np
from scipy.special import gamma
from scipy.signal import convolve2d
from scipy.special.orthogonal import hermitenorm
class Morlet(object):
"""Implements the Morlet wavelet class.
Note that the input parameters f and f0 are angular frequencies.
f0 should be more than 0.8 for this function to be correct, its
default value is f0 = 6.
"""
def __init__(self, f0=6):
self._set_f0(f0)
self.name = 'Morlet'
def psi_ft(self, f):
"""Fourier transform of the approximate Morlet wavelet."""
return (np.pi ** -0.25) * np.exp(-0.5 * (f - self.f0) ** 2)
def psi(self, t):
"""Morlet wavelet as described in Torrence and Compo (1998)."""
return (np.pi ** -0.25) * np.exp(1j * self.f0 * t - t ** 2 / 2)
def flambda(self):
"""Fourier wavelength as of Torrence and Compo (1998)."""
return (4 * np.pi) / (self.f0 + np.sqrt(2 + self.f0 ** 2))
def coi(self):
"""e-Folding Time as of Torrence and Compo (1998)."""
return 1. / np.sqrt(2)
def sup(self):
"""Wavelet support defined by the e-Folding time."""
return 1. / self.coi
def _set_f0(self, f0):
# Sets the Morlet wave number, the degrees of freedom and the
# empirically derived factors for the wavelet bases C_{\delta},
# gamma, \delta j_0 (Torrence and Compo, 1998, Table 2)
self.f0 = f0 # Wave number
self.dofmin = 2 # Minimum degrees of freedom
if self.f0 == 6:
self.cdelta = 0.776 # Reconstruction factor
self.gamma = 2.32 # Decorrelation factor for time averaging
self.deltaj0 = 0.60 # Factor for scale averaging
else:
self.cdelta = -1
self.gamma = -1
self.deltaj0 = -1
def smooth(self, W, dt, dj, scales):
"""Smoothing function used in coherence analysis.
Parameters
----------
W :
dt :
dj :
scales :
Returns
-------
T :
"""
# The smoothing is performed by using a filter given by the absolute
# value of the wavelet function at each scale, normalized to have a
# total weight of unity, according to suggestions by Torrence &
# Webster (1999) and by Grinsted et al. (2004).
m, n = W.shape
# Filter in time.
k = 2 * np.pi * fft.fftfreq(fft_kwargs(W[0, :])['n'])
k2 = k ** 2
snorm = scales / dt
# Smoothing by Gaussian window (absolute value of wavelet function)
# using the convolution theorem: multiplication by Gaussian curve in
# Fourier domain for each scale, outer product of scale and frequency
F = np.exp(-0.5 * (snorm[:, np.newaxis] ** 2) * k2) # Outer product
smooth = fft.ifft(F * fft.fft(W, axis=1, **fft_kwargs(W[0, :])),
axis=1, # Along Fourier frequencies
**fft_kwargs(W[0, :], overwrite_x=True))
T = smooth[:, :n] # Remove possibly padded region due to FFT
if np.isreal(W).all():
T = T.real
# Filter in scale. For the Morlet wavelet it's simply a boxcar with
# 0.6 width.
wsize = self.deltaj0 / dj * 2
win = rect(int(np.round(wsize)), normalize=True)
T = convolve2d(T, win[:, np.newaxis], 'same') # Scales are "vertical"
return T
class Paul(object):
"""Implements the Paul wavelet class.
Note that the input parameter f is the angular frequency and that
the default order for this wavelet is m=4.
"""
def __init__(self, m=4):
self._set_m(m)
self.name = 'Paul'
# def psi_ft(self, f):
# """Fourier transform of the Paul wavelet."""
# return (2 ** self.m /
# np.sqrt(self.m * np.prod(range(2, 2 * self.m))) *
# f ** self.m * np.exp(-f) * (f > 0))
def psi_ft(self, f): # modified by SJ
"""Fourier transform of the Paul wavelet with limits to prevent underflow."""
expnt = -f
expnt[expnt < -100] = -100 # Apply the threshold to avoid extreme values
return (2 ** self.m /
np.sqrt(self.m * np.prod(range(2, 2 * self.m))) *
f ** self.m * np.exp(expnt) * (f > 0))
def psi(self, t):
"""Paul wavelet as described in Torrence and Compo (1998)."""
return (2 ** self.m * 1j ** self.m * np.prod(range(2, self.m - 1)) /
np.sqrt(np.pi * np.prod(range(2, 2 * self.m + 1))) *
(1 - 1j * t) ** (-(self.m + 1)))
def flambda(self):
"""Fourier wavelength as of Torrence and Compo (1998)."""
return 4 * np.pi / (2 * self.m + 1)
def coi(self):
"""e-Folding Time as of Torrence and Compo (1998)."""
return np.sqrt(2)
def sup(self):
"""Wavelet support defined by the e-Folding time."""
return 1 / self.coi
def _set_m(self, m):
# Sets the m derivative of a Gaussian, the degrees of freedom and the
# empirically derived factors for the wavelet bases C_{\delta},
# gamma, \delta j_0 (Torrence and Compo, 1998, Table 2)
self.m = m # Wavelet order
self.dofmin = 2 # Minimum degrees of freedom
if self.m == 4:
self.cdelta = 1.132 # Reconstruction factor
self.gamma = 1.17 # Decorrelation factor for time averaging
self.deltaj0 = 1.50 # Factor for scale averaging
else:
self.cdelta = -1
self.gamma = -1
self.deltaj0 = -1
class DOG(object):
"""Implements the derivative of a Guassian wavelet class.
Note that the input parameter f is the angular frequency and that
for m=2 the DOG becomes the Mexican hat wavelet, which is then
default.
"""
def __init__(self, m=2):
self._set_m(m)
self.name = 'DOG'
def psi_ft(self, f):
"""Fourier transform of the DOG wavelet."""
return (- 1j ** self.m / np.sqrt(gamma(self.m + 0.5)) * f ** self.m *
np.exp(- 0.5 * f ** 2))
def psi(self, t):
"""DOG wavelet as described in Torrence and Compo (1998).
The derivative of a Gaussian of order `n` can be determined using
the probabilistic Hermite polynomials. They are explicitly
written as:
Hn(x) = 2 ** (-n / s) * n! * sum ((-1) ** m) *
(2 ** 0.5 * x) ** (n - 2 * m) / (m! * (n - 2*m)!)
or in the recursive form:
Hn(x) = x * Hn(x) - nHn-1(x)
Source: http://www.ask.com/wiki/Hermite_polynomials
"""
p = hermitenorm(self.m)
return ((-1) ** (self.m + 1) * np.polyval(p, t) *
np.exp(-t ** 2 / 2) / np.sqrt(gamma(self.m + 0.5)))
def flambda(self):
"""Fourier wavelength as of Torrence and Compo (1998)."""
return (2 * np.pi / np.sqrt(self.m + 0.5))
def coi(self):
"""e-Folding Time as of Torrence and Compo (1998)."""
return 1 / np.sqrt(2)
def sup(self):
"""Wavelet support defined by the e-Folding time."""
return 1 / self.coi
def _set_m(self, m):
# Sets the m derivative of a Gaussian, the degrees of freedom and the
# empirically derived factors for the wavelet bases C_{\delta},
# gamma, \delta j_0 (Torrence and Compo, 1998, Table 2).
self.m = m # m-derivative
self.dofmin = 1 # Minimum degrees of freedom
if self.m == 2:
self.cdelta = 3.541 # Reconstruction factor
self.gamma = 1.43 # Decorrelation factor for time averaging
self.deltaj0 = 1.40 # Factor for scale averaging
elif self.m == 6:
self.cdelta = 1.966
self.gamma = 1.37
self.deltaj0 = 0.97
else:
self.cdelta = -1
self.gamma = -1
self.deltaj0 = -1
class MexicanHat(DOG):
"""Implements the Mexican hat wavelet class.
This class inherits the DOG class using m=2.
"""
def __init__(self):
self.name = 'Mexican Hat'
self._set_m(2)
def cwt(signal, dt, dj=1/12, s0=-1, J=-1, wavelet='morlet', freqs=None, pad=True):
"""Continuous wavelet transform of the signal at specified scales.
Parameters
----------
signal : numpy.ndarray, list
Input signal array.
dt : float
Sampling interval.
dj : float, optional
Spacing between discrete scales. Default value is 1/12.
Smaller values will result in better scale resolution, but
slower calculation and plot.
s0 : float, optional
Smallest scale of the wavelet. Default value is 2*dt.
J : float, optional
Number of scales less one. Scales range from s0 up to
s0 * 2**(J * dj), which gives a total of (J + 1) scales.
Default is J = (log2(N * dt / so)) / dj.
wavelet : instance of Wavelet class, or string
Mother wavelet class. Default is Morlet wavelet.
freqs : numpy.ndarray, optional
Custom frequencies to use instead of the ones corresponding
to the scales described above. Corresponding scales are
calculated using the wavelet Fourier wavelength.
pad : optional. Default is True. if set, then pad the time series
with enough zeroes to get N up to the next higher power of 2.
This prevents wraparound from the end of the time series to
the beginning, and also speeds up the FFT's used to do the
wavelet transform. This will not eliminate all edge effects.
(added by SJ)
Returns
-------
W : numpy.ndarray
Wavelet transform according to the selected mother wavelet.
Has (J+1) x N dimensions.
sj : numpy.ndarray
Vector of scale indices given by sj = s0 * 2**(j * dj),
j={0, 1, ..., J}.
freqs : array like
Vector of Fourier frequencies (in 1 / time units) that
corresponds to the wavelet scales.
coi : numpy.ndarray
Returns the cone of influence, which is a vector of N
points containing the maximum Fourier period of useful
information at that particular time. Periods greater than
those are subject to edge effects.
fft : numpy.ndarray
Normalized fast Fourier transform of the input signal.
fftfreqs : numpy.ndarray
Fourier frequencies (in 1/time units) for the calculated
FFT spectrum.
Example
-------
>> mother = wavelet.Morlet(6.)
>> wave, scales, freqs, coi, fft, fftfreqs = wavelet.cwt(signal,
0.25, 0.25, 0.5, 28, mother)
"""
wavelet = _check_parameter_wavelet(wavelet)
# Original signal length
n0 = len(signal)
# If no custom frequencies are set, then set default frequencies
# according to input parameters `dj`, `s0` and `J`. Otherwise, set wavelet
# scales according to Fourier equivalent frequencies.
# Zero-padding if enabled (added + the entire code further modified by SJ)
if pad:
next_power_of_two = int(2 ** np.ceil(np.log2(n0)))
signal_padded = np.zeros(next_power_of_two)
signal_padded[:n0] = signal - np.mean(signal) # Remove the mean and pad with zeros
else:
signal_padded = signal - np.mean(signal) # Remove the mean without padding
N = len(signal_padded) # Length of the padded signal
# Calculate scales and frequencies if not provided
if freqs is None:
# Smallest resolvable scale
# if s0 == -1:
# s0 = 2 * dt / wavelet.flambda()
# # Number of scales
# if J == -1:
# J = int(np.round(np.log2(N * dt / s0) / dj))
if s0 == -1:
s0 = 2 * dt
if J == -1:
J = int((np.log(float(N) * dt / s0) / np.log(2)) / dj)
# The scales as of Mallat 1999
sj = s0 * 2 ** (np.arange(0, J + 1) * dj)
# Fourier equivalent frequencies
freqs = 1 / (wavelet.flambda() * sj)
else:
# The wavelet scales using custom frequencies.
sj = 1 / (wavelet.flambda() * freqs)
# Fourier transform of the (padded) signal
signal_ft = fft.fft(signal_padded, **fft_kwargs(signal_padded))
# Fourier angular frequencies
ftfreqs = 2 * np.pi * fft.fftfreq(N, dt)
# Creates wavelet transform matrix as outer product of scaled transformed
# wavelets and transformed signal according to the convolution theorem.
# (i) Transform scales to column vector for outer product;
# (ii) Calculate 2D matrix [s, f] for each scale s and Fourier angular
# frequency f;
# (iii) Calculate wavelet transform;
sj_col = sj[:, np.newaxis]
psi_ft_bar = ((sj_col * ftfreqs[1] * N) ** .5 *
np.conjugate(wavelet.psi_ft(sj_col * ftfreqs)))
W = fft.ifft(signal_ft * psi_ft_bar, axis=1,
**fft_kwargs(signal_ft, overwrite_x=True))
# Trim the wavelet transform to original signal length if padded
if pad:
W = W[:, :n0] # Trim to the original signal length
# Checks for NaN in transform results and removes them from the scales if
# needed, frequencies and wavelet transform. Trims wavelet transform at
# length `n0`.
sel = np.invert(np.isnan(W).all(axis=1))
if np.any(sel):
sj = sj[sel]
freqs = freqs[sel]
W = W[sel, :]
# Determines the cone-of-influence. Note that it is returned as a function
# of time in Fourier periods. Uses triangualr Bartlett window with
# non-zero end-points.
coi = (n0 / 2 - np.abs(np.arange(0, n0) - (n0 - 1) / 2))
coi = wavelet.flambda() * wavelet.coi() * dt * coi
return (W[:, :n0], sj, freqs, coi, signal_ft[1:N//2] / N ** 0.5,
ftfreqs[1:N//2] / (2 * np.pi))
def icwt(W, sj, dt, dj=1/12, wavelet='morlet'):
"""Inverse continuous wavelet transform.
Parameters
----------
W : numpy.ndarray
Wavelet transform, the result of the `cwt` function.
sj : numpy.ndarray
Vector of scale indices as returned by the `cwt` function.
dt : float
Sample spacing.
dj : float, optional
Spacing between discrete scales as used in the `cwt`
function. Default value is 0.25.
wavelet : instance of Wavelet class, or string
Mother wavelet class. Default is Morlet
Returns
-------
iW : numpy.ndarray
Inverse wavelet transform.
Example
-------
>> mother = wavelet.Morlet()
>> wave, scales, freqs, coi, fft, fftfreqs = wavelet.cwt(var,
0.25, 0.25, 0.5, 28, mother)
>> iwave = wavelet.icwt(wave, scales, 0.25, 0.25, mother)
"""
wavelet = _check_parameter_wavelet(wavelet)
a, b = W.shape
c = sj.size
if a == c:
sj = (np.ones([b, 1]) * sj).transpose()
elif b == c:
sj = np.ones([a, 1]) * sj
else:
raise Warning('Input array dimensions do not match.')
# As of Torrence and Compo (1998), eq. (11)
iW = (dj * np.sqrt(dt) / (wavelet.cdelta * wavelet.psi(0)) *
(np.real(W) / np.sqrt(sj)).sum(axis=0))
return iW
def significance(signal, dt, scales, sigma_test=0, alpha=None,
significance_level=0.95, dof=-1, wavelet='morlet'):
"""Significance test for the one dimensional wavelet transform.
Parameters
----------
signal : array like, float
Input signal array. If a float number is given, then the
variance is assumed to have this value. If an array is
given, then its variance is automatically computed.
dt : float
Sample spacing.
scales : array like
Vector of scale indices given returned by `cwt` function.
sigma_test : int, optional
Sets the type of significance test to be performed.
Accepted values are 0 (default), 1 or 2. See notes below for
further details.
alpha : float, optional
Lag-1 autocorrelation, used for the significance levels.
Default is 0.0.
significance_level : float, optional
Significance level to use. Default is 0.95.
dof : variant, optional
Degrees of freedom for significance test to be set
according to the type set in sigma_test.
wavelet : instance of Wavelet class, or string
Mother wavelet class. Default is Morlet
Returns
-------
signif : array like
Significance levels as a function of scale.
fft_theor (array like):
Theoretical red-noise spectrum as a function of period.
Notes
-----
If sigma_test is set to 0, performs a regular chi-square test,
according to Torrence and Compo (1998) equation 18.
If set to 1, performs a time-average test (equation 23). In
this case, dof should be set to the number of local wavelet
spectra that where averaged together. For the global
wavelet spectra it would be dof=N, the number of points in
the time-series.
If set to 2, performs a scale-average test (equations 25 to
28). In this case dof should be set to a two element vector
[s1, s2], which gives the scale range that were averaged
together. If, for example, the average between scales 2 and
8 was taken, then dof=[2, 8].
"""
wavelet = _check_parameter_wavelet(wavelet)
try:
n0 = len(signal)
except TypeError:
n0 = 1
J = len(scales) - 1
dj = np.log2(scales[1] / scales[0])
if n0 == 1:
variance = signal
else:
variance = signal.std() ** 2
if alpha is None:
alpha, _, _ = ar1(signal)
period = scales * wavelet.flambda() # Fourier equivalent periods
freq = dt / period # Normalized frequency
dofmin = wavelet.dofmin # Degrees of freedom with no smoothing
Cdelta = wavelet.cdelta # Reconstruction factor
gamma_fac = wavelet.gamma # Time-decorrelation factor
dj0 = wavelet.deltaj0 # Scale-decorrelation factor
# Theoretical discrete Fourier power spectrum of the noise signal
# following Gilman et al. (1963) and Torrence and Compo (1998),
# equation 16.
def pk(k, a, N):
return (1 - a ** 2) / (1 + a ** 2 - 2 * a * np.cos(2 * np.pi * k / N))
fft_theor = pk(freq, alpha, n0)
fft_theor = variance * fft_theor # Including time-series variance
signif = fft_theor
try:
if dof == -1:
dof = dofmin
except ValueError:
pass
if sigma_test == 0: # No smoothing, dof=dofmin, TC98 sec. 4
dof = dofmin
# As in Torrence and Compo (1998), equation 18.
chisquare = chi2.ppf(significance_level, dof) / dof
signif = fft_theor * chisquare
elif sigma_test == 1: # Time-averaged significance
if len(dof) == 1:
dof = np.zeros(1, J+1) + dof
sel = find(dof < 1)
dof[sel] = 1
# As in Torrence and Compo (1998), equation 23:
dof = dofmin * (1 + (dof * dt / gamma_fac / scales) ** 2) ** 0.5
sel = find(dof < dofmin)
dof[sel] = dofmin # Minimum dof is dofmin
for n, d in enumerate(dof):
chisquare = chi2.ppf(significance_level, d) / d
signif[n] = fft_theor[n] * chisquare
elif sigma_test == 2: # Time-averaged significance
if len(dof) != 2:
raise Exception('DOF must be set to [s1, s2], '
'the range of scale-averages')
if Cdelta == -1:
raise ValueError('Cdelta and dj0 not defined '
'for {} with f0={}'.format(wavelet.name,
wavelet.f0))
s1, s2 = dof
sel = find((scales >= s1) & (scales <= s2))
navg = sel.size
if navg == 0:
raise ValueError('No valid scales between {} and {}.'.format(s1,
s2))
# As in Torrence and Compo (1998), equation 25.
Savg = 1 / sum(1. / scales[sel])
# Power-of-two mid point:
Smid = np.exp((np.log(s1) + np.log(s2)) / 2.)
# As in Torrence and Compo (1998), equation 28.
dof = (dofmin * navg * Savg / Smid) * \
((1 + (navg * dj / dj0) ** 2) ** 0.5)
# As in Torrence and Compo (1998), equation 27.
fft_theor = Savg * sum(fft_theor[sel] / scales[sel])
chisquare = chi2.ppf(significance_level, dof) / dof
# As in Torrence and Compo (1998), equation 26.
signif = (dj * dt / Cdelta / Savg) * fft_theor * chisquare
else:
raise ValueError('sigma_test must be either 0, 1, or 2.')
return signif, fft_theor
def xwt(y1, y2, dt, dj=1/12, s0=-1, J=-1, significance_level=0.95,
wavelet='morlet', normalize=False, no_default_signif=False):
"""Cross wavelet transform (XWT) of two signals.
The XWT finds regions in time frequency space where the time series
show high common power.
Parameters
----------
y1, y2 : numpy.ndarray, list
Input signal array to calculate cross wavelet transform.
dt : float
Sample spacing.
dj : float, optional
Spacing between discrete scales. Default value is 1/12.
Smaller values will result in better scale resolution, but
slower calculation and plot.
s0 : float, optional
Smallest scale of the wavelet. Default value is 2*dt.
J : float, optional
Number of scales less one. Scales range from s0 up to
s0 * 2**(J * dj), which gives a total of (J + 1) scales.
Default is J = (log2(N*dt/so))/dj.
wavelet : instance of a wavelet class, optional
Mother wavelet class. Default is Morlet wavelet.
significance_level : float, optional
Significance level to use. Default is 0.95.
normalize : bool, optional
If set to true, normalizes CWT by the standard deviation of
the signals.
Returns
-------
xwt (array like):
Cross wavelet transform according to the selected mother
wavelet.
x (array like):
Intersected independent variable.
coi (array like):
Cone of influence, which is a vector of N points containing
the maximum Fourier period of useful information at that
particular time. Periods greater than those are subject to
edge effects.
freqs (array like):
Vector of Fourier equivalent frequencies (in 1 / time units)
that correspond to the wavelet scales.
signif (array like):
Significance levels as a function of scale.
Notes
-----
Torrence and Compo (1998) state that the percent point function
(PPF) -- inverse of the cumulative distribution function -- of a
chi-square distribution at 95% confidence and two degrees of
freedom is Z2(95%)=3.999. However, calculating the PPF using
chi2.ppf gives Z2(95%)=5.991. To ensure similar significance
intervals as in Grinsted et al. (2004), one has to use confidence
of 86.46%.
"""
wavelet = _check_parameter_wavelet(wavelet)
# Makes sure input signal are numpy arrays.
y1 = np.asarray(y1)
y2 = np.asarray(y2)
# Calculates the standard deviation of both input signals.
std1 = y1.std()
std2 = y2.std()
# Normalizes both signals, if appropriate.
if normalize:
y1_normal = (y1 - y1.mean()) / std1
y2_normal = (y2 - y2.mean()) / std2
else:
y1_normal = y1
y2_normal = y2
# Calculates the CWT of the time-series making sure the same parameters
# are used in both calculations.
_kwargs = dict(dj=dj, s0=s0, J=J, wavelet=wavelet)
W1, sj, freq, coi, _, _ = cwt(y1_normal, dt, **_kwargs)
W2, sj, freq, coi, _, _ = cwt(y2_normal, dt, **_kwargs)
# Now the wavelet transform coherence
# W12ini = W1 * W2.conj()
# scales = np.ones([1, y1.size]) * sj[:, None]
# # -- Normalization by Scale and Smoothing
# W12 = wavelet.smooth(W12ini / scales, dt, dj, sj)
# Calculates the cross CWT of y1 and y2.
W12 = W1 * W2.conj()
# And the significance tests. Note that the confidence level is calculated
# using the percent point function (PPF) of the chi-squared cumulative
# distribution function (CDF) instead of using Z1(95%) = 2.182 and
# Z2(95%)=3.999 as suggested by Torrence & Compo (1998) and Grinsted et
# al. (2004). If the CWT has been normalized, then std1 and std2 should
# be reset to unity, otherwise the standard deviation of both series have
# to be calculated.
if normalize:
std1 = std2 = 1.
a1, _, _ = ar1(y1)
a2, _, _ = ar1(y2)
Pk1 = ar1_spectrum(freq * dt, a1)
Pk2 = ar1_spectrum(freq * dt, a2)
dof = wavelet.dofmin
if not no_default_signif:
PPF = chi2.ppf(significance_level, dof)
signif = (std1 * std2 * (Pk1 * Pk2) ** 0.5 * PPF / dof)
else:
signif = np.asarray([0])
# The resuts:
return W12, coi, freq, signif
def wct(y1, y2, dt, dj=1/12, s0=-1, J=-1, sig=False,
significance_level=0.95, wavelet='morlet', normalize=False, **kwargs):
"""Wavelet coherence transform (WCT).
The WCT finds regions in time frequency space where the two time
series co-vary, but do not necessarily have high power.
Parameters
----------
y1, y2 : numpy.ndarray, list
Input signals.
dt : float
Sample spacing.
dj : float, optional
Spacing between discrete scales. Default value is 1/12.
Smaller values will result in better scale resolution, but
slower calculation and plot.
s0 : float, optional
Smallest scale of the wavelet. Default value is 2*dt.
J : float, optional
Number of scales less one. Scales range from s0 up to
s0 * 2**(J * dj), which gives a total of (J + 1) scales.
Default is J = (log2(N*dt/so))/dj.
sig : bool
set to compute signficance, default is True
significance_level (float, optional) :
Significance level to use. Default is 0.95.
normalize (boolean, optional) :
If set to true, normalizes CWT by the standard deviation of
the signals.
Returns
-------
WCT : magnitude of coherence
aWCT : phase angle of coherence
coi (array like):
Cone of influence, which is a vector of N points containing
the maximum Fourier period of useful information at that
particular time. Periods greater than those are subject to
edge effects.
freq (array like):
Vector of Fourier equivalent frequencies (in 1 / time units) coi :
sig : Significance levels as a function of scale
if sig=True when called, otherwise zero.
See also
--------
cwt, xwt
"""
wavelet = _check_parameter_wavelet(wavelet)
nt = len(y1)
# Checking some input parameters
if s0 == -1:
# s0 = 2 * dt / wavelet.flambda()
s0 = 2 * dt
if J == -1:
# Number of scales
# J = int(np.round(np.log2(y1.size * dt / s0) / dj))
J = int((np.log(float(nt) * dt / s0) / np.log(2)) / dj)
# Makes sure input signals are numpy arrays.
y1 = np.asarray(y1)
y2 = np.asarray(y2)
# Calculates the standard deviation of both input signals.
std1 = y1.std()
std2 = y2.std()
# Normalizes both signals, if appropriate.
if normalize:
y1_normal = (y1 - y1.mean()) / std1
y2_normal = (y2 - y2.mean()) / std2
else:
y1_normal = y1
y2_normal = y2
# Calculates the CWT of the time-series making sure the same parameters
# are used in both calculations.
_kwargs = dict(dj=dj, s0=s0, J=J, wavelet=wavelet)
W1, sj, freq, coi, _, _ = cwt(y1_normal, dt, **_kwargs)
W2, sj, freq, coi, _, _ = cwt(y2_normal, dt, **_kwargs)
scales1 = np.ones([1, y1.size]) * sj[:, None]
scales2 = np.ones([1, y2.size]) * sj[:, None]
# Smooth the wavelet spectra before truncating -- Time Smoothing
S1 = wavelet.smooth(np.abs(W1) ** 2 / scales1, dt, dj, sj)
S2 = wavelet.smooth(np.abs(W2) ** 2 / scales2, dt, dj, sj)
# Now the wavelet transform coherence
W12 = W1 * W2.conj()
scales = np.ones([1, y1.size]) * sj[:, None]
# -- Normalization by Scale and Scale Smoothing
S12 = wavelet.smooth(W12 / scales, dt, dj, sj)
WCT = np.abs(S12) ** 2 / (S1 * S2)
aWCT = np.angle(W12)
# Calculates the significance using Monte Carlo simulations with 95%
# confidence as a function of scale.
if sig:
a1, b1, c1 = ar1(y1)
a2, b2, c2 = ar1(y2)
sig = wct_significance(a1, a2, dt=dt, dj=dj, s0=s0, J=J,
significance_level=significance_level,
wavelet=wavelet, **kwargs)
else:
sig = np.asarray([0])
return WCT, aWCT, coi, freq, sig
def wct_significance(al1, al2, dt, dj, s0, J, significance_level=0.95,
wavelet='morlet', mc_count=50, progress=True,
cache=True):
"""Wavelet coherence transform significance.
Calculates WCT significance using Monte Carlo simulations with
95% confidence.
Parameters
----------
al1, al2: float
Lag-1 autoregressive coeficients of both time series.
dt : float
Sample spacing.
dj : float, optional
Spacing between discrete scales. Default value is 1/12.
Smaller values will result in better scale resolution, but
slower calculation and plot.
s0 : float, optional
Smallest scale of the wavelet. Default value is 2*dt.
J : float, optional
Number of scales less one. Scales range from s0 up to
s0 * 2**(J * dj), which gives a total of (J + 1) scales.
Default is J = (log2(N*dt/so))/dj.
significance_level : float, optional
Significance level to use. Default is 0.95.
wavelet : instance of a wavelet class, optional
Mother wavelet class. Default is Morlet wavelet.
mc_count : integer, optional
Number of Monte Carlo simulations. Default is 300.
progress : bool, optional
If `True` (default), shows progress bar on screen.
cache : bool, optional
If `True` (default) saves cache to file.
Returns
-------
TODO
"""
if cache:
# Load cache if previously calculated. It is assumed that wavelet
# analysis is performed using the wavelet's default parameters.
aa = np.round(np.arctanh(np.array([al1, al2]) * 4))
aa = np.abs(aa) + 0.5 * (aa < 0)
cache_file = 'wct_sig_{:0.5f}_{:0.5f}_{:0.5f}_{:0.5f}_{:d}_{}'\
.format(aa[0], aa[1], dj, s0 / dt, J, wavelet.name)
cache_dir = get_cache_dir()
try:
dat = np.loadtxt('{}/{}.gz'.format(cache_dir, cache_file),
unpack=True)
print('NOTE: WCT significance loaded from cache.\n')
return dat
except IOError:
pass
# Some output to the screen
print('Calculating wavelet coherence significance')
# Choose N so that largest scale has at least some part outside the COI
ms = s0 * (2 ** (J * dj)) / dt
N = int(np.ceil(ms * 6))
noise1 = rednoise(N, al1, 1)
nW1, sj, freq, coi, _, _ = cwt(noise1, dt=dt, dj=dj, s0=s0, J=J,
wavelet=wavelet)
period = np.ones([1, N]) / freq[:, None]
coi = np.ones([J + 1, 1]) * coi[None, :]
outsidecoi = (period <= coi)
scales = np.ones([1, N]) * sj[:, None]
sig95 = np.zeros(J + 1)
maxscale = find(outsidecoi.any(axis=1))[-1]
sig95[outsidecoi.any(axis=1)] = np.nan
nbins = 1000
wlc = np.ma.zeros([J + 1, nbins])
# Displays progress bar with tqdm
for _ in tqdm(range(mc_count), disable=not progress):
# Generates two red-noise signals with lag-1 autoregressive
# coefficients given by al1 and al2
noise1 = rednoise(N, al1, 1)
noise2 = rednoise(N, al2, 1)
# Calculate the cross wavelet transform of both red-noise signals
kwargs = dict(dt=dt, dj=dj, s0=s0, J=J, wavelet=wavelet)
nW1, sj, freq, coi, _, _ = cwt(noise1, **kwargs)
nW2, sj, freq, coi, _, _ = cwt(noise2, **kwargs)
nW12 = nW1 * nW2.conj()
# Smooth wavelet wavelet transforms and calculate wavelet coherence
# between both signals.
S1 = wavelet.smooth(np.abs(nW1) ** 2 / scales, dt, dj, sj)
S2 = wavelet.smooth(np.abs(nW2) ** 2 / scales, dt, dj, sj)
S12 = wavelet.smooth(nW12 / scales, dt, dj, sj)
R2 = np.ma.array(np.abs(S12) ** 2 / (S1 * S2), mask=~outsidecoi)
# Walks through each scale outside the cone of influence and builds a
# coherence coefficient counter.
for s in range(maxscale):
cd = np.floor(R2[s, :] * nbins)
for j, t in enumerate(cd[~cd.mask]):
wlc[s, int(t)] += 1
# After many, many, many Monte Carlo simulations, determine the
# significance using the coherence coefficient counter percentile.
wlc.mask = (wlc.data == 0.)
R2y = (np.arange(nbins) + 0.5) / nbins
for s in range(maxscale):
sel = ~wlc[s, :].mask
P = wlc[s, sel].data.cumsum()
P = (P - 0.5) / P[-1]
sig95[s] = np.interp(significance_level, P, R2y[sel])
if cache:
# Save the results on cache to avoid to many computations in the future
np.savetxt('{}/{}.gz'.format(cache_dir, cache_file), sig95)
# And returns the results
return sig95
def _check_parameter_wavelet(wavelet):
mothers = {'morlet': Morlet, 'paul': Paul, 'dog': DOG,
'mexicanhat': MexicanHat}
# Checks if input parameter is a string. For backwards
# compatibility with Python 2 we check if instance is a
# `str`.
try:
if isinstance(wavelet, str):
return mothers[wavelet]()
except NameError:
if isinstance(wavelet, str):
return mothers[wavelet]()
# Otherwise, return itself.
return wavelet
WaLSA_k_omega.py
This module provides functions for performing k-ω analysis and filtering in spatio-temporal datasets.
WaLSA_k_omega.py
# --------------------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
# --------------------------------------------------------------------------------------------------------------
# The following codes are baed on those originally written by Rob Rutten, David B. Jess, and Samuel D. T. Grant
# --------------------------------------------------------------------------------------------------------------
import numpy as np # type: ignore
from scipy.optimize import curve_fit # type: ignore
from scipy.signal import convolve # type: ignore
# --------------------------------------------------------------------------------------------
def gaussian_function(sigma, width=None):
"""
Create a Gaussian kernel that closely matches the IDL implementation.
Parameters:
sigma (float or list of floats): Standard deviation(s) of the Gaussian.
width (int or list of ints, optional): Width of the Gaussian kernel.
Returns:
np.ndarray: The Gaussian kernel.
"""
# sigma = np.array(sigma, dtype=np.float64)
sigma = np.atleast_1d(np.array(sigma, dtype=np.float64)) # Ensure sigma is at least 1D
if np.any(sigma <= 0):
raise ValueError("Sigma must be greater than 0.")
# width = np.array(width)
width = np.atleast_1d(np.array(width)) # Ensure width is always at least 1D
width = np.maximum(width.astype(np.int64), 1) # Convert to integers and ensure > 0
if np.any(width <= 0):
raise ValueError("Width must be greater than 0.")
nSigma = np.size(sigma)
if nSigma > 8:
raise ValueError('Sigma can have no more than 8 elements')
nWidth = np.size(width)
if nWidth > nSigma:
raise ValueError('Incorrect width specification')
if (nWidth == 1) and (nSigma > 1):
width = np.full(nSigma, width[0])
# kernel = np.zeros(tuple(width.astype(int)), dtype=np.float64)
kernel = np.zeros(tuple(width[:nSigma].astype(int)), dtype=np.float64) # Match kernel size to nSigma
temp = np.zeros(8, dtype=np.int64)
temp[:nSigma] = width[:nSigma]
width = temp
if nSigma == 2:
b = a = 0
if nSigma >= 1:
a = np.linspace(0, width[0] - 1, width[0]) - width[0] // 2 + (0 if width[0] % 2 else 0.5)
if nSigma >= 2:
b = np.linspace(0, width[1] - 1, width[1]) - width[1] // 2 + (0 if width[1] % 2 else 0.5)
a1 = b1 = 0
# Nested loop for kernel computation (to be completed for larger nSigma ....)
for bb in range(width[1]):
b1 = (b[bb] ** 2) / (2 * sigma[1] ** 2) if nSigma >= 2 else 0
for aa in range(width[0]):
a1 = (a[aa] ** 2) / (2 * sigma[0] ** 2) if nSigma >= 1 else 0
kernel[aa, bb] = np.exp(
-np.clip((a1 + b1), -1e3, 1e3)
)
elif nSigma == 1:
a = 0
if nSigma >= 1:
a = np.linspace(0, width[0] - 1, width[0]) - width[0] // 2 + (0 if width[0] % 2 else 0.5)
a1 = 0
# Nested loop for kernel computation
for aa in range(width[0]):
a1 = (a[aa] ** 2) / (2 * sigma[0] ** 2) if nSigma >= 1 else 0
kernel[aa] = np.exp(
-np.clip(a1, -1e3, 1e3)
)
kernel = np.nan_to_num(kernel, nan=0.0, posinf=0.0, neginf=0.0)
if np.sum(kernel) == 0:
raise ValueError("Generated Gaussian kernel is invalid (all zeros).")
return kernel
def walsa_radial_distances(shape):
"""
Compute the radial distance array for a given shape.
Parameters:
shape (tuple): Shape of the array, typically (ny, nx).
Returns:
numpy.ndarray: Array of radial distances.
"""
if not (isinstance(shape, tuple) and len(shape) == 2):
raise ValueError("Shape must be a tuple with two elements, e.g., (ny, nx).")
y, x = np.indices(shape)
cy, cx = (np.array(shape) - 1) / 2
return np.sqrt((x - cx) ** 2 + (y - cy) ** 2)
def avgstd(array):
"""
Calculate the average and standard deviation of an array.
"""
avrg = np.sum(array) / array.size
stdev = np.sqrt(np.sum((array - avrg) ** 2) / array.size)
return avrg, stdev
def linear(x, *p):
"""
Compute a linear model y = p[0] + x * p[1].
(used in temporal detrending )
"""
if len(p) < 2:
raise ValueError("Parameters p[0] and p[1] must be provided.")
ymod = p[0] + x * p[1]
return ymod
def gradient(xy, *p):
"""
Gradient function for fitting spatial trends.
Parameters:
xy: Tuple of grid coordinates (x, y).
p: Coefficients [offset, slope_x, slope_y].
"""
x, y = xy # Unpack the tuple
if len(p) < 3:
raise ValueError("Parameters p[0], p[1], and p[2] must be provided.")
return p[0] + x * p[1] + y * p[2]
def apod3dcube(cube, apod):
"""
Apodizes a 3D cube in all three coordinates, with detrending.
Parameters:
cube : Input 3D data cube with dimensions (nx, ny, nt).
apod (float): Apodization factor (0 means no apodization).
Returns:
Apodized 3D cube.
"""
# Get cube dimensions
nt, nx, ny = cube.shape
apocube = np.zeros_like(cube, dtype=np.float32)
# Define temporal apodization
apodt = np.ones(nt, dtype=np.float32)
if apod != 0:
apodrimt = nt * apod
apodrimt = int(apodrimt) # Ensure apodrimt is an integer
apodt[:apodrimt] = (np.sin(np.pi / 2. * np.arange(apodrimt) / apodrimt)) ** 2
apodt = apodt * np.roll(np.flip(apodt), 1)
# Apply symmetrical apodization
# Temporal detrending (mean-image trend, not per pixel)
ttrend = np.zeros(nt, dtype=np.float32)
tf = np.arange(nt) + 1.0
for it in range(nt):
img = cube[it, :, :]
# ttrend[it], _ = avgstd(img)
ttrend[it] = np.mean(img)
# Fit the trend with a linear model
fitp, _ = curve_fit(linear, tf, ttrend, p0=[1000.0, 0.0])
# fit = fitp[0] + tf * fitp[1]
fit = linear(tf, *fitp)
# Temporal apodization per (x, y) column
for it in range(nt):
img = cube[it, :, :]
apocube[it, :, :] = (img - fit[it]) * apodt[it]
# Define spatial apodization
apodx = np.ones(nx, dtype=np.float32)
apody = np.ones(ny, dtype=np.float32)
if apod != 0:
apodrimx = apod * nx
apodrimy = apod * ny
apodrimx = int(apodrimx) # Ensure apodrimx is an integer
apodrimy = int(apodrimy) # Ensure apodrimy is an integer
apodx[:apodrimx] = (np.sin(np.pi / 2. * np.arange(int(apodrimx)) / apodrimx)) ** 2
apody[:apodrimy] = (np.sin(np.pi / 2. * np.arange(int(apodrimy)) / apodrimy)) ** 2
apodx = apodx * np.roll(np.flip(apodx), 1)
apody = apody * np.roll(np.flip(apody), 1)
apodxy = np.outer(apodx, apody)
else:
apodxy = np.outer(apodx, apody)
# Spatial gradient removal and apodizing per image
# xf, yf = np.meshgrid(np.arange(nx), np.arange(ny), indexing='ij')
xf = np.ones((nx, ny), dtype=np.float32)
yf = np.copy(xf)
for it in range(nt):
img = apocube[it, :, :]
# avg, _ = avgstd(img)
avg = np.mean(img)
# Ensure xf, yf, and img are properly defined
assert xf.shape == yf.shape == img.shape, "xf, yf, and img must have matching shapes."
# Flatten the inputs for curve_fit
x_flat = xf.ravel()
y_flat = yf.ravel()
img_flat = img.ravel()
# Ensure lengths match
assert len(x_flat) == len(y_flat) == len(img_flat), "Flattened inputs must have the same length."
# Call curve_fit with initial parameters
fitp, _ = curve_fit(gradient, (x_flat, y_flat), img_flat, p0=[1000.0, 0.0, 0.0])
# Apply the fitted parameters
fit = gradient((xf, yf), *fitp)
apocube[it, :, :] = (img - fit) * apodxy + avg
return apocube
def ko_dist(sx, sy, double=False):
"""
Set up Pythagorean distance array from the origin.
"""
# Create distance grids for x and y
# (computing dx and dy using floating-point division)
dx = np.tile(np.arange(sx / 2 + 1) / (sx / 2), (int(sy / 2 + 1), 1)).T
dy = np.flip(np.tile(np.arange(sy / 2 + 1) / (sy / 2), (int(sx / 2 + 1), 1)), axis=0)
# Compute dxy
dxy = np.sqrt(dx**2 + dy**2) * (min(sx, sy) / 2 + 1)
# Initialize afstanden
afstanden = np.zeros((sx, sy), dtype=np.float64)
# Assign dxy to the first quadrant (upper-left)
afstanden[:sx // 2 + 1, :sy // 2 + 1] = dxy
# Second quadrant (upper-right) - 90° clockwise
afstanden[sx // 2:, :sy // 2 + 1] = np.flip(np.roll(dxy[:-1, :], shift=-1, axis=1), axis=1)
# Third quadrant (lower-left) - 270° clockwise
afstanden[:sx // 2 + 1, sy // 2:] = np.flip(np.roll(dxy[:, :-1], shift=-1, axis=0), axis=0)
# Fourth quadrant (lower-right) - 180° rotation
afstanden[sx // 2:, sy // 2:] = np.flip(dxy[:-1, :-1], axis=(0, 1))
# Convert to integers if 'double' is False
if not double:
afstanden = np.round(afstanden).astype(int)
return afstanden
def averpower(cube):
"""
Compute 2D (k_h, f) power array by circular averaging over k_x, k_y.
Parameters:
cube : Input 3D data cube of dimensions (nx, ny, nt).
Returns:
avpow : 2D array of average power over distances, dimensions (maxdist+1, nt/2+1).
"""
# Get cube dimensions
nt, nx, ny = cube.shape
# Perform FFT in all three dimensions (first in time direction)
fftcube = np.fft.fft(np.fft.fft(np.fft.fft(cube, axis=0)[:nt//2+1, :, :], axis=1), axis=2)
# Set up distances
afstanden = ko_dist(nx, ny) # Integer-rounded Pythagoras array
maxdist = min(nx, ny) // 2 + 1 # Largest quarter circle
# Initialize average power array
avpow = np.zeros((maxdist + 1, nt // 2 + 1), dtype=np.float64)
# Compute average power over all k_h distances, building power(k_h, f)
for i in range(maxdist + 1):
where_indices = np.where(afstanden == i)
for j in range(nt // 2 + 1):
w1 = fftcube[j, :, :][where_indices]
avpow[i, j] = np.sum(np.abs(w1) ** 2) / len(where_indices)
return avpow
def walsa_kopower_funct(datacube, **kwargs):
"""
Calculate k-omega diagram
(Fourier power at temporal frequency f against horizontal spatial wavenumber k_h)
Origonally written in IDL by Rob Rutten (RR) assembly of Alfred de Wijn's routines (2010)
- Translated into Pythn by Shahin Jafarzadeh (2024)
Parameters:
datacube: Input data cube [t, x, y].
arcsecpx (float): Spatial sampling in arcsec/pixel.
cadence (float): Temporal sampling in seconds.
apod: fractional extent of apodization edges; default 0.1
kmax: maximum k_h axis as fraction of Nyquist value, default 0.2
fmax: maximum f axis as fraction of Nyquist value, default 0.5
minpower: minimum of power plot range, default maxpower-5
maxpower: maximum of power plot range, default alog10(max(power))
Returns:
k-omega power map.
"""
# Set default parameter values
defaults = {
'apod': 0.1,
'kmax': 1.0,
'fmax': 1.0,
'minpower': None,
'maxpower': None
}
# Update defaults with user-provided values
params = {**defaults, **kwargs}
# Apodize the cube
apocube = apod3dcube(datacube, params['apod'])
# Compute radially-averaged power
avpow = averpower(apocube)
return avpow
# -------------------------------------- Main Function ----------------------------------------
def WaLSA_k_omega(signal, time=None, **kwargs):
"""
NAME: WaLSA_k_omega
part of -- WaLSAtools --
* Main function to calculate and plot k-omega diagram.
ORIGINAL CODE: QUEEns Fourier Filtering (QUEEFF) code
WRITTEN, ANNOTATED, TESTED AND UPDATED in IDL BY:
(1) Dr. David B. Jess
(2) Dr. Samuel D. T. Grant
The original code along with its manual can be downloaded at: https://bit.ly/37mx9ic
WaLSA_k_omega (in IDL): A lightly modified version of the original code (i.e., a few additional keywords added) by Dr. Shahin Jafarzadeh
- Translated into Pythn by Shahin Jafarzadeh (2024)
Parameters:
signal (array): Input datacube, normally in the form of [x, y, t] or [t, x, y]. Note that the input datacube must have identical x and y dimensions. If not, the datacube will be cropped accordingly.
time (array): Time array corresponding to the input datacube.
pixelsize (float): Spatial sampling of the input datacube. If not given, it is plotted in units of 'pixel'.
filtering (bool): If True, filtering is applied, and the filtered datacube (filtered_cube) is returned. Otherwise, None is returned. Default: False.
f1 (float): Optional lower (temporal) frequency to filter, in Hz.
f2 (float): Optional upper (temporal) frequency to filter, in Hz.
k1 (float): Optional lower (spatial) wavenumber to filter, in units of pixelsize^-1 (k = (2 * π) / wavelength).
k2 (float): Optional upper (spatial) wavenumber to filter, in units of pixelsize^-1.
spatial_torus (bool): If True, makes the annulus used for spatial filtering have a Gaussian-shaped profile, useful for preventing aliasing. Default: True.
temporal_torus (bool): If True, makes the temporal filter have a Gaussian-shaped profile, useful for preventing aliasing. Default: True.
no_spatial_filt (bool): If True, ensures no spatial filtering is performed on the dataset (i.e., only temporal filtering is applied).
no_temporal_filt (bool): If True, ensures no temporal filtering is performed on the dataset (i.e., only spatial filtering is applied).
silent (bool): If True, suppresses the k-ω diagram plot.
smooth (bool): If True, power is smoothed. Default: True.
mode (int): Output power mode: 0 = log10(power) (default), 1 = linear power, 2 = sqrt(power) = amplitude.
processing_maps (bool): If True, the function returns the processing maps (spatial_fft_map, torus_map, spatial_fft_filtered_map, temporal_fft, temporal_filter, temporal_frequencies, spatial_frequencies). Otherwise, they are all returned as None. Default: False.
OUTPUTS:
power : 2D array of power (see mode for the scale).
frequencies : 1D array of frequencies (in mHz).
wavenumber : 1D array of wavenumber (in pixelsize^-1).
filtered_cube : 3D array of filtered datacube (if filtering is set).
processing_maps (if set to True)
IF YOU USE THIS CODE, THEN PLEASE CITE THE ORIGINAL PUBLICATION WHERE IT WAS USED:
Jess et al. 2017, ApJ, 842, 59 (http://adsabs.harvard.edu/abs/2017ApJ...842...59J)
"""
# Set default parameter values
defaults = {
'pixelsize': 1,
'format': 'txy',
'filtering': False,
'f1': None,
'f2': None,
'k1': None,
'k2': None,
'spatial_torus': True,
'temporal_torus': True,
'no_spatial_filt': False,
'no_temporal_filt': False,
'silent': False,
'xlog': False,
'ylog': False,
'xrange': None,
'yrange': None,
'nox2': False,
'noy2': False,
'smooth': True,
'mode': 0,
'xtitle': 'Wavenumber',
'xtitle_units': '(pixel⁻¹)',
'ytitle': 'Frequency',
'yttitle_units': '(Hz)',
'x2ndaxistitle': 'Spatial size',
'y2ndaxistitle': 'Period',
'x2ndaxistitle_units': '(pixel)',
'y2ndaxistitle_units': '(s)',
'processing_maps': False
}
# Update defaults with user-provided values
params = {**defaults, **kwargs}
tdiff = np.diff(time)
cadence = np.median(tdiff)
pixelsize = params['pixelsize']
filtered_cube = None
spatial_fft_map = None
torus_map = None
spatial_fft_filtered_map = None
temporal_fft = None
temporal_frequencies = None
spatial_frequencies = None
# Check and adjust format if necessary
if params['format'] == 'xyt':
cube = np.transpose(cube, (2, 0, 1)) # Convert 'xyt' to 'txy'
elif params['format'] != 'txy':
raise ValueError("Unsupported format. Choose 'txy' or 'xyt'.")
if not params['silent']:
print(f"Processing k-ω analysis for a 3D cube with format '{params['format']}' and shape {signal.shape}.")
# Input dimensions
nt, nx, ny = signal.shape
if nx != ny:
min_dim = min(nx, ny)
signal = signal[:min_dim, :min_dim, :]
nt, nx, ny = signal.shape
# Calculating the Nyquist frequencies
spatial_nyquist = (2 * np.pi) / (pixelsize * 2)
temporal_nyquist = 1 / (cadence * 2)
print("")
print(f"Input datacube size (t,x,y): {signal.shape}")
print("")
print("Spatially, the important values are:")
print(f" 2-pixel size = {pixelsize * 2:.2f} {params['x2ndaxistitle_units']}")
print(f" Nyquist wavenumber = {spatial_nyquist:.2f} {params['xtitle_units']}")
if params['no_spatial_filt']:
print("*** NO SPATIAL FILTERING WILL BE PERFORMED ***")
print("")
print("Temporally, the important values are:")
print(f" 2-element duration (Nyquist period) = {cadence * 2:.2f} {params['y2ndaxistitle_units']}")
print(f" Time series duration = {cadence * signal.shape[2]:.2f} {params['y2ndaxistitle_units']}")
temporal_nyquist = 1 / (cadence * 2)
print(f" Nyquist frequency = {temporal_nyquist:.2f} {params['yttitle_units']}")
if params['no_temporal_filt']:
print("***NO TEMPORAL FILTERING WILL BE PERFORMED***")
print("")
# Generate k-omega power map
print("Constructing a k-ω diagram of the input datacube..........")
print("")
# Make the k-omega diagram using the proven method of Rob Rutten
kopower = walsa_kopower_funct(signal)
# Scales
xsize_kopower = kopower.shape[0]
dxsize_kopower = spatial_nyquist / float(xsize_kopower - 1)
kopower_xscale = np.arange(xsize_kopower) * dxsize_kopower # in pixel⁻¹
ysize_kopower = kopower.shape[1]
dysize_kopower = temporal_nyquist / float(ysize_kopower - 1)
kopower_yscale = (np.arange(ysize_kopower) * dysize_kopower) # in Hz
# Generate Gaussian Kernel
Gaussian_kernel = gaussian_function(sigma=[0.65, 0.65], width=3)
Gaussian_kernel_norm = np.nansum(Gaussian_kernel) # Normalize kernel sum
# Copy kopower to kopower_plot
kopower_plot = kopower.copy()
# Convolve kopower (ignoring zero-th element, starting from index 1)
kopower_plot[:, 1:] = convolve(
kopower[:, 1:],
Gaussian_kernel,
mode='same'
) / Gaussian_kernel_norm
# Normalize to frequency resolution (in mHz)
freq = kopower_yscale[1:]
if freq[0] == 0:
freq0 = freq[1]
else:
freq0 = freq[0]
kopower_plot /= freq0
# Apply logarithmic or square root transformation based on mode
if params['mode'] == 0: # Logarithmic scaling
kopower_plot = np.log10(kopower_plot)
elif params['mode'] == 2: # Square root scaling
kopower_plot = np.sqrt(kopower_plot)
# Normalize the power - preferred for plotting
komegamap = np.clip(
kopower_plot[1:, 1:],
np.nanmin(kopower_plot[1:, 1:]),
np.nanmax(kopower_plot[1:, 1:])
)
kopower_zscale = kopower_plot[1:,1:]
# Rotate kopower counterclockwise by 90 degrees
komegamap = np.rot90(komegamap, k=1)
# Flip vertically (y-axis)
komegamap = np.flip(komegamap, axis=0)
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# Filtering implementation
if params['filtering']:
# Extract parameters from the dictionary
k1 = params['k1']
k2 = params['k2']
f1 = params['f1']
f2 = params['f2']
if params['no_spatial_filt']:
k1 = kopower_xscale[1] # Default lower wavenumber
k2 = np.nanmax(kopower_xscale) # Default upper wavenumber
# Ensure k1 and k2 are within valid bounds
if k1 is None or k1 <= 0.0:
k1 = kopower_xscale[1]
if k2 is None or k2 > np.nanmax(kopower_xscale):
k2 = np.nanmax(kopower_xscale)
if params['no_temporal_filt']:
f1 = kopower_yscale[1]
f2 = np.nanmax(kopower_yscale)
# Ensure f1 and f2 are within valid bounds
if f1 is None or f1 <= 0.0:
f1 = kopower_yscale[1]
if f2 is None or f2 > np.nanmax(kopower_yscale):
f2 = np.nanmax(kopower_yscale)
print("Start filtering (in k-ω space) ......")
print("")
print(f"The preserved wavenumbers are [{k1:.3f}, {k2:.3f}] {params['xtitle_units']}")
print(f"The preserved spatial sizes are [{(2 * np.pi) / k2:.3f}, {(2 * np.pi) / k1:.3f}] {params['x2ndaxistitle_units']}")
print("")
print(f"The preserved frequencies are [{f1:.3f}, {f2:.3f}] {params['yttitle_units']}")
print(f"The preserved periods are [{int(1 / (f2))}, {int(1 / (f1))}] {params['y2ndaxistitle_units']}")
print("")
# Perform the 3D Fourier transform
print("Making a 3D Fourier transform of the input datacube ..........")
threedft = np.fft.fftshift(np.fft.fftn(signal))
# Calculate the frequency axes for the 3D FFT
temp_x = np.arange((nx - 1) // 2) + 1
is_N_even = (nx % 2 == 0)
if is_N_even:
spatial_frequencies_orig = (
np.concatenate([[0.0], temp_x, [nx / 2], -nx / 2 + temp_x]) / (nx * pixelsize)) * (2.0 * np.pi)
else:
spatial_frequencies_orig = (
np.concatenate([[0.0], temp_x, [-(nx / 2 + 1)] + temp_x]) / (nx * pixelsize)) * (2.0 * np.pi)
temp_t = np.arange((nt - 1) // 2) + 1 # Use integer division for clarity
is_N_even = (nt % 2 == 0)
if is_N_even:
temporal_frequencies_orig = (
np.concatenate([[0.0], temp_t, [nt / 2], -nt / 2 + temp_t])) / (nt * cadence)
else:
temporal_frequencies_orig = (
np.concatenate([[0.0], temp_t, [-(nt / 2 + 1)] + temp_t])) / (nt * cadence)
# Compensate frequency axes for the use of FFT shift
indices = np.where(spatial_frequencies_orig >= 0)[0]
spatial_positive_frequencies = len(indices)
if len(spatial_frequencies_orig) % 2 == 0:
spatial_frequencies = np.roll(spatial_frequencies_orig, spatial_positive_frequencies - 2)
else:
spatial_frequencies = np.roll(spatial_frequencies_orig, spatial_positive_frequencies - 1)
tindices = np.where(temporal_frequencies_orig >= 0)[0]
temporal_positive_frequencies = len(tindices)
if len(temporal_frequencies_orig) % 2 == 0:
temporal_frequencies = np.roll(temporal_frequencies_orig, temporal_positive_frequencies - 2)
else:
temporal_frequencies = np.roll(temporal_frequencies_orig, temporal_positive_frequencies - 1)
# Ensure the threedft aligns with the new frequency axes
if len(temporal_frequencies_orig) % 2 == 0:
for x in range(nx):
for y in range(ny):
threedft[:, x, y] = np.roll(threedft[:, x, y], -1)
if len(spatial_frequencies_orig) % 2 == 0:
for z in range(nt):
threedft[z, :, :] = np.roll(threedft[z, :, :], shift=(-1, -1), axis=(0, 1))
# Convert frequencies and wavenumbers of interest into FFT datacube pixels
pixel_k1_positive = np.argmin(np.abs(spatial_frequencies_orig - k1))
pixel_k2_positive = np.argmin(np.abs(spatial_frequencies_orig - k2))
pixel_f1_positive = np.argmin(np.abs(temporal_frequencies - f1))
pixel_f2_positive = np.argmin(np.abs(temporal_frequencies - f2))
pixel_f1_negative = np.argmin(np.abs(temporal_frequencies + f1))
pixel_f2_negative = np.argmin(np.abs(temporal_frequencies + f2))
torus_depth = int((pixel_k2_positive - pixel_k1_positive) / 2) * 2
torus_center = int(((pixel_k2_positive - pixel_k1_positive) / 2) + pixel_k1_positive)
if params['spatial_torus'] and not params['no_spatial_filt']:
# Create a filter ring preserving equal wavenumbers for both kx and ky
# This forms a torus to preserve an integrated Gaussian shape across the width of the annulus
spatial_torus = np.zeros((torus_depth, nx, ny))
for i in range(torus_depth // 2 + 1):
spatial_ring = np.logical_xor(
(walsa_radial_distances((nx, ny)) <= (torus_center - i)),
(walsa_radial_distances((nx, ny)) <= (torus_center + i + 1))
)
spatial_ring = spatial_ring.astype(int) # Convert True -> 1 and False -> 0
spatial_ring[spatial_ring > 0] = 1.
spatial_ring[spatial_ring != 1] = 0.
spatial_torus[i, :, :] = spatial_ring
spatial_torus[torus_depth - i - 1, :, :] = spatial_ring
# Integrate through the torus to find the spatial filter
spatial_ring_filter = np.nansum(spatial_torus, axis=0) / float(torus_depth)
spatial_ring_filter = spatial_ring_filter / np.nanmax(spatial_ring_filter) # Ensure the peaks are at 1.0
if not params['spatial_torus'] and not params['no_spatial_filt']:
spatial_ring_filter = (
(walsa_radial_distances((nx, ny)) <= (torus_center - int(torus_depth / 2))).astype(int) -
(walsa_radial_distances((nx, ny)) <= (torus_center + int(torus_depth / 2) + 1)).astype(int)
)
spatial_ring_filter = spatial_ring_filter / np.nanmax(spatial_ring_filter) # Ensure the peaks are at 1.0
spatial_ring_filter[spatial_ring_filter != 1] = 0
if params['no_spatial_filt']:
spatial_ring_filter = np.ones((nx, ny))
if not params['no_temporal_filt'] and params['temporal_torus']:
# CREATE A GAUSSIAN TEMPORAL FILTER TO PREVENT ALIASING
temporal_filter = np.zeros(nt, dtype=float)
filter_width = pixel_f2_positive - pixel_f1_positive
# Determine sigma based on filter width
if filter_width < 25:
sigma = 3
if filter_width >= 25 and filter_width < 30:
sigma = 4
if filter_width >= 30 and filter_width < 40:
sigma = 5
if filter_width >= 40 and filter_width < 45:
sigma = 6
if filter_width >= 45 and filter_width < 50:
sigma = 7
if filter_width >= 50 and filter_width < 55:
sigma = 8
if filter_width >= 55 and filter_width < 60:
sigma = 9
if filter_width >= 60 and filter_width < 65:
sigma = 10
if filter_width >= 65 and filter_width < 70:
sigma = 11
if filter_width >= 70 and filter_width < 80:
sigma = 12
if filter_width >= 80 and filter_width < 90:
sigma = 13
if filter_width >= 90 and filter_width < 100:
sigma = 14
if filter_width >= 100 and filter_width < 110:
sigma = 15
if filter_width >= 110 and filter_width < 130:
sigma = 16
if filter_width >= 130:
sigma = 17
# Generate the Gaussian kernel
temporal_gaussian = gaussian_function(sigma=sigma, width=filter_width)
# Apply the Gaussian to the temporal filter
temporal_filter[pixel_f1_positive:pixel_f2_positive] = temporal_gaussian
temporal_filter[pixel_f2_negative:pixel_f1_negative] = temporal_gaussian
# Normalize the filter to ensure the peaks are at 1.0
temporal_filter /= np.nanmax(temporal_filter)
if not params['no_temporal_filt'] and not params['temporal_torus']:
temporal_filter = np.zeros(nt, dtype=float)
temporal_filter[pixel_f1_positive:pixel_f2_positive + 1] = 1.0
temporal_filter[pixel_f2_negative:pixel_f1_negative + 1] = 1.0
if params['no_temporal_filt']:
temporal_filter = np.ones(nt, dtype=float)
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# Create useful variables for plotting (if needed), demonstrating the filtering process
if params['processing_maps']:
# Define the spatial frequency step for plotting
spatial_dx = spatial_frequencies[1] - spatial_frequencies[0]
spatial_dy = spatial_frequencies[1] - spatial_frequencies[0]
# Torus map
torus_map = {
'data': spatial_ring_filter,
'dx': spatial_dx,
'dy': spatial_dy,
'xc': 0,
'yc': 0,
'time': '',
'units': 'pixels'
}
# Compute the total spatial FFT
spatial_fft = np.nansum(threedft, axis=0)
# Spatial FFT map
spatial_fft_map = {
'data': np.log10(spatial_fft),
'dx': spatial_dx,
'dy': spatial_dy,
'xc': 0,
'yc': 0,
'time': '',
'units': 'pixels'
}
# Spatial FFT filtered and its map
spatial_fft_filtered = spatial_fft * spatial_ring_filter
spatial_fft_filtered_map = {
'data': np.log10(np.maximum(spatial_fft_filtered, 1e-15)),
'dx': spatial_dx,
'dy': spatial_dy,
'xc': 0,
'yc': 0,
'time': '',
'units': 'pixels'
}
# Compute the total temporal FFT
temporal_fft = np.nansum(np.nansum(threedft, axis=2), axis=1)
else:
spatial_fft_map = None
torus_map = None
spatial_fft_filtered_map = None
temporal_fft = None
temporal_filter, temporal_frequencies, spatial_frequencies = None, None, None
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# Apply the Gaussian filters to the data to prevent aliasing
for i in range(nt):
threedft[i, :, :] *= spatial_ring_filter
for x in range(nx):
for y in range(ny):
threedft[:, x, y] *= temporal_filter
# ALSO NEED TO ENSURE THE threedft ALIGNS WITH THE OLD FREQUENCY AXES USED BY THE /center CALL
if len(temporal_frequencies_orig) % 2 == 0:
for x in range(nx):
for y in range(ny):
threedft[:, x, y] = np.roll(threedft[:, x, y], shift=1, axis=0)
if len(spatial_frequencies_orig) % 2 == 0:
for t in range(nt):
threedft[t, :, :] = np.roll(threedft[t, :, :], shift=(1, 1), axis=(0, 1))
threedft[z, :, :] = np.roll(np.roll(threedft[z, :, :], shift=1, axis=0), shift=1, axis=1)
# Inverse FFT to get the filtered cube
# filtered_cube = np.real(np.fft.ifftn(threedft, norm='ortho'))
filtered_cube = np.real(np.fft.ifftn(np.fft.ifftshift(threedft)))
else:
filtered_cube = None
print("Filtered datacube generated.")
return komegamap, kopower_xscale[1:], kopower_yscale[1:], filtered_cube, spatial_fft_map, torus_map, spatial_fft_filtered_map, temporal_fft, temporal_filter, temporal_frequencies, spatial_frequencies
WaLSA_pod.py
This module implements Proper Orthogonal Decomposition (POD), as well as Spectral POD (SPOD), for analysing multi-dimensional data and extracting dominant spatial patterns.
WaLSA_pod.py
# --------------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
# --------------------------------------------------------------------------------------------------------
# The following codes are based on those originally written by Jonathan E. Higham & Luiz A. C. A. Schiavo
# --------------------------------------------------------------------------------------------------------
import numpy as np # type: ignore
from scipy.signal import welch, find_peaks # type: ignore
from scipy.linalg import svd # type: ignore
from scipy.optimize import curve_fit # type: ignore
import math
def print_pod_results(results):
"""
Print a summary of the results from WaLSA_pod, including parameter descriptions, types, and shapes.
Parameters:
-----------
results : dict
The dictionary containing all POD results and relevant outputs.
"""
descriptions = {
'input_data': 'Original input data, mean subtracted (Shape: (Nt, Ny, Nx))',
'spatial_mode': 'Reshaped spatial modes matching the dimensions of the input data (Shape: (Nmodes, Ny, Nx))',
'temporal_coefficient': 'Temporal coefficients associated with each spatial mode (Shape: (Nmodes, Nt))',
'eigenvalue': 'Eigenvalues corresponding to singular values squared (Shape: (Nmodes))',
'eigenvalue_contribution': 'Eigenvalue contribution of each mode (Shape: (Nmodes))',
'cumulative_eigenvalues': 'Cumulative percentage of eigenvalues for the first "num_cumulative_modes" modes (Shape: (num_cumulative_modes))',
'combined_welch_psd': 'Combined Welch power spectral density for the temporal coefficients of the firts "num_modes" modes (Shape: (Nf))',
'frequencies': 'Frequencies identified in the Welch spectrum (Shape: (Nf))',
'combined_welch_significance': 'Significance threshold of the combined Welch spectrum (Shape: (Nf,))',
'reconstructed': 'Reconstructed frame at the specified timestep (or for the entire time series) using the top "num_modes" modes (Shape: (Ny, Nx))',
'sorted_frequencies': 'Frequencies identified in the Welch combined power spectrum (Shape: (Nfrequencies))',
'frequency_filtered_modes': 'Frequency-filtered spatial POD modes for the first "num_top_frequencies" frequencies (Shape: (Nt, Ny, Nx, num_top_frequencies))',
'frequency_filtered_modes_frequencies': 'Frequencies corresponding to the frequency-filtered modes (Shape: (num_top_frequencies))',
'SPOD_spatial_modes': 'SPOD spatial modes if SPOD is used (Shape: (Nspod_modes, Ny, Nx))',
'SPOD_temporal_coefficients': 'SPOD temporal coefficients if SPOD is used (Shape: (Nspod_modes, Nt))',
'p': 'Left singular vectors (spatial modes) from SVD (Shape: (Nx, Nmodes))',
's': 'Singular values from SVD (Shape: (Nmodes))',
'a': 'Right singular vectors (temporal coefficients) from SVD (Shape: (Nmodes, Nt))'
}
print("\n---- POD/SPOD Results Summary ----\n")
for key, value in results.items():
desc = descriptions.get(key, 'No description available')
shape = np.shape(value) if value is not None else 'None'
dtype = type(value).__name__
print(f"{key} ({dtype}, Shape: {shape}): {desc}")
print("\n----------------------------------")
def spod(data, **kwargs):
"""
Perform Spectral Proper Orthogonal Decomposition (SPOD) analysis on input data.
Steps:
1. Load the data.
2. Compute a time average.
3. Build the correlation matrix to compute the POD using the snapshot method.
4. Compute eigenvalue decomposition for the correlation matrix for the fluctuation field.
The eigenvalues and eigenvectors can be computed by any eigenvalue function or by an SVD procedure, since the correlation matrix is symmetric positive-semidefinite.
5. After obtaining the eigenvalues and eigenvectors, compute the temporal and spatial modes according to the snapshot method.
Parameters:
-----------
data : np.ndarray
3D data array with shape (time, x, y) or similar.
**kwargs : dict, optional
Additional keyword arguments to configure the analysis.
Returns:
--------
SPOD_spatial_modes : np.ndarray
SPOD_temporal_coefficients : np.ndarray
"""
# Set default parameter values
defaults = {
'silent': False,
'num_modes': None,
'filter_size': None,
'periodic_matrix': True
}
# Update defaults with user-provided values
params = {**defaults, **kwargs}
if params['num_modes'] is None:
params['num_modes'] = data.shape[0]
if params['filter_size'] is None:
params['filter_size'] = params['num_modes']
if not params['silent']:
print("Starting SPOD analysis ....")
nsnapshots, ny, nx = data.shape
# Center the data by subtracting the mean
time_average = data.mean(axis=0)
fluctuation_field = data * 0 # Initialize fluctuation field with zeros
for n in range(nsnapshots):
fluctuation_field[n,:,:] = data[n,:,:] - time_average # Subtract time-average from each snapshot
# Build the correlation matrix (snapshot method)
correlation_matrix = np.zeros((nsnapshots, nsnapshots))
for i in range(nsnapshots):
for j in range(i, nsnapshots):
correlation_matrix[i, j] = (fluctuation_field[i, :, :] * fluctuation_field[j, :, :]).sum() / (nsnapshots * nx * ny)
correlation_matrix[j, i] = correlation_matrix[i, j]
# SPOD correlation matrix with periodic boundary conditions
nmatrix = np.zeros((3 * nsnapshots, 3 * nsnapshots))
nmatrix[nsnapshots:2 * nsnapshots, nsnapshots:2 * nsnapshots] = correlation_matrix[0:nsnapshots, 0:nsnapshots]
if params['periodic_matrix']:
for i in range(3):
for j in range(3):
xs = i * nsnapshots # x-offset for periodic positioning
ys = j * nsnapshots # y-offset for periodic positioning
nmatrix[0 + xs:nsnapshots + xs, 0 + ys:nsnapshots + ys] = correlation_matrix[0:nsnapshots, 0:nsnapshots]
# Apply Gaussian filter
gk = np.zeros((2 * params['filter_size'] + 1)) # Create 1D Gaussian kernel
esp = 8.0 # Exponential parameter controlling the spread of the filter
sumgk = 0.0 # Sum of the Gaussian kernel for normalization
for n in range(2 * params['filter_size'] + 1):
k = -params['filter_size'] + n # Offset from the center of the kernel
gk[n] = math.exp(-esp * k**2.0 / (params['filter_size']**2.0)) # Gaussian filter formula
sumgk += gk[n] # Sum of kernel values for normalization
# Filter the extended correlation matrix
for j in range(nsnapshots, 2 * nsnapshots):
for i in range(nsnapshots, 2 * nsnapshots):
aux = 0.0 # Initialize variable for the weighted sum
for n in range(2 * params['filter_size'] + 1):
k = -params['filter_size'] + n # Offset from the current index
aux += nmatrix[i + k, j + k] * gk[n] # Apply Gaussian weighting
nmatrix[i, j] = aux / sumgk # Normalize by the sum of the Gaussian kernel
# Extract the SPOD correlation matrix from the central part of the filtered matrix
spectral_matrix = nmatrix[nsnapshots:2 * nsnapshots, nsnapshots:2 * nsnapshots]
# Perform SVD to compute SPOD
UU, SS, VV = svd(spectral_matrix, full_matrices=True, compute_uv=True)
# Temporal coefficients
SPOD_temporal_coefficients = np.zeros((nsnapshots, nsnapshots))
for k in range(nsnapshots):
SPOD_temporal_coefficients[:, k] = np.sqrt(SS[k] * nsnapshots) * UU[:, k]
# Extract spatial SPOD modes
SPOD_spatial_modes = np.zeros((params['num_modes'], ny, nx))
for m in range(params['num_modes']):
for t in range(nsnapshots):
SPOD_spatial_modes[m, :, :] += fluctuation_field[t, :, :] * SPOD_temporal_coefficients[t, m] / (SS[m] * nsnapshots)
if not params['silent']:
print(f"SPOD analysis completed.")
return SPOD_spatial_modes, SPOD_temporal_coefficients
# -------------------------------------- Main Function ----------------------------------------
def WaLSA_pod(signal, time, **kwargs):
"""
Perform Proper Orthogonal Decomposition (POD) analysis on input data.
Parameters:
signal (array): 3D data cube with shape (time, x, y) or similar.
time (array): 1D array representing the time points for each time step in the data.
num_modes (int, optional): Number of top modes to compute. Default is None (all modes).
num_top_frequencies (int, optional): Number of top frequencies to consider. Default is None (all frequencies).
top_frequencies (list, optional): List of top frequencies to consider. Default is None.
num_cumulative_modes (int, optional): Number of cumulative modes to consider. Default is None (all modes).
welch_nperseg (int, optional): Number of samples per segment for Welch's method. Default is 150.
welch_noverlap (int, optional): Number of overlapping samples for Welch's method. Default is 25.
welch_nfft (int, optional): Number of points for the FFT. Default is 2^14.
welch_fs (int, optional): Sampling frequency for the data. Default is 2.
nperm (int, optional): Number of permutations for significance testing. Default is 1000.
siglevel (float, optional): Significance level for the Welch spectrum. Default is 0.95.
timestep_to_reconstruct (int, optional): Timestep of the datacube to reconstruct using the top modes. Default is 0.
num_modes_reconstruct (int, optional): Number of modes to use for reconstruction. Default is None (all modes).
reconstruct_all (bool, optional): If True, reconstruct the entire time series using the top modes. Default is False.
spod (bool, optional): If True, perform Spectral Proper Orthogonal Decomposition (SPOD) analysis. Default is False.
spod_filter_size (int, optional): Filter size for SPOD analysis. Default is None.
spod_num_modes (int, optional): Number of SPOD modes to compute. Default is None.
print_results (bool, optional): If True, print a summary of results. Default is True.
**kwargs : Additional keyword arguments to configure the analysis.
Returns:
results : dict
A dictionary containing all computed POD results and relevant outputs.
See 'descriptions' in the 'print_pod_results' function on top of this page.
"""
# Set default parameter values
defaults = {
'silent': False,
'num_modes': None,
'num_top_frequencies': None,
'top_frequencies': None,
'num_cumulative_modes': None,
'welch_nperseg': 150,
'welch_noverlap': 25,
'welch_nfft': 2**14,
'welch_fs': 2,
'nperm': 1000,
'siglevel': 0.95,
'timestep_to_reconstruct': 0,
'reconstruct_all': False,
'num_modes_reconstruct': None,
'spod': False,
'spod_filter_size': None,
'spod_num_modes': None,
'print_results': True # Print summary of results by default
}
# Update defaults with user-provided values
params = {**defaults, **kwargs}
data = signal
if params['num_modes'] is None:
params['num_modes'] = data.shape[0]
if params['num_top_frequencies'] is None:
params['num_top_frequencies'] = min(params['num_modes'], 10)
if params['num_cumulative_modes'] is None:
params['num_cumulative_modes'] = min(params['num_modes'], 10)
if params['num_modes_reconstruct'] is None:
params['num_modes_reconstruct'] = min(params['num_modes'], 10)
if not params['silent']:
print("Starting POD analysis ....")
print(f"Processing a 3D cube with shape {data.shape}.")
# The first step is to read in the data and then perform the SVD (or POD). Before we do this,
# we need to reshape the matrix such that it is an N x T matrix where N is the column vectorized set of each spatial image and T is the temporal domain.
# We also need to ensure that the mean is subtracted from the data; this will ensure we are looking only at the variance of the data, and mode 1 will not be contaminated with the mean image.
# Reshape the 3D data into 2D array where each row is a vector from the original 3D data
inp = np.reshape(data, (data.shape[0], data.shape[1] * data.shape[2])).astype(np.float32)
inp = inp.T # Transpose the matrix to have spatial vectors as columns
# Center the data by subtracting the mean
mean_per_row = np.nanmean(inp, axis=1, keepdims=True)
mean_replicated = np.tile(mean_per_row, (1, data.shape[0]))
inp_centered = inp - mean_replicated
# Input data, mean subtracted
input_dat = np.zeros((data.shape[0], data.shape[1], data.shape[2]))
for im in range(data.shape[0]):
input_dat[im, :, :] = np.reshape(inp_centered[:, im], (data.shape[1], data.shape[2]))
# Perform SVD to compute POD
p, s, a = svd(inp_centered, full_matrices=False)
sorg = s.copy() # Store original singular values
eigenvalue = s**2 # Convert singular values to eigenvalues
# Reshape spatial modes to have the same shape as the input data
num_modes = p.shape[1]
spatial_mode = np.zeros((num_modes, data.shape[1], data.shape[2]))
for m in range(num_modes):
spatial_mode[m, :, :] = np.reshape(p[:, m], (data.shape[1], data.shape[2]))
# Extract temporal coefficients
temporal_coefficient = a[:num_modes, :]
# Calculate eigenvalue contributions
eigenvalue_contribution = eigenvalue / np.sum(eigenvalue)
# Calculate cumulative eigenvalues
cumulative_eigenvalues = []
for m in range(params['num_cumulative_modes']):
contm = 100 * eigenvalue[0:m] / np.sum(eigenvalue)
cumulative_eigenvalues.append(np.sum(contm))
if params['reconstruct_all']:
# Reconstruct the entire time series using the top 'num_modes_reconstruct' modes
reconstructed = np.zeros((data.shape[0], data.shape[1], data.shape[2]))
for tindex in range(data.shape[0]):
reconim = np.zeros((data.shape[1], data.shape[2]))
for i in range(params['num_modes_reconstruct']):
reconim=reconim+np.reshape(p[:, i], (data.shape[1], data.shape[2]))*a[i, tindex]*sorg[i]
reconstructed[tindex, :, :] = reconim
else:
# Reconstruct the specified timestep using the top 'num_modes_reconstruct' modes
reconstructed = np.zeros((data.shape[1], data.shape[2]))
for i in range(params['num_modes_reconstruct']):
reconstructed=reconstructed+np.reshape(p[:, i], (data.shape[1], data.shape[2]))*a[i, params['timestep_to_reconstruct']]*sorg[i]
#---------------------------------------------------------------------------------
# Combined Welch power spectrum and its significance
#---------------------------------------------------------------------------------
# Compute Welch power spectrum for each mode and combine them for 'num_modes' modes
combined_welch_psd = []
for m in range(params['num_modes']):
frequencies, px = welch(a[m, :] - np.mean(a[m, :]),
nperseg=params['welch_nperseg'],
noverlap=params['welch_noverlap'],
nfft=params['welch_nfft'],
fs=params['welch_fs'])
if m == 0:
combined_welch_psd = np.zeros((len(frequencies),))
combined_welch_psd += eigenvalue_contribution[m] * (px / np.sum(px))
# Generate resampled peaks to compute significance threshold
resampled_peaks = np.zeros((params['nperm'], len(frequencies)))
for i in range(params['nperm']):
resampled_data = np.random.randn(*a.shape)
resampled_peak = np.zeros((len(frequencies),))
for m in range(params['num_modes']):
_, px = welch(resampled_data[m, :] - np.mean(resampled_data[m, :]),
nperseg=params['welch_nperseg'],
noverlap=params['welch_noverlap'],
nfft=params['welch_nfft'],
fs=params['welch_fs'])
resampled_peak += eigenvalue_contribution[m] * (px / np.sum(px))
resampled_peak /= (np.max(resampled_peak) + 1e-30) # a small epsilon added to avoid division by zero
resampled_peaks[i, :] = resampled_peak
# Calculate significance threshold
combined_welch_significance = np.percentile(resampled_peaks, 100-(params['siglevel']*100), axis=0)
#---------------------------------------------------------------------------------
# Find peaks in the combined spectrum and sort them in descending order
normalized_peak = combined_welch_psd / (np.max(combined_welch_psd) + 1e-30)
peak_indices, _ = find_peaks(normalized_peak)
sorted_indices = np.argsort(normalized_peak[peak_indices])[::-1]
sorted_frequencies = frequencies[peak_indices][sorted_indices]
# Generate a table of the top N frequencies
# Cleaner single-line list comprehension
table_data = [
[float(np.round(freq, 2)), float(np.round(pwr, 2))]
for freq, pwr in zip(sorted_frequencies[:params['num_top_frequencies']],
normalized_peak[peak_indices][sorted_indices][:params['num_top_frequencies']])
]
# The frequencies that the POD is able to capture have now been identified (in the top 'num_top_frequencies' modes).
# These frequencies can be fitted to the temporal coefficients of the top 'num_top_frequencies' modes,
# allowing for a representation of the data described solely by these "pure" frequencies.
# This approach enables the reconstruction of the original data using the identified dominant frequencies,
# resulting in frequency-filtered spatial POD modes.
clean_data = np.zeros((inp.shape[0],inp.shape[1],10))
if params['top_frequencies'] is None:
top_frequencies = sorted_frequencies[:params['num_top_frequencies']]
else:
top_frequencies = params['top_frequencies']
params['num_top_frequencies'] = len(params['top_frequencies'])
for i in range(params['num_top_frequencies']):
def model_fun(t, amplitude, phase):
"""
Generate a sinusoidal model function.
Parameters:
t (array-like): The time variable.
amplitude (float): The amplitude of the sinusoidal function.
phase (float): The phase shift of the sinusoidal function.
Returns:
array-like: The computed sinusoidal values at each time point `t`.
"""
return amplitude * np.sin(2 * np.pi * top_frequencies[i] * t + phase)
for j in range(params['num_modes_reconstruct']):
csignal = a[j,:]
# Initial Guesses for Parameters: [Amplitude, Phase]
initial_guess = [1, 0]
# Nonlinear Fit
fit_params, _ = curve_fit(model_fun, time, csignal, p0=initial_guess)
# Clean Signal from Fit
if j == 0:
clean_signal=np.zeros((params['num_modes_reconstruct'],len(csignal)))
clean_signal[j,:] = model_fun(time, *fit_params)
# forming a set of clean data (reconstructed data at the fitted frequencies; frequency filtered reconstructed data)
clean_data[:,:,i] = p[:,0:params['num_modes_reconstruct']]@np.diag(sorg[0:params['num_modes_reconstruct']])@clean_signal
frequency_filtered_modes = np.zeros((data.shape[0], data.shape[1], data.shape[2], params['num_top_frequencies']))
for jj in range(params['num_top_frequencies']):
for frame_index in range(data.shape[0]):
frequency_filtered_modes[frame_index, :, :, jj] = np.reshape(clean_data[:, frame_index, jj], (data.shape[1], data.shape[2]))
if params['spod']:
SPOD_spatial_modes, SPOD_temporal_coefficients = spod(data, num_modes=params['spod_num_modes'], filter_size=params['spod_filter_size'])
else:
SPOD_spatial_modes = None
SPOD_temporal_coefficients = None
results = {
'input_data': input_dat, # Original input data, mean subtracted (Shape: (Nt, Ny, Nx))
'spatial_mode': spatial_mode, # POD spatial modes matching the dimensions of the input data (Shape: (Nmodes, Ny, Nx))
'temporal_coefficient': temporal_coefficient, # POD temporal coefficients associated with each spatial mode (Shape: (Nmodes, Nt))
'eigenvalue': eigenvalue, # Eigenvalues corresponding to singular values squared (Shape: (Nmodes))
'eigenvalue_contribution': eigenvalue_contribution, # Eigenvalue contribution of each mode (Shape: (Nmodes))
'cumulative_eigenvalues': cumulative_eigenvalues, # Cumulative percentage of eigenvalues for the first 'num_cumulative_modes' modes (Shape: (Ncumulative_modes))
'combined_welch_psd': combined_welch_psd, # Combined Welch power spectral density for the temporal coefficients (Shape: (Nf))
'frequencies': frequencies, # Frequencies identified in the Welch spectrum (Shape: (Nf))
'combined_welch_significance': combined_welch_significance, # Significance threshold of the combined Welch spectrum (Shape: (Nf))
'reconstructed': reconstructed, # Reconstructed frame at the specified timestep (or for the entire time series) using the top modes (Shape: (Ny, Nx))
'sorted_frequencies': sorted_frequencies, # Frequencies identified in the Welch spectrum (Shape: (Nfrequencies,))
'frequency_filtered_modes': frequency_filtered_modes, # Frequency-filtered spatial POD modes (Shape: (Nt, Ny, Nx, Ntop_frequencies))
'frequency_filtered_modes_frequencies': top_frequencies, # Frequencies corresponding to the frequency-filtered modes (Shape: (Ntop_frequencies))
'SPOD_spatial_modes': SPOD_spatial_modes, # SPOD spatial modes if SPOD is used (Shape: (Nspod_modes, Ny, Nx))
'SPOD_temporal_coefficients': SPOD_temporal_coefficients, # SPOD temporal coefficients if SPOD is used (Shape: (Nspod_modes, Nt))
'p': p, # Left singular vectors (spatial modes) from SVD (Shape: (Nx, Nmodes))
's': sorg, # Singular values from SVD (Shape: (Nmodes,))
'a': a # Right singular vectors (temporal coefficients) from SVD (Shape: (Nmodes, Nt))
}
if not params['silent']:
print(f"POD analysis completed.")
print(f"Top {params['num_top_frequencies']} frequencies and normalized power values:\n{table_data}")
print(f"Total variance contribution of the first {params['num_modes']} modes: {np.sum(100 * eigenvalue[:params['num_modes']] / np.sum(eigenvalue)):.2f}%")
if params['print_results']:
print_pod_results(results)
return results
WaLSA_cross_spectra.py
This module implements cross-correlation analysis techniques, resulting in cross-spectrum, coherence, and phase relationships, for investigating correlations between two time series.
WaLSA_cross_spectra.py
# -----------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
# -----------------------------------------------------------------------------------------------------
import numpy as np # type: ignore
from scipy.signal import coherence, csd # type: ignore
from .WaLSA_speclizer import WaLSA_speclizer # type: ignore
from .WaLSA_wavelet import xwt, wct # type: ignore
from .WaLSA_detrend_apod import WaLSA_detrend_apod # type: ignore
from .WaLSA_wavelet_confidence import WaLSA_wavelet_confidence # type: ignore
# -------------------------------------- Main Function ----------------------------------------
def WaLSA_cross_spectra(signal=None, time=None, method=None, **kwargs):
"""
Compute the cross-spectra between two time series.
Parameters:
data1 (array): The first time series (1D).
data2 (array): The first time series (1D).
time (array): The time array of the signals.
method (str): The method to use for the analysis. Options are 'welch' and 'wavelet'.
"""
if method == 'welch':
return getcross_spectrum_Welch(signal, time=time, **kwargs)
elif method == 'wavelet':
return getcross_spectrum_Wavelet(signal, time=time, **kwargs)
elif method == 'fft':
print("Note: FFT method selected. Cross-spectra calculations will use the Welch method instead, "
"which segments the signal into multiple parts to reduce noise sensitivity. "
"You can control frequency resolution vs. noise reduction using the 'nperseg' parameter.")
return getcross_spectrum_Welch(signal, time=time, **kwargs)
else:
raise ValueError(f"Unknown method: {method}")
# -------------------------------------- Welch ----------------------------------------
def getcross_spectrum_Welch(signal, time, **kwargs):
"""
Calculate cross-spectral relationships of two time series whose amplitudes and powers
are computed using the WaLSA_speclizer routine.
The cross-spectrum is complex valued, thus its magnitude is returned as the
co-spectrum. The phase lags between the two time series are are estimated
from the imaginary and real arguments of the complex cross spectrum.
The coherence is calculated from the normalized square of the amplitude
of the complex cross-spectrum
Parameters:
data1 (array): The first 1D time series signal.
data2 (array): The second 1D time series signal.
time (array): The time array corresponding to the signals.
nperseg (int, optional): Length of each segment for analysis. Default: 256.
noverlap (int, optional): Number of points to overlap between segments. Default: 128.
window (str, optional): Type of window function used in the Welch method. Default: 'hann'.
siglevel (float, optional): Significance level for confidence intervals. Default: 0.95.
nperm (int, optional): Number of permutations for significance testing. Default: 1000.
silent (bool, optional): If True, suppress print statements. Default: False.
**kwargs: Additional parameters for the analysis method.
Returns:
cospectrum, frequencies, phase_angle, coherence, signif_cross, signif_coh, d1_power, d2_power
"""
# Define default values for the optional parameters
defaults = {
'data1': None, # First data array
'data1': None, # Second data array
'nperseg': 256, # Number of points per segments to average
'apod': 0.1, # Apodization function
'nodetrendapod': None, # No detrending ot apodization applied
'pxdetrend': 2, # Detrend parameter
'meandetrend': None, # Detrend parameter
'polyfit': None, # Detrend parameter
'meantemporal': None, # Detrend parameter
'recon': None # Detrend parameter
}
# Update defaults with any user-provided keyword arguments
params = {**defaults, **kwargs}
data1 = params['data1']
data2 = params['data2']
dummy = signal+2
tdiff = np.diff(time)
cadence = np.median(tdiff)
# Power spectrum for data1
power_data1, frequencies, _ = WaLSA_speclizer(signal=data1, time=time, method='welch',
amplitude=True, nosignificance=True, silent=kwargs.pop('silent', True), **kwargs)
# Power spectrum for data2
power_data2, _, _ = WaLSA_speclizer(signal=data2, time=time, method='welch',
amplitude=True, nosignificance=True, silent=kwargs.pop('silent', False), **kwargs)
# Calculate cross-spectrum
_, crosspower = csd(data1, data2, fs=1.0/cadence, window='hann', nperseg=params['nperseg'],)
cospectrum = np.abs(crosspower)
# Calculate phase lag
phase_angle = np.angle(crosspower, deg=True)
# Calculate coherence
freq_coh, coh = coherence(data1, data2, 1.0/cadence, nperseg=params['nperseg'])
return frequencies, cospectrum, phase_angle, power_data1, power_data2, freq_coh, coh
# -------------------------------------- Wavelet ----------------------------------------
def getcross_spectrum_Wavelet(signal, time, **kwargs):
"""
Calculate the cross-power spectrum, coherence spectrum, and phase angles
between two signals using Wavelet Transform.
Parameters:
data1 (array): The first 1D time series signal.
data2 (array): The second 1D time series signal.
time (array): The time array corresponding to the signals.
siglevel (float): Significance level for the confidence intervals. Default: 0.95.
nperm (int): Number of permutations for significance testing. Default: 1000.
mother (str): The mother wavelet function to use. Default: 'morlet'.
GWS (bool): If True, calculate the Global Wavelet Spectrum. Default: False.
RGWS (bool): If True, calculate the Refined Global Wavelet Spectrum (time-integrated power, excluding COI and insignificant areas). Default: False.
dj (float): Scale spacing. Smaller values result in better scale resolution but slower calculations. Default: 0.025.
s0 (float): Initial (smallest) scale of the wavelet. Default: 2 * dt.
J (int): Number of scales minus one. Scales range from s0 up to s0 * 2**(J * dj), giving a total of (J + 1) scales. Default: (log2(N * dt / s0)) / dj.
lag1 (float): Lag-1 autocorrelation. Default: 0.0.
apod (float): Extent of apodization edges (of a Tukey window). Default: 0.1.
pxdetrend (int): Subtract linear trend with time per pixel. Options: 1 (simple) or 2 (advanced). Default: 2.
polyfit (int): Degree of polynomial fit for detrending the data. If set, a polynomial fit (instead of linear) is applied. Default: None.
meantemporal (bool): If True, apply simple temporal detrending by subtracting the mean signal from the data, skipping fitting procedures. Default: False.
meandetrend (bool): If True, subtract the linear trend with time for the image means (spatial detrending). Default: False.
recon (bool): If True, perform Fourier reconstruction of the input time series. This does not preserve amplitudes but is useful for examining frequencies far from the low-frequency range. Default: False.
resample_original (bool): If True, and if recon set True, approximate values close to the original are returned for comparison. Default: False.
nodetrendapod (bool): If True, neither detrending nor apodization is performed. Default: False.
silent (bool): If True, suppress print statements. Default: False.
Returns:
cross_power : np.ndarray
Cross power spectrum between `data1` and `data2`.
cross_periods : np.ndarray
Periods corresponding to the cross-power spectrum.
cross_sig : np.ndarray
Significance of the cross-power spectrum.
cross_coi : np.ndarray
Cone of influence for the cross-power spectrum.
coherence : np.ndarray
Coherence spectrum between `data1` and `data2`.
coh_periods : np.ndarray
Periods corresponding to the coherence spectrum.
coh_sig : np.ndarray
Significance of the coherence spectrum.
corr_coi : np.ndarray
Cone of influence for the coherence spectrum.
phase_angle : np.ndarray
2D array containing x- and y-components of the phase direction
arrows for each frequency and time point.
Notes
-----
- Cross-power, coherence, and phase angles are calculated using
**cross-wavelet transform (XWT)** and **wavelet coherence transform (WCT)**.
- Arrows for phase direction are computed such that:
- Arrows pointing downwards indicate anti-phase.
- Arrows pointing upwards indicate in-phase.
- Arrows pointing right indicate `data1` leading `data2`.
- Arrows pointing left indicate `data2` leading `data1`.
Examples
--------
>>> cross_power, cross_periods, cross_sig, cross_coi, coherence, coh_periods, coh_sig, corr_coi, phase_angle, dt = \
>>> getcross_spectrum_Wavelet(signal, cadence, data1=signal1, data2=signal2)
"""
# Define default values for optional parameters
defaults = {
'data1': None, # First data array
'data2': None, # Second data array
'mother': 'morlet', # Type of mother wavelet
'siglevel': 0.95, # Significance level for cross-spectral analysis
'dj': 0.025, # Spacing between scales
's0': -1, # Initial scale
'J': -1, # Number of scales
'cache': False, # Cache results to avoid recomputation
'apod': 0.1,
'silent': False,
'pxdetrend': 2,
'meandetrend': False,
'polyfit': None,
'meantemporal': False,
'recon': False,
'resample_original': False,
'nperm': 1000 # Number of permutations for significance testing
}
# Update defaults with any user-provided keyword arguments
params = {**defaults, **kwargs}
data1 = params['data1']
data2 = params['data2']
dummy = signal+2
tdiff = np.diff(time)
cadence = np.median(tdiff)
data1orig = data1.copy()
data2orig = data2.copy()
data1 = WaLSA_detrend_apod(
data1,
apod=params['apod'],
meandetrend=params['meandetrend'],
pxdetrend=params['pxdetrend'],
polyfit=params['polyfit'],
meantemporal=params['meantemporal'],
recon=params['recon'],
cadence=cadence,
resample_original=params['resample_original'],
silent=True
)
data2 = WaLSA_detrend_apod(
data2,
apod=params['apod'],
meandetrend=params['meandetrend'],
pxdetrend=params['pxdetrend'],
polyfit=params['polyfit'],
meantemporal=params['meantemporal'],
recon=params['recon'],
cadence=cadence,
resample_original=params['resample_original'],
silent=True
)
# Calculate signal properties
nt = len(params['data1']) # Number of time points
# Determine the initial scale s0 if not provided
if params['s0'] == -1:
params['s0'] = 2 * cadence
# Determine the number of scales J if not provided
if params['J'] == -1:
params['J'] = int((np.log(float(nt) * cadence / params['s0']) / np.log(2)) / params['dj'])
# ----------- CROSS-WAVELET TRANSFORM (XWT) -----------
W12, cross_coi, freq, signif = xwt(
data1, data2, cadence, dj=params['dj'], s0=params['s0'], J=params['J'],
no_default_signif=True, wavelet=params['mother'], normalize=False
)
print('Wavelet cross-power spectrum calculated.')
# Calculate the cross-power spectrum
cross_power = np.abs(W12) ** 2
cross_periods = 1 / freq # Periods corresponding to the frequency axis
#----------------------------------------------------------------------
# Calculate significance levels using Monte Carlo randomization method
#----------------------------------------------------------------------
nxx, nyy = cross_power.shape
cross_perm = np.zeros((nxx, nyy, params['nperm']))
print('\nCalculating wavelet cross-power significance:')
total_iterations = params['nperm'] # Total number of (x, y) combinations
iteration_count = 0 # To keep track of progress
for ip in range(params['nperm']):
iteration_count += 1
progress = (iteration_count / total_iterations) * 100
print(f"\rProgress: {progress:.2f}%", end="")
y_perm1 = np.random.permutation(data1orig) # Permuting the original signal
apocube1 = WaLSA_detrend_apod(
y_perm1,
apod=params['apod'],
meandetrend=params['meandetrend'],
pxdetrend=params['pxdetrend'],
polyfit=params['polyfit'],
meantemporal=params['meantemporal'],
recon=params['recon'],
cadence=cadence,
resample_original=params['resample_original'],
silent=True
)
y_perm2 = np.random.permutation(data2orig) # Permuting the original signal
apocube2 = WaLSA_detrend_apod(
y_perm2,
apod=params['apod'],
meandetrend=params['meandetrend'],
pxdetrend=params['pxdetrend'],
polyfit=params['polyfit'],
meantemporal=params['meantemporal'],
recon=params['recon'],
cadence=cadence,
resample_original=params['resample_original'],
silent=True
)
W12s, _, _, _ = xwt(
apocube1, apocube2, cadence, dj=params['dj'], s0=params['s0'], J=params['J'],
no_default_signif=True, wavelet=params['mother'], normalize=False
)
cross_power_sig = np.abs(W12s) ** 2
cross_perm[:, :, ip] = cross_power_sig
signifn = WaLSA_wavelet_confidence(cross_perm, siglevel=params['siglevel'])
cross_sig = cross_power / signifn
# ----------- WAVELET COHERENCE TRANSFORM (WCT) -----------
WCT, aWCT, coh_coi, freq, sig = wct(
data1, data2, cadence, dj=params['dj'], s0=params['s0'], J=params['J'],
sig=False, wavelet=params['mother'], normalize=False, cache=params['cache']
)
print('Wavelet coherence calculated.')
# Calculate the coherence spectrum
coh_periods = 1 / freq # Periods corresponding to the frequency axis
coherence = WCT
#----------------------------------------------------------------------
# Calculate significance levels using Monte Carlo randomization method
#----------------------------------------------------------------------
nxx, nyy = coherence.shape
coh_perm = np.zeros((nxx, nyy, params['nperm']))
print('\nCalculating wavelet coherence significance:')
total_iterations = params['nperm'] # Total number of permutations
iteration_count = 0 # To keep track of progress
for ip in range(params['nperm']):
iteration_count += 1
progress = (iteration_count / total_iterations) * 100
print(f"\rProgress: {progress:.2f}%", end="")
y_perm1 = np.random.permutation(data1orig) # Permuting the original signal
apocube1 = WaLSA_detrend_apod(
y_perm1,
apod=params['apod'],
meandetrend=params['meandetrend'],
pxdetrend=params['pxdetrend'],
polyfit=params['polyfit'],
meantemporal=params['meantemporal'],
recon=params['recon'],
cadence=cadence,
resample_original=params['resample_original'],
silent=True
)
y_perm2 = np.random.permutation(data2orig) # Permuting the original signal
apocube2 = WaLSA_detrend_apod(
y_perm2,
apod=params['apod'],
meandetrend=params['meandetrend'],
pxdetrend=params['pxdetrend'],
polyfit=params['polyfit'],
meantemporal=params['meantemporal'],
recon=params['recon'],
cadence=cadence,
resample_original=params['resample_original'],
silent=True
)
WCTs, _, _, _, _ = wct(
apocube1, apocube2, cadence, dj=params['dj'], s0=params['s0'], J=params['J'],
sig=False, wavelet=params['mother'], normalize=False, cache=params['cache']
)
coh_perm[:, :, ip] = WCTs
sig_coh = WaLSA_wavelet_confidence(coh_perm, siglevel=params['siglevel'])
coh_sig = coherence / sig_coh # Ratio > 1 means coherence is significant
# --------------- PHASE ANGLES ---------------
phase_angle = aWCT
return (
cross_power, cross_periods, cross_sig, cross_coi,
coherence, coh_periods, coh_sig, coh_coi,
phase_angle
)
WaLSA_detrend_apod.py
This module provides functions for detrending and apodizing time series data to mitigate trends and edge effects.
WaLSA_detrend_apod.py
# -----------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
# -----------------------------------------------------------------------------------------------------
import numpy as np # type: ignore
from scipy.optimize import curve_fit # type: ignore
from .WaLSA_wavelet import cwt, significance # type: ignore
# -------------------------------------- Wavelet ----------------------------------------
def getWavelet(signal, time, **kwargs):
"""
Perform wavelet analysis using the pycwt package.
Parameters:
signal (array): The input signal (1D).
time (array): The time array corresponding to the signal.
siglevel (float): Significance level for the confidence intervals. Default: 0.95.
nperm (int): Number of permutations for significance testing. Default: 1000.
mother (str): The mother wavelet function to use. Default: 'morlet'.
GWS (bool): If True, calculate the Global Wavelet Spectrum. Default: False.
RGWS (bool): If True, calculate the Refined Global Wavelet Spectrum (time-integrated power, excluding COI and insignificant areas). Default: False.
dj (float): Scale spacing. Smaller values result in better scale resolution but slower calculations. Default: 0.025.
s0 (float): Initial (smallest) scale of the wavelet. Default: 2 * dt.
J (int): Number of scales minus one. Scales range from s0 up to s0 * 2**(J * dj), giving a total of (J + 1) scales. Default: (log2(N * dt / s0)) / dj.
lag1 (float): Lag-1 autocorrelation. Default: 0.0.
apod (float): Extent of apodization edges (of a Tukey window). Default: 0.1.
pxdetrend (int): Subtract linear trend with time per pixel. Options: 1 (simple) or 2 (advanced). Default: 2.
polyfit (int): Degree of polynomial fit for detrending the data. If set, a polynomial fit (instead of linear) is applied. Default: None.
meantemporal (bool): If True, apply simple temporal detrending by subtracting the mean signal from the data, skipping fitting procedures. Default: False.
meandetrend (bool): If True, subtract the linear trend with time for the image means (spatial detrending). Default: False.
recon (bool): If True, perform Fourier reconstruction of the input time series. This does not preserve amplitudes but is useful for examining frequencies far from the low-frequency range. Default: False.
resample_original (bool): If True, and if recon set True, approximate values close to the original are returned for comparison. Default: False.
nodetrendapod (bool): If True, neither detrending nor apodization is performed. Default: False.
silent (bool): If True, suppress print statements. Default: False.
**kwargs: Additional parameters for the analysis method.
Returns:
power: The wavelet power spectrum.
periods: Corresponding periods.
sig_slevel: The significance levels.
coi: The cone of influence.
Optionally, if global_power=True:
global_power: Global wavelet power spectrum.
global_conf: Confidence levels for the global wavelet spectrum.
Optionally, if RGWS=True:
rgws_periods: Periods for the refined global wavelet spectrum.
rgws_power: Refined global wavelet power spectrum.
"""
# Define default values for the optional parameters similar to IDL
defaults = {
'siglevel': 0.95,
'mother': 'morlet', # Morlet wavelet as the mother function
'dj': 1/32. , # Scale spacing
's0': -1, # Initial scale
'J': -1, # Number of scales
'lag1': 0.0, # Lag-1 autocorrelation
'silent': False,
'nperm': 1000, # Number of permutations for significance calculation
}
# Update defaults with any user-provided keyword arguments
params = {**defaults, **kwargs}
tdiff = np.diff(time)
cadence = np.median(tdiff)
n = len(signal)
dt = cadence
# Standardize the signal before the wavelet transform
std_signal = signal.std()
norm_signal = signal / std_signal
# Determine the initial scale s0 if not provided
if params['s0'] == -1:
params['s0'] = 2 * dt
# Determine the number of scales J if not provided
if params['J'] == -1:
params['J'] = int((np.log(float(n) * dt / params['s0']) / np.log(2)) / params['dj'])
# Perform wavelet transform
W, scales, frequencies, coi, _, _ = cwt(
norm_signal,
dt,
dj=params['dj'],
s0=params['s0'],
J=params['J'],
wavelet=params['mother']
)
power = np.abs(W) ** 2 # Wavelet power spectrum
periods = 1 / frequencies # Convert frequencies to periods
return power, periods, coi, scales, W
# -----------------------------------------------------------------------------------------------------
# Linear detrending function for curve fitting
def linear(x, a, b):
return a + b * x
# Custom Tukey window implementation
def custom_tukey(nt, apod=0.1):
apodrim = int(apod * nt)
apodt = np.ones(nt) # Initialize with ones
# Apply sine-squared taper to the first 'apodrim' points
taper = (np.sin(np.pi / 2. * np.arange(apodrim) / apodrim)) ** 2
# Apply taper symmetrically at both ends
apodt[:apodrim] = taper
apodt[-apodrim:] = taper[::-1] # Reverse taper for the last points
return apodt
# Main detrending and apodization function
# apod=0: The Tukey window becomes a rectangular window (no tapering).
# apod=1: The Tukey window becomes a Hann window (fully tapered with a cosine function).
def WaLSA_detrend_apod(cube, apod=0.1, meandetrend=False, pxdetrend=2, polyfit=None, meantemporal=False,
recon=False, cadence=None, resample_original=False, min_resample=None,
max_resample=None, silent=False, dj=32, lo_cutoff=None, hi_cutoff=None, upper=False):
nt = len(cube) # Assume input is a 1D signal
cube = cube - np.mean(cube) # Remove the mean of the input signal
apocube = np.copy(cube) # Create a copy of the input signal
t = np.arange(nt) # Time array
# Apply Tukey window (apodization)
if apod > 0:
tukey_window = custom_tukey(nt, apod)
apocube = apocube * tukey_window # Apodize the signal
# Mean detrend (optional)
if meandetrend:
avg_signal = np.mean(apocube)
time = np.arange(nt)
mean_fit_params, _ = curve_fit(linear, time, avg_signal)
mean_trend = linear(time, *mean_fit_params)
apocube -= mean_trend
# Wavelet-based Fourier reconstruction (optional)
if recon and cadence:
apocube = WaLSA_wave_recon(apocube, cadence, dj=dj, lo_cutoff=lo_cutoff, hi_cutoff=hi_cutoff, upper=upper)
# Pixel-based detrending (temporal detrend)
if pxdetrend > 0:
mean_val = np.mean(apocube)
if meantemporal:
# Simple temporal detrend by subtracting the mean
apocube -= mean_val
else:
# Advanced detrend (linear or polynomial fit)
if polyfit is not None:
poly_coeffs = np.polyfit(t, apocube, polyfit)
trend = np.polyval(poly_coeffs, t)
else:
popt, _ = curve_fit(linear, t, apocube, p0=[mean_val, 0])
trend = linear(t, *popt)
apocube -= trend
# Resampling to preserve amplitudes (optional)
if resample_original:
if min_resample is None:
min_resample = np.min(apocube)
if max_resample is None:
max_resample = np.max(apocube)
apocube = np.interp(apocube, (np.min(apocube), np.max(apocube)), (min_resample, max_resample))
if not silent:
print("Detrending and apodization complete.")
return apocube
# Wavelet-based reconstruction function (optional)
def WaLSA_wave_recon(ts, delt, dj=32, lo_cutoff=None, hi_cutoff=None, upper=False):
"""
Reconstructs the wavelet-filtered time series based on given frequency cutoffs.
"""
# Define duration based on the time series length
dur = (len(ts) - 1) * delt
# Assign default values if lo_cutoff or hi_cutoff is None
if lo_cutoff is None:
lo_cutoff = 0.0 # Default to 0 as in IDL
if hi_cutoff is None:
hi_cutoff = dur / (3.0 * np.sqrt(2))
mother = 'morlet'
num_points = len(ts)
time_array = np.linspace(0, (num_points - 1) * delt, num_points)
_, period, coi, scales, wave= getWavelet(signal=ts, time=time_array, method='wavelet', siglevel=0.99, apod=0.1, mother=mother)
# Ensure good_per_idx and bad_per_idx are properly assigned
if upper:
good_per_idx = np.where(period > hi_cutoff)[0][0]
bad_per_idx = len(period)
else:
good_per_idx = np.where(period > lo_cutoff)[0][0]
bad_per_idx = np.where(period > hi_cutoff)[0][0]
# set the power inside the CoI equal to zero
# (i.e., exclude points inside the CoI -- subject to edge effect)
iampl = np.zeros((len(ts), len(period)), dtype=float)
for i in range(len(ts)):
pcol = np.real(wave[:, i]) # Extract real part of wavelet transform for this time index
ii = np.where(period < coi[i])[0] # Find indices where period is less than COI at time i
if ii.size > 0:
iampl[i, ii] = pcol[ii] # Assign values where condition is met
# Initialize reconstructed signal array
recon_sum = np.zeros(len(ts))
# Summation over valid period indices
for i in range(good_per_idx, bad_per_idx):
recon_sum += iampl[:, i] / np.sqrt(scales[i])
# Apply normalization factor
recon_all = dj * np.sqrt(delt) * recon_sum / (0.766 * (np.pi ** -0.25))
return recon_all
WaLSA_confidence.py
This module implements statistical significance testing for the spectral analysis results using various methods.
WaLSA_confidence.py
# -----------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
# -----------------------------------------------------------------------------------------------------
import numpy as np # type: ignore
def WaLSA_confidence(ps_perm, siglevel=0.05, nf=None):
"""
Find the confidence levels (significance levels) for the given power spectrum permutations.
Parameters:
-----------
ps_perm : np.ndarray
2D array where each row represents a power spectrum, and each column represents a permutation.
siglevel : float
Significance level (default 0.05 for 95% confidence).
nf : int
Number of frequencies (length of the power spectra). If None, inferred from the shape of ps_perm.
Returns:
--------
signif : np.ndarray
Significance levels for each frequency.
"""
if nf is None:
nf = ps_perm.shape[0] # Number of frequencies inferred from the shape of ps_perm
signif = np.zeros((nf))
# Loop through each frequency
for iv in range(nf):
# Extract the permutation values for this frequency
tmp = np.sort(ps_perm[iv, :]) # Sort the power spectrum permutation values
# Calculate the number of permutations
ntmp = len(tmp)
# Find the significance threshold
nsig = int(round(siglevel * ntmp)) # Number of permutations to cut off for the significance level
# Set the confidence level for this frequency
signif[iv] = tmp[-nsig] # Select the (ntmp - nsig)-th element (equivalent to IDL's ROUND and indexing)
return signif
WaLSA_wavelet_confidence.py
This module implements statistical significance testing for the wavelet analysis results.
WaLSA_wavelet_confidence.py
# -----------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
# -----------------------------------------------------------------------------------------------------
import numpy as np # type: ignore
def WaLSA_wavelet_confidence(ps_perm, siglevel=None):
"""
Find the confidence levels (significance levels) for the given wavelet power spectrum permutations.
Parameters:
-----------
ps_perm : np.ndarray
2D array where each row represents a power spectrum, and each column represents a permutation.
siglevel : float
e.g., 0.95 for 95% confidence level (significance at 5%).
Returns:
--------
signif : np.ndarray
Significance levels for each frequency.
"""
nnff , nntt, nnperm = ps_perm.shape
signif = np.zeros((nnff, nntt))
for iv in range(nnff):
for it in range(nntt):
signif[iv, it] = np.percentile(ps_perm[iv, it, :], 100 * siglevel) # e.g., 95% percentile (for 95% consideence level)
return signif
WaLSA_io.py
This module provides functions for input/output operations, such as saving images as PDF (in both RGB and CMYK formats) and image contrast enhancements.
WaLSA_io.py
# -----------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
# -----------------------------------------------------------------------------------------------------
import subprocess
import shutil
import os
import stat
import numpy as np # type: ignore
from matplotlib.backends.backend_pdf import PdfPages # type: ignore
def WaLSA_save_pdf(fig, pdf_path, color_mode='RGB', dpi=300, bbox_inches=None, pad_inches=0):
"""
Save a PDF from a Matplotlib figure with an option to convert to CMYK if Ghostscript is available.
Parameters:
- fig: Matplotlib figure object to save.
- pdf_path: Path to save the PDF file.
- color_mode: 'RGB' (default) to save in RGB, or 'CMYK' to convert to CMYK.
- dpi: Resolution in dots per inch. Default is 300.
- bbox_inches: Set to 'tight' to remove extra white space. Default is 'tight'.
- pad_inches: Padding around the figure. Default is 0.
"""
# Save the initial PDF in RGB format using fig.savefig to apply bbox_inches, pad_inches, and dpi
if bbox_inches is None:
# Save the initial PDF in RGB format
with PdfPages(pdf_path) as pdf:
pdf.savefig(fig, transparent=True)
if color_mode.upper() == 'CMYK':
# Check if Ghostscript is installed
if shutil.which("gs") is not None:
# Set permissions on the initial RGB PDF to ensure access for Ghostscript
os.chmod(pdf_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IROTH)
# Define a temporary path for the CMYK file
temp_cmyk_path = pdf_path.replace('.pdf', '_temp_cmyk.pdf')
# Run Ghostscript to create the CMYK PDF as a temporary file
subprocess.run([
"gs", "-o", temp_cmyk_path,
"-sDEVICE=pdfwrite",
"-dProcessColorModel=/DeviceCMYK",
"-dColorConversionStrategy=/CMYK",
"-dNOPAUSE", "-dBATCH", "-dSAFER",
"-dPDFSETTINGS=/prepress", pdf_path
], check=True)
# Replace the original RGB PDF with the CMYK version
os.remove(pdf_path) # Delete the RGB version
os.rename(temp_cmyk_path, pdf_path) # Rename CMYK as the original file
print(f"PDF saved in CMYK format as '{pdf_path}'")
else:
print("Warning: Ghostscript is not installed, so the PDF remains in RGB format.")
print("To enable CMYK saving, please install Ghostscript:")
print("- On macOS, use: brew install ghostscript")
print("- On Ubuntu, use: sudo apt install ghostscript")
print("- On Windows, download from: https://www.ghostscript.com/download.html")
else:
print(f"PDF saved in RGB format as '{pdf_path}'")
else:
# Temporary path for the RGB version of the PDF
temp_pdf_path = pdf_path.replace('.pdf', '_temp.pdf')
fig.savefig(
temp_pdf_path,
format='pdf',
dpi=dpi,
bbox_inches=bbox_inches,
pad_inches=pad_inches,
transparent=True
)
if color_mode.upper() == 'CMYK':
# Check if Ghostscript is installed
if shutil.which("gs") is not None:
# Set permissions on the initial RGB PDF to ensure access for Ghostscript
os.chmod(temp_pdf_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IROTH)
# Run Ghostscript to create the CMYK PDF as a temporary file
subprocess.run([
"gs", "-o", pdf_path,
"-sDEVICE=pdfwrite",
"-dProcessColorModel=/DeviceCMYK",
"-dColorConversionStrategy=/CMYK",
"-dNOPAUSE", "-dBATCH", "-dSAFER",
"-dPDFSETTINGS=/prepress", temp_pdf_path
], check=True)
# Remove the temporary PDF file
os.remove(temp_pdf_path)
print(f"PDF saved in CMYK format as '{pdf_path}'")
else:
print("Warning: Ghostscript is not installed, so the PDF remains in RGB format.")
print("To enable CMYK saving, please install Ghostscript:")
print("- On macOS, use: brew install ghostscript")
print("- On Ubuntu, use: sudo apt install ghostscript")
print("- On Windows, download from: https://www.ghostscript.com/download.html")
else:
# Rename the temporary file as the final file if no CMYK conversion is needed
os.rename(temp_pdf_path, pdf_path)
print(f"PDF saved in RGB format as '{pdf_path}'")
def WaLSA_histo_opt(image, cutoff=1e-3, top_only=False, bot_only=False):
"""
Clip image values based on the cutoff percentile to enhance contrast.
Inspired by IDL's "iris_histo_opt" function (Copyright: P.Suetterlin, V.Hansteen, and M. Carlsson)
Parameters:
- image: 2D array (image data).
- cutoff: Fraction of values to clip at the top and bottom (default is 0.001).
- top_only: If True, clip only the highest values.
- bot_only: If True, clip only the lowest values.
Returns:
- Clipped image for better contrast.
"""
# Ignore NaNs in the image
finite_values = image[np.isfinite(image)]
# Calculate lower and upper bounds based on cutoff percentiles
lower_bound = np.percentile(finite_values, cutoff * 100)
upper_bound = np.percentile(finite_values, (1 - cutoff) * 100)
# Clip image according to bounds and options
if top_only:
return np.clip(image, None, upper_bound)
elif bot_only:
return np.clip(image, lower_bound, None)
else:
return np.clip(image, lower_bound, upper_bound)
WaLSA_interactive.py
This module implements the interactive interface of WaLSAtools, guiding users through the analysis process.
WaLSA_interactive.py
# -----------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
# -----------------------------------------------------------------------------------------------------
import ipywidgets as widgets # type: ignore
from IPython.display import display, clear_output, HTML # type: ignore
import os # type: ignore
# Function to detect if it's running in a notebook or terminal
def is_notebook():
try:
# Check if IPython is in the environment and if we're using a Jupyter notebook
from IPython import get_ipython # type: ignore
shell = get_ipython().__class__.__name__
if shell == 'ZMQInteractiveShell':
return True # Jupyter notebook or qtconsole
elif shell == 'TerminalInteractiveShell':
return False # IPython terminal
else:
return False # Other types
except (NameError, ImportError):
# NameError: get_ipython is not defined
# ImportError: IPython is not installed
return False # Standard Python interpreter or shell
# Function to print logo and credits for both environments
def print_logo_and_credits():
logo_terminal = r"""
__ __ _ _____
\ \ / / | | / ____| /\
\ \ /\ / / ▄▄▄▄▄ | | | (___ / \
\ \/ \/ / ▀▀▀▀██ | | \___ \ / /\ \
\ /\ / ▄██▀▀██ | |____ ____) | / ____ \
\/ \/ ▀██▄▄██ |______| |_____/ /_/ \_\
"""
credits_terminal = """
© WaLSA Team (www.WaLSA.team)
-----------------------------------------------------------------------
WaLSAtools v1.0.0 - Wave analysis tools
Documentation: www.WaLSA.tools
GitHub repository: www.github.com/WaLSAteam/WaLSAtools
-----------------------------------------------------------------------
If you use WaLSAtools in your research, please cite:
Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
Free access to a view-only version: https://WaLSA.tools/nrmp
Supplementary Information: https://WaLSA.tools/nrmp-si
-----------------------------------------------------------------------
Choose a category, data type, and analysis method from the list below,
to get hints on the calling sequence and parameters:
"""
credits_notebook = """
<div style="margin-left: 30px; margin-top: 20px; font-size: 1.1em; line-height: 0.8;">
<p>© WaLSA Team (<a href="https://www.WaLSA.team" target="_blank">www.WaLSA.team</a>)</p>
<hr style="width: 70%; margin: 0; border: 0.98px solid #888; margin-bottom: 10px;">
<p><strong>WaLSAtools</strong> v1.0.0 - Wave analysis tools</p>
<p>Documentation: <a href="https://www.WaLSA.tools" target="_blank">www.WaLSA.tools</a></p>
<p>GitHub repository: <a href="https://www.github.com/WaLSAteam/WaLSAtools" target="_blank">www.github.com/WaLSAteam/WaLSAtools</a></p>
<hr style="width: 70%; margin: 0; border: 0.98px solid #888; margin-bottom: 10px;">
<p>If you use <strong>WaLSAtools</strong> in your research, please cite:</p>
<p>Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, <em>Nature Reviews Methods Primers</em>, 5, 21</p>
<p>Free access to a view-only version: <a href="https://WaLSA.tools/nrmp" target="_blank">www.WaLSA.tools</a></p>
<p>Supplementary Information: <a href="https://WaLSA.tools/nrmp-si" target="_blank">www.WaLSA.tools</a></p>
<hr style="width: 70%; margin: 0; border: 0.98px solid #888; margin-bottom: 15px;">
<p>Choose a category, data type, and analysis method from the list below,</p>
<p>to get hints on the calling sequence and parameters:</p>
</div>
"""
if is_notebook():
try:
# For scripts
current_dir = os.path.dirname(os.path.abspath(__file__))
except NameError:
# For Jupyter notebooks
current_dir = os.getcwd()
img_path = os.path.join(current_dir, '..', 'assets', 'WaLSAtools_black.png')
# display(HTML(f'<img src="{img_path}" style="margin-left: 40px; margin-top: 20px; width:300px; height: auto;">')) # not shwon in Jupyter notebook, only in MS Code
import base64
# Convert the image to Base64
with open(img_path, "rb") as img_file:
encoded_img = base64.b64encode(img_file.read()).decode('utf-8')
# Embed the Base64 image in the HTML
html_code_logo = f"""
<div style="margin-left: 30px; margin-top: 20px;">
<img src="data:image/png;base64,{encoded_img}" style="width: 300px; height: auto;">
</div>
"""
display(HTML(html_code_logo))
display(HTML(credits_notebook))
else:
print(logo_terminal)
print(credits_terminal)
from .parameter_definitions import display_parameters_text, single_series_parameters, cross_correlation_parameters
# Terminal-based interactive function
import textwrap
import shutil
def walsatools_terminal():
"""Main interactive function for terminal version of WaLSAtools."""
print_logo_and_credits()
# Get terminal width
terminal_width = shutil.get_terminal_size().columns
# Calculate 90% of the width (and ensure it's an integer)
line_width = int(terminal_width * 0.90)
# Step 1: Select Category
while True:
print("\n Category:")
print(" (a) Single time series analysis")
print(" (b) Cross-correlation between two time series")
category = input(" --- Select a category (a/b): ").strip().lower()
if category not in ['a', 'b']:
print(" Invalid selection. Please enter either 'a' or 'b'.\n")
continue
# Step 2: Data Type
if category == 'a':
while True:
print("\n Data Type:")
print(" (1) 1D signal")
print(" (2) 3D datacube")
method = input(" --- Select a data type (1/2): ").strip()
if method not in ['1', '2']:
print(" Invalid selection. Please enter either '1' or '2'.\n")
continue
# Step 3: Analysis Method
if method == '1': # 1D Signal
while True:
print("\n Analysis Method:")
print(" (1) FFT")
print(" (2) Wavelet")
print(" (3) Lomb-Scargle")
print(" (4) Welch")
print(" (5) EMD")
analysis_type = input(" --- Select an analysis method (1-5): ").strip()
method_map_1d = {
'1': 'fft',
'2': 'wavelet',
'3': 'lombscargle',
'4': 'welch',
'5': 'emd'
}
selected_method = method_map_1d.get(analysis_type, 'unknown')
if selected_method == 'unknown':
print(" Invalid selection. Please select a valid analysis method (1-5).\n")
continue
# Generate and display the calling sequence for 1D signal
print("\n Calling sequence:\n")
return_values = single_series_parameters[selected_method]['return_values']
command = f">>> {return_values} = WaLSAtools(signal=INPUT_DATA, time=TIME_ARRAY, method='{selected_method}', **kwargs)"
wrapper = textwrap.TextWrapper(width=line_width - 4, initial_indent=' ', subsequent_indent=' ')
wrapped_command = wrapper.fill(command)
print(wrapped_command)
# Display parameter hints
display_parameters_text(selected_method, category='a')
return # Exit the function after successful output
elif method == '2': # 3D Datacube
while True:
print("\n Analysis Method:")
print(" (1) k-omega")
print(" (2) POD")
print(" (3) Dominant Freq / Mean Power Spectrum")
analysis_type = input(" --- Select an analysis method (1-3): ").strip()
if analysis_type == '3': # Sub-method required
while True:
print("\n Analysis method for Dominant Freq / Mean Power Spectrum:")
print(" (1) FFT")
print(" (2) Wavelet")
print(" (3) Lomb-Scargle")
print(" (4) Welch")
sub_method = input(" --- Select a method (1-4): ").strip()
sub_method_map = {'1': 'fft', '2': 'wavelet', '3': 'lombscargle', '4': 'welch'}
selected_method = sub_method_map.get(sub_method, 'unknown')
if selected_method == 'unknown':
print(" Invalid selection. Please select a valid sub-method (1-4).\n")
continue
# Generate and display the calling sequence for sub-method
print("\n Calling sequence:\n")
return_values = 'dominant_frequency, mean_power, frequency, power_map'
command = f">>> {return_values} = WaLSAtools(data=INPUT_DATA, time=TIME_ARRAY, averagedpower=True, dominantfreq=True, method='{selected_method}', **kwargs)"
wrapper = textwrap.TextWrapper(width=line_width - 4, initial_indent=' ', subsequent_indent=' ')
wrapped_command = wrapper.fill(command)
print(wrapped_command)
# Display parameter hints
display_parameters_text(selected_method, category='a')
return # Exit the function after successful output
method_map_3d = {
'1': 'k-omega',
'2': 'pod'
}
selected_method = method_map_3d.get(analysis_type, 'unknown')
if selected_method == 'unknown':
print(" Invalid selection. Please select a valid analysis method (1-3).\n")
continue
# Generate and display the calling sequence for k-omega/POD
print("\n Calling sequence:\n")
return_values = single_series_parameters[selected_method]['return_values']
command = f">>> {return_values} = WaLSAtools(data1=INPUT_DATA, time=TIME_ARRAY, method='{selected_method}', **kwargs)"
wrapper = textwrap.TextWrapper(width=line_width - 4, initial_indent=' ', subsequent_indent=' ')
wrapped_command = wrapper.fill(command)
print(wrapped_command)
# Display parameter hints
display_parameters_text(selected_method, category='a')
return # Exit the function after successful output
elif category == 'b': # Cross-correlation
while True:
print("\n Data Type:")
print(" (1) 1D signal")
method = input(" --- Select a data type (1): ").strip()
if method != '1':
print(" Invalid selection. Please enter '1'.\n")
continue
while True:
print("\n Analysis Method:")
print(" (1) FFT")
print(" (2) Wavelet")
print(" (3) Welch")
analysis_type = input(" --- Select an analysis method (1-2): ").strip()
cross_correlation_map = {
'1': 'fft',
'2': 'wavelet',
'3': 'welch'
}
selected_method = cross_correlation_map.get(analysis_type, 'unknown')
if selected_method == 'unknown':
print(" Invalid selection. Please select a valid analysis method (1-2).\n")
continue
# Generate and display the calling sequence for cross-correlation
print("\n Calling sequence:\n")
return_values = cross_correlation_parameters[selected_method]['return_values']
command = f">>> {return_values} = WaLSAtools(data1=INPUT_DATA1, data2=INPUT_DATA2, time=TIME_ARRAY, method='{selected_method}', **kwargs)"
wrapper = textwrap.TextWrapper(width=line_width - 4, initial_indent=' ', subsequent_indent=' ')
wrapped_command = wrapper.fill(command)
print(wrapped_command)
# Display parameter hints
display_parameters_text(selected_method, category='b')
return # Exit the function after successful output
# Jupyter-based interactive function
import ipywidgets as widgets # type: ignore
from IPython.display import display, clear_output, HTML # type: ignore
from .parameter_definitions import display_parameters_html, single_series_parameters, cross_correlation_parameters # type: ignore
# Global flag to prevent multiple observers
is_observer_attached = False
def walsatools_jupyter():
"""Main interactive function for Jupyter Notebook version of WaLSAtools."""
global is_observer_attached, category, method, analysis_type, sub_method # Declare global variables for reuse
# Detach any existing observers and reset the flag
try:
detach_observers()
except NameError:
pass # `detach_observers` hasn't been defined yet
is_observer_attached = False # Reset observer flag
# Clear any previous output
clear_output(wait=True)
print_logo_and_credits()
# Recreate widgets to reset state
category = widgets.Dropdown(
options=['Select Category', 'Single time series analysis', 'Cross-correlation between two time series'],
value='Select Category',
description='Category:'
)
method = widgets.Dropdown(
options=['Select Data Type'],
value='Select Data Type',
description='Data Type:'
)
analysis_type = widgets.Dropdown(
options=['Select Method'],
value='Select Method',
description='Method:'
)
sub_method = widgets.Dropdown(
options=['Select Sub-method', 'FFT', 'Wavelet', 'Lomb-Scargle', 'Welch'],
value='Select Sub-method',
description='Sub-method:',
layout=widgets.Layout(display='none') # Initially hidden
)
# Persistent output widget
output = widgets.Output()
def clear_output_if_unselected(change=None):
"""Clear the output widget if any dropdown menu is unselected."""
with output:
if (
category.value == 'Select Category'
or method.value == 'Select Data Type'
or analysis_type.value == 'Select Method'
or (
analysis_type.value == 'Dominant Freq / Mean Power Spectrum'
and sub_method.layout.display == 'block'
and sub_method.value == 'Select Sub-method'
)
):
clear_output(wait=True)
warn = '<div style="font-size: 1.1em; margin-left: 30px; margin-top:15px; margin-bottom: 15px;">Please select appropriate options from all dropdown menus.</div>'
display(HTML(warn))
def update_method_options(change=None):
"""Update available Method and Sub-method options."""
clear_output_if_unselected() # Ensure the output clears if any dropdown is unselected.
sub_method.layout.display = 'none'
sub_method.value = 'Select Sub-method'
sub_method.options = ['Select Sub-method']
if category.value == 'Single time series analysis':
method.options = ['Select Data Type', '1D signal', '3D datacube']
if method.value == '1D signal':
analysis_type.options = ['Select Method', 'FFT', 'Wavelet', 'Lomb-Scargle', 'Welch', 'EMD']
elif method.value == '3D datacube':
analysis_type.options = ['Select Method', 'k-omega', 'POD', 'Dominant Freq / Mean Power Spectrum']
else:
analysis_type.options = ['Select Method']
elif category.value == 'Cross-correlation between two time series':
method.options = ['Select Data Type', '1D signal']
if method.value == '1D signal':
analysis_type.options = ['Select Method', 'FFT', 'Wavelet', 'Welch']
else:
analysis_type.options = ['Select Method']
else:
method.options = ['Select Data Type']
analysis_type.options = ['Select Method']
def update_sub_method_visibility(change=None):
"""Show or hide the Sub-method dropdown based on conditions."""
clear_output_if_unselected() # Ensure the output clears if any dropdown is unselected.
if (
category.value == 'Single time series analysis'
and method.value == '3D datacube'
and analysis_type.value == 'Dominant Freq / Mean Power Spectrum'
):
sub_method.options = ['Select Sub-method', 'FFT', 'Wavelet', 'Lomb-Scargle', 'Welch']
sub_method.layout.display = 'block'
else:
sub_method.options = ['Select Sub-method']
sub_method.value = 'Select Sub-method'
sub_method.layout.display = 'none'
def update_calling_sequence(change=None):
"""Update the function calling sequence based on user's selection."""
clear_output_if_unselected() # Ensure the output clears if any dropdown is unselected.
if (
category.value == 'Select Category'
or method.value == 'Select Data Type'
or analysis_type.value == 'Select Method'
or (
analysis_type.value == 'Dominant Freq / Mean Power Spectrum'
and sub_method.layout.display == 'block'
and sub_method.value == 'Select Sub-method'
)
):
return # Do nothing until all required fields are properly selected
with output:
clear_output(wait=True)
# Handle Dominant Frequency / Mean Power Spectrum with Sub-method
if (
category.value == 'Single time series analysis'
and method.value == '3D datacube'
and analysis_type.value == 'Dominant Freq / Mean Power Spectrum'
):
if sub_method.value == 'Select Sub-method':
print("Please select a Sub-method.")
return
sub_method_map = {'FFT': 'fft', 'Wavelet': 'wavelet', 'Lomb-Scargle': 'lombscargle', 'Welch': 'welch'}
selected_method = sub_method_map.get(sub_method.value, 'unknown')
return_values = 'dominant_frequency, mean_power, frequency, power_map'
command = f"{return_values} = WaLSAtools(signal=INPUT_DATA, time=TIME_ARRAY, averagedpower=True, dominantfreq=True, method='{selected_method}', **kwargs)"
# Handle k-omega and POD
elif (
category.value == 'Single time series analysis'
and method.value == '3D datacube'
and analysis_type.value in ['k-omega', 'POD']
):
method_map = {'k-omega': 'k-omega', 'POD': 'pod'}
selected_method = method_map.get(analysis_type.value, 'unknown')
parameter_definitions = single_series_parameters
return_values = parameter_definitions.get(selected_method, {}).get('return_values', '')
command = f"{return_values} = WaLSAtools(signal=INPUT_DATA, time=TIME_ARRAY, method='{selected_method}', **kwargs)"
# Handle Cross-correlation
elif category.value == 'Cross-correlation between two time series':
cross_correlation_map = {'FFT': 'fft', 'Wavelet': 'wavelet', 'Welch': 'welch'}
selected_method = cross_correlation_map.get(analysis_type.value, 'unknown')
parameter_definitions = cross_correlation_parameters
return_values = parameter_definitions.get(selected_method, {}).get('return_values', '')
command = f"{return_values} = WaLSAtools(data1=INPUT_DATA1, data2=INPUT_DATA2, time=TIME_ARRAY, method='{selected_method}', **kwargs)"
# Handle Single 1D signal analysis
elif category.value == 'Single time series analysis' and method.value == '1D signal':
method_map = {'FFT': 'fft', 'Wavelet': 'wavelet', 'Lomb-Scargle': 'lombscargle', 'Welch': 'welch', 'EMD': 'emd'}
selected_method = method_map.get(analysis_type.value, 'unknown')
parameter_definitions = single_series_parameters
return_values = parameter_definitions.get(selected_method, {}).get('return_values', '')
command = f"{return_values} = WaLSAtools(signal=INPUT_DATA, time=TIME_ARRAY, method='{selected_method}', **kwargs)"
else:
print("Invalid configuration.")
return
# Generate and display the command in HTML
html_code = f"""
<div style="font-size: 1.2em; margin-left: 30px; margin-top:15px; margin-bottom: 15px;">Calling sequence:</div>
<div style="display: flex; margin-left: 30px; margin-bottom: 3ch;">
<span style="color: #222; min-width: 4ch;">>>> </span>
<pre style="
white-space: pre-wrap;
word-wrap: break-word;
color: #01016D;
margin: 0;
">{command}</pre>
</div>
"""
display(HTML(html_code))
display_parameters_html(selected_method, category=category.value)
def attach_observers():
"""Attach observers to the dropdown widgets and ensure no duplicates."""
global is_observer_attached
if not is_observer_attached:
detach_observers()
category.observe(clear_output_if_unselected, names='value')
method.observe(clear_output_if_unselected, names='value')
analysis_type.observe(clear_output_if_unselected, names='value')
category.observe(update_method_options, names='value')
method.observe(update_method_options, names='value')
analysis_type.observe(update_sub_method_visibility, names='value')
sub_method.observe(update_calling_sequence, names='value')
analysis_type.observe(update_calling_sequence, names='value')
is_observer_attached = True
def detach_observers():
"""Detach all observers to prevent multiple triggers."""
try:
category.unobserve(update_method_options, names='value')
method.unobserve(update_method_options, names='value')
analysis_type.unobserve(update_sub_method_visibility, names='value')
sub_method.unobserve(update_calling_sequence, names='value')
analysis_type.unobserve(update_calling_sequence, names='value')
except ValueError:
pass
attach_observers()
display(category, method, analysis_type, sub_method, output)
# Unified interactive function for both terminal and Jupyter environments
def interactive():
if is_notebook():
walsatools_jupyter()
else:
walsatools_terminal()
WaLSA_plot_k_omega.py
This module provides functions for plotting k-ω diagrams and filtered datacubes.
WaLSA_plot_k_omega.py
# -----------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
# -----------------------------------------------------------------------------------------------------
# The following codes are baed on those originally written by David B. Jess and Samuel D. T. Grant
# -----------------------------------------------------------------------------------------------------
import numpy as np # type: ignore
from matplotlib.colors import ListedColormap # type: ignore
from WaLSAtools import WaLSA_histo_opt # type: ignore
from scipy.interpolate import griddata # type: ignore
import matplotlib.pyplot as plt # type: ignore
from matplotlib.ticker import FixedLocator, FixedFormatter # type: ignore
import matplotlib.patches as patches # type: ignore
import matplotlib.patheffects as path_effects # type: ignore
from matplotlib.colors import Normalize # type: ignore
def WaLSA_plot_k_omega(
kopower, kopower_xscale, kopower_yscale,
xtitle='Wavenumber (pixel⁻¹)', ytitle='Frequency (mHz)',
xlog=False, ylog=False, xrange=None, yrange=None,
xtick_interval=0.05, ytick_interval=200,
xtick_minor_interval=0.01, ytick_minor_interval=50,
colorbar_label='Log₁₀(Power)', cmap='viridis', cbartab=0.12,
figsize=(9, 4), smooth=True, dypix = 1200, dxpix = 1600, ax=None,
k1=None, k2=None, f1=None, f2=None, colorbar_location='right'
):
"""
Plots the k-omega diagram with corrected orientation and axis alignment.
Automatically clips the power array and scales based on x and y ranges.
Example usage:
WaLSA_plot_k_omega(
kopower=power,
kopower_xscale=wavenumber,
kopower_yscale=frequencies*1000.,
xlog=False, ylog=False,
xrange=(0, 0.3), figsize=(8, 4), cbartab=0.18,
xtitle='Wavenumber (pixel⁻¹)', ytitle='Frequency (mHz)',
colorbar_label='Log₁₀(Oscillation Power)',
f1=0.470*1000, f2=0.530*1000,
k1=0.047, k2=0.25
)
"""
if xrange is None or len(xrange) != 2:
xrange = [np.min(kopower_xscale), np.max(kopower_xscale)]
if yrange is None or len(yrange) != 2:
yrange = [np.min(kopower_yscale), np.max(kopower_yscale)]
# Handle xrange and yrange adjustments
if xrange is not None and len(xrange) == 2 and xrange[0] == 0:
xrange = [np.min(kopower_xscale), xrange[1]]
if yrange is not None and len(yrange) == 2 and yrange[0] == 0:
yrange = [np.min(kopower_yscale), yrange[1]]
# Clip kopower, kopower_xscale, and kopower_yscale based on xrange and yrange
if xrange:
x_min, x_max = xrange
x_indices = np.where((kopower_xscale >= x_min) & (kopower_xscale <= x_max))[0]
# Check if the lower bound is not included
if kopower_xscale[x_indices[0]] > x_min and x_indices[0] > 0:
x_indices = np.insert(x_indices, 0, x_indices[0] - 1)
# Check if the upper bound is not included
if kopower_xscale[x_indices[-1]] < x_max and x_indices[-1] + 1 < len(kopower_xscale):
x_indices = np.append(x_indices, x_indices[-1] + 1)
kopower = kopower[:, x_indices]
kopower_xscale = kopower_xscale[x_indices]
if yrange:
y_min, y_max = yrange
y_indices = np.where((kopower_yscale >= y_min) & (kopower_yscale <= y_max))[0]
# Check if the lower bound is not included
if kopower_yscale[y_indices[0]] > y_min and y_indices[0] > 0:
y_indices = np.insert(y_indices, 0, y_indices[0] - 1)
# Check if the upper bound is not included
if kopower_yscale[y_indices[-1]] < y_max and y_indices[-1] + 1 < len(kopower_yscale):
y_indices = np.append(y_indices, y_indices[-1] + 1)
kopower = kopower[y_indices, :]
kopower_yscale = kopower_yscale[y_indices]
# Pixel coordinates in data space
if xlog:
nxpix = np.logspace(np.log10(xrange[0]), np.log10(xrange[1]), dxpix)
else:
nxpix = np.linspace(xrange[0], xrange[1], dxpix)
if ylog:
nypix = np.logspace(np.log10(yrange[0]), np.log10(yrange[1]), dypix)
else:
nypix = np.linspace(yrange[0], yrange[1], dypix)
# Interpolation over data
interpolator = griddata(
points=(np.repeat(kopower_yscale, len(kopower_xscale)), np.tile(kopower_xscale, len(kopower_yscale))),
values=kopower.ravel(),
xi=(nypix[:, None], nxpix[None, :]),
method='linear'
)
newimage = np.nan_to_num(interpolator)
# Clip the interpolated image to the exact xrange and yrange
x_indices = (nxpix >= xrange[0]) & (nxpix <= xrange[1])
y_indices = (nypix >= yrange[0]) & (nypix <= yrange[1])
newimage = newimage[np.ix_(y_indices, x_indices)]
nxpix = nxpix[x_indices]
nypix = nypix[y_indices]
# Extent for plotting
extent = [nxpix[0], nxpix[-1], nypix[0], nypix[-1]]
# Set up the plot
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
# Load custom colormap
rgb_values = np.loadtxt('Color_Tables/idl_colormap_13.txt') / 255.0
idl_colormap_13 = ListedColormap(rgb_values)
# Compute minimum and maximum values of the data
vmin = np.nanmin(kopower)
vmax = np.nanmax(kopower)
img = ax.imshow(
WaLSA_histo_opt(newimage), extent=extent, origin='lower', aspect='auto', cmap=idl_colormap_13, norm=Normalize(vmin=vmin, vmax=vmax)
)
# Configure axis labels and scales
ax.set_xlabel(xtitle)
ax.set_ylabel(ytitle)
if xlog:
ax.set_xscale('log')
if ylog:
ax.set_yscale('log')
# Configure ticks
ax.xaxis.set_major_locator(plt.MultipleLocator(xtick_interval))
ax.xaxis.set_minor_locator(plt.MultipleLocator(xtick_minor_interval))
ax.yaxis.set_major_locator(plt.MultipleLocator(ytick_interval))
ax.yaxis.set_minor_locator(plt.MultipleLocator(ytick_minor_interval))
# # Add a secondary x and y axes for perido and spatial size
ax.tick_params(axis='both', which='both', direction='out', top=True, right=True)
ax.tick_params(axis='both', which='major', length=7, width=1.1) # Major ticks
ax.tick_params(axis='both', which='minor', length=4, width=1.1) # Minor ticks
major_xticks = ax.get_xticks()
major_yticks = ax.get_yticks()
# Define a function to calculate the period (1000 / frequency)
def frequency_to_period(frequency):
return 1000. / frequency if frequency > 0 else np.inf
# Generate labels for the secondary y-axis based on the primary y-axis ticks
period_labels = [f"{frequency_to_period(tick):.1f}" if tick > 0 else "" for tick in major_yticks]
# Set custom labels for the secondary y-axis
ax_right = ax.secondary_yaxis('right')
ax_right.set_yticks(major_yticks)
ax_right.set_yticklabels(period_labels)
ax_right.set_ylabel('Period (s)', labelpad=12)
# Define a function to calculate spatial size (2π / wavenumber)
def wavenumber_to_spatial_size(wavenumber):
return 2 * np.pi / wavenumber if wavenumber > 0 else np.inf
# Generate labels for the secondary x-axis based on the primary x-axis ticks
spatial_size_labels = [f"{wavenumber_to_spatial_size(tick):.1f}" if tick > 0 else "" for tick in major_xticks]
# Set custom labels for the secondary x-axis
ax_top = ax.secondary_xaxis('top')
ax_top.set_xticks(major_xticks)
ax_top.set_xticklabels(spatial_size_labels)
ax_top.set_xlabel('Spatial Size (pixel)', labelpad=12)
if k1 is not None and k2 is not None and f1 is not None and f2 is not None:
width = k2 - k1
height = f2 - f1
rectangle = patches.Rectangle(
(k1, f1), width, height, zorder=10,
linewidth=1.5, edgecolor='white', facecolor='none', linestyle='--'
)
rectangle.set_path_effects([
path_effects.Stroke(linewidth=2.5, foreground='black'),
path_effects.Normal()
])
# Mark the filtered area
ax.add_patch(rectangle)
# Add colorbar
tick_values = [
vmin,
vmin + (vmax - vmin) * 0.33,
vmin + (vmax - vmin) * 0.67,
vmax
]
# Format the tick labels
tick_labels = [f"{v:.1f}" for v in tick_values]
if colorbar_location == 'top':
cbar = plt.colorbar(img, ax=ax, orientation='horizontal', pad=cbartab, location='top', aspect=30)
cbar.set_label(colorbar_label)
# Override the ticks and tick labels
cbar.set_ticks(tick_values) # Set custom tick locations
cbar.set_ticklabels(tick_labels) # Set custom tick labels
cbar.ax.xaxis.set_major_locator(FixedLocator(tick_values)) # Fix tick locations
cbar.ax.xaxis.set_major_formatter(FixedFormatter(tick_labels)) # Fix tick labels
# Suppress auto ticks completely
cbar.ax.xaxis.set_minor_locator(FixedLocator([])) # Ensure no minor ticks appear
else:
cbar = plt.colorbar(img, ax=ax, orientation='vertical', pad=cbartab, aspect=22)
cbar.set_label(colorbar_label)
# Override the ticks and tick labels
cbar.set_ticks(tick_values) # Set custom tick locations
cbar.set_ticklabels(tick_labels) # Set custom tick labels
cbar.ax.yaxis.set_major_locator(FixedLocator(tick_values)) # Fix tick locations
cbar.ax.yaxis.set_major_formatter(FixedFormatter(tick_labels)) # Fix tick labels
# Suppress auto ticks completely
cbar.ax.yaxis.set_minor_locator(FixedLocator([])) # Ensure no minor ticks appear
return ax
WaLSA_plot_wavelet_spectrum.py
This module provides functions for plotting wavelet power spectra and related visualizations.
WaLSA_plot_wavelet_spectrum.py
# -----------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
# -----------------------------------------------------------------------------------------------------
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
from mpl_toolkits.axes_grid1 import make_axes_locatable
def WaLSA_plot_wavelet_spectrum(t, power, periods, sig_slevel, coi, dt, normalize_power=False,
title='Wavelet Power Spectrum', ylabel='Period [s]', xlabel='Time [s]',
colorbar_label='Power (%)', ax=None, colormap='custom', removespace=False):
"""Plots the wavelet power spectrum of a given signal.
Parameters:
- t (array-like): Time values for the signal.
- power (2D array-like): Wavelet power spectrum.
- periods (array-like): Period values corresponding to the power.
- sig_slevel (2D array-like): Significance levels of the power spectrum.
- coi (array-like): Cone of influence, showing where edge effects might be present.
- dt (float): Time interval between successive time values.
- normalize_power (bool): If True, normalize power to 0-100%. Default is False.
- title (str): Title of the plot. Default is 'Wavelet Power Spectrum'.
- ylabel (str): Y-axis label. Default is 'Period [s]'.
- xlabel (str): X-axis label. Default is 'Time [s]'.
- colorbar_label (str): Label for the color bar. Default is 'Power (%)'.
- ax (matplotlib axis, optional): Axis to plot on. Default is None, which creates a new figure and axis.
- colormap (str or colormap, optional): Colormap to be used. Default is 'custom'.
- removespace (bool, optional): If True, limits the maximum y-range to the peak of the cone of influence.
"""
if ax is None:
fig, ax = plt.subplots(figsize=(12, 6))
# Custom colormap with white at the lower end
if colormap == 'custom':
cmap = LinearSegmentedColormap.from_list('custom_white', ['white', 'blue', 'green', 'yellow', 'red'], N=256)
else:
cmap = plt.get_cmap(colormap)
# Normalize power if requested
if normalize_power:
power = 100 * power / np.nanmax(power)
# Define a larger number of levels to create a continuous color bar
levels = np.linspace(np.nanmin(power), np.nanmax(power), 100)
# Plot the wavelet power spectrum
CS = ax.contourf(t, periods, power, levels=levels, cmap=cmap, extend='neither') # Removed 'extend' mode for straight ends
# 95% significance contour
ax.contour(t, periods, sig_slevel, levels=[1], colors='k', linewidths=[1.0])
if removespace:
max_period = np.max(coi)
else:
max_period = np.max(periods)
# Cone-of-influence
ax.plot(t, coi, '-k', lw=2)
ax.fill(np.concatenate([t, t[-1:] + dt, t[-1:] + dt, t[:1] - dt, t[:1] - dt]),
np.concatenate([coi, [1e-9], [max_period], [max_period], [1e-9]]),
'k', alpha=0.3, hatch='x')
# Log scale for periods
ax.set_ylim([np.min(periods), max_period])
ax.set_yscale('log', base=10)
ax.yaxis.set_major_formatter(ticker.ScalarFormatter())
ax.ticklabel_format(axis='y', style='plain')
ax.invert_yaxis()
# Set axis limits and labels
ax.set_xlim([t.min(), t.max()])
ax.set_ylabel(ylabel, fontsize=14)
ax.set_xlabel(xlabel, fontsize=14)
ax.tick_params(axis='both', which='major', labelsize=12)
ax.set_title(title, fontsize=16)
# Add a secondary y-axis for frequency in Hz
ax_freq = ax.twinx()
# Set limits for the frequency axis based on the `max_period` used for the period axis
min_frequency = 1 / max_period
max_frequency = 1 / np.min(periods)
ax_freq.set_yscale('log', base=10)
ax_freq.set_ylim([max_frequency, min_frequency]) # Adjust frequency range properly
ax_freq.yaxis.set_major_formatter(ticker.ScalarFormatter())
ax_freq.ticklabel_format(axis='y', style='plain')
ax_freq.invert_yaxis()
ax_freq.set_ylabel('Frequency (Hz)', fontsize=14)
ax_freq.tick_params(axis='both', which='major', labelsize=12)
# Add color bar on top with minimal distance
divider = make_axes_locatable(ax)
cax = divider.append_axes('top', size='5%', pad=0.01) # Use 'pad=0.01' for a small fixed distance
cbar = plt.colorbar(CS, cax=cax, orientation='horizontal')
# Move color bar label to the top of the bar
cbar.set_label(colorbar_label, fontsize=12, labelpad=5)
cbar.ax.tick_params(labelsize=10, direction='out', top=True, labeltop=True, bottom=False, labelbottom=False)
cbar.ax.xaxis.set_label_position('top')
# Adjust layout if a new figure was created
if ax is None:
plt.tight_layout()
plt.show()