Skip to content

Under the Hood

We strongly recommend everyone to follow the procedure as instructed here when using WaLSAtools — a user-friendly tool — which gives you all information you need to do your analysis. However, for experts who want to make themselves familiar with the techniques and codes under the hood, inspect them and modify/develop/improve them, some of the main codes are also provided below. Please note that all codes and their dependencies are available in the GitHub repository.

Analysis Modules

WaLSAtools is built upon a collection of analysis modules, each designed for a specific aspect of wave analysis. These modules are combined and accessed through the main WaLSAtools interface, providing a streamlined and user-friendly experience.

Here's a brief overview of the core analysis modules:

WaLSA_speclizer.py

This module provides a collection of spectral analysis techniques, including FFT, Lomb-Scargle, Wavelet, Welch, and EMD/HHT.

WaLSA_speclizer.py
# -----------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
# http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, in press.
# -----------------------------------------------------------------------------------------------------

import numpy as np # type: ignore
from astropy.timeseries import LombScargle # type: ignore
from tqdm import tqdm # type: ignore
from .WaLSA_detrend_apod import WaLSA_detrend_apod # type: ignore
from .WaLSA_confidence import WaLSA_confidence # type: ignore
from .WaLSA_wavelet import cwt, significance # type: ignore
from scipy.signal import welch # type: ignore

# -------------------------------------- Main Function ----------------------------------------
def WaLSA_speclizer(signal=None, time=None, method=None, 
                    dominantfreq=False, averagedpower=False, **kwargs):
    """
    Main function to prepare data and call the specific spectral analysis method.

    Parameters:
        signal (array): The input signal (1D or 3D).
        time (array): The time array of the signal.
        method (str): The method to use for spectral analysis ('fft', 'lombscargle', etc.)
        **kwargs: Additional parameters for data preparation and analysis methods.

    Returns:
        Power spectrum, frequencies, and significance (if applicable).
    """
    if not dominantfreq or not averagedpower:
        if method == 'fft':
            return getpowerFFT(signal, time=time, **kwargs)
        elif method == 'lombscargle':
            return getpowerLS(signal, time=time, **kwargs)
        elif method == 'wavelet':
            return getpowerWavelet(signal, time=time, **kwargs)
        elif method == 'welch':
            return welch_psd(signal, time=time, **kwargs)
        elif method == 'emd':
            return getEMD_HHT(signal, time=time, **kwargs)
        else:
            raise ValueError(f"Unknown method: {method}")
    else:
        return get_dominant_averaged(signal, time=time, method=method, **kwargs)

# -------------------------------------- FFT ----------------------------------------
def getpowerFFT(signal, time, **kwargs):
    """
    Perform FFT analysis on the given signal.

    Parameters:
        signal (array): The input signal (1D).
        time (array): The time array corresponding to the signal.
        siglevel (float): Significance level for the confidence intervals. Default: 0.95.
        nperm (int): Number of permutations for significance testing. Default: 1000.
        nosignificance (bool): If True, skip significance calculation. Default: False.
        apod (float): Extent of apodization edges (of a Tukey window). Default: 0.1.
        pxdetrend (int): Subtract linear trend with time per pixel. Options: 1 (simple) or 2 (advanced). Default: 2.
        polyfit (int): Degree of polynomial fit for detrending the data. If set, a polynomial fit (instead of linear) is applied. Default: None.
        meantemporal (bool): If True, apply simple temporal detrending by subtracting the mean signal from the data, skipping fitting procedures. Default: False.
        meandetrend (bool): If True, subtract the linear trend with time for the image means (spatial detrending). Default: False.
        recon (bool): If True, perform Fourier reconstruction of the input time series. This does not preserve amplitudes but is useful for examining frequencies far from the low-frequency range.  Default: False.
        resample_original (bool): If True, and if recon set True, approximate values close to the original are returned for comparison. Default: False.
        nodetrendapod (bool): If True, neither detrending nor apodization is performed. Default: False.
        amplitude (bool): If True, return the amplitudes of the Fourier transform. Default: False.
        silent (bool): If True, suppress print statements. Default: False.
        **kwargs: Additional parameters for the analysis method.

    Returns:
        Power spectrum, frequencies, significance, amplitudes
    """
    # Define default values for the optional parameters
    defaults = {
        'siglevel': 0.95,
        'significance': None,
        'nperm': 1000,
        'nosignificance': False,
        'apod': 0.1,
        'pxdetrend': 2,
        'meandetrend': False,
        'polyfit': None,
        'meantemporal': False,
        'recon': False,
        'resample_original': False,
        'nodetrendapod': False,
        'amplitude': False,
        'silent': False
    }

    # Update defaults with any user-provided keyword arguments
    params = {**defaults, **kwargs}

    params['siglevel'] = 1 - params['siglevel'] # different convention

    tdiff = np.diff(time)
    cadence = np.median(tdiff)

    # Perform detrending and apodization
    if not params['nodetrendapod']:
        apocube = WaLSA_detrend_apod(
            signal, 
            apod=params['apod'], 
            meandetrend=params['meandetrend'], 
            pxdetrend=params['pxdetrend'], 
            polyfit=params['polyfit'], 
            meantemporal=params['meantemporal'], 
            recon=params['recon'], 
            cadence=cadence, 
            resample_original=params['resample_original'], 
            silent=params['silent']
        )
    else:
        apocube = signal

    nt = len(apocube)  # Length of the time series (1D)

    # Calculate the frequencies
    frequencies = 1. / (cadence * 2) * np.arange(nt // 2 + 1) / (nt // 2)
    frequencies = frequencies[1:]
    powermap = np.zeros(len(frequencies))
    signal = apocube
    spec = np.fft.fft(signal)
    power = 2 * np.abs(spec[1:len(frequencies) + 1]) ** 2
    powermap[:] = power / frequencies[0]

    if params['amplitude']:
        amplitudes = np.zeros((len(signal), len(frequencies)), dtype=np.complex_)
        amplitudes = spec[1:len(frequencies) + 1]
    else: 
        amplitudes = None

    # Calculate significance if requested
    if not params['nosignificance']:
        ps_perm = np.zeros((len(frequencies), params['nperm']))
        for ip in range(params['nperm']):
            perm_signal = np.random.permutation(signal)  # Permuting the original signal
            apocube = WaLSA_detrend_apod(
                perm_signal, 
                apod=params['apod'], 
                meandetrend=params['meandetrend'], 
                pxdetrend=params['pxdetrend'], 
                polyfit=params['polyfit'], 
                meantemporal=params['meantemporal'], 
                recon=params['recon'], 
                cadence=cadence, 
                resample_original=params['resample_original'], 
                silent=True
            )
            perm_spec = np.fft.fft(perm_signal)
            perm_power = 2 * np.abs(perm_spec[1:len(frequencies) + 1]) ** 2
            ps_perm[:, ip] = perm_power

        # Correct call to WaLSA_confidence
        significance = WaLSA_confidence(ps_perm, siglevel=params['siglevel'], nf=len(frequencies))
        significance = significance / frequencies[0]
    else:
        significance = None

    if not params['silent']:
        print("FFT processed.")

    return powermap, frequencies, significance, amplitudes

# -------------------------------------- Lomb-Scargle ----------------------------------------
def getpowerLS(signal, time, **kwargs):
    """
    Perform Lomb-Scargle analysis on the given signal.

    Parameters:
        signal (array): The input signal (1D).
        time (array): The time array corresponding to the signal.
        siglevel (float): Significance level for the confidence intervals. Default: 0.95.
        nperm (int): Number of permutations for significance testing. Default: 1000.
        dy (array): Errors or observational uncertainties associated with the time series.
        fit_mean (bool): If True, include a constant offset as part of the model at each frequency. This improves accuracy, especially for incomplete phase coverage.
        center_data (bool): If True, pre-center the data by subtracting the weighted mean of the input data. This is especially important if fit_mean=False.
        nterms (int): Number of terms to use in the Fourier fit. Default: 1.
        normalization (str): The normalization method for the periodogram. Options: 'standard', 'model', 'log', 'psd'. Default: 'standard'.
        nosignificance (bool): If True, skip significance calculation. Default: False.
        apod (float): Extent of apodization edges (of a Tukey window). Default: 0.1.
        pxdetrend (int): Subtract linear trend with time per pixel. Options: 1 (simple) or 2 (advanced). Default: 2.
        polyfit (int): Degree of polynomial fit for detrending the data. If set, a polynomial fit (instead of linear) is applied. Default: None.
        meantemporal (bool): If True, apply simple temporal detrending by subtracting the mean signal from the data, skipping fitting procedures. Default: False.
        meandetrend (bool): If True, subtract the linear trend with time for the image means (spatial detrending). Default: False.
        recon (bool): If True, perform Fourier reconstruction of the input time series. This does not preserve amplitudes but is useful for examining frequencies far from the low-frequency range.  Default: False.
        resample_original (bool): If True, and if recon set True, approximate values close to the original are returned for comparison. Default: False.
        nodetrendapod (bool): If True, neither detrending nor apodization is performed. Default: False..
        silent (bool): If True, suppress print statements. Default: False.
        **kwargs: Additional parameters for the analysis method.

    Returns:
        Power spectrum, frequencies, and significance (if applicable).
    """
    # Define default values for the optional parameters
    defaults = {
        'siglevel': 0.05,
        'nperm': 1000,
        'nosignificance': False,
        'apod': 0.1,
        'pxdetrend': 2,
        'meandetrend': False,
        'polyfit': None,
        'meantemporal': False,
        'recon': False,
        'resample_original': False,
        'silent': False,  # Ensure 'silent' is included in defaults
        'nodetrendapod': False,
        'dy': None, # Error or sequence of observational errors associated with times t
        'fit_mean': False, # If True include a constant offset as part of the model at each frequency. This can lead to more accurate results, especially in the case of incomplete phase coverage
        'center_data': True, # If True pre-center the data by subtracting the weighted mean of the input data. This is especially important if fit_mean = False
        'nterms': 1, # Number of terms to use in the Fourier fit
        'normalization': 'standard' # The normalization to use for the periodogram. Options are 'standard', 'model', 'log', 'psd'
    }

    # Update defaults with any user-provided keyword arguments
    params = {**defaults, **kwargs}

    params['siglevel'] = 1 - params['siglevel'] # different convention

    tdiff = np.diff(time)
    cadence = np.median(tdiff)

    # Perform detrending and apodization
    if not params['nodetrendapod']:
        apocube = WaLSA_detrend_apod(
            signal, 
            apod=params['apod'], 
            meandetrend=params['meandetrend'], 
            pxdetrend=params['pxdetrend'], 
            polyfit=params['polyfit'], 
            meantemporal=params['meantemporal'], 
            recon=params['recon'], 
            cadence=cadence, 
            resample_original=params['resample_original'], 
            silent=params['silent']
        )
    else:
        apocube = signal

    frequencies, power = LombScargle(time, apocube, dy=params['dy'], fit_mean=params['fit_mean'], 
                                     center_data=params['center_data'], nterms=params['nterms'], 
                                     normalization=params['normalization']).autopower()

    # Calculate significance if needed
    if not params['nosignificance']:
        ps_perm = np.zeros((len(frequencies), params['nperm']))
        for ip in range(params['nperm']):
            perm_signal = np.random.permutation(signal)  # Permuting the original signal
            apocubep = WaLSA_detrend_apod(
                perm_signal, 
                apod=params['apod'], 
                meandetrend=params['meandetrend'], 
                pxdetrend=params['pxdetrend'], 
                polyfit=params['polyfit'], 
                meantemporal=params['meantemporal'], 
                recon=params['recon'], 
                cadence=cadence, 
                resample_original=params['resample_original'], 
                silent=True
            )
            frequencies, perm_power = LombScargle(time, apocubep, dy=params['dy'], fit_mean=params['fit_mean'], 
                                     center_data=params['center_data'], nterms=params['nterms'], 
                                     normalization=params['normalization']).autopower()
            ps_perm[:, ip] = perm_power
        significance = WaLSA_confidence(ps_perm, siglevel=params['siglevel'], nf=len(frequencies))
    else:
        significance = None

    if not params['silent']:
        print("Lomb-Scargle processed.")

    return power, frequencies, significance

# -------------------------------------- Wavelet ----------------------------------------
def getpowerWavelet(signal, time, **kwargs):
    """
    Perform wavelet analysis using the pycwt package.

    Parameters:
        signal (array): The input signal (1D).
        time (array): The time array corresponding to the signal.
        siglevel (float): Significance level for the confidence intervals. Default: 0.95.
        nperm (int): Number of permutations for significance testing. Default: 1000.
        mother (str): The mother wavelet function to use. Default: 'morlet'.
        GWS (bool): If True, calculate the Global Wavelet Spectrum. Default: False.
        RGWS (bool): If True, calculate the Refined Global Wavelet Spectrum (time-integrated power, excluding COI and insignificant areas). Default: False.
        dj (float): Scale spacing. Smaller values result in better scale resolution but slower calculations. Default: 0.025.
        s0 (float): Initial (smallest) scale of the wavelet. Default: 2 * dt.
        J (int): Number of scales minus one. Scales range from s0 up to s0 * 2**(J * dj), giving a total of (J + 1) scales. Default: (log2(N * dt / s0)) / dj.
        lag1 (float): Lag-1 autocorrelation. Default: 0.0.
        apod (float): Extent of apodization edges (of a Tukey window). Default: 0.1.
        pxdetrend (int): Subtract linear trend with time per pixel. Options: 1 (simple) or 2 (advanced). Default: 2.
        polyfit (int): Degree of polynomial fit for detrending the data. If set, a polynomial fit (instead of linear) is applied. Default: None.
        meantemporal (bool): If True, apply simple temporal detrending by subtracting the mean signal from the data, skipping fitting procedures. Default: False.
        meandetrend (bool): If True, subtract the linear trend with time for the image means (spatial detrending). Default: False.
        recon (bool): If True, perform Fourier reconstruction of the input time series. This does not preserve amplitudes but is useful for examining frequencies far from the low-frequency range.  Default: False.
        resample_original (bool): If True, and if recon set True, approximate values close to the original are returned for comparison. Default: False.
        nodetrendapod (bool): If True, neither detrending nor apodization is performed. Default: False.
        silent (bool): If True, suppress print statements. Default: False.
        **kwargs: Additional parameters for the analysis method.

    Returns:
        power: The wavelet power spectrum.
        periods: Corresponding periods.
        sig_slevel: The significance levels.
        coi: The cone of influence.
        Optionally, if global_power=True:
        global_power: Global wavelet power spectrum.
        global_conf: Confidence levels for the global wavelet spectrum.
        Optionally, if RGWS=True:
        rgws_periods: Periods for the refined global wavelet spectrum.
        rgws_power: Refined global wavelet power spectrum.
    """
    # Define default values for the optional parameters similar to IDL
    defaults = {
        'siglevel': 0.95,
        'mother': 'morlet',  # Morlet wavelet as the mother function
        'dj': 0.025,  # Scale spacing
        's0': -1,  # Initial scale
        'J': -1,  # Number of scales
        'lag1': 0.0,  # Lag-1 autocorrelation
        'apod': 0.1,  # Tukey window apodization
        'silent': False,
        'pxdetrend': 2,
        'meandetrend': False,
        'polyfit': None,
        'meantemporal': False,
        'recon': False,
        'resample_original': False,
        'GWS': False,  # If True, calculate global wavelet spectrum
        'RGWS': False,  # If True, calculate refined global wavelet spectrum (excluding COI)
        'nperm': 1000,   # Number of permutations for significance calculation
        'nodetrendapod': False
    }

    # Update defaults with any user-provided keyword arguments
    params = {**defaults, **kwargs}

    tdiff = np.diff(time)
    cadence = np.median(tdiff)

    # Perform detrending and apodization
    if not params['nodetrendapod']:
        apocube = WaLSA_detrend_apod(
            signal, 
            apod=params['apod'], 
            meandetrend=params['meandetrend'], 
            pxdetrend=params['pxdetrend'], 
            polyfit=params['polyfit'], 
            meantemporal=params['meantemporal'], 
            recon=params['recon'], 
            cadence=cadence, 
            resample_original=params['resample_original'], 
            silent=params['silent']
        )
    else:
        apocube = signal

    n = len(apocube)
    dt = cadence

    # Standardize the signal before the wavelet transform
    std_signal = apocube.std()
    norm_signal = apocube / std_signal

    # Determine the initial scale s0 if not provided
    if params['s0'] == -1:
        params['s0'] = 2 * dt

    # Determine the number of scales J if not provided
    if params['J'] == -1:
        params['J'] = int((np.log(float(n) * dt / params['s0']) / np.log(2)) / params['dj'])

    # Perform wavelet transform
    W, scales, frequencies, coi, _, _ = cwt(
        norm_signal,
        dt,
        dj=params['dj'],
        s0=params['s0'],
        J=params['J'],
        wavelet=params['mother']
    )

    power = np.abs(W) ** 2  # Wavelet power spectrum
    periods = 1 / frequencies  # Convert frequencies to periods

    # Calculate the significance levels
    signif, _ = significance(
        1.0,  # Normalized variance
        dt,
        scales,
        0,  # Ignored for now, used for background noise (e.g., red noise)
        params['lag1'],
        significance_level=params['siglevel'],
        wavelet=params['mother']
    )

    # Calculate the significance level for each power value
    sig_matrix = np.ones([len(scales), n]) * signif[:, None]
    sig_slevel = power / sig_matrix

    # Calculate global power if requested
    if params['GWS']:
        global_power = np.mean(power, axis=1)  # Average along the time axis

        dof = n - scales   # the -scale corrects for padding at edges
        global_conf, _ = significance(
            norm_signal,                   # time series data
            dt,              # time step between values
            scales,          # scale vector
            1,               # sigtest = 1 for "time-average" test
            0.0,
            significance_level=params['siglevel'],
            dof=dof,
            wavelet=params['mother']
        )
    else:
        global_power = None
        global_conf = None

    # Calculate refined global wavelet spectrum (RGWS) if requested
    if params['RGWS']:
        isig = sig_slevel < 1.0
        power[isig] = np.nan
        ipower = np.full_like(power, np.nan)
        for i in range(len(coi)):
            pcol = power[:, i]
            valid_idx = periods < coi[i]
            ipower[valid_idx, i] = pcol[valid_idx]
        rgws_power = np.nansum(ipower, axis=1)
    else:
        rgws_power = None

    if not params['silent']:
        print("Wavelet (" + params['mother'] + ") processed.")

    return power, periods, sig_slevel, coi, global_power, global_conf, rgws_power

# -------------------------------------- Welch ----------------------------------------
def welch_psd(signal, time, **kwargs):
    """
    Calculate Welch Power Spectral Density (PSD) and significance levels.

    Parameters:
        signal (array): The 1D time series signal.
        time (array): The time array corresponding to the signal.
        nperseg (int, optional): Length of each segment for analysis. Default: 256.
        noverlap (int, optional): Number of points to overlap between segments. Default: 128.
        window (str, optional): Type of window function used in the Welch method. Default: 'hann'.
        siglevel (float, optional): Significance level for confidence intervals. Default: 0.95.
        nperm (int, optional): Number of permutations for significance testing. Default: 1000.
        silent (bool, optional): If True, suppress print statements. Default: False.
        **kwargs: Additional parameters for the analysis method.

    Returns:
        frequencies: Frequencies at which the PSD is estimated.
        psd: Power Spectral Density values.
        significance: Significance levels for the PSD.
    """
    # Define default values for the optional parameters similar to FFT
    defaults = {
        'siglevel': 0.95,
        'nperm': 1000,  # Number of permutations for significance calculation
        'window': 'hann',  # Window type for Welch method
        'nperseg': 256,  # Length of each segment
        'silent': False,
        'noverlap': 128  # Number of points to overlap between segments
    }

    # Update defaults with any user-provided keyword arguments
    params = {**defaults, **kwargs}

    tdiff = np.diff(time)
    cadence = np.median(tdiff)

    # Calculate Welch PSD
    frequencies, psd = welch(signal, fs=1.0 / cadence, window=params['window'], nperseg=params['nperseg'], noverlap=params['noverlap'])

    # Calculate significance levels using permutation
    ps_perm = np.zeros((len(frequencies), params['nperm']))
    for ip in range(params['nperm']):
        perm_signal = np.random.permutation(signal)  # Permuting the original signal
        _, perm_psd = welch(perm_signal, fs=1.0 / cadence, window=params['window'], nperseg=params['nperseg'], noverlap=params['noverlap'])
        ps_perm[:, ip] = perm_psd

    # Calculate significance level for Welch PSD
    significance = np.percentile(ps_perm, params['siglevel'] * 100, axis=1)

    if not params['silent']:
        print("Welch processed.")

    return psd, frequencies, significance

# -------------------------------------- EMD / EEMD ----------------------------------------
import numpy as np # type: ignore
from .PyEMD import EMD, EEMD # type: ignore
from scipy.stats import norm # type: ignore
from scipy.signal import hilbert, welch # type: ignore
from scipy.signal import find_peaks # type: ignore
from scipy.fft import fft, fftfreq # type: ignore

# Function to apply EMD and return Intrinsic Mode Functions (IMFs)
def apply_emd(signal, time):
    emd = EMD()
    emd.FIXE_H = 5
    imfs = emd.emd(signal, time) # max_imfs=7, emd.FIXE_H = 10
    return imfs

# Function to apply EEMD and return Intrinsic Mode Functions (IMFs)
def apply_eemd(signal, time, noise_std=0.2, num_realizations=1000):
    eemd = EEMD()
    eemd.FIXE_H = 5
    eemd.noise_seed(12345)
    eemd.noise_width = noise_std
    eemd.ensemble_size = num_realizations
    imfs = eemd.eemd(signal, time)
    return imfs

# Function to generate white noise IMFs for significance testing
def generate_white_noise_imfs(signal_length, time, num_realizations, use_eemd=None, noise_std=0.2):
    white_noise_imfs = []
    if use_eemd:
        eemd = EEMD()
        eemd.FIXE_H = 5
        eemd.noise_seed(12345)
        eemd.noise_width = noise_std
        eemd.ensemble_size = num_realizations
        for _ in range(num_realizations):
            white_noise = np.random.normal(size=signal_length)
            imfs = eemd.eemd(white_noise, time)
            white_noise_imfs.append(imfs)
    else:
        emd = EMD()
        emd.FIXE_H = 5
        for _ in range(num_realizations):
            white_noise = np.random.normal(size=signal_length)
            imfs = emd.emd(white_noise, time)
            white_noise_imfs.append(imfs)
    return white_noise_imfs

# Function to calculate the energy of each IMF
def calculate_imf_energy(imfs):
    return [np.sum(imf**2) for imf in imfs]

# Function to determine the significance of the IMFs
def test_imf_significance(imfs, white_noise_imfs):
    imf_energies = calculate_imf_energy(imfs)
    num_imfs = len(imfs)

    white_noise_energies = [calculate_imf_energy(imf_set) for imf_set in white_noise_imfs]

    significance_levels = []
    for i in range(num_imfs):
        # Collect the i-th IMF energy from each white noise realization
        white_noise_energy_dist = [imf_energies[i] for imf_energies in white_noise_energies if i < len(imf_energies)]
        mean_energy = np.mean(white_noise_energy_dist)
        std_energy = np.std(white_noise_energy_dist)

        z_score = (imf_energies[i] - mean_energy) / std_energy
        significance_level = 1 - norm.cdf(z_score)
        significance_levels.append(significance_level)

    return significance_levels

# Function to compute instantaneous frequency of IMFs using Hilbert transform
def compute_instantaneous_frequency(imfs, time):
    instantaneous_frequencies = []
    for imf in imfs:
        analytic_signal = hilbert(imf)
        instantaneous_phase = np.unwrap(np.angle(analytic_signal))
        instantaneous_frequency = np.diff(instantaneous_phase) / (2.0 * np.pi * np.diff(time))
        instantaneous_frequency = np.abs(instantaneous_frequency)  # Ensure frequencies are positive
        instantaneous_frequencies.append(instantaneous_frequency)
    return instantaneous_frequencies

# Function to compute HHT power spectrum
def compute_hht_power_spectrum(imfs, instantaneous_frequencies, freq_bins):
    power_spectrum = np.zeros_like(freq_bins)

    for i in range(len(imfs)):
        amplitude = np.abs(hilbert(imfs[i]))
        power = amplitude**2
        for j in range(len(instantaneous_frequencies[i])):
            freq_idx = np.searchsorted(freq_bins, instantaneous_frequencies[i][j])
            if freq_idx < len(freq_bins):
                power_spectrum[freq_idx] += power[j]

    return power_spectrum

# Function to compute significance level for HHT power spectrum
def compute_significance_level(white_noise_imfs, freq_bins, time):
    all_power_spectra = []
    for imfs in white_noise_imfs:
        instantaneous_frequencies = compute_instantaneous_frequency(imfs, time)
        power_spectrum = compute_hht_power_spectrum(imfs, instantaneous_frequencies, freq_bins)
        all_power_spectra.append(power_spectrum)

    # Compute the 95th percentile power spectrum as the significance level
    significance_level = np.percentile(all_power_spectra, 95, axis=0)

    return significance_level

## Custom rounding function
def custom_round(freq):
    if freq < 1:
        return round(freq, 1)
    else:
        return round(freq)

def smooth_power_spectrum(power_spectrum, window_size=5):
    return np.convolve(power_spectrum, np.ones(window_size)/window_size, mode='same')

# Function to identify and print significant peaks
def identify_significant_peaks(freq_bins, power_spectrum, significance_level):
    peaks, _ = find_peaks(power_spectrum, height=significance_level)
    significant_frequencies = freq_bins[peaks]
    rounded_frequencies = [custom_round(freq) for freq in significant_frequencies]
    rounded_frequencies_two = [round(freq, 1) for freq in significant_frequencies]
    print("Significant Frequencies (Hz):", rounded_frequencies_two)
    print("Significant Frequencies (Hz):", rounded_frequencies)
    return significant_frequencies

# Function to generate randomized signals
def generate_randomized_signals(signal, num_realizations):
    randomized_signals = []
    for _ in range(num_realizations):
        randomized_signal = np.random.permutation(signal)
        randomized_signals.append(randomized_signal)
    return randomized_signals

# Function to calculate the FFT power spectrum for each IMF
def calculate_fft_psd_spectra(imfs, time):
    psd_spectra = []
    for imf in imfs:
        N = len(imf)
        T = time[1] - time[0]  # Assuming uniform sampling
        yf = fft(imf)
        xf = fftfreq(N, T)[:N // 2]
        psd = (2.0 / N) * (np.abs(yf[:N // 2]) ** 2) / (N * T)
        psd_spectra.append((xf, psd))
    return psd_spectra

# Function to calculate the FFT power spectrum for randomized signals
def calculate_fft_psd_spectra_randomized(imfs, time, num_realizations):
    all_psd_spectra = []
    for imf in imfs:
        randomized_signals = generate_randomized_signals(imf, num_realizations)
        for signal in randomized_signals:
            N = len(signal)
            T = time[1] - time[0]  # Assuming uniform sampling
            yf = fft(signal)
            psd = (2.0 / N) * (np.abs(yf[:N // 2]) ** 2) / (N * T)
            all_psd_spectra.append(psd)
    return np.array(all_psd_spectra)

# Function to calculate the 95th percentile confidence level
def calculate_confidence_level(imfs, time, num_realizations):
    confidence_levels = []
    for imf in imfs:
        all_psd_spectra = calculate_fft_psd_spectra_randomized([imf], time, num_realizations)
        confidence_level = np.percentile(all_psd_spectra, 95, axis=0)
        confidence_levels.append(confidence_level)
    return confidence_levels

# Function to calculate the Welch PSD for each IMF
def calculate_welch_psd_spectra(imfs, fs):
    psd_spectra = []
    for imf in imfs:
        f, psd = welch(imf, fs=fs)
        psd_spectra.append((f, psd))
    return psd_spectra

# Function to calculate the Welch PSD for randomized signals
def calculate_welch_psd_spectra_randomized(imfs, fs, num_realizations):
    all_psd_spectra = []
    for imf in imfs:
        randomized_signals = generate_randomized_signals(imf, num_realizations)
        for signal in randomized_signals:
            f, psd = welch(signal, fs=fs)
            all_psd_spectra.append(psd)
    return np.array(all_psd_spectra)

# Function to calculate the 95th percentile confidence level
def calculate_confidence_level_welch(imfs, fs, num_realizations):
    confidence_levels = []
    for imf in imfs:
        all_psd_spectra = calculate_welch_psd_spectra_randomized([imf], fs, num_realizations)
        confidence_level = np.percentile(all_psd_spectra, 95, axis=0)
        confidence_levels.append(confidence_level)
    return confidence_levels

# ************** Main EMD/EEMD routine **************
def getEMD_HHT(signal, time, **kwargs):
    """
    Calculate EMD/EEMD and HHT

    Parameters:
        signal (array): The input signal (1D).
        time (array): The time array of the signal.
        siglevel (float): Significance level for the confidence intervals. Default: 0.95.
        nperm (int): Number of permutations for significance testing. Default: 1000.
        EEMD (bool): If True, use Ensemble Empirical Mode Decomposition (EEMD) instead of Empirical Mode Decomposition (EMD). Default: False.
        Welch_psd (bool): If True, calculate Welch PSD spectra instead of FFT PSD spectra (for the psd_spectra and psd_confidence_levels). Default: False.
        apod (float): Extent of apodization edges (of a Tukey window). Default: 0.1.
        pxdetrend (int): Subtract linear trend with time per pixel. Options: 1 (simple) or 2 (advanced). Default: 2.
        polyfit (int): Degree of polynomial fit for detrending the data. If set, a polynomial fit (instead of linear) is applied. Default: None.
        meantemporal (bool): If True, apply simple temporal detrending by subtracting the mean signal from the data, skipping fitting procedures. Default: False.
        meandetrend (bool): If True, subtract the linear trend with time for the image means (spatial detrending). Default: False.
        recon (bool): If True, perform Fourier reconstruction of the input time series. This does not preserve amplitudes but is useful for examining frequencies far from the low-frequency range. Default: False.
        resample_original (bool): If True, and if recon is set to True, approximate values close to the original are returned for comparison. Default: False.
        nodetrendapod (bool): If True, neither detrending nor apodization is performed. Default: False.
        silent (bool): If True, suppress print statements. Default: False.
        **kwargs: Additional parameters for the analysis method.

    Returns:
        frequencies: Frequencies at which the PSD is estimated.
        psd: Power Spectral Density values.
        significance: Significance levels for the PSD.
        ‘marginal’ spectrum of Hilbert-Huang Transform (HHT)
    """
    # Define default values for the optional parameters similar to FFT
    defaults = {
        'siglevel': 0.95,
        'nperm': 1000,  # Number of permutations for significance calculation
        'apod': 0.1,  # Tukey window apodization
        'silent': False,
        'pxdetrend': 2,
        'meandetrend': False,
        'polyfit': None,
        'meantemporal': False,
        'recon': False,
        'resample_original': False,
        'EEMD': False,  # If True, calculate Ensemble Empirical Mode Decomposition (EEMD)
        'nodetrendapod': False,
        'Welch_psd': False,  # If True, calculate Welch PSD spectra
        'significant_imfs': False  # If True, return only significant IMFs (and for associated calculations)
    }

    # Update defaults with any user-provided keyword arguments
    params = {**defaults, **kwargs}

    tdiff = np.diff(time)
    cadence = np.median(tdiff)

    # Perform detrending and apodization
    if not params['nodetrendapod']:
        apocube = WaLSA_detrend_apod(
            signal, 
            apod=params['apod'], 
            meandetrend=params['meandetrend'], 
            pxdetrend=params['pxdetrend'], 
            polyfit=params['polyfit'], 
            meantemporal=params['meantemporal'], 
            recon=params['recon'], 
            cadence=cadence, 
            resample_original=params['resample_original'], 
            silent=params['silent']
        )

    n = len(apocube)
    dt = cadence
    fs = 1 / dt  # Sampling frequency

    if params['EEMD']:
        imfs = apply_eemd(signal, time)
        use_eemd = True
    else:
        imfs = apply_emd(signal, time)
        use_eemd = False

    white_noise_imfs = generate_white_noise_imfs(len(signal), time, params['nperm'], use_eemd=use_eemd)

    IMF_significance_levels = test_imf_significance(imfs, white_noise_imfs)

    # Determine significant IMFs
    significant_imfs_val = [imf for imf, sig_level in zip(imfs, IMF_significance_levels) if sig_level < params['siglevel']]
    if params['significant_imfs']:
        imfs = significant_imfs_val

    # Compute instantaneous frequencies of IMFs
    instantaneous_frequencies = compute_instantaneous_frequency(imfs, time)

    if params['Welch_psd']:
        # Calculate and plot Welch PSD spectra of (significant) IMFs
        psd_spectra = calculate_welch_psd_spectra(imfs, fs)
        # Calculate 95th percentile confidence levels for Welch PSD spectra
        psd_confidence_levels = calculate_confidence_level_welch(imfs, fs, params['nperm'])
    else:
        # Calculate and plot FFT PSD spectra of (significant) IMFs
        psd_spectra = calculate_fft_psd_spectra(imfs, time)
        # Calculate 95th percentile confidence levels for FFT PSD spectra
        psd_confidence_levels = calculate_confidence_level(imfs, time, params['nperm'])

    # Define frequency bins for HHT power spectrum
    max_freq = max([freq.max() for freq in instantaneous_frequencies])
    HHT_freq_bins = np.linspace(0, max_freq, 100)

    # Compute HHT power spectrum of (significant) IMFs
    HHT_power_spectrum = compute_hht_power_spectrum(imfs, instantaneous_frequencies, HHT_freq_bins)
    # Compute significance level for HHT power spectrum
    HHT_significance_level = compute_significance_level(white_noise_imfs, HHT_freq_bins, time)    

    if not params['silent']:
        if params['EEMD']:
            print("EEMD processed.")
        else:
            print("EMD processed.")

    return HHT_power_spectrum, HHT_significance_level, HHT_freq_bins, psd_spectra, psd_confidence_levels, imfs, IMF_significance_levels, instantaneous_frequencies


# ----------------------- Dominant Frequency & Averaged Power -------------------------------
def get_dominant_averaged(cube, time, **kwargs):
    """
    Analyze a 3D data cube to compute the dominant frequency and averaged power.

    Parameters:
        cube (3D array): Input data cube, expected in either 'txy' or 'xyt' format.
        time (array): Time array of the data cube.
        method (str): Analysis method ('fft' or 'wavelet').
        format (str): Format of the data cube ('txy' or 'xyt'). Default is 'txy'.
        **kwargs: Additional parameters specific to the analysis method.

    Returns:
        dominantfreq (float): Dominant frequency of the data cube.
        averagedpower (float): Averaged power of the data cube.
    """
    defaults = {
        'method': 'fft',
        'format': 'txy',
        'silent': False,
        'GWS': False,  # If True, calculate global wavelet spectrum
        'RGWS': True,  # If True, calculate refined global wavelet spectrum
    }

    # Update defaults with any user-provided keyword arguments
    params = {**defaults, **kwargs}

    # Check and adjust format if necessary
    if params['format'] == 'xyt':
        cube = np.transpose(cube, (2, 0, 1))  # Convert 'xyt' to 'txy'
    elif params['format'] != 'txy':
        raise ValueError("Unsupported format. Choose 'txy' or 'xyt'.")

    # Initialize arrays for storing results across spatial coordinates
    nt, nx, ny = cube.shape
    dominantfreq = np.zeros((nx, ny))

    if params['method'] == 'fft':
        method_name = 'FFT'
    elif params['method'] == 'lombscargle':
        method_name = 'Lomb-Scargle'
    elif params['method'] == 'wavelet':
        method_name = 'Wavelet'
    elif params['method'] == 'welch':
        method_name = 'Welch'

    if not params['silent']:
            print(f"Processing {method_name} for a 3D cube with format '{params['format']}' and shape {cube.shape}.")
            print(f"Calculating Dominant frequencies and/or averaged power spectrum ({method_name}) ....")

    # Iterate over spatial coordinates and apply the chosen analysis method
    if params['method'] == 'fft':
        for x in tqdm(range(nx), desc="Processing x", leave=True):
            for y in range(ny):
                signal = cube[:, x, y]
                fft_power, fft_freqs, _, _ = getpowerFFT(signal, time, silent=True, nosignificance=True, **kwargs)
                if x == 0 and y == 0:
                    powermap = np.zeros((nx, ny, len(fft_freqs)))
                powermap[x, y, :] = fft_power
                # Determine the dominant frequency for this pixel
                dominantfreq[x, y] = fft_freqs[np.argmax(fft_power)]

        # Calculate the averaged power over all pixels
        averagedpower = np.mean(powermap, axis=(0, 1))
        frequencies = fft_freqs
        print("\nAnalysis completed.")

    elif params['method'] == 'lombscargle':
        for x in tqdm(range(nx), desc="Processing x", leave=True):
            for y in range(ny):
                signal = cube[:, x, y]
                ls_power, ls_freqs, _ = getpowerLS(signal, time, silent=True, **kwargs)
                if x == 0 and y == 0:
                    powermap = np.zeros((nx, ny, len(ls_freqs)))
                powermap[x, y, :] = ls_power
                # Determine the dominant frequency for this pixel
                dominantfreq[x, y] = ls_freqs[np.argmax(ls_power)]

        # Calculate the averaged power over all pixels
        averagedpower = np.mean(powermap, axis=(0, 1))
        frequencies = ls_freqs
        print("\nAnalysis completed.")

    elif params['method'] == 'wavelet':
        for x in tqdm(range(nx), desc="Processing x", leave=True):
            for y in range(ny):
                signal = cube[:, x, y]
                if params['GWS']:
                    _, wavelet_periods, _, _, wavelet_power, _, _ = getpowerWavelet(signal, time, silent=True, **kwargs)
                elif params['RGWS']:
                    _, wavelet_periods, _, _, _, _, wavelet_power = getpowerWavelet(signal, time, silent=True, **kwargs)
                if x == 0 and y == 0:
                    powermap = np.zeros((nx, ny, len(wavelet_periods)))
                wavelet_freq = 1. / wavelet_periods
                powermap[x, y, :] = wavelet_power
                # Determine the dominant frequency for this pixel
                dominantfreq[x, y] = wavelet_freq[np.argmax(wavelet_power)]

        # Calculate the averaged power over all pixels
        averagedpower = np.mean(powermap, axis=(0, 1))
        frequencies = wavelet_freq
        print("\nAnalysis completed.")

    elif params['method'] == 'welch':
        for x in tqdm(range(nx), desc="Processing x", leave=True):
            for y in range(ny):
                signal = cube[:, x, y]
                welch_power, welch_freqs, _ = welch_psd(signal, time, silent=True, **kwargs)
                if x == 0 and y == 0:
                    powermap = np.zeros((nx, ny, len(welch_freqs)))
                powermap[x, y, :] = welch_power
                # Determine the dominant frequency for this pixel
                dominantfreq[x, y] = welch_freqs[np.argmax(welch_power)]

        # Calculate the averaged power over all pixels
        averagedpower = np.mean(powermap, axis=(0, 1))
        frequencies = welch_freqs
        print("\nAnalysis completed.")

    else:
        raise ValueError("Unsupported method. Choose 'fft', 'lombscargle', 'wavelet', or 'welch'.")

    return dominantfreq, averagedpower, frequencies, powermap

walsa_wavelet.py

This module implements the Wavelet Transform and related functionalities.

walsa_wavelet.py

WaLSA_k_omega.py

This module provides functions for performing k-ω analysis and filtering in spatio-temporal datasets.

WaLSA_k_omega.py
# --------------------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
# http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, in press.
# --------------------------------------------------------------------------------------------------------------
# The following codes are baed on those originally written by Rob Rutten, David B. Jess, and Samuel D. T. Grant
# --------------------------------------------------------------------------------------------------------------

import numpy as np # type: ignore
from scipy.optimize import curve_fit # type: ignore
from scipy.signal import convolve # type: ignore
# --------------------------------------------------------------------------------------------

def gaussian_function(sigma, width=None):
    """
    Create a Gaussian kernel that closely matches the IDL implementation.

    Parameters:
        sigma (float or list of floats): Standard deviation(s) of the Gaussian.
        width (int or list of ints, optional): Width of the Gaussian kernel.

    Returns:
        np.ndarray: The Gaussian kernel.
    """

    # sigma = np.array(sigma, dtype=np.float64)
    sigma = np.atleast_1d(np.array(sigma, dtype=np.float64))  # Ensure sigma is at least 1D
    if np.any(sigma <= 0):
        raise ValueError("Sigma must be greater than 0.")

    # width = np.array(width)
    width = np.atleast_1d(np.array(width))  # Ensure width is always at least 1D
    width = np.maximum(width.astype(np.int64), 1)  # Convert to integers and ensure > 0
    if np.any(width <= 0):
        raise ValueError("Width must be greater than 0.")

    nSigma = np.size(sigma)
    if nSigma > 8:
        raise ValueError('Sigma can have no more than 8 elements')
    nWidth = np.size(width)
    if nWidth > nSigma:
        raise ValueError('Incorrect width specification')
    if (nWidth == 1) and (nSigma > 1):
        width = np.full(nSigma, width[0])

    # kernel = np.zeros(tuple(width.astype(int)), dtype=np.float64)
    kernel = np.zeros(tuple(width[:nSigma].astype(int)), dtype=np.float64)  # Match kernel size to nSigma

    temp = np.zeros(8, dtype=np.int64)
    temp[:nSigma] = width[:nSigma]
    width = temp

    if nSigma == 2:
        b = a = 0
        if nSigma >= 1:
            a = np.linspace(0, width[0] - 1, width[0]) - width[0] // 2 + (0 if width[0] % 2 else 0.5)
        if nSigma >= 2:
            b = np.linspace(0, width[1] - 1, width[1]) - width[1] // 2 + (0 if width[1] % 2 else 0.5)

        a1 = b1 = 0
        # Nested loop for kernel computation (to be completed for larger nSigma ....)
        for bb in range(width[1]): 
            b1 = (b[bb] ** 2) / (2 * sigma[1] ** 2) if nSigma >= 2 else 0
            for aa in range(width[0]): 
                a1 = (a[aa] ** 2) / (2 * sigma[0] ** 2) if nSigma >= 1 else 0
                kernel[aa, bb] = np.exp(
                    -np.clip((a1 + b1), -1e3, 1e3)
                )
    elif nSigma == 1:
        a = 0
        if nSigma >= 1:
            a = np.linspace(0, width[0] - 1, width[0]) - width[0] // 2 + (0 if width[0] % 2 else 0.5)

        a1 = 0

        # Nested loop for kernel computation
        for aa in range(width[0]): 
            a1 = (a[aa] ** 2) / (2 * sigma[0] ** 2) if nSigma >= 1 else 0
            kernel[aa] = np.exp(
                -np.clip(a1, -1e3, 1e3)
            )

    kernel = np.nan_to_num(kernel, nan=0.0, posinf=0.0, neginf=0.0)

    if np.sum(kernel) == 0:
        raise ValueError("Generated Gaussian kernel is invalid (all zeros).")

    return kernel


def walsa_radial_distances(shape):
    """
    Compute the radial distance array for a given shape.

    Parameters:
        shape (tuple): Shape of the array, typically (ny, nx).

    Returns:
        numpy.ndarray: Array of radial distances.
    """
    if not (isinstance(shape, tuple) and len(shape) == 2):
        raise ValueError("Shape must be a tuple with two elements, e.g., (ny, nx).")

    y, x = np.indices(shape)
    cy, cx = (np.array(shape) - 1) / 2
    return np.sqrt((x - cx) ** 2 + (y - cy) ** 2)

def avgstd(array):
    """
    Calculate the average and standard deviation of an array.
    """
    avrg = np.sum(array) / array.size
    stdev = np.sqrt(np.sum((array - avrg) ** 2) / array.size)
    return avrg, stdev

def linear(x, *p):
    """
    Compute a linear model y = p[0] + x * p[1].
    (used in temporal detrending )
    """
    if len(p) < 2:
        raise ValueError("Parameters p[0] and p[1] must be provided.")
    ymod = p[0] + x * p[1]
    return ymod


def gradient(xy, *p):
    """
    Gradient function for fitting spatial trends.
    Parameters:
        xy: Tuple of grid coordinates (x, y).
        p: Coefficients [offset, slope_x, slope_y].
    """
    x, y = xy  # Unpack the tuple
    if len(p) < 3:
        raise ValueError("Parameters p[0], p[1], and p[2] must be provided.")
    return p[0] + x * p[1] + y * p[2]


def apod3dcube(cube, apod):
    """
    Apodizes a 3D cube in all three coordinates, with detrending.

    Parameters:
        cube : Input 3D data cube with dimensions (nx, ny, nt).
        apod (float): Apodization factor (0 means no apodization).

    Returns:
        Apodized 3D cube.
    """
    # Get cube dimensions
    nt, nx, ny = cube.shape
    apocube = np.zeros_like(cube, dtype=np.float32)

    # Define temporal apodization
    apodt = np.ones(nt, dtype=np.float32)
    if apod != 0:
        apodrimt = nt * apod
        apodrimt = int(apodrimt)  # Ensure apodrimt is an integer
        apodt[:apodrimt] = (np.sin(np.pi / 2. * np.arange(apodrimt) / apodrimt)) ** 2
        apodt = apodt * np.roll(np.flip(apodt), 1)  
        # Apply symmetrical apodization

    # Temporal detrending (mean-image trend, not per pixel)
    ttrend = np.zeros(nt, dtype=np.float32)
    tf = np.arange(nt) + 1.0
    for it in range(nt):
        img = cube[it, :, :]
        # ttrend[it], _ = avgstd(img)
        ttrend[it] = np.mean(img)

    # Fit the trend with a linear model
    fitp, _ = curve_fit(linear, tf, ttrend, p0=[1000.0, 0.0])
    # fit = fitp[0] + tf * fitp[1]
    fit = linear(tf, *fitp)

    # Temporal apodization per (x, y) column
    for it in range(nt):
        img = cube[it, :, :]
        apocube[it, :, :] = (img - fit[it]) * apodt[it]

    # Define spatial apodization
    apodx = np.ones(nx, dtype=np.float32)
    apody = np.ones(ny, dtype=np.float32)
    if apod != 0:
        apodrimx = apod * nx
        apodrimy = apod * ny
        apodrimx = int(apodrimx)  # Ensure apodrimx is an integer
        apodrimy = int(apodrimy)  # Ensure apodrimy is an integer
        apodx[:apodrimx] = (np.sin(np.pi / 2. * np.arange(int(apodrimx)) / apodrimx)) ** 2
        apody[:apodrimy] = (np.sin(np.pi / 2. * np.arange(int(apodrimy)) / apodrimy)) ** 2
        apodx = apodx * np.roll(np.flip(apodx), 1)
        apody = apody * np.roll(np.flip(apody), 1)
        apodxy = np.outer(apodx, apody)
    else:
        apodxy = np.outer(apodx, apody)

    # Spatial gradient removal and apodizing per image
    # xf, yf = np.meshgrid(np.arange(nx), np.arange(ny), indexing='ij')
    xf = np.ones((nx, ny), dtype=np.float32)
    yf = np.copy(xf)
    for it in range(nt):
        img = apocube[it, :, :]
        # avg, _ = avgstd(img)
        avg = np.mean(img)

        # Ensure xf, yf, and img are properly defined
        assert xf.shape == yf.shape == img.shape, "xf, yf, and img must have matching shapes."

        # Flatten the inputs for curve_fit
        x_flat = xf.ravel()
        y_flat = yf.ravel()
        img_flat = img.ravel()

        # Ensure lengths match
        assert len(x_flat) == len(y_flat) == len(img_flat), "Flattened inputs must have the same length."

        # Call curve_fit with initial parameters
        fitp, _ = curve_fit(gradient, (x_flat, y_flat), img_flat, p0=[1000.0, 0.0, 0.0])

        # Apply the fitted parameters
        fit = gradient((xf, yf), *fitp)
        apocube[it, :, :] = (img - fit) * apodxy + avg

    return apocube


def ko_dist(sx, sy, double=False):
    """
    Set up Pythagorean distance array from the origin.
    """
    # Create distance grids for x and y
    # (computing dx and dy using floating-point division)
    dx = np.tile(np.arange(sx / 2 + 1) / (sx / 2), (int(sy / 2 + 1), 1)).T
    dy = np.flip(np.tile(np.arange(sy / 2 + 1) / (sy / 2), (int(sx / 2 + 1), 1)), axis=0)
    # Compute dxy
    dxy = np.sqrt(dx**2 + dy**2) * (min(sx, sy) / 2 + 1)
    # Initialize afstanden
    afstanden = np.zeros((sx, sy), dtype=np.float64)
    # Assign dxy to the first quadrant (upper-left)
    afstanden[:sx // 2 + 1, :sy // 2 + 1] = dxy
    # Second quadrant (upper-right) - 90° clockwise
    afstanden[sx // 2:, :sy // 2 + 1] = np.flip(np.roll(dxy[:-1, :], shift=-1, axis=1), axis=1)
    # Third quadrant (lower-left) - 270° clockwise
    afstanden[:sx // 2 + 1, sy // 2:] = np.flip(np.roll(dxy[:, :-1], shift=-1, axis=0), axis=0)
    # Fourth quadrant (lower-right) - 180° rotation
    afstanden[sx // 2:, sy // 2:] = np.flip(dxy[:-1, :-1], axis=(0, 1))     

    # Convert to integers if 'double' is False
    if not double:
        afstanden = np.round(afstanden).astype(int)

    return afstanden


def averpower(cube):
    """
    Compute 2D (k_h, f) power array by circular averaging over k_x, k_y.

    Parameters:
    cube : Input 3D data cube of dimensions (nx, ny, nt).

    Returns:
    avpow : 2D array of average power over distances, dimensions (maxdist+1, nt/2+1).
    """
    # Get cube dimensions
    nt, nx, ny = cube.shape

    # Perform FFT in all three dimensions (first in time direction)
    fftcube = np.fft.fft(np.fft.fft(np.fft.fft(cube, axis=0)[:nt//2+1, :, :], axis=1), axis=2)

    # Set up distances
    afstanden = ko_dist(nx, ny)  # Integer-rounded Pythagoras array

    maxdist = min(nx, ny) // 2 + 1  # Largest quarter circle

    # Initialize average power array
    avpow = np.zeros((maxdist + 1, nt // 2 + 1), dtype=np.float64)

    # Compute average power over all k_h distances, building power(k_h, f)
    for i in range(maxdist + 1):
        where_indices = np.where(afstanden == i)
        for j in range(nt // 2 + 1):
            w1 = fftcube[j, :, :][where_indices]
            avpow[i, j] = np.sum(np.abs(w1) ** 2) / len(where_indices)

    return avpow


def walsa_kopower_funct(datacube, **kwargs):
    """
    Calculate k-omega diagram
    (Fourier power at temporal frequency f against horizontal spatial wavenumber k_h)

    Origonally written in IDL by Rob Rutten (RR) assembly of Alfred de Wijn's routines (2010)
    - Translated into Pythn by Shahin Jafarzadeh (2024)

    Parameters:
        datacube: Input data cube [t, x, y].
        arcsecpx (float): Spatial sampling in arcsec/pixel.
        cadence (float): Temporal sampling in seconds.
        apod: fractional extent of apodization edges; default 0.1
        kmax: maximum k_h axis as fraction of Nyquist value, default 0.2
        fmax: maximum f axis as fraction of Nyquist value, default 0.5
        minpower: minimum of power plot range, default maxpower-5
        maxpower: maximum of power plot range, default alog10(max(power))

    Returns:
        k-omega power map.
    """

    # Set default parameter values
    defaults = {
        'apod': 0.1,
        'kmax': 1.0,
        'fmax': 1.0,
        'minpower': None,
        'maxpower': None
    }

    # Update defaults with user-provided values
    params = {**defaults, **kwargs}

    # Apodize the cube
    apocube = apod3dcube(datacube, params['apod'])

    # Compute radially-averaged power
    avpow = averpower(apocube)

    return avpow

# -------------------------------------- Main Function ----------------------------------------
def WaLSA_k_omega(signal, time=None, **kwargs):
    """
    NAME: WaLSA_k_omega
        part of -- WaLSAtools --

    * Main function to calculate and plot k-omega diagram.

    ORIGINAL CODE: QUEEns Fourier Filtering (QUEEFF) code
    WRITTEN, ANNOTATED, TESTED AND UPDATED in IDL BY:
    (1) Dr. David B. Jess
    (2) Dr. Samuel D. T. Grant
    The original code along with its manual can be downloaded at: https://bit.ly/37mx9ic

    WaLSA_k_omega (in IDL): A lightly modified version of the original code (i.e., a few additional keywords added) by Dr. Shahin Jafarzadeh
    - Translated into Pythn by Shahin Jafarzadeh (2024)

    Parameters:
        signal (array): Input datacube, normally in the form of [x, y, t] or [t, x, y]. Note that the input datacube must have identical x and y dimensions. If not, the datacube will be cropped accordingly.
        time (array): Time array corresponding to the input datacube.
        pixelsize (float): Spatial sampling of the input datacube. If not given, it is plotted in units of 'pixel'.
        filtering (bool): If True, filtering is applied, and the filtered datacube (filtered_cube) is returned. Otherwise, None is returned. Default: False.
        f1 (float): Optional lower (temporal) frequency to filter, in Hz.
        f2 (float): Optional upper (temporal) frequency to filter, in Hz.
        k1 (float): Optional lower (spatial) wavenumber to filter, in units of pixelsize^-1 (k = (2 * π) / wavelength).
        k2 (float): Optional upper (spatial) wavenumber to filter, in units of pixelsize^-1.
        spatial_torus (bool): If True, makes the annulus used for spatial filtering have a Gaussian-shaped profile, useful for preventing aliasing. Default: True.
        temporal_torus (bool): If True, makes the temporal filter have a Gaussian-shaped profile, useful for preventing aliasing. Default: True.
        no_spatial_filt (bool): If True, ensures no spatial filtering is performed on the dataset (i.e., only temporal filtering is applied).
        no_temporal_filt (bool): If True, ensures no temporal filtering is performed on the dataset (i.e., only spatial filtering is applied).
        silent (bool): If True, suppresses the k-ω diagram plot.
        smooth (bool): If True, power is smoothed. Default: True.
        mode (int): Output power mode: 0 = log10(power) (default), 1 = linear power, 2 = sqrt(power) = amplitude.
        processing_maps (bool): If True, the function returns the processing maps (spatial_fft_map, torus_map, spatial_fft_filtered_map, temporal_fft, temporal_filter, temporal_frequencies, spatial_frequencies). Otherwise, they are all returned as None. Default: False.

    OUTPUTS:
        power : 2D array of power (see mode for the scale).
        frequencies : 1D array of frequencies (in mHz).
        wavenumber : 1D array of wavenumber (in pixelsize^-1).
        filtered_cube : 3D array of filtered datacube (if filtering is set).
        processing_maps (if set to True)

    IF YOU USE THIS CODE, THEN PLEASE CITE THE ORIGINAL PUBLICATION WHERE IT WAS USED:
    Jess et al. 2017, ApJ, 842, 59 (http://adsabs.harvard.edu/abs/2017ApJ...842...59J)
    """

    # Set default parameter values
    defaults = {
        'pixelsize': 1,
        'format': 'txy',
        'filtering': False,
        'f1': None,
        'f2': None,
        'k1': None,
        'k2': None,
        'spatial_torus': True,
        'temporal_torus': True,
        'no_spatial_filt': False,
        'no_temporal_filt': False,
        'silent': False,
        'xlog': False,
        'ylog': False,
        'xrange': None,
        'yrange': None,
        'nox2': False,
        'noy2': False,
        'smooth': True,
        'mode': 0,
        'xtitle': 'Wavenumber',
        'xtitle_units': '(pixel⁻¹)',
        'ytitle': 'Frequency',
        'yttitle_units': '(Hz)',
        'x2ndaxistitle': 'Spatial size',
        'y2ndaxistitle': 'Period',
        'x2ndaxistitle_units': '(pixel)',
        'y2ndaxistitle_units': '(s)',
        'processing_maps': False
    }

    # Update defaults with user-provided values
    params = {**defaults, **kwargs}

    tdiff = np.diff(time)
    cadence = np.median(tdiff)

    pixelsize = params['pixelsize']
    filtered_cube = None
    spatial_fft_map = None
    torus_map = None
    spatial_fft_filtered_map = None
    temporal_fft = None
    temporal_frequencies = None
    spatial_frequencies = None

    # Check and adjust format if necessary
    if params['format'] == 'xyt':
        cube = np.transpose(cube, (2, 0, 1))  # Convert 'xyt' to 'txy'
    elif params['format'] != 'txy':
        raise ValueError("Unsupported format. Choose 'txy' or 'xyt'.")
    if not params['silent']:
            print(f"Processing k-ω analysis for a 3D cube with format '{params['format']}' and shape {signal.shape}.")

    # Input dimensions
    nt, nx, ny = signal.shape
    if nx != ny:
        min_dim = min(nx, ny)
        signal = signal[:min_dim, :min_dim, :]
    nt, nx, ny = signal.shape

    # Calculating the Nyquist frequencies
    spatial_nyquist = (2 * np.pi) / (pixelsize * 2)
    temporal_nyquist = 1 / (cadence * 2)

    print("")
    print(f"Input datacube size (t,x,y): {signal.shape}")
    print("")
    print("Spatially, the important values are:")
    print(f"    2-pixel size = {pixelsize * 2:.2f} {params['x2ndaxistitle_units']}")
    print(f"    Nyquist wavenumber = {spatial_nyquist:.2f} {params['xtitle_units']}")
    if params['no_spatial_filt']:
        print("*** NO SPATIAL FILTERING WILL BE PERFORMED ***")
    print("")
    print("Temporally, the important values are:")
    print(f"    2-element duration (Nyquist period) = {cadence * 2:.2f} {params['y2ndaxistitle_units']}")
    print(f"    Time series duration = {cadence * signal.shape[2]:.2f} {params['y2ndaxistitle_units']}")
    temporal_nyquist = 1 / (cadence * 2)
    print(f"    Nyquist frequency = {temporal_nyquist:.2f} {params['yttitle_units']}")
    if params['no_temporal_filt']:
        print("***NO TEMPORAL FILTERING WILL BE PERFORMED***")
    print("")

    # Generate k-omega power map
    print("Constructing a k-ω diagram of the input datacube..........")
    print("")
    # Make the k-omega diagram using the proven method of Rob Rutten
    kopower = walsa_kopower_funct(signal)

    # Scales
    xsize_kopower = kopower.shape[0]
    dxsize_kopower = spatial_nyquist / float(xsize_kopower - 1)
    kopower_xscale = np.arange(xsize_kopower) * dxsize_kopower  # in pixel⁻¹

    ysize_kopower = kopower.shape[1]
    dysize_kopower = temporal_nyquist / float(ysize_kopower - 1)
    kopower_yscale = (np.arange(ysize_kopower) * dysize_kopower) # in Hz

    # Generate Gaussian Kernel
    Gaussian_kernel = gaussian_function(sigma=[0.65, 0.65], width=3)
    Gaussian_kernel_norm = np.nansum(Gaussian_kernel)  # Normalize kernel sum

    # Copy kopower to kopower_plot
    kopower_plot = kopower.copy()

    # Convolve kopower (ignoring zero-th element, starting from index 1)
    kopower_plot[:, 1:] = convolve(
        kopower[:, 1:], 
        Gaussian_kernel, 
        mode='same'
    ) / Gaussian_kernel_norm

    # Normalize to frequency resolution (in mHz)
    freq = kopower_yscale[1:]
    if freq[0] == 0:
        freq0 = freq[1]
    else:
        freq0 = freq[0]
    kopower_plot /= freq0

    # Apply logarithmic or square root transformation based on mode
    if params['mode'] == 0:  # Logarithmic scaling
        kopower_plot = np.log10(kopower_plot)
    elif params['mode'] == 2:  # Square root scaling
        kopower_plot = np.sqrt(kopower_plot)

    # Normalize the power - preferred for plotting
    komegamap = np.clip(
        kopower_plot[1:, 1:],
        np.nanmin(kopower_plot[1:, 1:]),
        np.nanmax(kopower_plot[1:, 1:])
    )

    kopower_zscale = kopower_plot[1:,1:]

    # Rotate kopower counterclockwise by 90 degrees
    komegamap = np.rot90(komegamap, k=1)
    # Flip vertically (y-axis)
    komegamap = np.flip(komegamap, axis=0) 

    # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    # Filtering implementation
    if params['filtering']:
        # Extract parameters from the dictionary
        k1 = params['k1']
        k2 = params['k2']
        f1 = params['f1']
        f2 = params['f2']

        if params['no_spatial_filt']:
            k1 = kopower_xscale[1]  # Default lower wavenumber
            k2 = np.nanmax(kopower_xscale)  # Default upper wavenumber
        # Ensure k1 and k2 are within valid bounds
        if k1 is None or k1 <= 0.0:
            k1 = kopower_xscale[1]
        if k2 is None or k2 > np.nanmax(kopower_xscale):
            k2 = np.nanmax(kopower_xscale)

        if params['no_temporal_filt']:
            f1 = kopower_yscale[1]
            f2 = np.nanmax(kopower_yscale)
        # Ensure f1 and f2 are within valid bounds
        if f1 is None or f1 <= 0.0:
            f1 = kopower_yscale[1] 
        if f2 is None or f2 > np.nanmax(kopower_yscale):
            f2 = np.nanmax(kopower_yscale) 

        print("Start filtering (in k-ω space) ......")
        print("")
        print(f"The preserved wavenumbers are [{k1:.3f}, {k2:.3f}] {params['xtitle_units']}")
        print(f"The preserved spatial sizes are [{(2 * np.pi) / k2:.3f}, {(2 * np.pi) / k1:.3f}] {params['x2ndaxistitle_units']}")
        print("")
        print(f"The preserved frequencies are [{f1:.3f}, {f2:.3f}] {params['yttitle_units']}")
        print(f"The preserved periods are [{int(1 / (f2))}, {int(1 / (f1))}] {params['y2ndaxistitle_units']}")
        print("")

        # Perform the 3D Fourier transform
        print("Making a 3D Fourier transform of the input datacube ..........")
        threedft = np.fft.fftshift(np.fft.fftn(signal))

        # Calculate the frequency axes for the 3D FFT
        temp_x = np.arange((nx - 1) // 2) + 1
        is_N_even = (nx % 2 == 0)
        if is_N_even:
            spatial_frequencies_orig = (
                np.concatenate([[0.0], temp_x, [nx / 2], -nx / 2 + temp_x]) / (nx * pixelsize)) * (2.0 * np.pi)
        else:
            spatial_frequencies_orig = (
                np.concatenate([[0.0], temp_x, [-(nx / 2 + 1)] + temp_x]) / (nx * pixelsize)) * (2.0 * np.pi)

        temp_t = np.arange((nt - 1) // 2) + 1  # Use integer division for clarity
        is_N_even = (nt % 2 == 0)
        if is_N_even:
            temporal_frequencies_orig = (
                np.concatenate([[0.0], temp_t, [nt / 2], -nt / 2 + temp_t])) / (nt * cadence)
        else:
            temporal_frequencies_orig = (
                np.concatenate([[0.0], temp_t, [-(nt / 2 + 1)] + temp_t])) / (nt * cadence)

        # Compensate frequency axes for the use of FFT shift
        indices = np.where(spatial_frequencies_orig >= 0)[0]
        spatial_positive_frequencies = len(indices)
        if len(spatial_frequencies_orig) % 2 == 0:
            spatial_frequencies = np.roll(spatial_frequencies_orig, spatial_positive_frequencies - 2)
        else:
            spatial_frequencies = np.roll(spatial_frequencies_orig, spatial_positive_frequencies - 1)

        tindices = np.where(temporal_frequencies_orig >= 0)[0]
        temporal_positive_frequencies = len(tindices)
        if len(temporal_frequencies_orig) % 2 == 0:
            temporal_frequencies = np.roll(temporal_frequencies_orig, temporal_positive_frequencies - 2)
        else:
            temporal_frequencies = np.roll(temporal_frequencies_orig, temporal_positive_frequencies - 1)

        # Ensure the threedft aligns with the new frequency axes
        if len(temporal_frequencies_orig) % 2 == 0:
            for x in range(nx):
                for y in range(ny):
                    threedft[:, x, y] = np.roll(threedft[:, x, y], -1)

        if len(spatial_frequencies_orig) % 2 == 0:
            for z in range(nt):
                threedft[z, :, :] = np.roll(threedft[z, :, :], shift=(-1, -1), axis=(0, 1))

        # Convert frequencies and wavenumbers of interest into FFT datacube pixels
        pixel_k1_positive = np.argmin(np.abs(spatial_frequencies_orig - k1))
        pixel_k2_positive = np.argmin(np.abs(spatial_frequencies_orig - k2))
        pixel_f1_positive = np.argmin(np.abs(temporal_frequencies - f1))
        pixel_f2_positive = np.argmin(np.abs(temporal_frequencies - f2))
        pixel_f1_negative = np.argmin(np.abs(temporal_frequencies + f1))
        pixel_f2_negative = np.argmin(np.abs(temporal_frequencies + f2))

        torus_depth = int((pixel_k2_positive - pixel_k1_positive) / 2) * 2
        torus_center = int(((pixel_k2_positive - pixel_k1_positive) / 2) + pixel_k1_positive)

        if params['spatial_torus'] and not params['no_spatial_filt']:
            # Create a filter ring preserving equal wavenumbers for both kx and ky
            # This forms a torus to preserve an integrated Gaussian shape across the width of the annulus
            spatial_torus = np.zeros((torus_depth, nx, ny)) 
            for i in range(torus_depth // 2 + 1):
                spatial_ring = np.logical_xor(
                    (walsa_radial_distances((nx, ny)) <= (torus_center - i)),
                    (walsa_radial_distances((nx, ny)) <= (torus_center + i + 1))
                )
                spatial_ring = spatial_ring.astype(int)  # Convert True -> 1 and False -> 0
                spatial_ring[spatial_ring > 0] = 1.
                spatial_ring[spatial_ring != 1] = 0.
                spatial_torus[i, :, :] = spatial_ring
                spatial_torus[torus_depth - i - 1, :, :] = spatial_ring

            # Integrate through the torus to find the spatial filter
            spatial_ring_filter = np.nansum(spatial_torus, axis=0) / float(torus_depth)

            spatial_ring_filter = spatial_ring_filter / np.nanmax(spatial_ring_filter)  # Ensure the peaks are at 1.0

        if not params['spatial_torus'] and not params['no_spatial_filt']:
            spatial_ring_filter = (
                (walsa_radial_distances((nx, ny)) <= (torus_center - int(torus_depth / 2))).astype(int) -
                (walsa_radial_distances((nx, ny)) <= (torus_center + int(torus_depth / 2) + 1)).astype(int)
            )
            spatial_ring_filter = spatial_ring_filter / np.nanmax(spatial_ring_filter)  # Ensure the peaks are at 1.0
            spatial_ring_filter[spatial_ring_filter != 1] = 0

        if params['no_spatial_filt']:
            spatial_ring_filter = np.ones((nx, ny))

        if not params['no_temporal_filt'] and params['temporal_torus']:
        # CREATE A GAUSSIAN TEMPORAL FILTER TO PREVENT ALIASING
            temporal_filter = np.zeros(nt, dtype=float)
            filter_width = pixel_f2_positive - pixel_f1_positive

            # Determine sigma based on filter width
            if filter_width < 25:
                sigma = 3
            if filter_width >= 25 and filter_width < 30:
                sigma = 4
            if filter_width >= 30 and filter_width < 40:
                sigma = 5
            if filter_width >= 40 and filter_width < 45:
                sigma = 6
            if filter_width >= 45 and filter_width < 50:
                sigma = 7
            if filter_width >= 50 and filter_width < 55:
                sigma = 8
            if filter_width >= 55 and filter_width < 60:
                sigma = 9
            if filter_width >= 60 and filter_width < 65:
                sigma = 10
            if filter_width >= 65 and filter_width < 70:
                sigma = 11
            if filter_width >= 70 and filter_width < 80:
                sigma = 12
            if filter_width >= 80 and filter_width < 90:
                sigma = 13
            if filter_width >= 90 and filter_width < 100:
                sigma = 14
            if filter_width >= 100 and filter_width < 110:
                sigma = 15
            if filter_width >= 110 and filter_width < 130:
                sigma = 16
            if filter_width >= 130:
                sigma = 17

            # Generate the Gaussian kernel
            temporal_gaussian = gaussian_function(sigma=sigma, width=filter_width)

            # Apply the Gaussian to the temporal filter
            temporal_filter[pixel_f1_positive:pixel_f2_positive] = temporal_gaussian
            temporal_filter[pixel_f2_negative:pixel_f1_negative] = temporal_gaussian
            # Normalize the filter to ensure the peaks are at 1.0
            temporal_filter /= np.nanmax(temporal_filter)

        if not params['no_temporal_filt'] and not params['temporal_torus']:
            temporal_filter = np.zeros(nt, dtype=float)
            temporal_filter[pixel_f1_positive:pixel_f2_positive + 1] = 1.0
            temporal_filter[pixel_f2_negative:pixel_f1_negative + 1] = 1.0

        if params['no_temporal_filt']:
            temporal_filter = np.ones(nt, dtype=float)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # Create useful variables for plotting (if needed), demonstrating the filtering process
        if params['processing_maps']:
            # Define the spatial frequency step for plotting
            spatial_dx = spatial_frequencies[1] - spatial_frequencies[0]
            spatial_dy = spatial_frequencies[1] - spatial_frequencies[0]

            # Torus map
            torus_map = {
                'data': spatial_ring_filter,
                'dx': spatial_dx,
                'dy': spatial_dy,
                'xc': 0,
                'yc': 0,
                'time': '',
                'units': 'pixels'
            }

            # Compute the total spatial FFT
            spatial_fft = np.nansum(threedft, axis=0)

            # Spatial FFT map
            spatial_fft_map = {
                'data': np.log10(spatial_fft),
                'dx': spatial_dx,
                'dy': spatial_dy,
                'xc': 0,
                'yc': 0,
                'time': '',
                'units': 'pixels'
            }

            # Spatial FFT filtered and its map
            spatial_fft_filtered = spatial_fft * spatial_ring_filter
            spatial_fft_filtered_map = {
                'data': np.log10(np.maximum(spatial_fft_filtered, 1e-15)),
                'dx': spatial_dx,
                'dy': spatial_dy,
                'xc': 0,
                'yc': 0,
                'time': '',
                'units': 'pixels'
            }

            # Compute the total temporal FFT
            temporal_fft = np.nansum(np.nansum(threedft, axis=2), axis=1)

        else:
            spatial_fft_map = None
            torus_map = None
            spatial_fft_filtered_map = None
            temporal_fft = None
            temporal_filter, temporal_frequencies, spatial_frequencies = None, None, None

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

        # Apply the Gaussian filters to the data to prevent aliasing
        for i in range(nt): 
            threedft[i, :, :] *= spatial_ring_filter

        for x in range(nx):  
            for y in range(ny):  
                threedft[:, x, y] *= temporal_filter

        # ALSO NEED TO ENSURE THE threedft ALIGNS WITH THE OLD FREQUENCY AXES USED BY THE /center CALL
        if len(temporal_frequencies_orig) % 2 == 0:
            for x in range(nx):  
                for y in range(ny): 
                    threedft[:, x, y] = np.roll(threedft[:, x, y], shift=1, axis=0)

        if len(spatial_frequencies_orig) % 2 == 0:
            for t in range(nt):  
                threedft[t, :, :] = np.roll(threedft[t, :, :], shift=(1, 1), axis=(0, 1))
                threedft[z, :, :] = np.roll(np.roll(threedft[z, :, :], shift=1, axis=0), shift=1, axis=1)

        # Inverse FFT to get the filtered cube
        # filtered_cube = np.real(np.fft.ifftn(threedft, norm='ortho'))
        filtered_cube = np.real(np.fft.ifftn(np.fft.ifftshift(threedft)))

    else:
        filtered_cube = None

        print("Filtered datacube generated.")

    return komegamap, kopower_xscale[1:], kopower_yscale[1:], filtered_cube, spatial_fft_map, torus_map, spatial_fft_filtered_map, temporal_fft, temporal_filter, temporal_frequencies, spatial_frequencies

WaLSA_pod.py

This module implements Proper Orthogonal Decomposition (POD), as well as Spectral POD (SPOD), for analysing multi-dimensional data and extracting dominant spatial patterns.

WaLSA_pod.py
# --------------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
# http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, in press.
# --------------------------------------------------------------------------------------------------------
# The following codes are based on those originally written by Jonathan E. Higham & Luiz A. C. A. Schiavo
# --------------------------------------------------------------------------------------------------------

import numpy as np # type: ignore
from scipy.signal import welch, find_peaks # type: ignore
from scipy.linalg import svd # type: ignore
from scipy.optimize import curve_fit # type: ignore
import math

def print_pod_results(results):
    """
    Print a summary of the results from WaLSA_pod, including parameter descriptions, types, and shapes.

    Parameters:
    -----------
    results : dict
        The dictionary containing all POD results and relevant outputs.
    """
    descriptions = {
        'input_data': 'Original input data, mean subtracted (Shape: (Nt, Ny, Nx))',
        'spatial_mode': 'Reshaped spatial modes matching the dimensions of the input data (Shape: (Nmodes, Ny, Nx))',
        'temporal_coefficient': 'Temporal coefficients associated with each spatial mode (Shape: (Nmodes, Nt))',
        'eigenvalue': 'Eigenvalues corresponding to singular values squared (Shape: (Nmodes))',
        'eigenvalue_contribution': 'Eigenvalue contribution of each mode (Shape: (Nmodes))',
        'cumulative_eigenvalues': 'Cumulative percentage of eigenvalues for the first "num_cumulative_modes" modes (Shape: (num_cumulative_modes))',
        'combined_welch_psd': 'Combined Welch power spectral density for the temporal coefficients of the firts "num_modes" modes (Shape: (Nf))',
        'frequencies': 'Frequencies identified in the Welch spectrum (Shape: (Nf))',
        'combined_welch_significance': 'Significance threshold of the combined Welch spectrum (Shape: (Nf,))',
        'reconstructed': 'Reconstructed frame at the specified timestep (or for the entire time series) using the top "num_modes" modes (Shape: (Ny, Nx))',
        'sorted_frequencies': 'Frequencies identified in the Welch combined power spectrum (Shape: (Nfrequencies))',
        'frequency_filtered_modes': 'Frequency-filtered spatial POD modes for the first "num_top_frequencies" frequencies (Shape: (Nt, Ny, Nx, num_top_frequencies))',
        'frequency_filtered_modes_frequencies': 'Frequencies corresponding to the frequency-filtered modes (Shape: (num_top_frequencies))',
        'SPOD_spatial_modes': 'SPOD spatial modes if SPOD is used (Shape: (Nspod_modes, Ny, Nx))',
        'SPOD_temporal_coefficients': 'SPOD temporal coefficients if SPOD is used (Shape: (Nspod_modes, Nt))',
        'p': 'Left singular vectors (spatial modes) from SVD (Shape: (Nx, Nmodes))',
        's': 'Singular values from SVD (Shape: (Nmodes))',
        'a': 'Right singular vectors (temporal coefficients) from SVD (Shape: (Nmodes, Nt))'
    }

    print("\n---- POD/SPOD Results Summary ----\n")
    for key, value in results.items():
        desc = descriptions.get(key, 'No description available')
        shape = np.shape(value) if value is not None else 'None'
        dtype = type(value).__name__
        print(f"{key} ({dtype}, Shape: {shape}): {desc}")
    print("\n----------------------------------")

def spod(data, **kwargs):
    """
    Perform Spectral Proper Orthogonal Decomposition (SPOD) analysis on input data.
    Steps:
    1. Load the data.
    2. Compute a time average.
    3. Build the correlation matrix to compute the POD using the snapshot method.
    4. Compute eigenvalue decomposition for the correlation matrix for the fluctuation field. 
    The eigenvalues and eigenvectors can be computed by any eigenvalue function or by an SVD procedure, since the correlation matrix is symmetric positive-semidefinite.
    5. After obtaining the eigenvalues and eigenvectors, compute the temporal and spatial modes according to the snapshot method.

    Parameters:
    -----------
    data : np.ndarray
        3D data array with shape (time, x, y) or similar. 
    **kwargs : dict, optional
        Additional keyword arguments to configure the analysis.

    Returns:
    --------
    SPOD_spatial_modes : np.ndarray
    SPOD_temporal_coefficients : np.ndarray
    """
    # Set default parameter values
    defaults = {
        'silent': False,
        'num_modes': None,
        'filter_size': None,
        'periodic_matrix': True
    }

    # Update defaults with user-provided values
    params = {**defaults, **kwargs}

    if params['num_modes'] is None:
        params['num_modes'] = data.shape[0]
    if params['filter_size'] is None:
        params['filter_size'] = params['num_modes']

    if not params['silent']:
        print("Starting SPOD analysis ....")

    nsnapshots, ny, nx = data.shape

    # Center the data by subtracting the mean
    time_average = data.mean(axis=0)
    fluctuation_field = data * 0  # Initialize fluctuation field with zeros
    for n in range(nsnapshots):
        fluctuation_field[n,:,:] = data[n,:,:] - time_average  # Subtract time-average from each snapshot

    # Build the correlation matrix (snapshot method)
    correlation_matrix = np.zeros((nsnapshots, nsnapshots))
    for i in range(nsnapshots):
        for j in range(i, nsnapshots):
            correlation_matrix[i, j] = (fluctuation_field[i, :, :] * fluctuation_field[j, :, :]).sum() / (nsnapshots * nx * ny)
            correlation_matrix[j, i] = correlation_matrix[i, j]

    # SPOD correlation matrix with periodic boundary conditions
    nmatrix = np.zeros((3 * nsnapshots, 3 * nsnapshots))
    nmatrix[nsnapshots:2 * nsnapshots, nsnapshots:2 * nsnapshots] = correlation_matrix[0:nsnapshots, 0:nsnapshots]

    if params['periodic_matrix']:
        for i in range(3):
            for j in range(3):
                xs = i * nsnapshots  # x-offset for periodic positioning
                ys = j * nsnapshots  # y-offset for periodic positioning
                nmatrix[0 + xs:nsnapshots + xs, 0 + ys:nsnapshots + ys] = correlation_matrix[0:nsnapshots, 0:nsnapshots]

    # Apply Gaussian filter
    gk = np.zeros((2 * params['filter_size'] + 1))  # Create 1D Gaussian kernel
    esp = 8.0  # Exponential parameter controlling the spread of the filter
    sumgk = 0.0  # Sum of the Gaussian kernel for normalization

    for n in range(2 * params['filter_size'] + 1):
        k = -params['filter_size'] + n  # Offset from the center of the kernel
        gk[n] = math.exp(-esp * k**2.0 / (params['filter_size']**2.0))  # Gaussian filter formula
        sumgk += gk[n]  # Sum of kernel values for normalization

    # Filter the extended correlation matrix
    for j in range(nsnapshots, 2 * nsnapshots):
        for i in range(nsnapshots, 2 * nsnapshots):
            aux = 0.0  # Initialize variable for the weighted sum
            for n in range(2 * params['filter_size'] + 1):
                k = -params['filter_size'] + n  # Offset from the current index
                aux += nmatrix[i + k, j + k] * gk[n]  # Apply Gaussian weighting
            nmatrix[i, j] = aux / sumgk  # Normalize by the sum of the Gaussian kernel

    # Extract the SPOD correlation matrix from the central part of the filtered matrix
    spectral_matrix = nmatrix[nsnapshots:2 * nsnapshots, nsnapshots:2 * nsnapshots]

    # Perform SVD to compute SPOD
    UU, SS, VV = svd(spectral_matrix, full_matrices=True, compute_uv=True)

    # Temporal coefficients
    SPOD_temporal_coefficients = np.zeros((nsnapshots, nsnapshots))
    for k in range(nsnapshots):
        SPOD_temporal_coefficients[:, k] = np.sqrt(SS[k] * nsnapshots) * UU[:, k]

    # Extract spatial SPOD modes
    SPOD_spatial_modes = np.zeros((params['num_modes'], ny, nx))
    for m in range(params['num_modes']):
        for t in range(nsnapshots):
            SPOD_spatial_modes[m, :, :] += fluctuation_field[t, :, :] * SPOD_temporal_coefficients[t, m] / (SS[m] * nsnapshots)

    if not params['silent']:
        print(f"SPOD analysis completed.")

    return SPOD_spatial_modes, SPOD_temporal_coefficients


# -------------------------------------- Main Function ----------------------------------------
def WaLSA_pod(signal, time, **kwargs):
    """
    Perform Proper Orthogonal Decomposition (POD) analysis on input data.

    Parameters:
        signal (array): 3D data cube with shape (time, x, y) or similar.
        time (array): 1D array representing the time points for each time step in the data.
        num_modes (int, optional): Number of top modes to compute. Default is None (all modes).
        num_top_frequencies (int, optional): Number of top frequencies to consider. Default is None (all frequencies).
        top_frequencies (list, optional): List of top frequencies to consider. Default is None.
        num_cumulative_modes (int, optional): Number of cumulative modes to consider. Default is None (all modes).
        welch_nperseg (int, optional): Number of samples per segment for Welch's method. Default is 150.
        welch_noverlap (int, optional): Number of overlapping samples for Welch's method. Default is 25.
        welch_nfft (int, optional): Number of points for the FFT. Default is 2^14.
        welch_fs (int, optional): Sampling frequency for the data. Default is 2.
        nperm (int, optional): Number of permutations for significance testing. Default is 1000.
        siglevel (float, optional): Significance level for the Welch spectrum. Default is 0.95.
        timestep_to_reconstruct (int, optional): Timestep of the datacube to reconstruct using the top modes. Default is 0.
        num_modes_reconstruct (int, optional): Number of modes to use for reconstruction. Default is None (all modes).
        reconstruct_all (bool, optional): If True, reconstruct the entire time series using the top modes. Default is False.
        spod (bool, optional): If True, perform Spectral Proper Orthogonal Decomposition (SPOD) analysis. Default is False.
        spod_filter_size (int, optional): Filter size for SPOD analysis. Default is None.
        spod_num_modes (int, optional): Number of SPOD modes to compute. Default is None.
        print_results (bool, optional): If True, print a summary of results. Default is True.

    **kwargs : Additional keyword arguments to configure the analysis.

    Returns:
    results : dict
        A dictionary containing all computed POD results and relevant outputs. 
        See 'descriptions' in the 'print_pod_results' function on top of this page.
    """
    # Set default parameter values
    defaults = {
        'silent': False,
        'num_modes': None,
        'num_top_frequencies': None,
        'top_frequencies': None,
        'num_cumulative_modes': None,
        'welch_nperseg': 150,
        'welch_noverlap': 25,
        'welch_nfft': 2**14,
        'welch_fs': 2,
        'nperm': 1000,
        'siglevel': 0.95,
        'timestep_to_reconstruct': 0,
        'reconstruct_all': False,
        'num_modes_reconstruct': None,
        'spod': False,
        'spod_filter_size': None,
        'spod_num_modes': None,
        'print_results': True  # Print summary of results by default
    }

    # Update defaults with user-provided values
    params = {**defaults, **kwargs}

    data = signal

    if params['num_modes'] is None:
        params['num_modes'] = data.shape[0]

    if params['num_top_frequencies'] is None:
        params['num_top_frequencies'] = min(params['num_modes'], 10)

    if params['num_cumulative_modes'] is None:
        params['num_cumulative_modes'] = min(params['num_modes'], 10)

    if params['num_modes_reconstruct'] is None:
        params['num_modes_reconstruct'] = min(params['num_modes'], 10)

    if not params['silent']:
        print("Starting POD analysis ....")
        print(f"Processing a 3D cube with shape {data.shape}.")

    # The first step is to read in the data and then perform the SVD (or POD). Before we do this, 
    # we need to reshape the matrix such that it is an N x T matrix where N is the column vectorized set of each spatial image and T is the temporal domain. 
    # We also need to ensure that the mean is subtracted from the data; this will ensure we are looking only at the variance of the data, and mode 1 will not be contaminated with the mean image.
    # Reshape the 3D data into 2D array where each row is a vector from the original 3D data
    inp = np.reshape(data, (data.shape[0], data.shape[1] * data.shape[2])).astype(np.float32)
    inp = inp.T  # Transpose the matrix to have spatial vectors as columns

    # Center the data by subtracting the mean
    mean_per_row = np.nanmean(inp, axis=1, keepdims=True)
    mean_replicated = np.tile(mean_per_row, (1, data.shape[0]))
    inp_centered = inp - mean_replicated

    # Input data, mean subtracted
    input_dat = np.zeros((data.shape[0], data.shape[1], data.shape[2]))
    for im in range(data.shape[0]):
        input_dat[im, :, :] = np.reshape(inp_centered[:, im], (data.shape[1], data.shape[2]))

    # Perform SVD to compute POD
    p, s, a = svd(inp_centered, full_matrices=False)
    sorg = s.copy()  # Store original singular values
    eigenvalue = s**2  # Convert singular values to eigenvalues

    # Reshape spatial modes to have the same shape as the input data
    num_modes = p.shape[1]
    spatial_mode = np.zeros((num_modes, data.shape[1], data.shape[2]))
    for m in range(num_modes):
        spatial_mode[m, :, :] = np.reshape(p[:, m], (data.shape[1], data.shape[2]))

    # Extract temporal coefficients
    temporal_coefficient = a[:num_modes, :]

    # Calculate eigenvalue contributions
    eigenvalue_contribution = eigenvalue / np.sum(eigenvalue)

    # Calculate cumulative eigenvalues
    cumulative_eigenvalues = []
    for m in range(params['num_cumulative_modes']): 
        contm = 100 * eigenvalue[0:m] / np.sum(eigenvalue)
        cumulative_eigenvalues.append(np.sum(contm)) 

    if params['reconstruct_all']:
        # Reconstruct the entire time series using the top 'num_modes_reconstruct' modes
        reconstructed = np.zeros((data.shape[0], data.shape[1], data.shape[2]))
        for tindex in range(data.shape[0]):
            reconim = np.zeros((data.shape[1], data.shape[2]))
            for i in range(params['num_modes_reconstruct']):
                reconim=reconim+np.reshape(p[:, i], (data.shape[1], data.shape[2]))*a[i, tindex]*sorg[i]
            reconstructed[tindex, :, :] = reconim
    else:
        # Reconstruct the specified timestep using the top 'num_modes_reconstruct' modes
        reconstructed = np.zeros((data.shape[1], data.shape[2]))
        for i in range(params['num_modes_reconstruct']):
            reconstructed=reconstructed+np.reshape(p[:, i], (data.shape[1], data.shape[2]))*a[i, params['timestep_to_reconstruct']]*sorg[i]

    #---------------------------------------------------------------------------------
    # Combined Welch power spectrum and its significance
    #---------------------------------------------------------------------------------
    # Compute Welch power spectrum for each mode and combine them for 'num_modes' modes
    combined_welch_psd = []
    for m in range(params['num_modes']):
        frequencies, px = welch(a[m, :] - np.mean(a[m, :]), 
                      nperseg=params['welch_nperseg'], 
                      noverlap=params['welch_noverlap'], 
                      nfft=params['welch_nfft'], 
                      fs=params['welch_fs'])
        if m == 0:
            combined_welch_psd = np.zeros((len(frequencies),))
        combined_welch_psd += eigenvalue_contribution[m] * (px / np.sum(px))

    # Generate resampled peaks to compute significance threshold
    resampled_peaks = np.zeros((params['nperm'], len(frequencies)))
    for i in range(params['nperm']):
        resampled_data = np.random.randn(*a.shape)
        resampled_peak = np.zeros((len(frequencies),))
        for m in range(params['num_modes']):
            _, px = welch(resampled_data[m, :] - np.mean(resampled_data[m, :]), 
                          nperseg=params['welch_nperseg'], 
                          noverlap=params['welch_noverlap'], 
                          nfft=params['welch_nfft'], 
                          fs=params['welch_fs'])
            resampled_peak += eigenvalue_contribution[m] * (px / np.sum(px))
        resampled_peak /= (np.max(resampled_peak) + 1e-30) # a small epsilon added to avoid division by zero
        resampled_peaks[i, :] = resampled_peak
    # Calculate significance threshold
    combined_welch_significance = np.percentile(resampled_peaks, 100-(params['siglevel']*100), axis=0)
    #---------------------------------------------------------------------------------

    # Find peaks in the combined spectrum and sort them in descending order
    normalized_peak = combined_welch_psd / (np.max(combined_welch_psd) + 1e-30)
    peak_indices, _ = find_peaks(normalized_peak)
    sorted_indices = np.argsort(normalized_peak[peak_indices])[::-1]
    sorted_frequencies = frequencies[peak_indices][sorted_indices]

    # Generate a table of the top N frequencies
    # Cleaner single-line list comprehension
    table_data = [
        [float(np.round(freq, 2)), float(np.round(pwr, 2))] 
        for freq, pwr in zip(sorted_frequencies[:params['num_top_frequencies']], 
                            normalized_peak[peak_indices][sorted_indices][:params['num_top_frequencies']])
    ]

    # The frequencies that the POD is able to capture have now been identified (in the top 'num_top_frequencies' modes). 
    # These frequencies can be fitted to the temporal coefficients of the top 'num_top_frequencies' modes, 
    # allowing for a representation of the data described solely by these "pure" frequencies. 
    # This approach enables the reconstruction of the original data using the identified dominant frequencies,
    # resulting in frequency-filtered spatial POD modes.
    clean_data = np.zeros((inp.shape[0],inp.shape[1],10))

    if params['top_frequencies'] is None:
        top_frequencies = sorted_frequencies[:params['num_top_frequencies']]
    else:    
        top_frequencies = params['top_frequencies']
        params['num_top_frequencies'] = len(params['top_frequencies'])

    for i in range(params['num_top_frequencies']):
        def model_fun(t, amplitude, phase):
            """
            Generate a sinusoidal model function.

            Parameters:
            t (array-like): The time variable.
            amplitude (float): The amplitude of the sinusoidal function.
            phase (float): The phase shift of the sinusoidal function.

            Returns:
            array-like: The computed sinusoidal values at each time point `t`.
            """
            return amplitude * np.sin(2 * np.pi * top_frequencies[i] * t + phase)

        for j in range(params['num_modes_reconstruct']):
            csignal = a[j,:]
            # Initial Guesses for Parameters: [Amplitude, Phase]
            initial_guess = [1, 0]
            # Nonlinear Fit
            fit_params, _ = curve_fit(model_fun, time, csignal, p0=initial_guess)
            # Clean Signal from Fit
            if j == 0:
                clean_signal=np.zeros((params['num_modes_reconstruct'],len(csignal)))
            clean_signal[j,:] = model_fun(time, *fit_params)
        # forming a set of clean data (reconstructed data at the fitted frequencies; frequency filtered reconstructed data)
        clean_data[:,:,i] = p[:,0:params['num_modes_reconstruct']]@np.diag(sorg[0:params['num_modes_reconstruct']])@clean_signal

    frequency_filtered_modes = np.zeros((data.shape[0], data.shape[1], data.shape[2], params['num_top_frequencies']))
    for jj in range(params['num_top_frequencies']):
        for frame_index in range(data.shape[0]):
            frequency_filtered_modes[frame_index, :, :, jj] = np.reshape(clean_data[:, frame_index, jj], (data.shape[1], data.shape[2]))

    if params['spod']:
        SPOD_spatial_modes, SPOD_temporal_coefficients = spod(data, num_modes=params['spod_num_modes'], filter_size=params['spod_filter_size'])
    else:
        SPOD_spatial_modes = None
        SPOD_temporal_coefficients = None

    results = {
        'input_data': input_dat, # Original input data, mean subtracted (Shape: (Nt, Ny, Nx))
        'spatial_mode': spatial_mode,  # POD spatial modes matching the dimensions of the input data (Shape: (Nmodes, Ny, Nx))
        'temporal_coefficient': temporal_coefficient,  # POD temporal coefficients associated with each spatial mode (Shape: (Nmodes, Nt))
        'eigenvalue': eigenvalue,  # Eigenvalues corresponding to singular values squared (Shape: (Nmodes))
        'eigenvalue_contribution': eigenvalue_contribution,  # Eigenvalue contribution of each mode (Shape: (Nmodes))
        'cumulative_eigenvalues': cumulative_eigenvalues,  # Cumulative percentage of eigenvalues for the first 'num_cumulative_modes' modes (Shape: (Ncumulative_modes))
        'combined_welch_psd': combined_welch_psd,  # Combined Welch power spectral density for the temporal coefficients (Shape: (Nf))
        'frequencies': frequencies,  # Frequencies identified in the Welch spectrum (Shape: (Nf))
        'combined_welch_significance': combined_welch_significance,  # Significance threshold of the combined Welch spectrum (Shape: (Nf))
        'reconstructed': reconstructed,  # Reconstructed frame at the specified timestep (or for the entire time series) using the top modes (Shape: (Ny, Nx))
        'sorted_frequencies': sorted_frequencies,  # Frequencies identified in the Welch spectrum (Shape: (Nfrequencies,))
        'frequency_filtered_modes': frequency_filtered_modes,  # Frequency-filtered spatial POD modes (Shape: (Nt, Ny, Nx, Ntop_frequencies))
        'frequency_filtered_modes_frequencies': top_frequencies,  # Frequencies corresponding to the frequency-filtered modes (Shape: (Ntop_frequencies))
        'SPOD_spatial_modes': SPOD_spatial_modes,  # SPOD spatial modes if SPOD is used (Shape: (Nspod_modes, Ny, Nx))
        'SPOD_temporal_coefficients': SPOD_temporal_coefficients,  # SPOD temporal coefficients if SPOD is used (Shape: (Nspod_modes, Nt))
        'p': p,  # Left singular vectors (spatial modes) from SVD (Shape: (Nx, Nmodes))
        's': sorg,  # Singular values from SVD (Shape: (Nmodes,))
        'a': a  # Right singular vectors (temporal coefficients) from SVD (Shape: (Nmodes, Nt))
    }

    if not params['silent']:
        print(f"POD analysis completed.")
        print(f"Top {params['num_top_frequencies']} frequencies and normalized power values:\n{table_data}")
        print(f"Total variance contribution of the first {params['num_modes']} modes: {np.sum(100 * eigenvalue[:params['num_modes']] / np.sum(eigenvalue)):.2f}%")

    if params['print_results']:
        print_pod_results(results)

    return results

WaLSA_cross_spectra.py

This module implements cross-correlation analysis techniques, resulting in cross-spectrum, coherence, and phase relationships, for investigating correlations between two time series.

WaLSA_cross_spectra.py
# -----------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
# http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, in press.
# -----------------------------------------------------------------------------------------------------

import numpy as np # type: ignore
from scipy.signal import coherence, csd # type: ignore
from .WaLSA_speclizer import WaLSA_speclizer # type: ignore
from .WaLSA_wavelet import xwt, wct # type: ignore
from .WaLSA_detrend_apod import WaLSA_detrend_apod # type: ignore
from .WaLSA_wavelet_confidence import WaLSA_wavelet_confidence # type: ignore

# -------------------------------------- Main Function ----------------------------------------
def WaLSA_cross_spectra(signal=None, time=None, method=None, **kwargs):
    """
    Compute the cross-spectra between two time series.

    Parameters:
        data1 (array): The first time series (1D).
        data2 (array): The first time series (1D).
        time (array): The time array of the signals.
        method (str): The method to use for the analysis. Options are 'welch' and 'wavelet'.
    """
    if method == 'welch':
        return getcross_spectrum_Welch(signal, time=time, **kwargs)
    elif method == 'wavelet':
        return getcross_spectrum_Wavelet(signal, time=time, **kwargs)
    else:
        raise ValueError(f"Unknown method: {method}")


# -------------------------------------- Welch ----------------------------------------
def getcross_spectrum_Welch(signal, time, **kwargs):
    """
    Calculate cross-spectral relationships of two time series whose amplitudes and powers
    are computed using the WaLSA_speclizer routine.
    The cross-spectrum is complex valued, thus its magnitude is returned as the
    co-spectrum. The phase lags between the two time series are are estimated 
    from the imaginary and real arguments of the complex cross spectrum.
    The coherence is calculated from the normalized square of the amplitude 
    of the complex cross-spectrum

    Parameters:
        data1 (array): The first 1D time series signal.
        data2 (array): The second 1D time series signal.
        time (array): The time array corresponding to the signals.
        nperseg (int, optional): Length of each segment for analysis. Default: 256.
        noverlap (int, optional): Number of points to overlap between segments. Default: 128.
        window (str, optional): Type of window function used in the Welch method. Default: 'hann'.
        siglevel (float, optional): Significance level for confidence intervals. Default: 0.95.
        nperm (int, optional): Number of permutations for significance testing. Default: 1000.
        silent (bool, optional): If True, suppress print statements. Default: False.
        **kwargs: Additional parameters for the analysis method.

    Returns:
        cospectrum, frequencies, phase_angle, coherence, signif_cross, signif_coh, d1_power, d2_power
    """
    # Define default values for the optional parameters
    defaults = {
        'data1': None,        # First data array
        'data1': None,        # Second data array
        'nperseg': 256,      # Number of points per segments to average
        'apod': 0.1,       # Apodization function
        'nodetrendapod': None,  # No detrending ot apodization applied
        'pxdetrend': 2,  # Detrend parameter
        'meandetrend': None,  # Detrend parameter
        'polyfit': None,    # Detrend parameter
        'meantemporal': None,  # Detrend parameter
        'recon': None      # Detrend parameter
    }

    # Update defaults with any user-provided keyword arguments
    params = {**defaults, **kwargs}

    data1 = params['data1']
    data2 = params['data2']

    dummy = signal+2

    tdiff = np.diff(time)
    cadence = np.median(tdiff)

    # Power spectrum for data1
    power_data1, frequencies, _ = WaLSA_speclizer(signal=data1, time=time, method='welch', 
                                               amplitude=True, nosignificance=True, silent=kwargs.pop('silent', True), **kwargs)

    # Power spectrum for data2
    power_data2, _, _ = WaLSA_speclizer(signal=data2, time=time, method='welch', 
                                               amplitude=True, nosignificance=True, silent=kwargs.pop('silent', False), **kwargs)

    # Calculate cross-spectrum
    _, crosspower = csd(data1, data2, fs=1.0/cadence, window='hann', nperseg=500)

    cospectrum = np.abs(crosspower)

    # Calculate phase lag
    phase_angle = np.angle(crosspower, deg=True)

    # Calculate coherence
    freq_coh, coh = coherence(data1, data2, 1.0/cadence, nperseg=500)

    return frequencies, cospectrum, phase_angle, power_data1, power_data2, freq_coh, coh


# -------------------------------------- Wavelet ----------------------------------------
def getcross_spectrum_Wavelet(signal, time, **kwargs):
    """
    Calculate the cross-power spectrum, coherence spectrum, and phase angles 
    between two signals using Wavelet Transform.

    Parameters:
        data1 (array): The first 1D time series signal.
        data2 (array): The second 1D time series signal.
        time (array): The time array corresponding to the signals.
        siglevel (float): Significance level for the confidence intervals. Default: 0.95.
        nperm (int): Number of permutations for significance testing. Default: 1000.
        mother (str): The mother wavelet function to use. Default: 'morlet'.
        GWS (bool): If True, calculate the Global Wavelet Spectrum. Default: False.
        RGWS (bool): If True, calculate the Refined Global Wavelet Spectrum (time-integrated power, excluding COI and insignificant areas). Default: False.
        dj (float): Scale spacing. Smaller values result in better scale resolution but slower calculations. Default: 0.025.
        s0 (float): Initial (smallest) scale of the wavelet. Default: 2 * dt.
        J (int): Number of scales minus one. Scales range from s0 up to s0 * 2**(J * dj), giving a total of (J + 1) scales. Default: (log2(N * dt / s0)) / dj.
        lag1 (float): Lag-1 autocorrelation. Default: 0.0.
        apod (float): Extent of apodization edges (of a Tukey window). Default: 0.1.
        pxdetrend (int): Subtract linear trend with time per pixel. Options: 1 (simple) or 2 (advanced). Default: 2.
        polyfit (int): Degree of polynomial fit for detrending the data. If set, a polynomial fit (instead of linear) is applied. Default: None.
        meantemporal (bool): If True, apply simple temporal detrending by subtracting the mean signal from the data, skipping fitting procedures. Default: False.
        meandetrend (bool): If True, subtract the linear trend with time for the image means (spatial detrending). Default: False.
        recon (bool): If True, perform Fourier reconstruction of the input time series. This does not preserve amplitudes but is useful for examining frequencies far from the low-frequency range.  Default: False.
        resample_original (bool): If True, and if recon set True, approximate values close to the original are returned for comparison. Default: False.
        nodetrendapod (bool): If True, neither detrending nor apodization is performed. Default: False.
        silent (bool): If True, suppress print statements. Default: False.

    Returns:
    cross_power : np.ndarray
        Cross power spectrum between `data1` and `data2`.

    cross_periods : np.ndarray
        Periods corresponding to the cross-power spectrum.

    cross_sig : np.ndarray
        Significance of the cross-power spectrum.

    cross_coi : np.ndarray
        Cone of influence for the cross-power spectrum.

    coherence : np.ndarray
        Coherence spectrum between `data1` and `data2`.

    coh_periods : np.ndarray
        Periods corresponding to the coherence spectrum.

    coh_sig : np.ndarray
        Significance of the coherence spectrum.

    corr_coi : np.ndarray
        Cone of influence for the coherence spectrum.

    phase_angle : np.ndarray
        2D array containing x- and y-components of the phase direction 
        arrows for each frequency and time point.

    Notes
    -----
    - Cross-power, coherence, and phase angles are calculated using 
      **cross-wavelet transform (XWT)** and **wavelet coherence transform (WCT)**.
    - Arrows for phase direction are computed such that:
        - Arrows pointing downwards indicate anti-phase.
        - Arrows pointing upwards indicate in-phase.
        - Arrows pointing right indicate `data1` leading `data2`.
        - Arrows pointing left indicate `data2` leading `data1`.

    Examples
    --------
    >>> cross_power, cross_periods, cross_sig, cross_coi, coherence, coh_periods, coh_sig, corr_coi, phase_angle, dt = \
    >>> getcross_spectrum_Wavelet(signal, cadence, data1=signal1, data2=signal2)
    """

    # Define default values for optional parameters
    defaults = {
        'data1': None,               # First data array
        'data2': None,                # Second data array
        'mother': 'morlet',          # Type of mother wavelet
        'siglevel': 0.95,  # Significance level for cross-spectral analysis
        'dj': 0.025,                 # Spacing between scales
        's0': -1,  # Initial scale
        'J': -1,  # Number of scales
        'cache': False,               # Cache results to avoid recomputation
        'apod': 0.1,
        'silent': False,
        'pxdetrend': 2,
        'meandetrend': False,
        'polyfit': None,
        'meantemporal': False,
        'recon': False,
        'resample_original': False,
        'nperm': 1000               # Number of permutations for significance testing
    }

    # Update defaults with any user-provided keyword arguments
    params = {**defaults, **kwargs}

    data1 = params['data1']
    data2 = params['data2']
    dummy = signal+2

    tdiff = np.diff(time)
    cadence = np.median(tdiff)

    data1orig = data1.copy()
    data2orig = data2.copy()

    data1 = WaLSA_detrend_apod(
        data1, 
        apod=params['apod'], 
        meandetrend=params['meandetrend'], 
        pxdetrend=params['pxdetrend'], 
        polyfit=params['polyfit'], 
        meantemporal=params['meantemporal'], 
        recon=params['recon'], 
        cadence=cadence, 
        resample_original=params['resample_original'], 
        silent=True
    )

    data2 = WaLSA_detrend_apod(
        data2, 
        apod=params['apod'], 
        meandetrend=params['meandetrend'], 
        pxdetrend=params['pxdetrend'], 
        polyfit=params['polyfit'], 
        meantemporal=params['meantemporal'], 
        recon=params['recon'], 
        cadence=cadence, 
        resample_original=params['resample_original'], 
        silent=True
    )

    # Calculate signal properties
    nt = len(params['data1'])  # Number of time points

    # Determine the initial scale s0 if not provided
    if params['s0'] == -1:
        params['s0'] = 2 * cadence

    # Determine the number of scales J if not provided
    if params['J'] == -1:
        params['J'] = int((np.log(float(nt) * cadence / params['s0']) / np.log(2)) / params['dj'])

    # ----------- CROSS-WAVELET TRANSFORM (XWT) -----------
    W12, cross_coi, freq, signif = xwt(
        data1, data2, cadence, dj=params['dj'], s0=params['s0'], J=params['J'], 
        no_default_signif=True, wavelet=params['mother'], normalize=False
    )
    print('Wavelet cross-power spectrum calculated.')
    # Calculate the cross-power spectrum
    cross_power = np.abs(W12) ** 2
    cross_periods = 1 / freq  # Periods corresponding to the frequency axis

    #----------------------------------------------------------------------
    # Calculate significance levels using Monte Carlo randomization method
    #----------------------------------------------------------------------
    nxx, nyy = cross_power.shape
    cross_perm = np.zeros((nxx, nyy, params['nperm']))
    print('\nCalculating wavelet cross-power significance:')
    total_iterations = params['nperm']  # Total number of (x, y) combinations
    iteration_count = 0  # To keep track of progress
    for ip in range(params['nperm']):
        iteration_count += 1
        progress = (iteration_count / total_iterations) * 100
        print(f"\rProgress: {progress:.2f}%", end="")

        y_perm1 = np.random.permutation(data1orig)  # Permuting the original signal
        apocube1 = WaLSA_detrend_apod(
            y_perm1, 
            apod=params['apod'], 
            meandetrend=params['meandetrend'], 
            pxdetrend=params['pxdetrend'], 
            polyfit=params['polyfit'], 
            meantemporal=params['meantemporal'], 
            recon=params['recon'], 
            cadence=cadence, 
            resample_original=params['resample_original'], 
            silent=True
        )
        y_perm2 = np.random.permutation(data2orig)  # Permuting the original signal
        apocube2 = WaLSA_detrend_apod(
            y_perm2, 
            apod=params['apod'], 
            meandetrend=params['meandetrend'], 
            pxdetrend=params['pxdetrend'], 
            polyfit=params['polyfit'], 
            meantemporal=params['meantemporal'], 
            recon=params['recon'], 
            cadence=cadence, 
            resample_original=params['resample_original'], 
            silent=True
        )
        W12s, _, _, _ = xwt(
            apocube1, apocube2, cadence, dj=params['dj'], s0=params['s0'], J=params['J'], 
            no_default_signif=True, wavelet=params['mother'], normalize=False
        )
        cross_power_sig = np.abs(W12s) ** 2
        cross_perm[:, :, ip] = cross_power_sig

    signifn = WaLSA_wavelet_confidence(cross_perm, siglevel=params['siglevel'])

    cross_sig = cross_power / signifn

    # ----------- WAVELET COHERENCE TRANSFORM (WCT) -----------
    WCT, aWCT, coh_coi, freq, sig = wct(
        data1, data2, cadence, dj=params['dj'], s0=params['s0'], J=params['J'], 
        sig=False, wavelet=params['mother'], normalize=False, cache=params['cache']
    )
    print('Wavelet coherence calculated.')
    # Calculate the coherence spectrum
    coh_periods = 1 / freq  # Periods corresponding to the frequency axis
    coherence = WCT

    #----------------------------------------------------------------------
    # Calculate significance levels using Monte Carlo randomization method
    #----------------------------------------------------------------------
    nxx, nyy = coherence.shape
    coh_perm = np.zeros((nxx, nyy, params['nperm']))
    print('\nCalculating wavelet coherence significance:')
    total_iterations = params['nperm']  # Total number of permutations
    iteration_count = 0  # To keep track of progress
    for ip in range(params['nperm']):
        iteration_count += 1
        progress = (iteration_count / total_iterations) * 100
        print(f"\rProgress: {progress:.2f}%", end="")

        y_perm1 = np.random.permutation(data1orig)  # Permuting the original signal
        apocube1 = WaLSA_detrend_apod(
            y_perm1, 
            apod=params['apod'], 
            meandetrend=params['meandetrend'], 
            pxdetrend=params['pxdetrend'], 
            polyfit=params['polyfit'], 
            meantemporal=params['meantemporal'], 
            recon=params['recon'], 
            cadence=cadence, 
            resample_original=params['resample_original'], 
            silent=True
        )
        y_perm2 = np.random.permutation(data2orig)  # Permuting the original signal
        apocube2 = WaLSA_detrend_apod(
            y_perm2, 
            apod=params['apod'], 
            meandetrend=params['meandetrend'], 
            pxdetrend=params['pxdetrend'], 
            polyfit=params['polyfit'], 
            meantemporal=params['meantemporal'], 
            recon=params['recon'], 
            cadence=cadence, 
            resample_original=params['resample_original'], 
            silent=True
        )
        WCTs, _, _, _, _ = wct(
            apocube1, apocube2, cadence, dj=params['dj'], s0=params['s0'], J=params['J'], 
            sig=False, wavelet=params['mother'], normalize=False, cache=params['cache']
        )
        coh_perm[:, :, ip] = WCTs

    sig_coh = WaLSA_wavelet_confidence(coh_perm, siglevel=params['siglevel'])

    coh_sig = coherence / sig_coh  # Ratio > 1 means coherence is significant

    # --------------- PHASE ANGLES ---------------
    phase_angle = aWCT

    return (
        cross_power, cross_periods, cross_sig, cross_coi, 
        coherence, coh_periods, coh_sig, coh_coi, 
        phase_angle
    )

WaLSA_detrend_apod.py

This module provides functions for detrending and apodizing time series data to mitigate trends and edge effects.

WaLSA_detrend_apod.py
# -----------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
# http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, in press.
# -----------------------------------------------------------------------------------------------------

import numpy as np # type: ignore
from scipy.optimize import curve_fit # type: ignore
from .WaLSA_wavelet import cwt, significance # type: ignore

# -------------------------------------- Wavelet ----------------------------------------
def getWavelet(signal, time, **kwargs):
    """
    Perform wavelet analysis using the pycwt package.

    Parameters:
        signal (array): The input signal (1D).
        time (array): The time array corresponding to the signal.
        siglevel (float): Significance level for the confidence intervals. Default: 0.95.
        nperm (int): Number of permutations for significance testing. Default: 1000.
        mother (str): The mother wavelet function to use. Default: 'morlet'.
        GWS (bool): If True, calculate the Global Wavelet Spectrum. Default: False.
        RGWS (bool): If True, calculate the Refined Global Wavelet Spectrum (time-integrated power, excluding COI and insignificant areas). Default: False.
        dj (float): Scale spacing. Smaller values result in better scale resolution but slower calculations. Default: 0.025.
        s0 (float): Initial (smallest) scale of the wavelet. Default: 2 * dt.
        J (int): Number of scales minus one. Scales range from s0 up to s0 * 2**(J * dj), giving a total of (J + 1) scales. Default: (log2(N * dt / s0)) / dj.
        lag1 (float): Lag-1 autocorrelation. Default: 0.0.
        apod (float): Extent of apodization edges (of a Tukey window). Default: 0.1.
        pxdetrend (int): Subtract linear trend with time per pixel. Options: 1 (simple) or 2 (advanced). Default: 2.
        polyfit (int): Degree of polynomial fit for detrending the data. If set, a polynomial fit (instead of linear) is applied. Default: None.
        meantemporal (bool): If True, apply simple temporal detrending by subtracting the mean signal from the data, skipping fitting procedures. Default: False.
        meandetrend (bool): If True, subtract the linear trend with time for the image means (spatial detrending). Default: False.
        recon (bool): If True, perform Fourier reconstruction of the input time series. This does not preserve amplitudes but is useful for examining frequencies far from the low-frequency range.  Default: False.
        resample_original (bool): If True, and if recon set True, approximate values close to the original are returned for comparison. Default: False.
        nodetrendapod (bool): If True, neither detrending nor apodization is performed. Default: False.
        silent (bool): If True, suppress print statements. Default: False.
        **kwargs: Additional parameters for the analysis method.

    Returns:
        power: The wavelet power spectrum.
        periods: Corresponding periods.
        sig_slevel: The significance levels.
        coi: The cone of influence.
        Optionally, if global_power=True:
        global_power: Global wavelet power spectrum.
        global_conf: Confidence levels for the global wavelet spectrum.
        Optionally, if RGWS=True:
        rgws_periods: Periods for the refined global wavelet spectrum.
        rgws_power: Refined global wavelet power spectrum.
    """
    # Define default values for the optional parameters similar to IDL
    defaults = {
        'siglevel': 0.95,
        'mother': 'morlet',  # Morlet wavelet as the mother function
        'dj': 1/32. ,  # Scale spacing
        's0': -1,  # Initial scale
        'J': -1,  # Number of scales
        'lag1': 0.0,  # Lag-1 autocorrelation
        'silent': False,
        'nperm': 1000,   # Number of permutations for significance calculation
    }

    # Update defaults with any user-provided keyword arguments
    params = {**defaults, **kwargs}

    tdiff = np.diff(time)
    cadence = np.median(tdiff)

    n = len(signal)
    dt = cadence

    # Standardize the signal before the wavelet transform
    std_signal = signal.std()
    norm_signal = signal / std_signal

    # Determine the initial scale s0 if not provided
    if params['s0'] == -1:
        params['s0'] = 2 * dt

    # Determine the number of scales J if not provided
    if params['J'] == -1:
        params['J'] = int((np.log(float(n) * dt / params['s0']) / np.log(2)) / params['dj'])

    # Perform wavelet transform
    W, scales, frequencies, coi, _, _ = cwt(
        norm_signal,
        dt,
        dj=params['dj'],
        s0=params['s0'],
        J=params['J'],
        wavelet=params['mother']
    )

    power = np.abs(W) ** 2  # Wavelet power spectrum
    periods = 1 / frequencies  # Convert frequencies to periods

    return power, periods, coi, scales, W

# -----------------------------------------------------------------------------------------------------

# Linear detrending function for curve fitting
def linear(x, a, b):
    return a + b * x

# Custom Tukey window implementation
def custom_tukey(nt, apod=0.1):
    apodrim = int(apod * nt)
    apodt = np.ones(nt)  # Initialize with ones

    # Apply sine-squared taper to the first 'apodrim' points
    taper = (np.sin(np.pi / 2. * np.arange(apodrim) / apodrim)) ** 2

    # Apply taper symmetrically at both ends
    apodt[:apodrim] = taper
    apodt[-apodrim:] = taper[::-1]  # Reverse taper for the last points

    return apodt

# Main detrending and apodization function
# apod=0: The Tukey window becomes a rectangular window (no tapering).
# apod=1: The Tukey window becomes a Hann window (fully tapered with a cosine function).
def WaLSA_detrend_apod(cube, apod=0.1, meandetrend=False, pxdetrend=2, polyfit=None, meantemporal=False,
                       recon=False, cadence=None, resample_original=False, min_resample=None, 
                       max_resample=None, silent=False, dj=32, lo_cutoff=None, hi_cutoff=None, upper=False):

    nt = len(cube)  # Assume input is a 1D signal
    cube = cube - np.mean(cube) # Remove the mean of the input signal
    apocube = np.copy(cube)  # Create a copy of the input signal
    t = np.arange(nt)  # Time array

    # Apply Tukey window (apodization)
    if apod > 0:
        tukey_window = custom_tukey(nt, apod)
        apocube = apocube * tukey_window  # Apodize the signal

    # Mean detrend (optional)
    if meandetrend:
        avg_signal = np.mean(apocube)
        time = np.arange(nt)
        mean_fit_params, _ = curve_fit(linear, time, avg_signal)
        mean_trend = linear(time, *mean_fit_params)
        apocube -= mean_trend

    # Wavelet-based Fourier reconstruction (optional)
    if recon and cadence:
        apocube = WaLSA_wave_recon(apocube, cadence, dj=dj, lo_cutoff=lo_cutoff, hi_cutoff=hi_cutoff, upper=upper)

    # Pixel-based detrending (temporal detrend)
    if pxdetrend > 0:
        mean_val = np.mean(apocube)
        if meantemporal:
            # Simple temporal detrend by subtracting the mean
            apocube -= mean_val
        else:
            # Advanced detrend (linear or polynomial fit)
            if polyfit is not None:
                poly_coeffs = np.polyfit(t, apocube, polyfit)
                trend = np.polyval(poly_coeffs, t)
            else:
                popt, _ = curve_fit(linear, t, apocube, p0=[mean_val, 0])
                trend = linear(t, *popt)
            apocube -= trend

    # Resampling to preserve amplitudes (optional)
    if resample_original:
        if min_resample is None:
            min_resample = np.min(apocube)
        if max_resample is None:
            max_resample = np.max(apocube)
        apocube = np.interp(apocube, (np.min(apocube), np.max(apocube)), (min_resample, max_resample))

    if not silent:
        print("Detrending and apodization complete.")

    return apocube

# Wavelet-based reconstruction function (optional)
def WaLSA_wave_recon(ts, delt, dj=32, lo_cutoff=None, hi_cutoff=None, upper=False):
    """
    Reconstructs the wavelet-filtered time series based on given frequency cutoffs.
    """

    # Define duration based on the time series length
    dur = (len(ts) - 1) * delt

    # Assign default values if lo_cutoff or hi_cutoff is None
    if lo_cutoff is None:
        lo_cutoff = 0.0  # Default to 0 as in IDL

    if hi_cutoff is None:
        hi_cutoff = dur / (3.0 * np.sqrt(2)) 

    mother = 'morlet'
    num_points = len(ts)
    time_array = np.linspace(0, (num_points - 1) * delt, num_points)

    _, period, coi, scales, wave= getWavelet(signal=ts, time=time_array, method='wavelet', siglevel=0.99, apod=0.1, mother=mother)

    # Ensure good_per_idx and bad_per_idx are properly assigned
    if upper:
        good_per_idx = np.where(period > hi_cutoff)[0][0]
        bad_per_idx = len(period)
    else:
        good_per_idx = np.where(period > lo_cutoff)[0][0]
        bad_per_idx = np.where(period > hi_cutoff)[0][0]

    # set the power inside the CoI equal to zero 
    # (i.e., exclude points inside the CoI -- subject to edge effect)
    iampl = np.zeros((len(ts), len(period)), dtype=float)
    for i in range(len(ts)):
        pcol = np.real(wave[:, i])  # Extract real part of wavelet transform for this time index
        ii = np.where(period < coi[i])[0]  # Find indices where period is less than COI at time i
        if ii.size > 0:
            iampl[i, ii] = pcol[ii]  # Assign values where condition is met

    # Initialize reconstructed signal array
    recon_sum = np.zeros(len(ts))
    # Summation over valid period indices
    for i in range(good_per_idx, bad_per_idx):
        recon_sum += iampl[:, i] / np.sqrt(scales[i])

    # Apply normalization factor
    recon_all = dj * np.sqrt(delt) * recon_sum / (0.766 * (np.pi ** -0.25))

    return recon_all

walsa_confidence.py

This module implements statistical significance testing for the spectral analysis results using various methods.

walsa_confidence.py

walsa_wavelet_confidence.py

This module implements statistical significance testing for the wavelet analysis results.

walsa_wavelet_confidence.py

walsa_io.py

This module provides functions for input/output operations, such as saving images as PDF (in both RGB and CMYK formats) and image contrast enhancements.

walsa_io.py

WaLSA_interactive.py

This module implements the interactive interface of WaLSAtools, guiding users through the analysis process.

WaLSA_interactive.py
# -----------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
# http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, in press.
# -----------------------------------------------------------------------------------------------------

import ipywidgets as widgets # type: ignore
from IPython.display import display, clear_output, HTML # type: ignore
import os # type: ignore

# Function to detect if it's running in a notebook or terminal
def is_notebook():
    try:
        # Check if IPython is in the environment and if we're using a Jupyter notebook
        from IPython import get_ipython # type: ignore
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True  # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # IPython terminal
        else:
            return False  # Other types
    except (NameError, ImportError):
        # NameError: get_ipython is not defined
        # ImportError: IPython is not installed
        return False  # Standard Python interpreter or shell

# Function to print logo and credits for both environments
def print_logo_and_credits():
    logo_terminal = r"""
        __          __          _          _____            
        \ \        / /         | |        / ____|     /\    
         \ \  /\  / /  ▄▄▄▄▄   | |       | (___      /  \   
          \ \/  \/ /   ▀▀▀▀██  | |        \___ \    / /\ \  
           \  /\  /   ▄██▀▀██  | |____    ____) |  / ____ \ 
            \/  \/    ▀██▄▄██  |______|  |_____/  /_/    \_\
    """
    credits_terminal = """
        © WaLSA Team (www.WaLSA.team)
        -----------------------------------------------------------------------
        WaLSAtools v1.0 - Wave analysis tools
        Documentation: www.WaLSA.tools
        GitHub repository: www.github.com/WaLSAteam/WaLSAtools
        -----------------------------------------------------------------------
        If you use WaLSAtools in your research, please cite:
        Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025,
        Nature Reviews Methods Primers, in press
        -----------------------------------------------------------------------
        Choose a category, data type, and analysis method from the list below,
        to get hints on the calling sequence and parameters:
        """

    credits_notebook = """
        <div style="margin-left: 30px; margin-top: 20px; font-size: 1.1em; line-height: 0.8;">
            <p>© WaLSA Team (<a href="https://www.WaLSA.team" target="_blank">www.WaLSA.team</a>)</p>
            <hr style="width: 70%; margin: 0; border: 0.98px solid #888; margin-bottom: 10px;">
            <p><strong>WaLSAtools</strong> v1.0 - Wave analysis tools</p>
            <p>Documentation: <a href="https://www.WaLSA.tools" target="_blank">www.WaLSA.tools</a></p>
            <p>GitHub repository: <a href="https://www.github.com/WaLSAteam/WaLSAtools" target="_blank">www.github.com/WaLSAteam/WaLSAtools</a></p>
            <hr style="width: 70%; margin: 0; border: 0.98px solid #888; margin-bottom: 10px;">
            <p>If you use <strong>WaLSAtools</strong> in your research, please cite:</p>
            <p>Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, <em>Nature Reviews Methods Primers</em>, in press</p>
            <hr style="width: 70%; margin: 0; border: 0.98px solid #888; margin-bottom: 15px;">
            <p>Choose a category, data type, and analysis method from the list below,</p>
            <p>to get hints on the calling sequence and parameters:</p>
        </div>
    """

    if is_notebook():
        try:
            # For scripts
            current_dir = os.path.dirname(os.path.abspath(__file__))
        except NameError:
            # For Jupyter notebooks
            current_dir = os.getcwd()
        img_path = os.path.join(current_dir, '..', 'assets', 'WaLSAtools_black.png')
        # display(HTML(f'<img src="{img_path}" style="margin-left: 40px; margin-top: 20px; width:300px; height: auto;">')) # not shwon in Jupyter notebook, only in MS Code
        import base64
        # Convert the image to Base64
        with open(img_path, "rb") as img_file:
            encoded_img = base64.b64encode(img_file.read()).decode('utf-8')
        # Embed the Base64 image in the HTML
        html_code_logo = f"""
        <div style="margin-left: 30px; margin-top: 20px;">
            <img src="data:image/png;base64,{encoded_img}" style="width: 300px; height: auto;">
        </div>
        """
        display(HTML(html_code_logo))
        display(HTML(credits_notebook))
    else:
        print(logo_terminal)
        print(credits_terminal)

from .parameter_definitions import display_parameters_text, single_series_parameters, cross_correlation_parameters

# Terminal-based interactive function
def walsatools_terminal():
    """Main interactive function for terminal version of WaLSAtools."""
    print_logo_and_credits()

    # Step 1: Select Category
    while True:
        print("\n    Category:")
        print("    (a) Single time series analysis")
        print("    (b) Cross-correlation between two time series")
        category = input("    --- Select a category (a/b): ").strip().lower()

        if category not in ['a', 'b']:
            print("    Invalid selection. Please enter either 'a' or 'b'.\n")
            continue

        # Step 2: Data Type
        if category == 'a':
            while True:
                print("\n    Data Type:")
                print("    (1) 1D signal")
                print("    (2) 3D datacube")
                method = input("    --- Select a data type (1/2): ").strip()

                if method not in ['1', '2']:
                    print("    Invalid selection. Please enter either '1' or '2'.\n")
                    continue

                # Step 3: Analysis Method
                if method == '1':  # 1D Signal
                    while True:
                        print("\n    Analysis Method:")
                        print("    (1) FFT")
                        print("    (2) Wavelet")
                        print("    (3) Lomb-Scargle")
                        print("    (4) Welch")
                        print("    (5) EMD")
                        analysis_type = input("    --- Select an analysis method (1-5): ").strip()

                        method_map_1d = {
                            '1': 'fft',
                            '2': 'wavelet',
                            '3': 'lombscargle',
                            '4': 'welch',
                            '5': 'emd'
                        }

                        selected_method = method_map_1d.get(analysis_type, 'unknown')

                        if selected_method == 'unknown':
                            print("    Invalid selection. Please select a valid analysis method (1-5).\n")
                            continue

                        # Generate and display the calling sequence for 1D signal
                        print("\n    Calling sequence:\n")
                        return_values = single_series_parameters[selected_method]['return_values']
                        command = f"    >>> {return_values} = WaLSAtools(signal=INPUT_DATA, time=TIME_ARRAY, method='{selected_method}', **kwargs)"
                        print(command)

                        # Display parameter hints
                        display_parameters_text(selected_method, category='a')
                        return  # Exit the function after successful output

                elif method == '2':  # 3D Datacube
                    while True:
                        print("\n    Analysis Method:")
                        print("    (1) k-omega")
                        print("    (2) POD")
                        print("    (3) Dominant Freq / Mean Power Spectrum")
                        analysis_type = input("    --- Select an analysis method (1-3): ").strip()

                        if analysis_type == '3':  # Sub-method required
                            while True:
                                print("\n    Analysis method for Dominant Freq / Mean Power Spectrum:")
                                print("    (1) FFT")
                                print("    (2) Wavelet")
                                print("    (3) Lomb-Scargle")
                                print("    (4) Welch")
                                sub_method = input("    --- Select a method (1-4): ").strip()

                                sub_method_map = {'1': 'fft', '2': 'wavelet', '3': 'lombscargle', '4': 'welch'}
                                selected_method = sub_method_map.get(sub_method, 'unknown')

                                if selected_method == 'unknown':
                                    print("    Invalid selection. Please select a valid sub-method (1-4).\n")
                                    continue

                                # Generate and display the calling sequence for sub-method
                                print("\n    Calling sequence:\n")
                                return_values = 'dominant_frequency, mean_power, frequency, power_map'
                                command = f"    >>> {return_values} = WaLSAtools(data=INPUT_DATA, time=TIME_ARRAY, averagedpower=True, dominantfreq=True, method='{selected_method}', **kwargs)"
                                print(command)

                                # Display parameter hints
                                display_parameters_text(selected_method, category='a')
                                return  # Exit the function after successful output

                        method_map_3d = {
                            '1': 'k-omega',
                            '2': 'pod'
                        }

                        selected_method = method_map_3d.get(analysis_type, 'unknown')

                        if selected_method == 'unknown':
                            print("    Invalid selection. Please select a valid analysis method (1-3).\n")
                            continue

                        # Generate and display the calling sequence for k-omega/POD
                        print("\n    Calling sequence:\n")
                        return_values = single_series_parameters[selected_method]['return_values']
                        command = f"    >>> {return_values} = WaLSAtools(data1=INPUT_DATA, time=TIME_ARRAY, method='{selected_method}', **kwargs)"
                        print(command)

                        # Display parameter hints
                        display_parameters_text(selected_method, category='a')
                        return  # Exit the function after successful output

        elif category == 'b':  # Cross-correlation
            while True:
                print("\n    Data Type:")
                print("    (1) 1D signal")
                method = input("    --- Select a data type (1): ").strip()

                if method != '1':
                    print("    Invalid selection. Please enter '1'.\n")
                    continue

                while True:
                    print("\n    Analysis Method:")
                    print("    (1) Wavelet")
                    print("    (2) Welch")
                    analysis_type = input("    --- Select an analysis method (1-2): ").strip()

                    cross_correlation_map = {
                        '1': 'wavelet',
                        '2': 'welch'
                    }

                    selected_method = cross_correlation_map.get(analysis_type, 'unknown')

                    if selected_method == 'unknown':
                        print("    Invalid selection. Please select a valid analysis method (1-2).\n")
                        continue

                    # Generate and display the calling sequence for cross-correlation
                    print("\n    Calling sequence:\n")
                    return_values = cross_correlation_parameters[selected_method]['return_values']
                    command = f">>> {return_values} = WaLSAtools(data1=INPUT_DATA1, data2=INPUT_DATA2, time=TIME_ARRAY, method='{selected_method}', **kwargs)"
                    print(command)

                    # Display parameter hints
                    display_parameters_text(selected_method, category='b')
                    return  # Exit the function after successful output


# Jupyter-based interactive function
import ipywidgets as widgets # type: ignore
from IPython.display import display, clear_output, HTML # type: ignore
from .parameter_definitions import display_parameters_html, single_series_parameters, cross_correlation_parameters # type: ignore

# Global flag to prevent multiple observers
is_observer_attached = False

def walsatools_jupyter():
    """Main interactive function for Jupyter Notebook version of WaLSAtools."""
    global is_observer_attached, category, method, analysis_type, sub_method  # Declare global variables for reuse

    # Detach any existing observers and reset the flag
    try:
        detach_observers()
    except NameError:
        pass  # `detach_observers` hasn't been defined yet

    is_observer_attached = False  # Reset observer flag

    # Clear any previous output
    clear_output(wait=True)

    print_logo_and_credits()

    # Recreate widgets to reset state
    category = widgets.Dropdown(
        options=['Select Category', 'Single time series analysis', 'Cross-correlation between two time series'],
        value='Select Category',
        description='Category:'
    )
    method = widgets.Dropdown(
        options=['Select Data Type'],
        value='Select Data Type',
        description='Data Type:'
    )
    analysis_type = widgets.Dropdown(
        options=['Select Method'],
        value='Select Method',
        description='Method:'
    )
    sub_method = widgets.Dropdown(
        options=['Select Sub-method', 'FFT', 'Wavelet', 'Lomb-Scargle', 'Welch'],
        value='Select Sub-method',
        description='Sub-method:',
        layout=widgets.Layout(display='none')  # Initially hidden
    )

    # Persistent output widget
    output = widgets.Output()

    def clear_output_if_unselected(change=None):
        """Clear the output widget if any dropdown menu is unselected."""
        with output:
            if (
                category.value == 'Select Category'
                or method.value == 'Select Data Type'
                or analysis_type.value == 'Select Method'
                or (
                    analysis_type.value == 'Dominant Freq / Mean Power Spectrum'
                    and sub_method.layout.display == 'block'
                    and sub_method.value == 'Select Sub-method'
                )
            ):
                clear_output(wait=True)
                warn = '<div style="font-size: 1.1em; margin-left: 30px; margin-top:15px; margin-bottom: 15px;">Please select appropriate options from all dropdown menus.</div>'
                display(HTML(warn))

    def update_method_options(change=None):
        """Update available Method and Sub-method options."""
        clear_output_if_unselected()  # Ensure the output clears if any dropdown is unselected.
        sub_method.layout.display = 'none'
        sub_method.value = 'Select Sub-method'
        sub_method.options = ['Select Sub-method']

        if category.value == 'Single time series analysis':
            method.options = ['Select Data Type', '1D signal', '3D datacube']
            if method.value == '1D signal':
                analysis_type.options = ['Select Method', 'FFT', 'Wavelet', 'Lomb-Scargle', 'Welch', 'EMD']
            elif method.value == '3D datacube':
                analysis_type.options = ['Select Method', 'k-omega', 'POD', 'Dominant Freq / Mean Power Spectrum']
            else:
                analysis_type.options = ['Select Method']
        elif category.value == 'Cross-correlation between two time series':
            method.options = ['Select Data Type', '1D signal']
            if method.value == '1D signal':
                analysis_type.options = ['Select Method', 'Wavelet', 'Welch']
            else:
                analysis_type.options = ['Select Method']
        else:
            method.options = ['Select Data Type']
            analysis_type.options = ['Select Method']

    def update_sub_method_visibility(change=None):
        """Show or hide the Sub-method dropdown based on conditions."""
        clear_output_if_unselected()  # Ensure the output clears if any dropdown is unselected.
        if (
            category.value == 'Single time series analysis'
            and method.value == '3D datacube'
            and analysis_type.value == 'Dominant Freq / Mean Power Spectrum'
        ):
            sub_method.options = ['Select Sub-method', 'FFT', 'Wavelet', 'Lomb-Scargle', 'Welch']
            sub_method.layout.display = 'block'
        else:
            sub_method.options = ['Select Sub-method']
            sub_method.value = 'Select Sub-method'
            sub_method.layout.display = 'none'

    def update_calling_sequence(change=None):
        """Update the function calling sequence based on user's selection."""
        clear_output_if_unselected()  # Ensure the output clears if any dropdown is unselected.
        if (
            category.value == 'Select Category'
            or method.value == 'Select Data Type'
            or analysis_type.value == 'Select Method'
            or (
                analysis_type.value == 'Dominant Freq / Mean Power Spectrum'
                and sub_method.layout.display == 'block'
                and sub_method.value == 'Select Sub-method'
            )
        ):
            return  # Do nothing until all required fields are properly selected

        with output:
            clear_output(wait=True)

            # Handle Dominant Frequency / Mean Power Spectrum with Sub-method
            if (
                category.value == 'Single time series analysis'
                and method.value == '3D datacube'
                and analysis_type.value == 'Dominant Freq / Mean Power Spectrum'
            ):
                if sub_method.value == 'Select Sub-method':
                    print("Please select a Sub-method.")
                    return

                sub_method_map = {'FFT': 'fft', 'Wavelet': 'wavelet', 'Lomb-Scargle': 'lombscargle', 'Welch': 'welch'}
                selected_method = sub_method_map.get(sub_method.value, 'unknown')
                return_values = 'dominant_frequency, mean_power, frequency, power_map'
                command = f"{return_values} = WaLSAtools(signal=INPUT_DATA, time=TIME_ARRAY, averagedpower=True, dominantfreq=True, method='{selected_method}', **kwargs)"

            # Handle k-omega and POD
            elif (
                category.value == 'Single time series analysis'
                and method.value == '3D datacube'
                and analysis_type.value in ['k-omega', 'POD']
            ):
                method_map = {'k-omega': 'k-omega', 'POD': 'pod'}
                selected_method = method_map.get(analysis_type.value, 'unknown')
                parameter_definitions = single_series_parameters
                return_values = parameter_definitions.get(selected_method, {}).get('return_values', '')
                command = f"{return_values} = WaLSAtools(signal=INPUT_DATA, time=TIME_ARRAY, method='{selected_method}', **kwargs)"

            # Handle Cross-correlation
            elif category.value == 'Cross-correlation between two time series':
                cross_correlation_map = {'Wavelet': 'wavelet', 'Welch': 'welch'}
                selected_method = cross_correlation_map.get(analysis_type.value, 'unknown')
                parameter_definitions = cross_correlation_parameters
                return_values = parameter_definitions.get(selected_method, {}).get('return_values', '')
                command = f"{return_values} = WaLSAtools(data1=INPUT_DATA1, data2=INPUT_DATA2, time=TIME_ARRAY, method='{selected_method}', **kwargs)"

            # Handle Single 1D signal analysis
            elif category.value == 'Single time series analysis' and method.value == '1D signal':
                method_map = {'FFT': 'fft', 'Wavelet': 'wavelet', 'Lomb-Scargle': 'lombscargle', 'Welch': 'welch', 'EMD': 'emd'}
                selected_method = method_map.get(analysis_type.value, 'unknown')
                parameter_definitions = single_series_parameters
                return_values = parameter_definitions.get(selected_method, {}).get('return_values', '')
                command = f"{return_values} = WaLSAtools(signal=INPUT_DATA, time=TIME_ARRAY, method='{selected_method}', **kwargs)"

            else:
                print("Invalid configuration.")
                return

            # Generate and display the command in HTML
            html_code = f"""
            <div style="font-size: 1.2em; margin-left: 30px; margin-top:15px; margin-bottom: 15px;">Calling sequence:</div>
            <div style="display: flex; margin-left: 30px; margin-bottom: 3ch;">
                <span style="color: #222; min-width: 4ch;">>>> </span>
                <pre style="
                    white-space: pre-wrap; 
                    word-wrap: break-word;  
                    color: #01016D; 
                    margin: 0;
                ">{command}</pre>
            </div>
            """
            display(HTML(html_code))
            display_parameters_html(selected_method, category=category.value)

    def attach_observers():
        """Attach observers to the dropdown widgets and ensure no duplicates."""
        global is_observer_attached
        if not is_observer_attached:
            detach_observers()
            category.observe(clear_output_if_unselected, names='value')
            method.observe(clear_output_if_unselected, names='value')
            analysis_type.observe(clear_output_if_unselected, names='value')
            category.observe(update_method_options, names='value')
            method.observe(update_method_options, names='value')
            analysis_type.observe(update_sub_method_visibility, names='value')
            sub_method.observe(update_calling_sequence, names='value')
            analysis_type.observe(update_calling_sequence, names='value')
            is_observer_attached = True

    def detach_observers():
        """Detach all observers to prevent multiple triggers."""
        try:
            category.unobserve(update_method_options, names='value')
            method.unobserve(update_method_options, names='value')
            analysis_type.unobserve(update_sub_method_visibility, names='value')
            sub_method.unobserve(update_calling_sequence, names='value')
            analysis_type.unobserve(update_calling_sequence, names='value')
        except ValueError:
            pass

    attach_observers()
    display(category, method, analysis_type, sub_method, output)


# Unified interactive function for both terminal and Jupyter environments
def interactive():
    if is_notebook():
        walsatools_jupyter()
    else:
        walsatools_terminal()

walsa_plot_k_omega.py

This module provides functions for plotting k-ω diagrams and filtered datacubes.

walsa_plot_k_omega.py

walsa_plot_wavelet_spectrum.py

This module provides functions for plotting wavelet power spectra and related visualizations.

walsa_plot_wavelet_spectrum.py