Skip to content

Under the Hood

We strongly recommend everyone to follow the procedure as instructed here when using WaLSAtools — a user-friendly tool — which gives you all information you need to do your analysis. However, for experts who want to make themselves familiar with the techniques and codes under the hood, inspect them and modify/develop/improve them, some of the main codes are also provided below. Please note that all codes and their dependencies are available in the GitHub repository.

Analysis Modules

WaLSAtools is built upon a collection of analysis modules, each designed for a specific aspect of wave analysis. These modules are combined and accessed through the main WaLSAtools interface, providing a streamlined and user-friendly experience.

Here's a brief overview of the core analysis modules:

WaLSA_speclizer.py

This module provides a collection of spectral analysis techniques, including FFT, Lomb-Scargle, Wavelet, Welch, and EMD/HHT.

WaLSA_speclizer.py
# -----------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
# http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
# -----------------------------------------------------------------------------------------------------

import numpy as np # type: ignore
from astropy.timeseries import LombScargle # type: ignore
from tqdm import tqdm # type: ignore
from .WaLSA_detrend_apod import WaLSA_detrend_apod # type: ignore
from .WaLSA_confidence import WaLSA_confidence # type: ignore
from .WaLSA_wavelet import cwt, significance # type: ignore
from scipy.signal import welch # type: ignore

# -------------------------------------- Main Function ----------------------------------------
def WaLSA_speclizer(signal=None, time=None, method=None, 
                    dominantfreq=False, averagedpower=False, **kwargs):
    """
    Main function to prepare data and call the specific spectral analysis method.

    Parameters:
        signal (array): The input signal (1D or 3D).
        time (array): The time array of the signal.
        method (str): The method to use for spectral analysis ('fft', 'lombscargle', etc.)
        **kwargs: Additional parameters for data preparation and analysis methods.

    Returns:
        Power spectrum, frequencies, and significance (if applicable).
    """
    if not dominantfreq or not averagedpower:
        if method == 'fft':
            return getpowerFFT(signal, time=time, **kwargs)
        elif method == 'lombscargle':
            return getpowerLS(signal, time=time, **kwargs)
        elif method == 'wavelet':
            return getpowerWavelet(signal, time=time, **kwargs)
        elif method == 'welch':
            return welch_psd(signal, time=time, **kwargs)
        elif method == 'emd':
            return getEMD_HHT(signal, time=time, **kwargs)
        else:
            raise ValueError(f"Unknown method: {method}")
    else:
        return get_dominant_averaged(signal, time=time, method=method, **kwargs)

# -------------------------------------- FFT ----------------------------------------
def getpowerFFT(signal, time, **kwargs):
    """
    Perform FFT analysis on the given signal.

    Parameters:
        signal (array): The input signal (1D).
        time (array): The time array corresponding to the signal.
        siglevel (float): Significance level for the confidence intervals. Default: 0.95.
        nperm (int): Number of permutations for significance testing. Default: 1000.
        nosignificance (bool): If True, skip significance calculation. Default: False.
        apod (float): Extent of apodization edges (of a Tukey window). Default: 0.1.
        pxdetrend (int): Subtract linear trend with time per pixel. Options: 1 (simple) or 2 (advanced). Default: 2.
        polyfit (int): Degree of polynomial fit for detrending the data. If set, a polynomial fit (instead of linear) is applied. Default: None.
        meantemporal (bool): If True, apply simple temporal detrending by subtracting the mean signal from the data, skipping fitting procedures. Default: False.
        meandetrend (bool): If True, subtract the linear trend with time for the image means (spatial detrending). Default: False.
        recon (bool): If True, perform Fourier reconstruction of the input time series. This does not preserve amplitudes but is useful for examining frequencies far from the low-frequency range.  Default: False.
        resample_original (bool): If True, and if recon set True, approximate values close to the original are returned for comparison. Default: False.
        nodetrendapod (bool): If True, neither detrending nor apodization is performed. Default: False.
        amplitude (bool): If True, return the amplitudes of the Fourier transform. Default: False.
        silent (bool): If True, suppress print statements. Default: False.
        **kwargs: Additional parameters for the analysis method.

    Returns:
        Power spectrum, frequencies, significance, amplitudes
    """
    # Define default values for the optional parameters
    defaults = {
        'siglevel': 0.95,
        'significance': None,
        'nperm': 1000,
        'nosignificance': False,
        'apod': 0.1,
        'pxdetrend': 2,
        'meandetrend': False,
        'polyfit': None,
        'meantemporal': False,
        'recon': False,
        'resample_original': False,
        'nodetrendapod': False,
        'amplitude': False,
        'silent': False
    }

    # Update defaults with any user-provided keyword arguments
    params = {**defaults, **kwargs}

    params['siglevel'] = 1 - params['siglevel'] # different convention

    tdiff = np.diff(time)
    cadence = np.median(tdiff)

    # Perform detrending and apodization
    if not params['nodetrendapod']:
        apocube = WaLSA_detrend_apod(
            signal, 
            apod=params['apod'], 
            meandetrend=params['meandetrend'], 
            pxdetrend=params['pxdetrend'], 
            polyfit=params['polyfit'], 
            meantemporal=params['meantemporal'], 
            recon=params['recon'], 
            cadence=cadence, 
            resample_original=params['resample_original'], 
            silent=params['silent']
        )
    else:
        apocube = signal

    nt = len(apocube)  # Length of the time series (1D)

    # Calculate the frequencies
    frequencies = 1. / (cadence * 2) * np.arange(nt // 2 + 1) / (nt // 2)
    frequencies = frequencies[1:]
    powermap = np.zeros(len(frequencies))
    signal = apocube
    spec = np.fft.fft(signal)
    power = 2 * np.abs(spec[1:len(frequencies) + 1]) ** 2
    powermap[:] = power / frequencies[0]

    if params['amplitude']:
        amplitudes = np.zeros((len(signal), len(frequencies)), dtype=np.complex_)
        amplitudes = spec[1:len(frequencies) + 1]
    else: 
        amplitudes = None

    # Calculate significance if requested
    if not params['nosignificance']:
        ps_perm = np.zeros((len(frequencies), params['nperm']))
        for ip in range(params['nperm']):
            perm_signal = np.random.permutation(signal)  # Permuting the original signal
            apocube = WaLSA_detrend_apod(
                perm_signal, 
                apod=params['apod'], 
                meandetrend=params['meandetrend'], 
                pxdetrend=params['pxdetrend'], 
                polyfit=params['polyfit'], 
                meantemporal=params['meantemporal'], 
                recon=params['recon'], 
                cadence=cadence, 
                resample_original=params['resample_original'], 
                silent=True
            )
            perm_spec = np.fft.fft(perm_signal)
            perm_power = 2 * np.abs(perm_spec[1:len(frequencies) + 1]) ** 2
            ps_perm[:, ip] = perm_power

        # Correct call to WaLSA_confidence
        significance = WaLSA_confidence(ps_perm, siglevel=params['siglevel'], nf=len(frequencies))
        significance = significance / frequencies[0]
    else:
        significance = None

    if not params['silent']:
        print("FFT processed.")

    return powermap, frequencies, significance, amplitudes

# -------------------------------------- Lomb-Scargle ----------------------------------------
def getpowerLS(signal, time, **kwargs):
    """
    Perform Lomb-Scargle analysis on the given signal.

    Parameters:
        signal (array): The input signal (1D).
        time (array): The time array corresponding to the signal.
        siglevel (float): Significance level for the confidence intervals. Default: 0.95.
        nperm (int): Number of permutations for significance testing. Default: 1000.
        dy (array): Errors or observational uncertainties associated with the time series.
        fit_mean (bool): If True, include a constant offset as part of the model at each frequency. This improves accuracy, especially for incomplete phase coverage.
        center_data (bool): If True, pre-center the data by subtracting the weighted mean of the input data. This is especially important if fit_mean=False.
        nterms (int): Number of terms to use in the Fourier fit. Default: 1.
        normalization (str): The normalization method for the periodogram. Options: 'standard', 'model', 'log', 'psd'. Default: 'standard'.
        nosignificance (bool): If True, skip significance calculation. Default: False.
        apod (float): Extent of apodization edges (of a Tukey window). Default: 0.1.
        pxdetrend (int): Subtract linear trend with time per pixel. Options: 1 (simple) or 2 (advanced). Default: 2.
        polyfit (int): Degree of polynomial fit for detrending the data. If set, a polynomial fit (instead of linear) is applied. Default: None.
        meantemporal (bool): If True, apply simple temporal detrending by subtracting the mean signal from the data, skipping fitting procedures. Default: False.
        meandetrend (bool): If True, subtract the linear trend with time for the image means (spatial detrending). Default: False.
        recon (bool): If True, perform Fourier reconstruction of the input time series. This does not preserve amplitudes but is useful for examining frequencies far from the low-frequency range.  Default: False.
        resample_original (bool): If True, and if recon set True, approximate values close to the original are returned for comparison. Default: False.
        nodetrendapod (bool): If True, neither detrending nor apodization is performed. Default: False..
        silent (bool): If True, suppress print statements. Default: False.
        **kwargs: Additional parameters for the analysis method.

    Returns:
        Power spectrum, frequencies, and significance (if applicable).
    """
    # Define default values for the optional parameters
    defaults = {
        'siglevel': 0.05,
        'nperm': 1000,
        'nosignificance': False,
        'apod': 0.1,
        'pxdetrend': 2,
        'meandetrend': False,
        'polyfit': None,
        'meantemporal': False,
        'recon': False,
        'resample_original': False,
        'silent': False,  # Ensure 'silent' is included in defaults
        'nodetrendapod': False,
        'dy': None, # Error or sequence of observational errors associated with times t
        'fit_mean': False, # If True include a constant offset as part of the model at each frequency. This can lead to more accurate results, especially in the case of incomplete phase coverage
        'center_data': True, # If True pre-center the data by subtracting the weighted mean of the input data. This is especially important if fit_mean = False
        'nterms': 1, # Number of terms to use in the Fourier fit
        'normalization': 'standard' # The normalization to use for the periodogram. Options are 'standard', 'model', 'log', 'psd'
    }

    # Update defaults with any user-provided keyword arguments
    params = {**defaults, **kwargs}

    params['siglevel'] = 1 - params['siglevel'] # different convention

    tdiff = np.diff(time)
    cadence = np.median(tdiff)

    # Perform detrending and apodization
    if not params['nodetrendapod']:
        apocube = WaLSA_detrend_apod(
            signal, 
            apod=params['apod'], 
            meandetrend=params['meandetrend'], 
            pxdetrend=params['pxdetrend'], 
            polyfit=params['polyfit'], 
            meantemporal=params['meantemporal'], 
            recon=params['recon'], 
            cadence=cadence, 
            resample_original=params['resample_original'], 
            silent=params['silent']
        )
    else:
        apocube = signal

    frequencies, power = LombScargle(time, apocube, dy=params['dy'], fit_mean=params['fit_mean'], 
                                     center_data=params['center_data'], nterms=params['nterms'], 
                                     normalization=params['normalization']).autopower()

    # Calculate significance if needed
    if not params['nosignificance']:
        ps_perm = np.zeros((len(frequencies), params['nperm']))
        for ip in range(params['nperm']):
            perm_signal = np.random.permutation(signal)  # Permuting the original signal
            apocubep = WaLSA_detrend_apod(
                perm_signal, 
                apod=params['apod'], 
                meandetrend=params['meandetrend'], 
                pxdetrend=params['pxdetrend'], 
                polyfit=params['polyfit'], 
                meantemporal=params['meantemporal'], 
                recon=params['recon'], 
                cadence=cadence, 
                resample_original=params['resample_original'], 
                silent=True
            )
            frequencies, perm_power = LombScargle(time, apocubep, dy=params['dy'], fit_mean=params['fit_mean'], 
                                     center_data=params['center_data'], nterms=params['nterms'], 
                                     normalization=params['normalization']).autopower()
            ps_perm[:, ip] = perm_power
        significance = WaLSA_confidence(ps_perm, siglevel=params['siglevel'], nf=len(frequencies))
    else:
        significance = None

    if not params['silent']:
        print("Lomb-Scargle processed.")

    return power, frequencies, significance

# -------------------------------------- Wavelet ----------------------------------------
def getpowerWavelet(signal, time, **kwargs):
    """
    Perform wavelet analysis using the pycwt package.

    Parameters:
        signal (array): The input signal (1D).
        time (array): The time array corresponding to the signal.
        siglevel (float): Significance level for the confidence intervals. Default: 0.95.
        nperm (int): Number of permutations for significance testing. Default: 1000.
        mother (str): The mother wavelet function to use. Default: 'morlet'.
        GWS (bool): If True, calculate the Global Wavelet Spectrum. Default: False.
        RGWS (bool): If True, calculate the Refined Global Wavelet Spectrum (time-integrated power, excluding COI and insignificant areas). Default: False.
        dj (float): Scale spacing. Smaller values result in better scale resolution but slower calculations. Default: 0.025.
        s0 (float): Initial (smallest) scale of the wavelet. Default: 2 * dt.
        J (int): Number of scales minus one. Scales range from s0 up to s0 * 2**(J * dj), giving a total of (J + 1) scales. Default: (log2(N * dt / s0)) / dj.
        lag1 (float): Lag-1 autocorrelation. Default: 0.0.
        apod (float): Extent of apodization edges (of a Tukey window). Default: 0.1.
        pxdetrend (int): Subtract linear trend with time per pixel. Options: 1 (simple) or 2 (advanced). Default: 2.
        polyfit (int): Degree of polynomial fit for detrending the data. If set, a polynomial fit (instead of linear) is applied. Default: None.
        meantemporal (bool): If True, apply simple temporal detrending by subtracting the mean signal from the data, skipping fitting procedures. Default: False.
        meandetrend (bool): If True, subtract the linear trend with time for the image means (spatial detrending). Default: False.
        recon (bool): If True, perform Fourier reconstruction of the input time series. This does not preserve amplitudes but is useful for examining frequencies far from the low-frequency range.  Default: False.
        resample_original (bool): If True, and if recon set True, approximate values close to the original are returned for comparison. Default: False.
        nodetrendapod (bool): If True, neither detrending nor apodization is performed. Default: False.
        silent (bool): If True, suppress print statements. Default: False.
        **kwargs: Additional parameters for the analysis method.

    Returns:
        power: The wavelet power spectrum.
        periods: Corresponding periods.
        sig_slevel: The significance levels.
        coi: The cone of influence.
        Optionally, if global_power=True:
        global_power: Global wavelet power spectrum.
        global_conf: Confidence levels for the global wavelet spectrum.
        Optionally, if RGWS=True:
        rgws_periods: Periods for the refined global wavelet spectrum.
        rgws_power: Refined global wavelet power spectrum.
    """
    # Define default values for the optional parameters similar to IDL
    defaults = {
        'siglevel': 0.95,
        'mother': 'morlet',  # Morlet wavelet as the mother function
        'dj': 0.025,  # Scale spacing
        's0': -1,  # Initial scale
        'J': -1,  # Number of scales
        'lag1': 0.0,  # Lag-1 autocorrelation
        'apod': 0.1,  # Tukey window apodization
        'silent': False,
        'pxdetrend': 2,
        'meandetrend': False,
        'polyfit': None,
        'meantemporal': False,
        'recon': False,
        'resample_original': False,
        'GWS': False,  # If True, calculate global wavelet spectrum
        'RGWS': False,  # If True, calculate refined global wavelet spectrum (excluding COI)
        'nperm': 1000,   # Number of permutations for significance calculation
        'nodetrendapod': False
    }

    # Update defaults with any user-provided keyword arguments
    params = {**defaults, **kwargs}

    tdiff = np.diff(time)
    cadence = np.median(tdiff)

    # Perform detrending and apodization
    if not params['nodetrendapod']:
        apocube = WaLSA_detrend_apod(
            signal, 
            apod=params['apod'], 
            meandetrend=params['meandetrend'], 
            pxdetrend=params['pxdetrend'], 
            polyfit=params['polyfit'], 
            meantemporal=params['meantemporal'], 
            recon=params['recon'], 
            cadence=cadence, 
            resample_original=params['resample_original'], 
            silent=params['silent']
        )
    else:
        apocube = signal

    n = len(apocube)
    dt = cadence

    # Standardize the signal before the wavelet transform
    std_signal = apocube.std()
    norm_signal = apocube / std_signal

    # Determine the initial scale s0 if not provided
    if params['s0'] == -1:
        params['s0'] = 2 * dt

    # Determine the number of scales J if not provided
    if params['J'] == -1:
        params['J'] = int((np.log(float(n) * dt / params['s0']) / np.log(2)) / params['dj'])

    # Perform wavelet transform
    W, scales, frequencies, coi, _, _ = cwt(
        norm_signal,
        dt,
        dj=params['dj'],
        s0=params['s0'],
        J=params['J'],
        wavelet=params['mother']
    )

    power = np.abs(W) ** 2  # Wavelet power spectrum
    periods = 1 / frequencies  # Convert frequencies to periods

    # Calculate the significance levels
    signif, _ = significance(
        1.0,  # Normalized variance
        dt,
        scales,
        0,  # Ignored for now, used for background noise (e.g., red noise)
        params['lag1'],
        significance_level=params['siglevel'],
        wavelet=params['mother']
    )

    # Calculate the significance level for each power value
    sig_matrix = np.ones([len(scales), n]) * signif[:, None]
    sig_slevel = power / sig_matrix

    # Calculate global power if requested
    if params['GWS']:
        global_power = np.mean(power, axis=1)  # Average along the time axis

        dof = n - scales   # the -scale corrects for padding at edges
        global_conf, _ = significance(
            norm_signal,                   # time series data
            dt,              # time step between values
            scales,          # scale vector
            1,               # sigtest = 1 for "time-average" test
            0.0,
            significance_level=params['siglevel'],
            dof=dof,
            wavelet=params['mother']
        )
    else:
        global_power = None
        global_conf = None

    # Calculate refined global wavelet spectrum (RGWS) if requested
    if params['RGWS']:
        isig = sig_slevel < 1.0
        power[isig] = np.nan
        ipower = np.full_like(power, np.nan)
        for i in range(len(coi)):
            pcol = power[:, i]
            valid_idx = periods < coi[i]
            ipower[valid_idx, i] = pcol[valid_idx]
        rgws_power = np.nansum(ipower, axis=1)
    else:
        rgws_power = None

    if not params['silent']:
        print("Wavelet (" + params['mother'] + ") processed.")

    return power, periods, sig_slevel, coi, global_power, global_conf, rgws_power

# -------------------------------------- Welch ----------------------------------------
def welch_psd(signal, time, **kwargs):
    """
    Calculate Welch Power Spectral Density (PSD) and significance levels.

    Parameters:
        signal (array): The 1D time series signal.
        time (array): The time array corresponding to the signal.
        nperseg (int, optional): Length of each segment for analysis. Default: 256.
        noverlap (int, optional): Number of points to overlap between segments. Default: 128.
        window (str, optional): Type of window function used in the Welch method. Default: 'hann'.
        siglevel (float, optional): Significance level for confidence intervals. Default: 0.95.
        nperm (int, optional): Number of permutations for significance testing. Default: 1000.
        silent (bool, optional): If True, suppress print statements. Default: False.
        **kwargs: Additional parameters for the analysis method.

    Returns:
        frequencies: Frequencies at which the PSD is estimated.
        psd: Power Spectral Density values.
        significance: Significance levels for the PSD.
    """
    # Define default values for the optional parameters similar to FFT
    defaults = {
        'siglevel': 0.95,
        'nperm': 1000,  # Number of permutations for significance calculation
        'window': 'hann',  # Window type for Welch method
        'nperseg': 256,  # Length of each segment
        'silent': False,
        'noverlap': 128  # Number of points to overlap between segments
    }

    # Update defaults with any user-provided keyword arguments
    params = {**defaults, **kwargs}

    tdiff = np.diff(time)
    cadence = np.median(tdiff)

    # Calculate Welch PSD
    frequencies, psd = welch(signal, fs=1.0 / cadence, window=params['window'], nperseg=params['nperseg'], noverlap=params['noverlap'])

    # Calculate significance levels using permutation
    ps_perm = np.zeros((len(frequencies), params['nperm']))
    for ip in range(params['nperm']):
        perm_signal = np.random.permutation(signal)  # Permuting the original signal
        _, perm_psd = welch(perm_signal, fs=1.0 / cadence, window=params['window'], nperseg=params['nperseg'], noverlap=params['noverlap'])
        ps_perm[:, ip] = perm_psd

    # Calculate significance level for Welch PSD
    significance = np.percentile(ps_perm, params['siglevel'] * 100, axis=1)

    if not params['silent']:
        print("Welch processed.")

    return psd, frequencies, significance

# -------------------------------------- EMD / EEMD ----------------------------------------
import numpy as np # type: ignore
from .PyEMD import EMD, EEMD # type: ignore
from scipy.stats import norm # type: ignore
from scipy.signal import hilbert, welch # type: ignore
from scipy.signal import find_peaks # type: ignore
from scipy.fft import fft, fftfreq # type: ignore

# Function to apply EMD and return Intrinsic Mode Functions (IMFs)
def apply_emd(signal, time):
    emd = EMD()
    emd.FIXE_H = 5
    imfs = emd.emd(signal, time) # max_imfs=7, emd.FIXE_H = 10
    return imfs

# Function to apply EEMD and return Intrinsic Mode Functions (IMFs)
def apply_eemd(signal, time, noise_std=0.2, num_realizations=1000):
    eemd = EEMD()
    eemd.FIXE_H = 5
    eemd.noise_seed(12345)
    eemd.noise_width = noise_std
    eemd.ensemble_size = num_realizations
    imfs = eemd.eemd(signal, time)
    return imfs

# Function to generate white noise IMFs for significance testing
def generate_white_noise_imfs(signal_length, time, num_realizations, use_eemd=None, noise_std=0.2):
    white_noise_imfs = []
    if use_eemd:
        eemd = EEMD()
        eemd.FIXE_H = 5
        eemd.noise_seed(12345)
        eemd.noise_width = noise_std
        eemd.ensemble_size = num_realizations
        for _ in range(num_realizations):
            white_noise = np.random.normal(size=signal_length)
            imfs = eemd.eemd(white_noise, time)
            white_noise_imfs.append(imfs)
    else:
        emd = EMD()
        emd.FIXE_H = 5
        for _ in range(num_realizations):
            white_noise = np.random.normal(size=signal_length)
            imfs = emd.emd(white_noise, time)
            white_noise_imfs.append(imfs)
    return white_noise_imfs

# Function to calculate the energy of each IMF
def calculate_imf_energy(imfs):
    return [np.sum(imf**2) for imf in imfs]

# Function to determine the significance of the IMFs
def test_imf_significance(imfs, white_noise_imfs):
    imf_energies = calculate_imf_energy(imfs)
    num_imfs = len(imfs)

    white_noise_energies = [calculate_imf_energy(imf_set) for imf_set in white_noise_imfs]

    significance_levels = []
    for i in range(num_imfs):
        # Collect the i-th IMF energy from each white noise realization
        white_noise_energy_dist = [imf_energies[i] for imf_energies in white_noise_energies if i < len(imf_energies)]
        mean_energy = np.mean(white_noise_energy_dist)
        std_energy = np.std(white_noise_energy_dist)

        z_score = (imf_energies[i] - mean_energy) / std_energy
        significance_level = 1 - norm.cdf(z_score)
        significance_levels.append(significance_level)

    return significance_levels

# Function to compute instantaneous frequency of IMFs using Hilbert transform
def compute_instantaneous_frequency(imfs, time):
    instantaneous_frequencies = []
    for imf in imfs:
        analytic_signal = hilbert(imf)
        instantaneous_phase = np.unwrap(np.angle(analytic_signal))
        instantaneous_frequency = np.diff(instantaneous_phase) / (2.0 * np.pi * np.diff(time))
        instantaneous_frequency = np.abs(instantaneous_frequency)  # Ensure frequencies are positive
        instantaneous_frequencies.append(instantaneous_frequency)
    return instantaneous_frequencies

# Function to compute HHT power spectrum
def compute_hht_power_spectrum(imfs, instantaneous_frequencies, freq_bins):
    power_spectrum = np.zeros_like(freq_bins)

    for i in range(len(imfs)):
        amplitude = np.abs(hilbert(imfs[i]))
        power = amplitude**2
        for j in range(len(instantaneous_frequencies[i])):
            freq_idx = np.searchsorted(freq_bins, instantaneous_frequencies[i][j])
            if freq_idx < len(freq_bins):
                power_spectrum[freq_idx] += power[j]

    return power_spectrum

# Function to compute significance level for HHT power spectrum
def compute_significance_level(white_noise_imfs, freq_bins, time):
    all_power_spectra = []
    for imfs in white_noise_imfs:
        instantaneous_frequencies = compute_instantaneous_frequency(imfs, time)
        power_spectrum = compute_hht_power_spectrum(imfs, instantaneous_frequencies, freq_bins)
        all_power_spectra.append(power_spectrum)

    # Compute the 95th percentile power spectrum as the significance level
    significance_level = np.percentile(all_power_spectra, 95, axis=0)

    return significance_level

## Custom rounding function
def custom_round(freq):
    if freq < 1:
        return round(freq, 1)
    else:
        return round(freq)

def smooth_power_spectrum(power_spectrum, window_size=5):
    return np.convolve(power_spectrum, np.ones(window_size)/window_size, mode='same')

# Function to identify and print significant peaks
def identify_significant_peaks(freq_bins, power_spectrum, significance_level):
    peaks, _ = find_peaks(power_spectrum, height=significance_level)
    significant_frequencies = freq_bins[peaks]
    rounded_frequencies = [custom_round(freq) for freq in significant_frequencies]
    rounded_frequencies_two = [round(freq, 1) for freq in significant_frequencies]
    print("Significant Frequencies (Hz):", rounded_frequencies_two)
    print("Significant Frequencies (Hz):", rounded_frequencies)
    return significant_frequencies

# Function to generate randomized signals
def generate_randomized_signals(signal, num_realizations):
    randomized_signals = []
    for _ in range(num_realizations):
        randomized_signal = np.random.permutation(signal)
        randomized_signals.append(randomized_signal)
    return randomized_signals

# Function to calculate the FFT power spectrum for each IMF
def calculate_fft_psd_spectra(imfs, time):
    psd_spectra = []
    for imf in imfs:
        N = len(imf)
        T = time[1] - time[0]  # Assuming uniform sampling
        yf = fft(imf)
        xf = fftfreq(N, T)[:N // 2]
        psd = (2.0 / N) * (np.abs(yf[:N // 2]) ** 2) / (N * T)
        psd_spectra.append((xf, psd))
    return psd_spectra

# Function to calculate the FFT power spectrum for randomized signals
def calculate_fft_psd_spectra_randomized(imfs, time, num_realizations):
    all_psd_spectra = []
    for imf in imfs:
        randomized_signals = generate_randomized_signals(imf, num_realizations)
        for signal in randomized_signals:
            N = len(signal)
            T = time[1] - time[0]  # Assuming uniform sampling
            yf = fft(signal)
            psd = (2.0 / N) * (np.abs(yf[:N // 2]) ** 2) / (N * T)
            all_psd_spectra.append(psd)
    return np.array(all_psd_spectra)

# Function to calculate the 95th percentile confidence level
def calculate_confidence_level(imfs, time, num_realizations):
    confidence_levels = []
    for imf in imfs:
        all_psd_spectra = calculate_fft_psd_spectra_randomized([imf], time, num_realizations)
        confidence_level = np.percentile(all_psd_spectra, 95, axis=0)
        confidence_levels.append(confidence_level)
    return confidence_levels

# Function to calculate the Welch PSD for each IMF
def calculate_welch_psd_spectra(imfs, fs):
    psd_spectra = []
    for imf in imfs:
        f, psd = welch(imf, fs=fs)
        psd_spectra.append((f, psd))
    return psd_spectra

# Function to calculate the Welch PSD for randomized signals
def calculate_welch_psd_spectra_randomized(imfs, fs, num_realizations):
    all_psd_spectra = []
    for imf in imfs:
        randomized_signals = generate_randomized_signals(imf, num_realizations)
        for signal in randomized_signals:
            f, psd = welch(signal, fs=fs)
            all_psd_spectra.append(psd)
    return np.array(all_psd_spectra)

# Function to calculate the 95th percentile confidence level
def calculate_confidence_level_welch(imfs, fs, num_realizations):
    confidence_levels = []
    for imf in imfs:
        all_psd_spectra = calculate_welch_psd_spectra_randomized([imf], fs, num_realizations)
        confidence_level = np.percentile(all_psd_spectra, 95, axis=0)
        confidence_levels.append(confidence_level)
    return confidence_levels

# ************** Main EMD/EEMD routine **************
def getEMD_HHT(signal, time, **kwargs):
    """
    Calculate EMD/EEMD and HHT

    Parameters:
        signal (array): The input signal (1D).
        time (array): The time array of the signal.
        siglevel (float): Significance level for the confidence intervals. Default: 0.95.
        nperm (int): Number of permutations for significance testing. Default: 1000.
        EEMD (bool): If True, use Ensemble Empirical Mode Decomposition (EEMD) instead of Empirical Mode Decomposition (EMD). Default: False.
        Welch_psd (bool): If True, calculate Welch PSD spectra instead of FFT PSD spectra (for the psd_spectra and psd_confidence_levels). Default: False.
        apod (float): Extent of apodization edges (of a Tukey window). Default: 0.1.
        pxdetrend (int): Subtract linear trend with time per pixel. Options: 1 (simple) or 2 (advanced). Default: 2.
        polyfit (int): Degree of polynomial fit for detrending the data. If set, a polynomial fit (instead of linear) is applied. Default: None.
        meantemporal (bool): If True, apply simple temporal detrending by subtracting the mean signal from the data, skipping fitting procedures. Default: False.
        meandetrend (bool): If True, subtract the linear trend with time for the image means (spatial detrending). Default: False.
        recon (bool): If True, perform Fourier reconstruction of the input time series. This does not preserve amplitudes but is useful for examining frequencies far from the low-frequency range. Default: False.
        resample_original (bool): If True, and if recon is set to True, approximate values close to the original are returned for comparison. Default: False.
        nodetrendapod (bool): If True, neither detrending nor apodization is performed. Default: False.
        silent (bool): If True, suppress print statements. Default: False.
        **kwargs: Additional parameters for the analysis method.

    Returns:
        frequencies: Frequencies at which the PSD is estimated.
        psd: Power Spectral Density values.
        significance: Significance levels for the PSD.
        ‘marginal’ spectrum of Hilbert-Huang Transform (HHT)
    """
    # Define default values for the optional parameters similar to FFT
    defaults = {
        'siglevel': 0.95,
        'nperm': 1000,  # Number of permutations for significance calculation
        'apod': 0.1,  # Tukey window apodization
        'silent': False,
        'pxdetrend': 2,
        'meandetrend': False,
        'polyfit': None,
        'meantemporal': False,
        'recon': False,
        'resample_original': False,
        'EEMD': False,  # If True, calculate Ensemble Empirical Mode Decomposition (EEMD)
        'nodetrendapod': False,
        'Welch_psd': False,  # If True, calculate Welch PSD spectra
        'significant_imfs': False  # If True, return only significant IMFs (and for associated calculations)
    }

    # Update defaults with any user-provided keyword arguments
    params = {**defaults, **kwargs}

    tdiff = np.diff(time)
    cadence = np.median(tdiff)

    # Perform detrending and apodization
    if not params['nodetrendapod']:
        apocube = WaLSA_detrend_apod(
            signal, 
            apod=params['apod'], 
            meandetrend=params['meandetrend'], 
            pxdetrend=params['pxdetrend'], 
            polyfit=params['polyfit'], 
            meantemporal=params['meantemporal'], 
            recon=params['recon'], 
            cadence=cadence, 
            resample_original=params['resample_original'], 
            silent=params['silent']
        )

    n = len(apocube)
    dt = cadence
    fs = 1 / dt  # Sampling frequency

    if params['EEMD']:
        imfs = apply_eemd(signal, time)
        use_eemd = True
    else:
        imfs = apply_emd(signal, time)
        use_eemd = False

    white_noise_imfs = generate_white_noise_imfs(len(signal), time, params['nperm'], use_eemd=use_eemd)

    IMF_significance_levels = test_imf_significance(imfs, white_noise_imfs)

    # Determine significant IMFs
    significant_imfs_val = [imf for imf, sig_level in zip(imfs, IMF_significance_levels) if sig_level < params['siglevel']]
    if params['significant_imfs']:
        imfs = significant_imfs_val

    # Compute instantaneous frequencies of IMFs
    instantaneous_frequencies = compute_instantaneous_frequency(imfs, time)

    if params['Welch_psd']:
        # Calculate and plot Welch PSD spectra of (significant) IMFs
        psd_spectra = calculate_welch_psd_spectra(imfs, fs)
        # Calculate 95th percentile confidence levels for Welch PSD spectra
        psd_confidence_levels = calculate_confidence_level_welch(imfs, fs, params['nperm'])
    else:
        # Calculate and plot FFT PSD spectra of (significant) IMFs
        psd_spectra = calculate_fft_psd_spectra(imfs, time)
        # Calculate 95th percentile confidence levels for FFT PSD spectra
        psd_confidence_levels = calculate_confidence_level(imfs, time, params['nperm'])

    # Define frequency bins for HHT power spectrum
    max_freq = max([freq.max() for freq in instantaneous_frequencies])
    HHT_freq_bins = np.linspace(0, max_freq, 100)

    # Compute HHT power spectrum of (significant) IMFs
    HHT_power_spectrum = compute_hht_power_spectrum(imfs, instantaneous_frequencies, HHT_freq_bins)
    # Compute significance level for HHT power spectrum
    HHT_significance_level = compute_significance_level(white_noise_imfs, HHT_freq_bins, time)    

    if not params['silent']:
        if params['EEMD']:
            print("EEMD processed.")
        else:
            print("EMD processed.")

    return HHT_power_spectrum, HHT_significance_level, HHT_freq_bins, psd_spectra, psd_confidence_levels, imfs, IMF_significance_levels, instantaneous_frequencies


# ----------------------- Dominant Frequency & Averaged Power -------------------------------
def get_dominant_averaged(cube, time, **kwargs):
    """
    Analyze a 3D data cube to compute the dominant frequency and averaged power.

    Parameters:
        cube (3D array): Input data cube, expected in either 'txy' or 'xyt' format.
        time (array): Time array of the data cube.
        method (str): Analysis method ('fft' or 'wavelet').
        format (str): Format of the data cube ('txy' or 'xyt'). Default is 'txy'.
        **kwargs: Additional parameters specific to the analysis method.

    Returns:
        dominantfreq (float): Dominant frequency of the data cube.
        averagedpower (float): Averaged power of the data cube.
    """
    defaults = {
        'method': 'fft',
        'format': 'txy',
        'silent': False,
        'GWS': False,  # If True, calculate global wavelet spectrum
        'RGWS': True,  # If True, calculate refined global wavelet spectrum
    }

    # Update defaults with any user-provided keyword arguments
    params = {**defaults, **kwargs}

    # Check and adjust format if necessary
    if params['format'] == 'xyt':
        cube = np.transpose(cube, (2, 0, 1))  # Convert 'xyt' to 'txy'
    elif params['format'] != 'txy':
        raise ValueError("Unsupported format. Choose 'txy' or 'xyt'.")

    # Initialize arrays for storing results across spatial coordinates
    nt, nx, ny = cube.shape
    dominantfreq = np.zeros((nx, ny))

    if params['method'] == 'fft':
        method_name = 'FFT'
    elif params['method'] == 'lombscargle':
        method_name = 'Lomb-Scargle'
    elif params['method'] == 'wavelet':
        method_name = 'Wavelet'
    elif params['method'] == 'welch':
        method_name = 'Welch'

    if not params['silent']:
            print(f"Processing {method_name} for a 3D cube with format '{params['format']}' and shape {cube.shape}.")
            print(f"Calculating Dominant frequencies and/or averaged power spectrum ({method_name}) ....")

    # Iterate over spatial coordinates and apply the chosen analysis method
    if params['method'] == 'fft':
        for x in tqdm(range(nx), desc="Processing x", leave=True):
            for y in range(ny):
                signal = cube[:, x, y]
                fft_power, fft_freqs, _, _ = getpowerFFT(signal, time, silent=True, nosignificance=True, **kwargs)
                if x == 0 and y == 0:
                    powermap = np.zeros((nx, ny, len(fft_freqs)))
                powermap[x, y, :] = fft_power
                # Determine the dominant frequency for this pixel
                dominantfreq[x, y] = fft_freqs[np.argmax(fft_power)]

        # Calculate the averaged power over all pixels
        averagedpower = np.mean(powermap, axis=(0, 1))
        frequencies = fft_freqs
        print("\nAnalysis completed.")

    elif params['method'] == 'lombscargle':
        for x in tqdm(range(nx), desc="Processing x", leave=True):
            for y in range(ny):
                signal = cube[:, x, y]
                ls_power, ls_freqs, _ = getpowerLS(signal, time, silent=True, **kwargs)
                if x == 0 and y == 0:
                    powermap = np.zeros((nx, ny, len(ls_freqs)))
                powermap[x, y, :] = ls_power
                # Determine the dominant frequency for this pixel
                dominantfreq[x, y] = ls_freqs[np.argmax(ls_power)]

        # Calculate the averaged power over all pixels
        averagedpower = np.mean(powermap, axis=(0, 1))
        frequencies = ls_freqs
        print("\nAnalysis completed.")

    elif params['method'] == 'wavelet':
        for x in tqdm(range(nx), desc="Processing x", leave=True):
            for y in range(ny):
                signal = cube[:, x, y]
                if params['GWS']:
                    _, wavelet_periods, _, _, wavelet_power, _, _ = getpowerWavelet(signal, time, silent=True, **kwargs)
                elif params['RGWS']:
                    _, wavelet_periods, _, _, _, _, wavelet_power = getpowerWavelet(signal, time, silent=True, **kwargs)
                if x == 0 and y == 0:
                    powermap = np.zeros((nx, ny, len(wavelet_periods)))
                wavelet_freq = 1. / wavelet_periods
                powermap[x, y, :] = wavelet_power
                # Determine the dominant frequency for this pixel
                dominantfreq[x, y] = wavelet_freq[np.argmax(wavelet_power)]

        # Calculate the averaged power over all pixels
        averagedpower = np.mean(powermap, axis=(0, 1))
        frequencies = wavelet_freq
        print("\nAnalysis completed.")

    elif params['method'] == 'welch':
        for x in tqdm(range(nx), desc="Processing x", leave=True):
            for y in range(ny):
                signal = cube[:, x, y]
                welch_power, welch_freqs, _ = welch_psd(signal, time, silent=True, **kwargs)
                if x == 0 and y == 0:
                    powermap = np.zeros((nx, ny, len(welch_freqs)))
                powermap[x, y, :] = welch_power
                # Determine the dominant frequency for this pixel
                dominantfreq[x, y] = welch_freqs[np.argmax(welch_power)]

        # Calculate the averaged power over all pixels
        averagedpower = np.mean(powermap, axis=(0, 1))
        frequencies = welch_freqs
        print("\nAnalysis completed.")

    else:
        raise ValueError("Unsupported method. Choose 'fft', 'lombscargle', 'wavelet', or 'welch'.")

    return dominantfreq, averagedpower, frequencies, powermap

WaLSA_wavelet.py

This module implements the Wavelet Transform and related functionalities.

WaLSA_wavelet.py
# -----------------------------------------------------------------------------------------------------
# WaLSAtools: - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
# http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
# -----------------------------------------------------------------------------------------------------

# The following code is a modified version of the PyCWT package.
# The original code and licence:
# PyCWT is released under the open source 3-Clause BSD license:
#
# Copyright (c) 2023 Sebastian Krieger, Nabil Freij, and contributors. All rights
# reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
#    may be used to endorse or promote products derived from this software
#    without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#----------------------------------------------------------------------------------
# Modifications by Shahin Jafarzadeh, 2024
#----------------------------------------------------------------------------------

from __future__ import (absolute_import, division, print_function,
                        unicode_literals)

import numpy as np # type: ignore
from tqdm import tqdm # type: ignore
from scipy.stats import chi2 # type: ignore

# Try to import the Python wrapper for FFTW.
try:
    import pyfftw.interfaces.scipy_fftpack as fft # type: ignore
    from multiprocessing import cpu_count

    # Fast planning, use all available threads.
    _FFTW_KWARGS_DEFAULT = {'planner_effort': 'FFTW_ESTIMATE',
                            'threads': cpu_count()}

    def fft_kwargs(signal, **kwargs):
        """Return optimized keyword arguments for FFTW"""
        kwargs.update(_FFTW_KWARGS_DEFAULT)
        kwargs['n'] = len(signal)  # do not pad
        return kwargs

# Otherwise, fall back to 2 ** n padded scipy FFTPACK
except ImportError:
    import scipy.fftpack as fft # type: ignore
    # Can be turned off, e.g. for MKL optimizations
    _FFT_NEXT_POW2 = True

    def fft_kwargs(signal, **kwargs):
        """Return next higher power of 2 for given signal to speed up FFT"""
        if _FFT_NEXT_POW2:
            return {'n': int(2 ** np.ceil(np.log2(len(signal))))}

from scipy.signal import lfilter # type: ignore
from os import makedirs
from os.path import exists, expanduser


def find(condition):
    """Returns the indices where ravel(condition) is true."""
    res, = np.nonzero(np.ravel(condition))
    return res


def ar1(x):
    """
    Allen and Smith autoregressive lag-1 autocorrelation coefficient.
    In an AR(1) model

        x(t) - <x> = gamma(x(t-1) - <x>) + \alpha z(t) ,

    where <x> is the process mean, gamma and \alpha are process
    parameters and z(t) is a Gaussian unit-variance white noise.

    Parameters
    ----------
    x : numpy.ndarray, list
        Univariate time series

    Returns
    -------
    g : float
        Estimate of the lag-one autocorrelation.
    a : float
        Estimate of the noise variance [var(x) ~= a**2/(1-g**2)]
    mu2 : float
        Estimated square on the mean of a finite segment of AR(1)
        noise, mormalized by the process variance.

    References
    ----------
    [1] Allen, M. R. and Smith, L. A. Monte Carlo SSA: detecting
        irregular oscillations in the presence of colored noise.
        *Journal of Climate*, **1996**, 9(12), 3373-3404.
        <http://dx.doi.org/10.1175/1520-0442(1996)009<3373:MCSDIO>2.0.CO;2>
    [2] http://www.madsci.org/posts/archives/may97/864012045.Eg.r.html

    """
    x = np.asarray(x)
    N = x.size
    xm = x.mean()
    x = x - xm

    # Estimates the lag zero and one covariance
    c0 = x.transpose().dot(x) / N
    c1 = x[0:N-1].transpose().dot(x[1:N]) / (N - 1)

    # According to A. Grinsteds' substitutions
    B = -c1 * N - c0 * N**2 - 2 * c0 + 2 * c1 - c1 * N**2 + c0 * N
    A = c0 * N**2
    C = N * (c0 + c1 * N - c1)
    D = B**2 - 4 * A * C

    if D > 0:
        g = (-B - D**0.5) / (2 * A)
    else:
        raise Warning('Cannot place an upperbound on the unbiased AR(1). '
                      'Series is too short or trend is to large.')

    # According to Allen & Smith (1996), footnote 4
    mu2 = -1 / N + (2 / N**2) * ((N - g**N) / (1 - g) -
                                 g * (1 - g**(N - 1)) / (1 - g)**2)
    c0t = c0 / (1 - mu2)
    a = ((1 - g**2) * c0t) ** 0.5

    return g, a, mu2


def ar1_spectrum(freqs, ar1=0.):
    """
    Lag-1 autoregressive theoretical power spectrum.

    Parameters
    ----------
    freqs : numpy.ndarray, list
        Frequencies at which to calculate the theoretical power
        spectrum.
    ar1 : float
        Autoregressive lag-1 correlation coefficient.

    Returns
    -------
    Pk : numpy.ndarray
        Theoretical discrete Fourier power spectrum of noise signal.

    References
    ----------
    [1] http://www.madsci.org/posts/archives/may97/864012045.Eg.r.html

    """
    # According to a post from the MadSci Network available at
    # http://www.madsci.org/posts/archives/may97/864012045.Eg.r.html,
    # the time-series spectrum for an auto-regressive model can be
    # represented as
    #
    # P_k = \frac{E}{\left|1- \sum\limits_{k=1}^{K} a_k \, e^{2 i \pi
    #   \frac{k f}{f_s} } \right|^2}
    #
    # which for an AR1 model reduces to
    #
    freqs = np.asarray(freqs)
    Pk = (1 - ar1 ** 2) / np.abs(1 - ar1 * np.exp(-2 * np.pi * 1j * freqs)) \
        ** 2

    return Pk


def rednoise(N, g, a=1.):
    """
    Red noise generator using filter.

    Parameters
    ----------
    N : int
        Length of the desired time series.
    g : float
        Lag-1 autocorrelation coefficient.
    a : float, optional
        Noise innovation variance parameter.

    Returns
    -------
    y : numpy.ndarray
        Red noise time series.

    """
    if g == 0:
        yr = np.randn(N, 1) * a
    else:
        # Twice the decorrelation time.
        tau = int(np.ceil(-2 / np.log(np.abs(g))))
        yr = lfilter([1, 0], [1, -g], np.random.randn(N + tau, 1) * a)
        yr = yr[tau:]

    return yr.flatten()


def rect(x, normalize=False):
    """TODO: describe what I do."""
    if type(x) in [int, float]:
        shape = [x, ]
    elif type(x) in [list, dict]:
        shape = x
    elif type(x) in [np.ndarray, np.ma.core.MaskedArray]:
        shape = x.shape
    X = np.zeros(shape)
    X[0] = X[-1] = 0.5
    X[1:-1] = 1

    if normalize:
        X /= X.sum()

    return X


def boxpdf(x):
    """
    Forces the probability density function of the input data to have
    a boxed distribution.

    Parameters
    ----------
    x (array like) :
        Input data

    Returns
    -------
    X (array like) :
        Boxed data varying between zero and one.
    Bx, By (array like) :
        Data lookup table.

    """
    import numpy as np
    x = np.asarray(x)
    n = x.size

    # Kind of 'unique'
    i = np.argsort(x)
    d = (np.diff(x[i]) != 0)
    j = find(np.concatenate([d, [True]]))
    X = x[i][j]

    j = np.concatenate([[0], j + 1])
    Y = 0.5 * (j[0:-1] + j[1:]) / n
    bX = np.interp(x, X, Y)

    return bX, X, Y


def get_cache_dir():
    """Returns the location of the cache directory."""
    # Sets cache directory according to user home path.
    cache_dir = '{}/.cache/pycwt/'.format(expanduser('~'))
    # Creates cache directory if not existant.
    if not exists(cache_dir):
        makedirs(cache_dir)
    # Returns cache directory.
    return cache_dir


import numpy as np
from scipy.special import gamma
from scipy.signal import convolve2d
from scipy.special.orthogonal import hermitenorm


class Morlet(object):
    """Implements the Morlet wavelet class.

    Note that the input parameters f and f0 are angular frequencies.
    f0 should be more than 0.8 for this function to be correct, its
    default value is f0 = 6.

    """

    def __init__(self, f0=6):
        self._set_f0(f0)
        self.name = 'Morlet'

    def psi_ft(self, f):
        """Fourier transform of the approximate Morlet wavelet."""
        return (np.pi ** -0.25) * np.exp(-0.5 * (f - self.f0) ** 2)

    def psi(self, t):
        """Morlet wavelet as described in Torrence and Compo (1998)."""
        return (np.pi ** -0.25) * np.exp(1j * self.f0 * t - t ** 2 / 2)

    def flambda(self):
        """Fourier wavelength as of Torrence and Compo (1998)."""
        return (4 * np.pi) / (self.f0 + np.sqrt(2 + self.f0 ** 2))

    def coi(self):
        """e-Folding Time as of Torrence and Compo (1998)."""
        return 1. / np.sqrt(2)

    def sup(self):
        """Wavelet support defined by the e-Folding time."""
        return 1. / self.coi

    def _set_f0(self, f0):
        # Sets the Morlet wave number, the degrees of freedom and the
        # empirically derived factors for the wavelet bases C_{\delta},
        # gamma, \delta j_0 (Torrence and Compo, 1998, Table 2)
        self.f0 = f0             # Wave number
        self.dofmin = 2          # Minimum degrees of freedom
        if self.f0 == 6:
            self.cdelta = 0.776  # Reconstruction factor
            self.gamma = 2.32    # Decorrelation factor for time averaging
            self.deltaj0 = 0.60  # Factor for scale averaging
        else:
            self.cdelta = -1
            self.gamma = -1
            self.deltaj0 = -1

    def smooth(self, W, dt, dj, scales):
        """Smoothing function used in coherence analysis.

        Parameters
        ----------
        W :
        dt :
        dj :
        scales :

        Returns
        -------
        T :

        """
        # The smoothing is performed by using a filter given by the absolute
        # value of the wavelet function at each scale, normalized to have a
        # total weight of unity, according to suggestions by Torrence &
        # Webster (1999) and by Grinsted et al. (2004).
        m, n = W.shape

        # Filter in time.
        k = 2 * np.pi * fft.fftfreq(fft_kwargs(W[0, :])['n'])
        k2 = k ** 2
        snorm = scales / dt
        # Smoothing by Gaussian window (absolute value of wavelet function)
        # using the convolution theorem: multiplication by Gaussian curve in
        # Fourier domain for each scale, outer product of scale and frequency
        F = np.exp(-0.5 * (snorm[:, np.newaxis] ** 2) * k2)  # Outer product
        smooth = fft.ifft(F * fft.fft(W, axis=1, **fft_kwargs(W[0, :])),
                          axis=1,  # Along Fourier frequencies
                          **fft_kwargs(W[0, :], overwrite_x=True))
        T = smooth[:, :n]  # Remove possibly padded region due to FFT

        if np.isreal(W).all():
            T = T.real

        # Filter in scale. For the Morlet wavelet it's simply a boxcar with
        # 0.6 width.
        wsize = self.deltaj0 / dj * 2
        win = rect(int(np.round(wsize)), normalize=True)
        T = convolve2d(T, win[:, np.newaxis], 'same')  # Scales are "vertical"

        return T


class Paul(object):
    """Implements the Paul wavelet class.

    Note that the input parameter f is the angular frequency and that
    the default order for this wavelet is m=4.

    """
    def __init__(self, m=4):
        self._set_m(m)
        self.name = 'Paul'

    # def psi_ft(self, f):
    #     """Fourier transform of the Paul wavelet."""
    #     return (2 ** self.m /
    #             np.sqrt(self.m * np.prod(range(2, 2 * self.m))) *
    #             f ** self.m * np.exp(-f) * (f > 0))

    def psi_ft(self, f): # modified by SJ
        """Fourier transform of the Paul wavelet with limits to prevent underflow."""
        expnt = -f
        expnt[expnt < -100] = -100  # Apply the threshold to avoid extreme values
        return (2 ** self.m /
                np.sqrt(self.m * np.prod(range(2, 2 * self.m))) *
                f ** self.m * np.exp(expnt) * (f > 0))

    def psi(self, t):
        """Paul wavelet as described in Torrence and Compo (1998)."""
        return (2 ** self.m * 1j ** self.m * np.prod(range(2, self.m - 1)) /
                np.sqrt(np.pi * np.prod(range(2, 2 * self.m + 1))) *
                (1 - 1j * t) ** (-(self.m + 1)))

    def flambda(self):
        """Fourier wavelength as of Torrence and Compo (1998)."""
        return 4 * np.pi / (2 * self.m + 1)

    def coi(self):
        """e-Folding Time as of Torrence and Compo (1998)."""
        return np.sqrt(2)

    def sup(self):
        """Wavelet support defined by the e-Folding time."""
        return 1 / self.coi

    def _set_m(self, m):
        # Sets the m derivative of a Gaussian, the degrees of freedom and the
        # empirically derived factors for the wavelet bases C_{\delta},
        # gamma, \delta j_0 (Torrence and Compo, 1998, Table 2)
        self.m = m               # Wavelet order
        self.dofmin = 2          # Minimum degrees of freedom
        if self.m == 4:
            self.cdelta = 1.132  # Reconstruction factor
            self.gamma = 1.17    # Decorrelation factor for time averaging
            self.deltaj0 = 1.50  # Factor for scale averaging
        else:
            self.cdelta = -1
            self.gamma = -1
            self.deltaj0 = -1


class DOG(object):
    """Implements the derivative of a Guassian wavelet class.

    Note that the input parameter f is the angular frequency and that
    for m=2 the DOG becomes the Mexican hat wavelet, which is then
    default.

    """
    def __init__(self, m=2):
        self._set_m(m)
        self.name = 'DOG'

    def psi_ft(self, f):
        """Fourier transform of the DOG wavelet."""
        return (- 1j ** self.m / np.sqrt(gamma(self.m + 0.5)) * f ** self.m *
                np.exp(- 0.5 * f ** 2))

    def psi(self, t):
        """DOG wavelet as described in Torrence and Compo (1998).

        The derivative of a Gaussian of order `n` can be determined using
        the probabilistic Hermite polynomials. They are explicitly
        written as:
            Hn(x) = 2 ** (-n / s) * n! * sum ((-1) ** m) *
                    (2 ** 0.5 * x) ** (n - 2 * m) / (m! * (n - 2*m)!)
        or in the recursive form:
            Hn(x) = x * Hn(x) - nHn-1(x)

        Source: http://www.ask.com/wiki/Hermite_polynomials

        """
        p = hermitenorm(self.m)
        return ((-1) ** (self.m + 1) * np.polyval(p, t) *
                np.exp(-t ** 2 / 2) / np.sqrt(gamma(self.m + 0.5)))

    def flambda(self):
        """Fourier wavelength as of Torrence and Compo (1998)."""
        return (2 * np.pi / np.sqrt(self.m + 0.5))

    def coi(self):
        """e-Folding Time as of Torrence and Compo (1998)."""
        return 1 / np.sqrt(2)

    def sup(self):
        """Wavelet support defined by the e-Folding time."""
        return 1 / self.coi

    def _set_m(self, m):
        # Sets the m derivative of a Gaussian, the degrees of freedom and the
        # empirically derived factors for the wavelet bases C_{\delta},
        # gamma, \delta j_0 (Torrence and Compo, 1998, Table 2).
        self.m = m               # m-derivative
        self.dofmin = 1          # Minimum degrees of freedom
        if self.m == 2:
            self.cdelta = 3.541  # Reconstruction factor
            self.gamma = 1.43    # Decorrelation factor for time averaging
            self.deltaj0 = 1.40  # Factor for scale averaging
        elif self.m == 6:
            self.cdelta = 1.966
            self.gamma = 1.37
            self.deltaj0 = 0.97
        else:
            self.cdelta = -1
            self.gamma = -1
            self.deltaj0 = -1


class MexicanHat(DOG):
    """Implements the Mexican hat wavelet class.

    This class inherits the DOG class using m=2.

    """
    def __init__(self):
        self.name = 'Mexican Hat'
        self._set_m(2)




def cwt(signal, dt, dj=1/12, s0=-1, J=-1, wavelet='morlet', freqs=None, pad=True):
    """Continuous wavelet transform of the signal at specified scales.

    Parameters
    ----------
    signal : numpy.ndarray, list
        Input signal array.
    dt : float
        Sampling interval.
    dj : float, optional
        Spacing between discrete scales. Default value is 1/12.
        Smaller values will result in better scale resolution, but
        slower calculation and plot.
    s0 : float, optional
        Smallest scale of the wavelet. Default value is 2*dt.
    J : float, optional
        Number of scales less one. Scales range from s0 up to
        s0 * 2**(J * dj), which gives a total of (J + 1) scales.
        Default is J = (log2(N * dt / so)) / dj.
    wavelet : instance of Wavelet class, or string
        Mother wavelet class. Default is Morlet wavelet.
    freqs : numpy.ndarray, optional
        Custom frequencies to use instead of the ones corresponding
        to the scales described above. Corresponding scales are
        calculated using the wavelet Fourier wavelength.
    pad : optional. Default is True. if set, then pad the time series 
        with enough zeroes to get N up to the next higher power of 2. 
        This prevents wraparound from the end of the time series to 
        the beginning, and also speeds up the FFT's used to do the 
        wavelet transform. This will not eliminate all edge effects.
        (added by SJ)

    Returns
    -------
    W : numpy.ndarray
        Wavelet transform according to the selected mother wavelet.
        Has (J+1) x N dimensions.
    sj : numpy.ndarray
        Vector of scale indices given by sj = s0 * 2**(j * dj),
        j={0, 1, ..., J}.
    freqs : array like
        Vector of Fourier frequencies (in 1 / time units) that
        corresponds to the wavelet scales.
    coi : numpy.ndarray
        Returns the cone of influence, which is a vector of N
        points containing the maximum Fourier period of useful
        information at that particular time. Periods greater than
        those are subject to edge effects.
    fft : numpy.ndarray
        Normalized fast Fourier transform of the input signal.
    fftfreqs : numpy.ndarray
        Fourier frequencies (in 1/time units) for the calculated
        FFT spectrum.

    Example
    -------
    >> mother = wavelet.Morlet(6.)
    >> wave, scales, freqs, coi, fft, fftfreqs = wavelet.cwt(signal,
           0.25, 0.25, 0.5, 28, mother)

    """
    wavelet = _check_parameter_wavelet(wavelet)

    # Original signal length
    n0 = len(signal)
    # If no custom frequencies are set, then set default frequencies
    # according to input parameters `dj`, `s0` and `J`. Otherwise, set wavelet
    # scales according to Fourier equivalent frequencies.

    # Zero-padding if enabled (added + the entire code further modified by SJ)
    if pad:
        next_power_of_two = int(2 ** np.ceil(np.log2(n0)))
        signal_padded = np.zeros(next_power_of_two)
        signal_padded[:n0] = signal - np.mean(signal)  # Remove the mean and pad with zeros
    else:
        signal_padded = signal - np.mean(signal)  # Remove the mean without padding

    N = len(signal_padded)  # Length of the padded signal

    # Calculate scales and frequencies if not provided
    if freqs is None:
        # Smallest resolvable scale
        # if s0 == -1:
        #     s0 = 2 * dt / wavelet.flambda()
        # # Number of scales
        # if J == -1:
        #     J = int(np.round(np.log2(N * dt / s0) / dj))
        if s0 == -1:
            s0 = 2 * dt
        if J == -1:
            J = int((np.log(float(N) * dt / s0) / np.log(2)) / dj)
        # The scales as of Mallat 1999
        sj = s0 * 2 ** (np.arange(0, J + 1) * dj)
        # Fourier equivalent frequencies
        freqs = 1 / (wavelet.flambda() * sj)
    else:
        # The wavelet scales using custom frequencies.
        sj = 1 / (wavelet.flambda() * freqs)

    # Fourier transform of the (padded) signal
    signal_ft = fft.fft(signal_padded, **fft_kwargs(signal_padded))
    # Fourier angular frequencies
    ftfreqs = 2 * np.pi * fft.fftfreq(N, dt)

    # Creates wavelet transform matrix as outer product of scaled transformed
    # wavelets and transformed signal according to the convolution theorem.
    # (i)   Transform scales to column vector for outer product;
    # (ii)  Calculate 2D matrix [s, f] for each scale s and Fourier angular
    #       frequency f;
    # (iii) Calculate wavelet transform;
    sj_col = sj[:, np.newaxis]
    psi_ft_bar = ((sj_col * ftfreqs[1] * N) ** .5 *
                  np.conjugate(wavelet.psi_ft(sj_col * ftfreqs)))
    W = fft.ifft(signal_ft * psi_ft_bar, axis=1,
                 **fft_kwargs(signal_ft, overwrite_x=True))

    # Trim the wavelet transform to original signal length if padded
    if pad:
        W = W[:, :n0]  # Trim to the original signal length

    # Checks for NaN in transform results and removes them from the scales if
    # needed, frequencies and wavelet transform. Trims wavelet transform at
    # length `n0`.
    sel = np.invert(np.isnan(W).all(axis=1))
    if np.any(sel):
        sj = sj[sel]
        freqs = freqs[sel]
        W = W[sel, :]

    # Determines the cone-of-influence. Note that it is returned as a function
    # of time in Fourier periods. Uses triangualr Bartlett window with
    # non-zero end-points.
    coi = (n0 / 2 - np.abs(np.arange(0, n0) - (n0 - 1) / 2))
    coi = wavelet.flambda() * wavelet.coi() * dt * coi

    return (W[:, :n0], sj, freqs, coi, signal_ft[1:N//2] / N ** 0.5,
            ftfreqs[1:N//2] / (2 * np.pi))


def icwt(W, sj, dt, dj=1/12, wavelet='morlet'):
    """Inverse continuous wavelet transform.

    Parameters
    ----------
    W : numpy.ndarray
        Wavelet transform, the result of the `cwt` function.
    sj : numpy.ndarray
        Vector of scale indices as returned by the `cwt` function.
    dt : float
        Sample spacing.
    dj : float, optional
        Spacing between discrete scales as used in the `cwt`
        function. Default value is 0.25.
    wavelet : instance of Wavelet class, or string
        Mother wavelet class. Default is Morlet

    Returns
    -------
    iW : numpy.ndarray
        Inverse wavelet transform.

    Example
    -------
    >> mother = wavelet.Morlet()
    >> wave, scales, freqs, coi, fft, fftfreqs = wavelet.cwt(var,
           0.25, 0.25, 0.5, 28, mother)
    >> iwave = wavelet.icwt(wave, scales, 0.25, 0.25, mother)

    """
    wavelet = _check_parameter_wavelet(wavelet)

    a, b = W.shape
    c = sj.size
    if a == c:
        sj = (np.ones([b, 1]) * sj).transpose()
    elif b == c:
        sj = np.ones([a, 1]) * sj
    else:
        raise Warning('Input array dimensions do not match.')

    # As of Torrence and Compo (1998), eq. (11)
    iW = (dj * np.sqrt(dt) / (wavelet.cdelta * wavelet.psi(0)) *
          (np.real(W) / np.sqrt(sj)).sum(axis=0))
    return iW


def significance(signal, dt, scales, sigma_test=0, alpha=None,
                 significance_level=0.95, dof=-1, wavelet='morlet'):
    """Significance test for the one dimensional wavelet transform.

    Parameters
    ----------
    signal : array like, float
        Input signal array. If a float number is given, then the
        variance is assumed to have this value. If an array is
        given, then its variance is automatically computed.
    dt : float
        Sample spacing.
    scales : array like
        Vector of scale indices given returned by `cwt` function.
    sigma_test : int, optional
        Sets the type of significance test to be performed.
        Accepted values are 0 (default), 1 or 2. See notes below for
        further details.
    alpha : float, optional
        Lag-1 autocorrelation, used for the significance levels.
        Default is 0.0.
    significance_level : float, optional
        Significance level to use. Default is 0.95.
    dof : variant, optional
        Degrees of freedom for significance test to be set
        according to the type set in sigma_test.
    wavelet : instance of Wavelet class, or string
        Mother wavelet class. Default is Morlet

    Returns
    -------
    signif : array like
        Significance levels as a function of scale.
    fft_theor (array like):
        Theoretical red-noise spectrum as a function of period.

    Notes
    -----
    If sigma_test is set to 0, performs a regular chi-square test,
    according to Torrence and Compo (1998) equation 18.

    If set to 1, performs a time-average test (equation 23). In
    this case, dof should be set to the number of local wavelet
    spectra that where averaged together. For the global
    wavelet spectra it would be dof=N, the number of points in
    the time-series.

    If set to 2, performs a scale-average test (equations 25 to
    28). In this case dof should be set to a two element vector
    [s1, s2], which gives the scale range that were averaged
    together. If, for example, the average between scales 2 and
    8 was taken, then dof=[2, 8].

    """
    wavelet = _check_parameter_wavelet(wavelet)

    try:
        n0 = len(signal)
    except TypeError:
        n0 = 1
    J = len(scales) - 1
    dj = np.log2(scales[1] / scales[0])

    if n0 == 1:
        variance = signal
    else:
        variance = signal.std() ** 2

    if alpha is None:
        alpha, _, _ = ar1(signal)

    period = scales * wavelet.flambda()  # Fourier equivalent periods
    freq = dt / period                   # Normalized frequency
    dofmin = wavelet.dofmin              # Degrees of freedom with no smoothing
    Cdelta = wavelet.cdelta              # Reconstruction factor
    gamma_fac = wavelet.gamma            # Time-decorrelation factor
    dj0 = wavelet.deltaj0                # Scale-decorrelation factor

    # Theoretical discrete Fourier power spectrum of the noise signal
    # following Gilman et al. (1963) and Torrence and Compo (1998),
    # equation 16.
    def pk(k, a, N):
        return (1 - a ** 2) / (1 + a ** 2 - 2 * a * np.cos(2 * np.pi * k / N))
    fft_theor = pk(freq, alpha, n0)
    fft_theor = variance * fft_theor     # Including time-series variance
    signif = fft_theor

    try:
        if dof == -1:
            dof = dofmin
    except ValueError:
        pass

    if sigma_test == 0:  # No smoothing, dof=dofmin, TC98 sec. 4
        dof = dofmin
        # As in Torrence and Compo (1998), equation 18.
        chisquare = chi2.ppf(significance_level, dof) / dof
        signif = fft_theor * chisquare
    elif sigma_test == 1:  # Time-averaged significance
        if len(dof) == 1:
            dof = np.zeros(1, J+1) + dof
        sel = find(dof < 1)
        dof[sel] = 1
        # As in Torrence and Compo (1998), equation 23:
        dof = dofmin * (1 + (dof * dt / gamma_fac / scales) ** 2) ** 0.5
        sel = find(dof < dofmin)
        dof[sel] = dofmin  # Minimum dof is dofmin
        for n, d in enumerate(dof):
            chisquare = chi2.ppf(significance_level, d) / d
            signif[n] = fft_theor[n] * chisquare
    elif sigma_test == 2:  # Time-averaged significance
        if len(dof) != 2:
            raise Exception('DOF must be set to [s1, s2], '
                            'the range of scale-averages')
        if Cdelta == -1:
            raise ValueError('Cdelta and dj0 not defined '
                             'for {} with f0={}'.format(wavelet.name,
                                                        wavelet.f0))
        s1, s2 = dof
        sel = find((scales >= s1) & (scales <= s2))
        navg = sel.size
        if navg == 0:
            raise ValueError('No valid scales between {} and {}.'.format(s1,
                                                                         s2))
        # As in Torrence and Compo (1998), equation 25.
        Savg = 1 / sum(1. / scales[sel])
        # Power-of-two mid point:
        Smid = np.exp((np.log(s1) + np.log(s2)) / 2.)
        # As in Torrence and Compo (1998), equation 28.
        dof = (dofmin * navg * Savg / Smid) * \
              ((1 + (navg * dj / dj0) ** 2) ** 0.5)
        # As in Torrence and Compo (1998), equation 27.
        fft_theor = Savg * sum(fft_theor[sel] / scales[sel])
        chisquare = chi2.ppf(significance_level, dof) / dof
        # As in Torrence and Compo (1998), equation 26.
        signif = (dj * dt / Cdelta / Savg) * fft_theor * chisquare
    else:
        raise ValueError('sigma_test must be either 0, 1, or 2.')

    return signif, fft_theor


def xwt(y1, y2, dt, dj=1/12, s0=-1, J=-1, significance_level=0.95,
        wavelet='morlet', normalize=False, no_default_signif=False):
    """Cross wavelet transform (XWT) of two signals.

    The XWT finds regions in time frequency space where the time series
    show high common power.

    Parameters
    ----------
    y1, y2 : numpy.ndarray, list
        Input signal array to calculate cross wavelet transform.
    dt : float
        Sample spacing.
    dj : float, optional
        Spacing between discrete scales. Default value is 1/12.
        Smaller values will result in better scale resolution, but
        slower calculation and plot.
    s0 : float, optional
        Smallest scale of the wavelet. Default value is 2*dt.
    J : float, optional
        Number of scales less one. Scales range from s0 up to
        s0 * 2**(J * dj), which gives a total of (J + 1) scales.
        Default is J = (log2(N*dt/so))/dj.
    wavelet : instance of a wavelet class, optional
        Mother wavelet class. Default is Morlet wavelet.
    significance_level : float, optional
        Significance level to use. Default is 0.95.
    normalize : bool, optional
        If set to true, normalizes CWT by the standard deviation of
        the signals.

    Returns
    -------
    xwt (array like):
        Cross wavelet transform according to the selected mother
        wavelet.
    x (array like):
        Intersected independent variable.
    coi (array like):
        Cone of influence, which is a vector of N points containing
        the maximum Fourier period of useful information at that
        particular time. Periods greater than those are subject to
        edge effects.
    freqs (array like):
        Vector of Fourier equivalent frequencies (in 1 / time units)
        that correspond to the wavelet scales.
    signif (array like):
        Significance levels as a function of scale.

    Notes
    -----
    Torrence and Compo (1998) state that the percent point function
    (PPF) -- inverse of the cumulative distribution function -- of a
    chi-square distribution at 95% confidence and two degrees of
    freedom is Z2(95%)=3.999. However, calculating the PPF using
    chi2.ppf gives Z2(95%)=5.991. To ensure similar significance
    intervals as in Grinsted et al. (2004), one has to use confidence
    of 86.46%.

    """
    wavelet = _check_parameter_wavelet(wavelet)

    # Makes sure input signal are numpy arrays.
    y1 = np.asarray(y1)
    y2 = np.asarray(y2)
    # Calculates the standard deviation of both input signals.
    std1 = y1.std()
    std2 = y2.std()
    # Normalizes both signals, if appropriate.
    if normalize:
        y1_normal = (y1 - y1.mean()) / std1
        y2_normal = (y2 - y2.mean()) / std2
    else:
        y1_normal = y1
        y2_normal = y2

    # Calculates the CWT of the time-series making sure the same parameters
    # are used in both calculations.
    _kwargs = dict(dj=dj, s0=s0, J=J, wavelet=wavelet)
    W1, sj, freq, coi, _, _ = cwt(y1_normal, dt, **_kwargs)
    W2, sj, freq, coi, _, _ = cwt(y2_normal, dt, **_kwargs)

    # Now the wavelet transform coherence
    # W12ini = W1 * W2.conj()
    # scales = np.ones([1, y1.size]) * sj[:, None]
    # # -- Normalization by Scale and Smoothing
    # W12 = wavelet.smooth(W12ini / scales, dt, dj, sj)

    # Calculates the cross CWT of y1 and y2.
    W12 = W1 * W2.conj()

    # And the significance tests. Note that the confidence level is calculated
    # using the percent point function (PPF) of the chi-squared cumulative
    # distribution function (CDF) instead of using Z1(95%) = 2.182 and
    # Z2(95%)=3.999 as suggested by Torrence & Compo (1998) and Grinsted et
    # al. (2004). If the CWT has been normalized, then std1 and std2 should
    # be reset to unity, otherwise the standard deviation of both series have
    # to be calculated.
    if normalize:
        std1 = std2 = 1.
    a1, _, _ = ar1(y1)
    a2, _, _ = ar1(y2)
    Pk1 = ar1_spectrum(freq * dt, a1)
    Pk2 = ar1_spectrum(freq * dt, a2)
    dof = wavelet.dofmin
    if not no_default_signif:
        PPF = chi2.ppf(significance_level, dof)
        signif = (std1 * std2 * (Pk1 * Pk2) ** 0.5 * PPF / dof)
    else:
        signif = np.asarray([0])

    # The resuts:
    return W12, coi, freq, signif


def wct(y1, y2, dt, dj=1/12, s0=-1, J=-1, sig=False,
        significance_level=0.95, wavelet='morlet', normalize=False, **kwargs):
    """Wavelet coherence transform (WCT).

    The WCT finds regions in time frequency space where the two time
    series co-vary, but do not necessarily have high power.

    Parameters
    ----------
    y1, y2 : numpy.ndarray, list
        Input signals.
    dt : float
        Sample spacing.
    dj : float, optional
        Spacing between discrete scales. Default value is 1/12.
        Smaller values will result in better scale resolution, but
        slower calculation and plot.
    s0 : float, optional
        Smallest scale of the wavelet. Default value is 2*dt.
    J : float, optional
        Number of scales less one. Scales range from s0 up to
        s0 * 2**(J * dj), which gives a total of (J + 1) scales.
        Default is J = (log2(N*dt/so))/dj.
    sig : bool 
        set to compute signficance, default is True
    significance_level (float, optional) :
        Significance level to use. Default is 0.95.
    normalize (boolean, optional) :
        If set to true, normalizes CWT by the standard deviation of
        the signals.

    Returns
    -------
    WCT : magnitude of coherence
    aWCT : phase angle of coherence
    coi (array like):
        Cone of influence, which is a vector of N points containing
        the maximum Fourier period of useful information at that
        particular time. Periods greater than those are subject to
        edge effects.
    freq (array like):
        Vector of Fourier equivalent frequencies (in 1 / time units)    coi :  
    sig :  Significance levels as a function of scale 
       if sig=True when called, otherwise zero.

    See also
    --------
    cwt, xwt

    """
    wavelet = _check_parameter_wavelet(wavelet)

    nt = len(y1)

    # Checking some input parameters
    if s0 == -1:
        # s0 = 2 * dt / wavelet.flambda()
        s0 = 2 * dt
    if J == -1:
        # Number of scales
        # J = int(np.round(np.log2(y1.size * dt / s0) / dj))
        J = int((np.log(float(nt) * dt / s0) / np.log(2)) / dj)

    # Makes sure input signals are numpy arrays.
    y1 = np.asarray(y1)
    y2 = np.asarray(y2)
    # Calculates the standard deviation of both input signals.
    std1 = y1.std()
    std2 = y2.std()
    # Normalizes both signals, if appropriate.
    if normalize:
        y1_normal = (y1 - y1.mean()) / std1
        y2_normal = (y2 - y2.mean()) / std2
    else:
        y1_normal = y1
        y2_normal = y2

    # Calculates the CWT of the time-series making sure the same parameters
    # are used in both calculations.
    _kwargs = dict(dj=dj, s0=s0, J=J, wavelet=wavelet)
    W1, sj, freq, coi, _, _ = cwt(y1_normal, dt, **_kwargs)
    W2, sj, freq, coi, _, _ = cwt(y2_normal, dt, **_kwargs)

    scales1 = np.ones([1, y1.size]) * sj[:, None]
    scales2 = np.ones([1, y2.size]) * sj[:, None]

    # Smooth the wavelet spectra before truncating -- Time Smoothing
    S1 = wavelet.smooth(np.abs(W1) ** 2 / scales1, dt, dj, sj)
    S2 = wavelet.smooth(np.abs(W2) ** 2 / scales2, dt, dj, sj)

    # Now the wavelet transform coherence
    W12 = W1 * W2.conj()
    scales = np.ones([1, y1.size]) * sj[:, None]
    # -- Normalization by Scale and Scale Smoothing
    S12 = wavelet.smooth(W12 / scales, dt, dj, sj)
    WCT = np.abs(S12) ** 2 / (S1 * S2)
    aWCT = np.angle(W12)

    # Calculates the significance using Monte Carlo simulations with 95%
    # confidence as a function of scale.
    if sig:
        a1, b1, c1 = ar1(y1)
        a2, b2, c2 = ar1(y2)

        sig = wct_significance(a1, a2, dt=dt, dj=dj, s0=s0, J=J,
                               significance_level=significance_level,
                               wavelet=wavelet, **kwargs)
    else:
        sig = np.asarray([0])

    return WCT, aWCT, coi, freq, sig


def wct_significance(al1, al2, dt, dj, s0, J, significance_level=0.95,
                     wavelet='morlet', mc_count=50, progress=True,
                     cache=True):
    """Wavelet coherence transform significance.

    Calculates WCT significance using Monte Carlo simulations with
    95% confidence.

    Parameters
    ----------
    al1, al2: float
        Lag-1 autoregressive coeficients of both time series.
    dt : float
        Sample spacing.
    dj : float, optional
        Spacing between discrete scales. Default value is 1/12.
        Smaller values will result in better scale resolution, but
        slower calculation and plot.
    s0 : float, optional
        Smallest scale of the wavelet. Default value is 2*dt.
    J : float, optional
        Number of scales less one. Scales range from s0 up to
        s0 * 2**(J * dj), which gives a total of (J + 1) scales.
        Default is J = (log2(N*dt/so))/dj.
    significance_level : float, optional
        Significance level to use. Default is 0.95.
    wavelet : instance of a wavelet class, optional
        Mother wavelet class. Default is Morlet wavelet.
    mc_count : integer, optional
        Number of Monte Carlo simulations. Default is 300.
    progress : bool, optional
        If `True` (default), shows progress bar on screen.
    cache : bool, optional
        If `True` (default) saves cache to file.

    Returns
    -------
    TODO

    """

    if cache:
        # Load cache if previously calculated. It is assumed that wavelet
        # analysis is performed using the wavelet's default parameters.
        aa = np.round(np.arctanh(np.array([al1, al2]) * 4))
        aa = np.abs(aa) + 0.5 * (aa < 0)
        cache_file = 'wct_sig_{:0.5f}_{:0.5f}_{:0.5f}_{:0.5f}_{:d}_{}'\
            .format(aa[0], aa[1], dj, s0 / dt, J, wavelet.name)
        cache_dir = get_cache_dir()
        try:
            dat = np.loadtxt('{}/{}.gz'.format(cache_dir, cache_file),
                             unpack=True)
            print('NOTE: WCT significance loaded from cache.\n')
            return dat
        except IOError:
            pass

    # Some output to the screen
    print('Calculating wavelet coherence significance')

    # Choose N so that largest scale has at least some part outside the COI
    ms = s0 * (2 ** (J * dj)) / dt
    N = int(np.ceil(ms * 6))
    noise1 = rednoise(N, al1, 1)
    nW1, sj, freq, coi, _, _ = cwt(noise1, dt=dt, dj=dj, s0=s0, J=J,
                                   wavelet=wavelet)

    period = np.ones([1, N]) / freq[:, None]
    coi = np.ones([J + 1, 1]) * coi[None, :]
    outsidecoi = (period <= coi)
    scales = np.ones([1, N]) * sj[:, None]
    sig95 = np.zeros(J + 1)
    maxscale = find(outsidecoi.any(axis=1))[-1]
    sig95[outsidecoi.any(axis=1)] = np.nan

    nbins = 1000
    wlc = np.ma.zeros([J + 1, nbins])
    # Displays progress bar with tqdm
    for _ in tqdm(range(mc_count), disable=not progress):
        # Generates two red-noise signals with lag-1 autoregressive
        # coefficients given by al1 and al2
        noise1 = rednoise(N, al1, 1)
        noise2 = rednoise(N, al2, 1)
        # Calculate the cross wavelet transform of both red-noise signals
        kwargs = dict(dt=dt, dj=dj, s0=s0, J=J, wavelet=wavelet)
        nW1, sj, freq, coi, _, _ = cwt(noise1, **kwargs)
        nW2, sj, freq, coi, _, _ = cwt(noise2, **kwargs)
        nW12 = nW1 * nW2.conj()
        # Smooth wavelet wavelet transforms and calculate wavelet coherence
        # between both signals.
        S1 = wavelet.smooth(np.abs(nW1) ** 2 / scales, dt, dj, sj)
        S2 = wavelet.smooth(np.abs(nW2) ** 2 / scales, dt, dj, sj)
        S12 = wavelet.smooth(nW12 / scales, dt, dj, sj)
        R2 = np.ma.array(np.abs(S12) ** 2 / (S1 * S2), mask=~outsidecoi)
        # Walks through each scale outside the cone of influence and builds a
        # coherence coefficient counter.
        for s in range(maxscale):
            cd = np.floor(R2[s, :] * nbins)
            for j, t in enumerate(cd[~cd.mask]):
                wlc[s, int(t)] += 1

    # After many, many, many Monte Carlo simulations, determine the
    # significance using the coherence coefficient counter percentile.
    wlc.mask = (wlc.data == 0.)
    R2y = (np.arange(nbins) + 0.5) / nbins
    for s in range(maxscale):
        sel = ~wlc[s, :].mask
        P = wlc[s, sel].data.cumsum()
        P = (P - 0.5) / P[-1]
        sig95[s] = np.interp(significance_level, P, R2y[sel])

    if cache:
        # Save the results on cache to avoid to many computations in the future
        np.savetxt('{}/{}.gz'.format(cache_dir, cache_file), sig95)

    # And returns the results
    return sig95


def _check_parameter_wavelet(wavelet):
    mothers = {'morlet': Morlet, 'paul': Paul, 'dog': DOG,
               'mexicanhat': MexicanHat}
    # Checks if input parameter is a string. For backwards
    # compatibility with Python 2 we check if instance is a
    # `str`.
    try:
        if isinstance(wavelet, str):
            return mothers[wavelet]()
    except NameError:
        if isinstance(wavelet, str):
            return mothers[wavelet]()
    # Otherwise, return itself.
    return wavelet

WaLSA_k_omega.py

This module provides functions for performing k-ω analysis and filtering in spatio-temporal datasets.

WaLSA_k_omega.py
# --------------------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
# http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
# --------------------------------------------------------------------------------------------------------------
# The following codes are baed on those originally written by Rob Rutten, David B. Jess, and Samuel D. T. Grant
# --------------------------------------------------------------------------------------------------------------

import numpy as np # type: ignore
from scipy.optimize import curve_fit # type: ignore
from scipy.signal import convolve # type: ignore
# --------------------------------------------------------------------------------------------

def gaussian_function(sigma, width=None):
    """
    Create a Gaussian kernel that closely matches the IDL implementation.

    Parameters:
        sigma (float or list of floats): Standard deviation(s) of the Gaussian.
        width (int or list of ints, optional): Width of the Gaussian kernel.

    Returns:
        np.ndarray: The Gaussian kernel.
    """

    # sigma = np.array(sigma, dtype=np.float64)
    sigma = np.atleast_1d(np.array(sigma, dtype=np.float64))  # Ensure sigma is at least 1D
    if np.any(sigma <= 0):
        raise ValueError("Sigma must be greater than 0.")

    # width = np.array(width)
    width = np.atleast_1d(np.array(width))  # Ensure width is always at least 1D
    width = np.maximum(width.astype(np.int64), 1)  # Convert to integers and ensure > 0
    if np.any(width <= 0):
        raise ValueError("Width must be greater than 0.")

    nSigma = np.size(sigma)
    if nSigma > 8:
        raise ValueError('Sigma can have no more than 8 elements')
    nWidth = np.size(width)
    if nWidth > nSigma:
        raise ValueError('Incorrect width specification')
    if (nWidth == 1) and (nSigma > 1):
        width = np.full(nSigma, width[0])

    # kernel = np.zeros(tuple(width.astype(int)), dtype=np.float64)
    kernel = np.zeros(tuple(width[:nSigma].astype(int)), dtype=np.float64)  # Match kernel size to nSigma

    temp = np.zeros(8, dtype=np.int64)
    temp[:nSigma] = width[:nSigma]
    width = temp

    if nSigma == 2:
        b = a = 0
        if nSigma >= 1:
            a = np.linspace(0, width[0] - 1, width[0]) - width[0] // 2 + (0 if width[0] % 2 else 0.5)
        if nSigma >= 2:
            b = np.linspace(0, width[1] - 1, width[1]) - width[1] // 2 + (0 if width[1] % 2 else 0.5)

        a1 = b1 = 0
        # Nested loop for kernel computation (to be completed for larger nSigma ....)
        for bb in range(width[1]): 
            b1 = (b[bb] ** 2) / (2 * sigma[1] ** 2) if nSigma >= 2 else 0
            for aa in range(width[0]): 
                a1 = (a[aa] ** 2) / (2 * sigma[0] ** 2) if nSigma >= 1 else 0
                kernel[aa, bb] = np.exp(
                    -np.clip((a1 + b1), -1e3, 1e3)
                )
    elif nSigma == 1:
        a = 0
        if nSigma >= 1:
            a = np.linspace(0, width[0] - 1, width[0]) - width[0] // 2 + (0 if width[0] % 2 else 0.5)

        a1 = 0

        # Nested loop for kernel computation
        for aa in range(width[0]): 
            a1 = (a[aa] ** 2) / (2 * sigma[0] ** 2) if nSigma >= 1 else 0
            kernel[aa] = np.exp(
                -np.clip(a1, -1e3, 1e3)
            )

    kernel = np.nan_to_num(kernel, nan=0.0, posinf=0.0, neginf=0.0)

    if np.sum(kernel) == 0:
        raise ValueError("Generated Gaussian kernel is invalid (all zeros).")

    return kernel


def walsa_radial_distances(shape):
    """
    Compute the radial distance array for a given shape.

    Parameters:
        shape (tuple): Shape of the array, typically (ny, nx).

    Returns:
        numpy.ndarray: Array of radial distances.
    """
    if not (isinstance(shape, tuple) and len(shape) == 2):
        raise ValueError("Shape must be a tuple with two elements, e.g., (ny, nx).")

    y, x = np.indices(shape)
    cy, cx = (np.array(shape) - 1) / 2
    return np.sqrt((x - cx) ** 2 + (y - cy) ** 2)

def avgstd(array):
    """
    Calculate the average and standard deviation of an array.
    """
    avrg = np.sum(array) / array.size
    stdev = np.sqrt(np.sum((array - avrg) ** 2) / array.size)
    return avrg, stdev

def linear(x, *p):
    """
    Compute a linear model y = p[0] + x * p[1].
    (used in temporal detrending )
    """
    if len(p) < 2:
        raise ValueError("Parameters p[0] and p[1] must be provided.")
    ymod = p[0] + x * p[1]
    return ymod


def gradient(xy, *p):
    """
    Gradient function for fitting spatial trends.
    Parameters:
        xy: Tuple of grid coordinates (x, y).
        p: Coefficients [offset, slope_x, slope_y].
    """
    x, y = xy  # Unpack the tuple
    if len(p) < 3:
        raise ValueError("Parameters p[0], p[1], and p[2] must be provided.")
    return p[0] + x * p[1] + y * p[2]


def apod3dcube(cube, apod):
    """
    Apodizes a 3D cube in all three coordinates, with detrending.

    Parameters:
        cube : Input 3D data cube with dimensions (nx, ny, nt).
        apod (float): Apodization factor (0 means no apodization).

    Returns:
        Apodized 3D cube.
    """
    # Get cube dimensions
    nt, nx, ny = cube.shape
    apocube = np.zeros_like(cube, dtype=np.float32)

    # Define temporal apodization
    apodt = np.ones(nt, dtype=np.float32)
    if apod != 0:
        apodrimt = nt * apod
        apodrimt = int(apodrimt)  # Ensure apodrimt is an integer
        apodt[:apodrimt] = (np.sin(np.pi / 2. * np.arange(apodrimt) / apodrimt)) ** 2
        apodt = apodt * np.roll(np.flip(apodt), 1)  
        # Apply symmetrical apodization

    # Temporal detrending (mean-image trend, not per pixel)
    ttrend = np.zeros(nt, dtype=np.float32)
    tf = np.arange(nt) + 1.0
    for it in range(nt):
        img = cube[it, :, :]
        # ttrend[it], _ = avgstd(img)
        ttrend[it] = np.mean(img)

    # Fit the trend with a linear model
    fitp, _ = curve_fit(linear, tf, ttrend, p0=[1000.0, 0.0])
    # fit = fitp[0] + tf * fitp[1]
    fit = linear(tf, *fitp)

    # Temporal apodization per (x, y) column
    for it in range(nt):
        img = cube[it, :, :]
        apocube[it, :, :] = (img - fit[it]) * apodt[it]

    # Define spatial apodization
    apodx = np.ones(nx, dtype=np.float32)
    apody = np.ones(ny, dtype=np.float32)
    if apod != 0:
        apodrimx = apod * nx
        apodrimy = apod * ny
        apodrimx = int(apodrimx)  # Ensure apodrimx is an integer
        apodrimy = int(apodrimy)  # Ensure apodrimy is an integer
        apodx[:apodrimx] = (np.sin(np.pi / 2. * np.arange(int(apodrimx)) / apodrimx)) ** 2
        apody[:apodrimy] = (np.sin(np.pi / 2. * np.arange(int(apodrimy)) / apodrimy)) ** 2
        apodx = apodx * np.roll(np.flip(apodx), 1)
        apody = apody * np.roll(np.flip(apody), 1)
        apodxy = np.outer(apodx, apody)
    else:
        apodxy = np.outer(apodx, apody)

    # Spatial gradient removal and apodizing per image
    # xf, yf = np.meshgrid(np.arange(nx), np.arange(ny), indexing='ij')
    xf = np.ones((nx, ny), dtype=np.float32)
    yf = np.copy(xf)
    for it in range(nt):
        img = apocube[it, :, :]
        # avg, _ = avgstd(img)
        avg = np.mean(img)

        # Ensure xf, yf, and img are properly defined
        assert xf.shape == yf.shape == img.shape, "xf, yf, and img must have matching shapes."

        # Flatten the inputs for curve_fit
        x_flat = xf.ravel()
        y_flat = yf.ravel()
        img_flat = img.ravel()

        # Ensure lengths match
        assert len(x_flat) == len(y_flat) == len(img_flat), "Flattened inputs must have the same length."

        # Call curve_fit with initial parameters
        fitp, _ = curve_fit(gradient, (x_flat, y_flat), img_flat, p0=[1000.0, 0.0, 0.0])

        # Apply the fitted parameters
        fit = gradient((xf, yf), *fitp)
        apocube[it, :, :] = (img - fit) * apodxy + avg

    return apocube


def ko_dist(sx, sy, double=False):
    """
    Set up Pythagorean distance array from the origin.
    """
    # Create distance grids for x and y
    # (computing dx and dy using floating-point division)
    dx = np.tile(np.arange(sx / 2 + 1) / (sx / 2), (int(sy / 2 + 1), 1)).T
    dy = np.flip(np.tile(np.arange(sy / 2 + 1) / (sy / 2), (int(sx / 2 + 1), 1)), axis=0)
    # Compute dxy
    dxy = np.sqrt(dx**2 + dy**2) * (min(sx, sy) / 2 + 1)
    # Initialize afstanden
    afstanden = np.zeros((sx, sy), dtype=np.float64)
    # Assign dxy to the first quadrant (upper-left)
    afstanden[:sx // 2 + 1, :sy // 2 + 1] = dxy
    # Second quadrant (upper-right) - 90° clockwise
    afstanden[sx // 2:, :sy // 2 + 1] = np.flip(np.roll(dxy[:-1, :], shift=-1, axis=1), axis=1)
    # Third quadrant (lower-left) - 270° clockwise
    afstanden[:sx // 2 + 1, sy // 2:] = np.flip(np.roll(dxy[:, :-1], shift=-1, axis=0), axis=0)
    # Fourth quadrant (lower-right) - 180° rotation
    afstanden[sx // 2:, sy // 2:] = np.flip(dxy[:-1, :-1], axis=(0, 1))     

    # Convert to integers if 'double' is False
    if not double:
        afstanden = np.round(afstanden).astype(int)

    return afstanden


def averpower(cube):
    """
    Compute 2D (k_h, f) power array by circular averaging over k_x, k_y.

    Parameters:
    cube : Input 3D data cube of dimensions (nx, ny, nt).

    Returns:
    avpow : 2D array of average power over distances, dimensions (maxdist+1, nt/2+1).
    """
    # Get cube dimensions
    nt, nx, ny = cube.shape

    # Perform FFT in all three dimensions (first in time direction)
    fftcube = np.fft.fft(np.fft.fft(np.fft.fft(cube, axis=0)[:nt//2+1, :, :], axis=1), axis=2)

    # Set up distances
    afstanden = ko_dist(nx, ny)  # Integer-rounded Pythagoras array

    maxdist = min(nx, ny) // 2 + 1  # Largest quarter circle

    # Initialize average power array
    avpow = np.zeros((maxdist + 1, nt // 2 + 1), dtype=np.float64)

    # Compute average power over all k_h distances, building power(k_h, f)
    for i in range(maxdist + 1):
        where_indices = np.where(afstanden == i)
        for j in range(nt // 2 + 1):
            w1 = fftcube[j, :, :][where_indices]
            avpow[i, j] = np.sum(np.abs(w1) ** 2) / len(where_indices)

    return avpow


def walsa_kopower_funct(datacube, **kwargs):
    """
    Calculate k-omega diagram
    (Fourier power at temporal frequency f against horizontal spatial wavenumber k_h)

    Origonally written in IDL by Rob Rutten (RR) assembly of Alfred de Wijn's routines (2010)
    - Translated into Pythn by Shahin Jafarzadeh (2024)

    Parameters:
        datacube: Input data cube [t, x, y].
        arcsecpx (float): Spatial sampling in arcsec/pixel.
        cadence (float): Temporal sampling in seconds.
        apod: fractional extent of apodization edges; default 0.1
        kmax: maximum k_h axis as fraction of Nyquist value, default 0.2
        fmax: maximum f axis as fraction of Nyquist value, default 0.5
        minpower: minimum of power plot range, default maxpower-5
        maxpower: maximum of power plot range, default alog10(max(power))

    Returns:
        k-omega power map.
    """

    # Set default parameter values
    defaults = {
        'apod': 0.1,
        'kmax': 1.0,
        'fmax': 1.0,
        'minpower': None,
        'maxpower': None
    }

    # Update defaults with user-provided values
    params = {**defaults, **kwargs}

    # Apodize the cube
    apocube = apod3dcube(datacube, params['apod'])

    # Compute radially-averaged power
    avpow = averpower(apocube)

    return avpow

# -------------------------------------- Main Function ----------------------------------------
def WaLSA_k_omega(signal, time=None, **kwargs):
    """
    NAME: WaLSA_k_omega
        part of -- WaLSAtools --

    * Main function to calculate and plot k-omega diagram.

    ORIGINAL CODE: QUEEns Fourier Filtering (QUEEFF) code
    WRITTEN, ANNOTATED, TESTED AND UPDATED in IDL BY:
    (1) Dr. David B. Jess
    (2) Dr. Samuel D. T. Grant
    The original code along with its manual can be downloaded at: https://bit.ly/37mx9ic

    WaLSA_k_omega (in IDL): A lightly modified version of the original code (i.e., a few additional keywords added) by Dr. Shahin Jafarzadeh
    - Translated into Pythn by Shahin Jafarzadeh (2024)

    Parameters:
        signal (array): Input datacube, normally in the form of [x, y, t] or [t, x, y]. Note that the input datacube must have identical x and y dimensions. If not, the datacube will be cropped accordingly.
        time (array): Time array corresponding to the input datacube.
        pixelsize (float): Spatial sampling of the input datacube. If not given, it is plotted in units of 'pixel'.
        filtering (bool): If True, filtering is applied, and the filtered datacube (filtered_cube) is returned. Otherwise, None is returned. Default: False.
        f1 (float): Optional lower (temporal) frequency to filter, in Hz.
        f2 (float): Optional upper (temporal) frequency to filter, in Hz.
        k1 (float): Optional lower (spatial) wavenumber to filter, in units of pixelsize^-1 (k = (2 * π) / wavelength).
        k2 (float): Optional upper (spatial) wavenumber to filter, in units of pixelsize^-1.
        spatial_torus (bool): If True, makes the annulus used for spatial filtering have a Gaussian-shaped profile, useful for preventing aliasing. Default: True.
        temporal_torus (bool): If True, makes the temporal filter have a Gaussian-shaped profile, useful for preventing aliasing. Default: True.
        no_spatial_filt (bool): If True, ensures no spatial filtering is performed on the dataset (i.e., only temporal filtering is applied).
        no_temporal_filt (bool): If True, ensures no temporal filtering is performed on the dataset (i.e., only spatial filtering is applied).
        silent (bool): If True, suppresses the k-ω diagram plot.
        smooth (bool): If True, power is smoothed. Default: True.
        mode (int): Output power mode: 0 = log10(power) (default), 1 = linear power, 2 = sqrt(power) = amplitude.
        processing_maps (bool): If True, the function returns the processing maps (spatial_fft_map, torus_map, spatial_fft_filtered_map, temporal_fft, temporal_filter, temporal_frequencies, spatial_frequencies). Otherwise, they are all returned as None. Default: False.

    OUTPUTS:
        power : 2D array of power (see mode for the scale).
        frequencies : 1D array of frequencies (in mHz).
        wavenumber : 1D array of wavenumber (in pixelsize^-1).
        filtered_cube : 3D array of filtered datacube (if filtering is set).
        processing_maps (if set to True)

    IF YOU USE THIS CODE, THEN PLEASE CITE THE ORIGINAL PUBLICATION WHERE IT WAS USED:
    Jess et al. 2017, ApJ, 842, 59 (http://adsabs.harvard.edu/abs/2017ApJ...842...59J)
    """

    # Set default parameter values
    defaults = {
        'pixelsize': 1,
        'format': 'txy',
        'filtering': False,
        'f1': None,
        'f2': None,
        'k1': None,
        'k2': None,
        'spatial_torus': True,
        'temporal_torus': True,
        'no_spatial_filt': False,
        'no_temporal_filt': False,
        'silent': False,
        'xlog': False,
        'ylog': False,
        'xrange': None,
        'yrange': None,
        'nox2': False,
        'noy2': False,
        'smooth': True,
        'mode': 0,
        'xtitle': 'Wavenumber',
        'xtitle_units': '(pixel⁻¹)',
        'ytitle': 'Frequency',
        'yttitle_units': '(Hz)',
        'x2ndaxistitle': 'Spatial size',
        'y2ndaxistitle': 'Period',
        'x2ndaxistitle_units': '(pixel)',
        'y2ndaxistitle_units': '(s)',
        'processing_maps': False
    }

    # Update defaults with user-provided values
    params = {**defaults, **kwargs}

    tdiff = np.diff(time)
    cadence = np.median(tdiff)

    pixelsize = params['pixelsize']
    filtered_cube = None
    spatial_fft_map = None
    torus_map = None
    spatial_fft_filtered_map = None
    temporal_fft = None
    temporal_frequencies = None
    spatial_frequencies = None

    # Check and adjust format if necessary
    if params['format'] == 'xyt':
        cube = np.transpose(cube, (2, 0, 1))  # Convert 'xyt' to 'txy'
    elif params['format'] != 'txy':
        raise ValueError("Unsupported format. Choose 'txy' or 'xyt'.")
    if not params['silent']:
            print(f"Processing k-ω analysis for a 3D cube with format '{params['format']}' and shape {signal.shape}.")

    # Input dimensions
    nt, nx, ny = signal.shape
    if nx != ny:
        min_dim = min(nx, ny)
        signal = signal[:min_dim, :min_dim, :]
    nt, nx, ny = signal.shape

    # Calculating the Nyquist frequencies
    spatial_nyquist = (2 * np.pi) / (pixelsize * 2)
    temporal_nyquist = 1 / (cadence * 2)

    print("")
    print(f"Input datacube size (t,x,y): {signal.shape}")
    print("")
    print("Spatially, the important values are:")
    print(f"    2-pixel size = {pixelsize * 2:.2f} {params['x2ndaxistitle_units']}")
    print(f"    Nyquist wavenumber = {spatial_nyquist:.2f} {params['xtitle_units']}")
    if params['no_spatial_filt']:
        print("*** NO SPATIAL FILTERING WILL BE PERFORMED ***")
    print("")
    print("Temporally, the important values are:")
    print(f"    2-element duration (Nyquist period) = {cadence * 2:.2f} {params['y2ndaxistitle_units']}")
    print(f"    Time series duration = {cadence * signal.shape[2]:.2f} {params['y2ndaxistitle_units']}")
    temporal_nyquist = 1 / (cadence * 2)
    print(f"    Nyquist frequency = {temporal_nyquist:.2f} {params['yttitle_units']}")
    if params['no_temporal_filt']:
        print("***NO TEMPORAL FILTERING WILL BE PERFORMED***")
    print("")

    # Generate k-omega power map
    print("Constructing a k-ω diagram of the input datacube..........")
    print("")
    # Make the k-omega diagram using the proven method of Rob Rutten
    kopower = walsa_kopower_funct(signal)

    # Scales
    xsize_kopower = kopower.shape[0]
    dxsize_kopower = spatial_nyquist / float(xsize_kopower - 1)
    kopower_xscale = np.arange(xsize_kopower) * dxsize_kopower  # in pixel⁻¹

    ysize_kopower = kopower.shape[1]
    dysize_kopower = temporal_nyquist / float(ysize_kopower - 1)
    kopower_yscale = (np.arange(ysize_kopower) * dysize_kopower) # in Hz

    # Generate Gaussian Kernel
    Gaussian_kernel = gaussian_function(sigma=[0.65, 0.65], width=3)
    Gaussian_kernel_norm = np.nansum(Gaussian_kernel)  # Normalize kernel sum

    # Copy kopower to kopower_plot
    kopower_plot = kopower.copy()

    # Convolve kopower (ignoring zero-th element, starting from index 1)
    kopower_plot[:, 1:] = convolve(
        kopower[:, 1:], 
        Gaussian_kernel, 
        mode='same'
    ) / Gaussian_kernel_norm

    # Normalize to frequency resolution (in mHz)
    freq = kopower_yscale[1:]
    if freq[0] == 0:
        freq0 = freq[1]
    else:
        freq0 = freq[0]
    kopower_plot /= freq0

    # Apply logarithmic or square root transformation based on mode
    if params['mode'] == 0:  # Logarithmic scaling
        kopower_plot = np.log10(kopower_plot)
    elif params['mode'] == 2:  # Square root scaling
        kopower_plot = np.sqrt(kopower_plot)

    # Normalize the power - preferred for plotting
    komegamap = np.clip(
        kopower_plot[1:, 1:],
        np.nanmin(kopower_plot[1:, 1:]),
        np.nanmax(kopower_plot[1:, 1:])
    )

    kopower_zscale = kopower_plot[1:,1:]

    # Rotate kopower counterclockwise by 90 degrees
    komegamap = np.rot90(komegamap, k=1)
    # Flip vertically (y-axis)
    komegamap = np.flip(komegamap, axis=0) 

    # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    # Filtering implementation
    if params['filtering']:
        # Extract parameters from the dictionary
        k1 = params['k1']
        k2 = params['k2']
        f1 = params['f1']
        f2 = params['f2']

        if params['no_spatial_filt']:
            k1 = kopower_xscale[1]  # Default lower wavenumber
            k2 = np.nanmax(kopower_xscale)  # Default upper wavenumber
        # Ensure k1 and k2 are within valid bounds
        if k1 is None or k1 <= 0.0:
            k1 = kopower_xscale[1]
        if k2 is None or k2 > np.nanmax(kopower_xscale):
            k2 = np.nanmax(kopower_xscale)

        if params['no_temporal_filt']:
            f1 = kopower_yscale[1]
            f2 = np.nanmax(kopower_yscale)
        # Ensure f1 and f2 are within valid bounds
        if f1 is None or f1 <= 0.0:
            f1 = kopower_yscale[1] 
        if f2 is None or f2 > np.nanmax(kopower_yscale):
            f2 = np.nanmax(kopower_yscale) 

        print("Start filtering (in k-ω space) ......")
        print("")
        print(f"The preserved wavenumbers are [{k1:.3f}, {k2:.3f}] {params['xtitle_units']}")
        print(f"The preserved spatial sizes are [{(2 * np.pi) / k2:.3f}, {(2 * np.pi) / k1:.3f}] {params['x2ndaxistitle_units']}")
        print("")
        print(f"The preserved frequencies are [{f1:.3f}, {f2:.3f}] {params['yttitle_units']}")
        print(f"The preserved periods are [{int(1 / (f2))}, {int(1 / (f1))}] {params['y2ndaxistitle_units']}")
        print("")

        # Perform the 3D Fourier transform
        print("Making a 3D Fourier transform of the input datacube ..........")
        threedft = np.fft.fftshift(np.fft.fftn(signal))

        # Calculate the frequency axes for the 3D FFT
        temp_x = np.arange((nx - 1) // 2) + 1
        is_N_even = (nx % 2 == 0)
        if is_N_even:
            spatial_frequencies_orig = (
                np.concatenate([[0.0], temp_x, [nx / 2], -nx / 2 + temp_x]) / (nx * pixelsize)) * (2.0 * np.pi)
        else:
            spatial_frequencies_orig = (
                np.concatenate([[0.0], temp_x, [-(nx / 2 + 1)] + temp_x]) / (nx * pixelsize)) * (2.0 * np.pi)

        temp_t = np.arange((nt - 1) // 2) + 1  # Use integer division for clarity
        is_N_even = (nt % 2 == 0)
        if is_N_even:
            temporal_frequencies_orig = (
                np.concatenate([[0.0], temp_t, [nt / 2], -nt / 2 + temp_t])) / (nt * cadence)
        else:
            temporal_frequencies_orig = (
                np.concatenate([[0.0], temp_t, [-(nt / 2 + 1)] + temp_t])) / (nt * cadence)

        # Compensate frequency axes for the use of FFT shift
        indices = np.where(spatial_frequencies_orig >= 0)[0]
        spatial_positive_frequencies = len(indices)
        if len(spatial_frequencies_orig) % 2 == 0:
            spatial_frequencies = np.roll(spatial_frequencies_orig, spatial_positive_frequencies - 2)
        else:
            spatial_frequencies = np.roll(spatial_frequencies_orig, spatial_positive_frequencies - 1)

        tindices = np.where(temporal_frequencies_orig >= 0)[0]
        temporal_positive_frequencies = len(tindices)
        if len(temporal_frequencies_orig) % 2 == 0:
            temporal_frequencies = np.roll(temporal_frequencies_orig, temporal_positive_frequencies - 2)
        else:
            temporal_frequencies = np.roll(temporal_frequencies_orig, temporal_positive_frequencies - 1)

        # Ensure the threedft aligns with the new frequency axes
        if len(temporal_frequencies_orig) % 2 == 0:
            for x in range(nx):
                for y in range(ny):
                    threedft[:, x, y] = np.roll(threedft[:, x, y], -1)

        if len(spatial_frequencies_orig) % 2 == 0:
            for z in range(nt):
                threedft[z, :, :] = np.roll(threedft[z, :, :], shift=(-1, -1), axis=(0, 1))

        # Convert frequencies and wavenumbers of interest into FFT datacube pixels
        pixel_k1_positive = np.argmin(np.abs(spatial_frequencies_orig - k1))
        pixel_k2_positive = np.argmin(np.abs(spatial_frequencies_orig - k2))
        pixel_f1_positive = np.argmin(np.abs(temporal_frequencies - f1))
        pixel_f2_positive = np.argmin(np.abs(temporal_frequencies - f2))
        pixel_f1_negative = np.argmin(np.abs(temporal_frequencies + f1))
        pixel_f2_negative = np.argmin(np.abs(temporal_frequencies + f2))

        torus_depth = int((pixel_k2_positive - pixel_k1_positive) / 2) * 2
        torus_center = int(((pixel_k2_positive - pixel_k1_positive) / 2) + pixel_k1_positive)

        if params['spatial_torus'] and not params['no_spatial_filt']:
            # Create a filter ring preserving equal wavenumbers for both kx and ky
            # This forms a torus to preserve an integrated Gaussian shape across the width of the annulus
            spatial_torus = np.zeros((torus_depth, nx, ny)) 
            for i in range(torus_depth // 2 + 1):
                spatial_ring = np.logical_xor(
                    (walsa_radial_distances((nx, ny)) <= (torus_center - i)),
                    (walsa_radial_distances((nx, ny)) <= (torus_center + i + 1))
                )
                spatial_ring = spatial_ring.astype(int)  # Convert True -> 1 and False -> 0
                spatial_ring[spatial_ring > 0] = 1.
                spatial_ring[spatial_ring != 1] = 0.
                spatial_torus[i, :, :] = spatial_ring
                spatial_torus[torus_depth - i - 1, :, :] = spatial_ring

            # Integrate through the torus to find the spatial filter
            spatial_ring_filter = np.nansum(spatial_torus, axis=0) / float(torus_depth)

            spatial_ring_filter = spatial_ring_filter / np.nanmax(spatial_ring_filter)  # Ensure the peaks are at 1.0

        if not params['spatial_torus'] and not params['no_spatial_filt']:
            spatial_ring_filter = (
                (walsa_radial_distances((nx, ny)) <= (torus_center - int(torus_depth / 2))).astype(int) -
                (walsa_radial_distances((nx, ny)) <= (torus_center + int(torus_depth / 2) + 1)).astype(int)
            )
            spatial_ring_filter = spatial_ring_filter / np.nanmax(spatial_ring_filter)  # Ensure the peaks are at 1.0
            spatial_ring_filter[spatial_ring_filter != 1] = 0

        if params['no_spatial_filt']:
            spatial_ring_filter = np.ones((nx, ny))

        if not params['no_temporal_filt'] and params['temporal_torus']:
        # CREATE A GAUSSIAN TEMPORAL FILTER TO PREVENT ALIASING
            temporal_filter = np.zeros(nt, dtype=float)
            filter_width = pixel_f2_positive - pixel_f1_positive

            # Determine sigma based on filter width
            if filter_width < 25:
                sigma = 3
            if filter_width >= 25 and filter_width < 30:
                sigma = 4
            if filter_width >= 30 and filter_width < 40:
                sigma = 5
            if filter_width >= 40 and filter_width < 45:
                sigma = 6
            if filter_width >= 45 and filter_width < 50:
                sigma = 7
            if filter_width >= 50 and filter_width < 55:
                sigma = 8
            if filter_width >= 55 and filter_width < 60:
                sigma = 9
            if filter_width >= 60 and filter_width < 65:
                sigma = 10
            if filter_width >= 65 and filter_width < 70:
                sigma = 11
            if filter_width >= 70 and filter_width < 80:
                sigma = 12
            if filter_width >= 80 and filter_width < 90:
                sigma = 13
            if filter_width >= 90 and filter_width < 100:
                sigma = 14
            if filter_width >= 100 and filter_width < 110:
                sigma = 15
            if filter_width >= 110 and filter_width < 130:
                sigma = 16
            if filter_width >= 130:
                sigma = 17

            # Generate the Gaussian kernel
            temporal_gaussian = gaussian_function(sigma=sigma, width=filter_width)

            # Apply the Gaussian to the temporal filter
            temporal_filter[pixel_f1_positive:pixel_f2_positive] = temporal_gaussian
            temporal_filter[pixel_f2_negative:pixel_f1_negative] = temporal_gaussian
            # Normalize the filter to ensure the peaks are at 1.0
            temporal_filter /= np.nanmax(temporal_filter)

        if not params['no_temporal_filt'] and not params['temporal_torus']:
            temporal_filter = np.zeros(nt, dtype=float)
            temporal_filter[pixel_f1_positive:pixel_f2_positive + 1] = 1.0
            temporal_filter[pixel_f2_negative:pixel_f1_negative + 1] = 1.0

        if params['no_temporal_filt']:
            temporal_filter = np.ones(nt, dtype=float)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # Create useful variables for plotting (if needed), demonstrating the filtering process
        if params['processing_maps']:
            # Define the spatial frequency step for plotting
            spatial_dx = spatial_frequencies[1] - spatial_frequencies[0]
            spatial_dy = spatial_frequencies[1] - spatial_frequencies[0]

            # Torus map
            torus_map = {
                'data': spatial_ring_filter,
                'dx': spatial_dx,
                'dy': spatial_dy,
                'xc': 0,
                'yc': 0,
                'time': '',
                'units': 'pixels'
            }

            # Compute the total spatial FFT
            spatial_fft = np.nansum(threedft, axis=0)

            # Spatial FFT map
            spatial_fft_map = {
                'data': np.log10(spatial_fft),
                'dx': spatial_dx,
                'dy': spatial_dy,
                'xc': 0,
                'yc': 0,
                'time': '',
                'units': 'pixels'
            }

            # Spatial FFT filtered and its map
            spatial_fft_filtered = spatial_fft * spatial_ring_filter
            spatial_fft_filtered_map = {
                'data': np.log10(np.maximum(spatial_fft_filtered, 1e-15)),
                'dx': spatial_dx,
                'dy': spatial_dy,
                'xc': 0,
                'yc': 0,
                'time': '',
                'units': 'pixels'
            }

            # Compute the total temporal FFT
            temporal_fft = np.nansum(np.nansum(threedft, axis=2), axis=1)

        else:
            spatial_fft_map = None
            torus_map = None
            spatial_fft_filtered_map = None
            temporal_fft = None
            temporal_filter, temporal_frequencies, spatial_frequencies = None, None, None

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

        # Apply the Gaussian filters to the data to prevent aliasing
        for i in range(nt): 
            threedft[i, :, :] *= spatial_ring_filter

        for x in range(nx):  
            for y in range(ny):  
                threedft[:, x, y] *= temporal_filter

        # ALSO NEED TO ENSURE THE threedft ALIGNS WITH THE OLD FREQUENCY AXES USED BY THE /center CALL
        if len(temporal_frequencies_orig) % 2 == 0:
            for x in range(nx):  
                for y in range(ny): 
                    threedft[:, x, y] = np.roll(threedft[:, x, y], shift=1, axis=0)

        if len(spatial_frequencies_orig) % 2 == 0:
            for t in range(nt):  
                threedft[t, :, :] = np.roll(threedft[t, :, :], shift=(1, 1), axis=(0, 1))
                threedft[z, :, :] = np.roll(np.roll(threedft[z, :, :], shift=1, axis=0), shift=1, axis=1)

        # Inverse FFT to get the filtered cube
        # filtered_cube = np.real(np.fft.ifftn(threedft, norm='ortho'))
        filtered_cube = np.real(np.fft.ifftn(np.fft.ifftshift(threedft)))

    else:
        filtered_cube = None

        print("Filtered datacube generated.")

    return komegamap, kopower_xscale[1:], kopower_yscale[1:], filtered_cube, spatial_fft_map, torus_map, spatial_fft_filtered_map, temporal_fft, temporal_filter, temporal_frequencies, spatial_frequencies

WaLSA_pod.py

This module implements Proper Orthogonal Decomposition (POD), as well as Spectral POD (SPOD), for analysing multi-dimensional data and extracting dominant spatial patterns.

WaLSA_pod.py
# --------------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
# http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
# --------------------------------------------------------------------------------------------------------
# The following codes are based on those originally written by Jonathan E. Higham & Luiz A. C. A. Schiavo
# --------------------------------------------------------------------------------------------------------

import numpy as np # type: ignore
from scipy.signal import welch, find_peaks # type: ignore
from scipy.linalg import svd # type: ignore
from scipy.optimize import curve_fit # type: ignore
import math

def print_pod_results(results):
    """
    Print a summary of the results from WaLSA_pod, including parameter descriptions, types, and shapes.

    Parameters:
    -----------
    results : dict
        The dictionary containing all POD results and relevant outputs.
    """
    descriptions = {
        'input_data': 'Original input data, mean subtracted (Shape: (Nt, Ny, Nx))',
        'spatial_mode': 'Reshaped spatial modes matching the dimensions of the input data (Shape: (Nmodes, Ny, Nx))',
        'temporal_coefficient': 'Temporal coefficients associated with each spatial mode (Shape: (Nmodes, Nt))',
        'eigenvalue': 'Eigenvalues corresponding to singular values squared (Shape: (Nmodes))',
        'eigenvalue_contribution': 'Eigenvalue contribution of each mode (Shape: (Nmodes))',
        'cumulative_eigenvalues': 'Cumulative percentage of eigenvalues for the first "num_cumulative_modes" modes (Shape: (num_cumulative_modes))',
        'combined_welch_psd': 'Combined Welch power spectral density for the temporal coefficients of the firts "num_modes" modes (Shape: (Nf))',
        'frequencies': 'Frequencies identified in the Welch spectrum (Shape: (Nf))',
        'combined_welch_significance': 'Significance threshold of the combined Welch spectrum (Shape: (Nf,))',
        'reconstructed': 'Reconstructed frame at the specified timestep (or for the entire time series) using the top "num_modes" modes (Shape: (Ny, Nx))',
        'sorted_frequencies': 'Frequencies identified in the Welch combined power spectrum (Shape: (Nfrequencies))',
        'frequency_filtered_modes': 'Frequency-filtered spatial POD modes for the first "num_top_frequencies" frequencies (Shape: (Nt, Ny, Nx, num_top_frequencies))',
        'frequency_filtered_modes_frequencies': 'Frequencies corresponding to the frequency-filtered modes (Shape: (num_top_frequencies))',
        'SPOD_spatial_modes': 'SPOD spatial modes if SPOD is used (Shape: (Nspod_modes, Ny, Nx))',
        'SPOD_temporal_coefficients': 'SPOD temporal coefficients if SPOD is used (Shape: (Nspod_modes, Nt))',
        'p': 'Left singular vectors (spatial modes) from SVD (Shape: (Nx, Nmodes))',
        's': 'Singular values from SVD (Shape: (Nmodes))',
        'a': 'Right singular vectors (temporal coefficients) from SVD (Shape: (Nmodes, Nt))'
    }

    print("\n---- POD/SPOD Results Summary ----\n")
    for key, value in results.items():
        desc = descriptions.get(key, 'No description available')
        shape = np.shape(value) if value is not None else 'None'
        dtype = type(value).__name__
        print(f"{key} ({dtype}, Shape: {shape}): {desc}")
    print("\n----------------------------------")

def spod(data, **kwargs):
    """
    Perform Spectral Proper Orthogonal Decomposition (SPOD) analysis on input data.
    Steps:
    1. Load the data.
    2. Compute a time average.
    3. Build the correlation matrix to compute the POD using the snapshot method.
    4. Compute eigenvalue decomposition for the correlation matrix for the fluctuation field. 
    The eigenvalues and eigenvectors can be computed by any eigenvalue function or by an SVD procedure, since the correlation matrix is symmetric positive-semidefinite.
    5. After obtaining the eigenvalues and eigenvectors, compute the temporal and spatial modes according to the snapshot method.

    Parameters:
    -----------
    data : np.ndarray
        3D data array with shape (time, x, y) or similar. 
    **kwargs : dict, optional
        Additional keyword arguments to configure the analysis.

    Returns:
    --------
    SPOD_spatial_modes : np.ndarray
    SPOD_temporal_coefficients : np.ndarray
    """
    # Set default parameter values
    defaults = {
        'silent': False,
        'num_modes': None,
        'filter_size': None,
        'periodic_matrix': True
    }

    # Update defaults with user-provided values
    params = {**defaults, **kwargs}

    if params['num_modes'] is None:
        params['num_modes'] = data.shape[0]
    if params['filter_size'] is None:
        params['filter_size'] = params['num_modes']

    if not params['silent']:
        print("Starting SPOD analysis ....")

    nsnapshots, ny, nx = data.shape

    # Center the data by subtracting the mean
    time_average = data.mean(axis=0)
    fluctuation_field = data * 0  # Initialize fluctuation field with zeros
    for n in range(nsnapshots):
        fluctuation_field[n,:,:] = data[n,:,:] - time_average  # Subtract time-average from each snapshot

    # Build the correlation matrix (snapshot method)
    correlation_matrix = np.zeros((nsnapshots, nsnapshots))
    for i in range(nsnapshots):
        for j in range(i, nsnapshots):
            correlation_matrix[i, j] = (fluctuation_field[i, :, :] * fluctuation_field[j, :, :]).sum() / (nsnapshots * nx * ny)
            correlation_matrix[j, i] = correlation_matrix[i, j]

    # SPOD correlation matrix with periodic boundary conditions
    nmatrix = np.zeros((3 * nsnapshots, 3 * nsnapshots))
    nmatrix[nsnapshots:2 * nsnapshots, nsnapshots:2 * nsnapshots] = correlation_matrix[0:nsnapshots, 0:nsnapshots]

    if params['periodic_matrix']:
        for i in range(3):
            for j in range(3):
                xs = i * nsnapshots  # x-offset for periodic positioning
                ys = j * nsnapshots  # y-offset for periodic positioning
                nmatrix[0 + xs:nsnapshots + xs, 0 + ys:nsnapshots + ys] = correlation_matrix[0:nsnapshots, 0:nsnapshots]

    # Apply Gaussian filter
    gk = np.zeros((2 * params['filter_size'] + 1))  # Create 1D Gaussian kernel
    esp = 8.0  # Exponential parameter controlling the spread of the filter
    sumgk = 0.0  # Sum of the Gaussian kernel for normalization

    for n in range(2 * params['filter_size'] + 1):
        k = -params['filter_size'] + n  # Offset from the center of the kernel
        gk[n] = math.exp(-esp * k**2.0 / (params['filter_size']**2.0))  # Gaussian filter formula
        sumgk += gk[n]  # Sum of kernel values for normalization

    # Filter the extended correlation matrix
    for j in range(nsnapshots, 2 * nsnapshots):
        for i in range(nsnapshots, 2 * nsnapshots):
            aux = 0.0  # Initialize variable for the weighted sum
            for n in range(2 * params['filter_size'] + 1):
                k = -params['filter_size'] + n  # Offset from the current index
                aux += nmatrix[i + k, j + k] * gk[n]  # Apply Gaussian weighting
            nmatrix[i, j] = aux / sumgk  # Normalize by the sum of the Gaussian kernel

    # Extract the SPOD correlation matrix from the central part of the filtered matrix
    spectral_matrix = nmatrix[nsnapshots:2 * nsnapshots, nsnapshots:2 * nsnapshots]

    # Perform SVD to compute SPOD
    UU, SS, VV = svd(spectral_matrix, full_matrices=True, compute_uv=True)

    # Temporal coefficients
    SPOD_temporal_coefficients = np.zeros((nsnapshots, nsnapshots))
    for k in range(nsnapshots):
        SPOD_temporal_coefficients[:, k] = np.sqrt(SS[k] * nsnapshots) * UU[:, k]

    # Extract spatial SPOD modes
    SPOD_spatial_modes = np.zeros((params['num_modes'], ny, nx))
    for m in range(params['num_modes']):
        for t in range(nsnapshots):
            SPOD_spatial_modes[m, :, :] += fluctuation_field[t, :, :] * SPOD_temporal_coefficients[t, m] / (SS[m] * nsnapshots)

    if not params['silent']:
        print(f"SPOD analysis completed.")

    return SPOD_spatial_modes, SPOD_temporal_coefficients


# -------------------------------------- Main Function ----------------------------------------
def WaLSA_pod(signal, time, **kwargs):
    """
    Perform Proper Orthogonal Decomposition (POD) analysis on input data.

    Parameters:
        signal (array): 3D data cube with shape (time, x, y) or similar.
        time (array): 1D array representing the time points for each time step in the data.
        num_modes (int, optional): Number of top modes to compute. Default is None (all modes).
        num_top_frequencies (int, optional): Number of top frequencies to consider. Default is None (all frequencies).
        top_frequencies (list, optional): List of top frequencies to consider. Default is None.
        num_cumulative_modes (int, optional): Number of cumulative modes to consider. Default is None (all modes).
        welch_nperseg (int, optional): Number of samples per segment for Welch's method. Default is 150.
        welch_noverlap (int, optional): Number of overlapping samples for Welch's method. Default is 25.
        welch_nfft (int, optional): Number of points for the FFT. Default is 2^14.
        welch_fs (int, optional): Sampling frequency for the data. Default is 2.
        nperm (int, optional): Number of permutations for significance testing. Default is 1000.
        siglevel (float, optional): Significance level for the Welch spectrum. Default is 0.95.
        timestep_to_reconstruct (int, optional): Timestep of the datacube to reconstruct using the top modes. Default is 0.
        num_modes_reconstruct (int, optional): Number of modes to use for reconstruction. Default is None (all modes).
        reconstruct_all (bool, optional): If True, reconstruct the entire time series using the top modes. Default is False.
        spod (bool, optional): If True, perform Spectral Proper Orthogonal Decomposition (SPOD) analysis. Default is False.
        spod_filter_size (int, optional): Filter size for SPOD analysis. Default is None.
        spod_num_modes (int, optional): Number of SPOD modes to compute. Default is None.
        print_results (bool, optional): If True, print a summary of results. Default is True.

    **kwargs : Additional keyword arguments to configure the analysis.

    Returns:
    results : dict
        A dictionary containing all computed POD results and relevant outputs. 
        See 'descriptions' in the 'print_pod_results' function on top of this page.
    """
    # Set default parameter values
    defaults = {
        'silent': False,
        'num_modes': None,
        'num_top_frequencies': None,
        'top_frequencies': None,
        'num_cumulative_modes': None,
        'welch_nperseg': 150,
        'welch_noverlap': 25,
        'welch_nfft': 2**14,
        'welch_fs': 2,
        'nperm': 1000,
        'siglevel': 0.95,
        'timestep_to_reconstruct': 0,
        'reconstruct_all': False,
        'num_modes_reconstruct': None,
        'spod': False,
        'spod_filter_size': None,
        'spod_num_modes': None,
        'print_results': True  # Print summary of results by default
    }

    # Update defaults with user-provided values
    params = {**defaults, **kwargs}

    data = signal

    if params['num_modes'] is None:
        params['num_modes'] = data.shape[0]

    if params['num_top_frequencies'] is None:
        params['num_top_frequencies'] = min(params['num_modes'], 10)

    if params['num_cumulative_modes'] is None:
        params['num_cumulative_modes'] = min(params['num_modes'], 10)

    if params['num_modes_reconstruct'] is None:
        params['num_modes_reconstruct'] = min(params['num_modes'], 10)

    if not params['silent']:
        print("Starting POD analysis ....")
        print(f"Processing a 3D cube with shape {data.shape}.")

    # The first step is to read in the data and then perform the SVD (or POD). Before we do this, 
    # we need to reshape the matrix such that it is an N x T matrix where N is the column vectorized set of each spatial image and T is the temporal domain. 
    # We also need to ensure that the mean is subtracted from the data; this will ensure we are looking only at the variance of the data, and mode 1 will not be contaminated with the mean image.
    # Reshape the 3D data into 2D array where each row is a vector from the original 3D data
    inp = np.reshape(data, (data.shape[0], data.shape[1] * data.shape[2])).astype(np.float32)
    inp = inp.T  # Transpose the matrix to have spatial vectors as columns

    # Center the data by subtracting the mean
    mean_per_row = np.nanmean(inp, axis=1, keepdims=True)
    mean_replicated = np.tile(mean_per_row, (1, data.shape[0]))
    inp_centered = inp - mean_replicated

    # Input data, mean subtracted
    input_dat = np.zeros((data.shape[0], data.shape[1], data.shape[2]))
    for im in range(data.shape[0]):
        input_dat[im, :, :] = np.reshape(inp_centered[:, im], (data.shape[1], data.shape[2]))

    # Perform SVD to compute POD
    p, s, a = svd(inp_centered, full_matrices=False)
    sorg = s.copy()  # Store original singular values
    eigenvalue = s**2  # Convert singular values to eigenvalues

    # Reshape spatial modes to have the same shape as the input data
    num_modes = p.shape[1]
    spatial_mode = np.zeros((num_modes, data.shape[1], data.shape[2]))
    for m in range(num_modes):
        spatial_mode[m, :, :] = np.reshape(p[:, m], (data.shape[1], data.shape[2]))

    # Extract temporal coefficients
    temporal_coefficient = a[:num_modes, :]

    # Calculate eigenvalue contributions
    eigenvalue_contribution = eigenvalue / np.sum(eigenvalue)

    # Calculate cumulative eigenvalues
    cumulative_eigenvalues = []
    for m in range(params['num_cumulative_modes']): 
        contm = 100 * eigenvalue[0:m] / np.sum(eigenvalue)
        cumulative_eigenvalues.append(np.sum(contm)) 

    if params['reconstruct_all']:
        # Reconstruct the entire time series using the top 'num_modes_reconstruct' modes
        reconstructed = np.zeros((data.shape[0], data.shape[1], data.shape[2]))
        for tindex in range(data.shape[0]):
            reconim = np.zeros((data.shape[1], data.shape[2]))
            for i in range(params['num_modes_reconstruct']):
                reconim=reconim+np.reshape(p[:, i], (data.shape[1], data.shape[2]))*a[i, tindex]*sorg[i]
            reconstructed[tindex, :, :] = reconim
    else:
        # Reconstruct the specified timestep using the top 'num_modes_reconstruct' modes
        reconstructed = np.zeros((data.shape[1], data.shape[2]))
        for i in range(params['num_modes_reconstruct']):
            reconstructed=reconstructed+np.reshape(p[:, i], (data.shape[1], data.shape[2]))*a[i, params['timestep_to_reconstruct']]*sorg[i]

    #---------------------------------------------------------------------------------
    # Combined Welch power spectrum and its significance
    #---------------------------------------------------------------------------------
    # Compute Welch power spectrum for each mode and combine them for 'num_modes' modes
    combined_welch_psd = []
    for m in range(params['num_modes']):
        frequencies, px = welch(a[m, :] - np.mean(a[m, :]), 
                      nperseg=params['welch_nperseg'], 
                      noverlap=params['welch_noverlap'], 
                      nfft=params['welch_nfft'], 
                      fs=params['welch_fs'])
        if m == 0:
            combined_welch_psd = np.zeros((len(frequencies),))
        combined_welch_psd += eigenvalue_contribution[m] * (px / np.sum(px))

    # Generate resampled peaks to compute significance threshold
    resampled_peaks = np.zeros((params['nperm'], len(frequencies)))
    for i in range(params['nperm']):
        resampled_data = np.random.randn(*a.shape)
        resampled_peak = np.zeros((len(frequencies),))
        for m in range(params['num_modes']):
            _, px = welch(resampled_data[m, :] - np.mean(resampled_data[m, :]), 
                          nperseg=params['welch_nperseg'], 
                          noverlap=params['welch_noverlap'], 
                          nfft=params['welch_nfft'], 
                          fs=params['welch_fs'])
            resampled_peak += eigenvalue_contribution[m] * (px / np.sum(px))
        resampled_peak /= (np.max(resampled_peak) + 1e-30) # a small epsilon added to avoid division by zero
        resampled_peaks[i, :] = resampled_peak
    # Calculate significance threshold
    combined_welch_significance = np.percentile(resampled_peaks, 100-(params['siglevel']*100), axis=0)
    #---------------------------------------------------------------------------------

    # Find peaks in the combined spectrum and sort them in descending order
    normalized_peak = combined_welch_psd / (np.max(combined_welch_psd) + 1e-30)
    peak_indices, _ = find_peaks(normalized_peak)
    sorted_indices = np.argsort(normalized_peak[peak_indices])[::-1]
    sorted_frequencies = frequencies[peak_indices][sorted_indices]

    # Generate a table of the top N frequencies
    # Cleaner single-line list comprehension
    table_data = [
        [float(np.round(freq, 2)), float(np.round(pwr, 2))] 
        for freq, pwr in zip(sorted_frequencies[:params['num_top_frequencies']], 
                            normalized_peak[peak_indices][sorted_indices][:params['num_top_frequencies']])
    ]

    # The frequencies that the POD is able to capture have now been identified (in the top 'num_top_frequencies' modes). 
    # These frequencies can be fitted to the temporal coefficients of the top 'num_top_frequencies' modes, 
    # allowing for a representation of the data described solely by these "pure" frequencies. 
    # This approach enables the reconstruction of the original data using the identified dominant frequencies,
    # resulting in frequency-filtered spatial POD modes.
    clean_data = np.zeros((inp.shape[0],inp.shape[1],10))

    if params['top_frequencies'] is None:
        top_frequencies = sorted_frequencies[:params['num_top_frequencies']]
    else:    
        top_frequencies = params['top_frequencies']
        params['num_top_frequencies'] = len(params['top_frequencies'])

    for i in range(params['num_top_frequencies']):
        def model_fun(t, amplitude, phase):
            """
            Generate a sinusoidal model function.

            Parameters:
            t (array-like): The time variable.
            amplitude (float): The amplitude of the sinusoidal function.
            phase (float): The phase shift of the sinusoidal function.

            Returns:
            array-like: The computed sinusoidal values at each time point `t`.
            """
            return amplitude * np.sin(2 * np.pi * top_frequencies[i] * t + phase)

        for j in range(params['num_modes_reconstruct']):
            csignal = a[j,:]
            # Initial Guesses for Parameters: [Amplitude, Phase]
            initial_guess = [1, 0]
            # Nonlinear Fit
            fit_params, _ = curve_fit(model_fun, time, csignal, p0=initial_guess)
            # Clean Signal from Fit
            if j == 0:
                clean_signal=np.zeros((params['num_modes_reconstruct'],len(csignal)))
            clean_signal[j,:] = model_fun(time, *fit_params)
        # forming a set of clean data (reconstructed data at the fitted frequencies; frequency filtered reconstructed data)
        clean_data[:,:,i] = p[:,0:params['num_modes_reconstruct']]@np.diag(sorg[0:params['num_modes_reconstruct']])@clean_signal

    frequency_filtered_modes = np.zeros((data.shape[0], data.shape[1], data.shape[2], params['num_top_frequencies']))
    for jj in range(params['num_top_frequencies']):
        for frame_index in range(data.shape[0]):
            frequency_filtered_modes[frame_index, :, :, jj] = np.reshape(clean_data[:, frame_index, jj], (data.shape[1], data.shape[2]))

    if params['spod']:
        SPOD_spatial_modes, SPOD_temporal_coefficients = spod(data, num_modes=params['spod_num_modes'], filter_size=params['spod_filter_size'])
    else:
        SPOD_spatial_modes = None
        SPOD_temporal_coefficients = None

    results = {
        'input_data': input_dat, # Original input data, mean subtracted (Shape: (Nt, Ny, Nx))
        'spatial_mode': spatial_mode,  # POD spatial modes matching the dimensions of the input data (Shape: (Nmodes, Ny, Nx))
        'temporal_coefficient': temporal_coefficient,  # POD temporal coefficients associated with each spatial mode (Shape: (Nmodes, Nt))
        'eigenvalue': eigenvalue,  # Eigenvalues corresponding to singular values squared (Shape: (Nmodes))
        'eigenvalue_contribution': eigenvalue_contribution,  # Eigenvalue contribution of each mode (Shape: (Nmodes))
        'cumulative_eigenvalues': cumulative_eigenvalues,  # Cumulative percentage of eigenvalues for the first 'num_cumulative_modes' modes (Shape: (Ncumulative_modes))
        'combined_welch_psd': combined_welch_psd,  # Combined Welch power spectral density for the temporal coefficients (Shape: (Nf))
        'frequencies': frequencies,  # Frequencies identified in the Welch spectrum (Shape: (Nf))
        'combined_welch_significance': combined_welch_significance,  # Significance threshold of the combined Welch spectrum (Shape: (Nf))
        'reconstructed': reconstructed,  # Reconstructed frame at the specified timestep (or for the entire time series) using the top modes (Shape: (Ny, Nx))
        'sorted_frequencies': sorted_frequencies,  # Frequencies identified in the Welch spectrum (Shape: (Nfrequencies,))
        'frequency_filtered_modes': frequency_filtered_modes,  # Frequency-filtered spatial POD modes (Shape: (Nt, Ny, Nx, Ntop_frequencies))
        'frequency_filtered_modes_frequencies': top_frequencies,  # Frequencies corresponding to the frequency-filtered modes (Shape: (Ntop_frequencies))
        'SPOD_spatial_modes': SPOD_spatial_modes,  # SPOD spatial modes if SPOD is used (Shape: (Nspod_modes, Ny, Nx))
        'SPOD_temporal_coefficients': SPOD_temporal_coefficients,  # SPOD temporal coefficients if SPOD is used (Shape: (Nspod_modes, Nt))
        'p': p,  # Left singular vectors (spatial modes) from SVD (Shape: (Nx, Nmodes))
        's': sorg,  # Singular values from SVD (Shape: (Nmodes,))
        'a': a  # Right singular vectors (temporal coefficients) from SVD (Shape: (Nmodes, Nt))
    }

    if not params['silent']:
        print(f"POD analysis completed.")
        print(f"Top {params['num_top_frequencies']} frequencies and normalized power values:\n{table_data}")
        print(f"Total variance contribution of the first {params['num_modes']} modes: {np.sum(100 * eigenvalue[:params['num_modes']] / np.sum(eigenvalue)):.2f}%")

    if params['print_results']:
        print_pod_results(results)

    return results

WaLSA_cross_spectra.py

This module implements cross-correlation analysis techniques, resulting in cross-spectrum, coherence, and phase relationships, for investigating correlations between two time series.

WaLSA_cross_spectra.py
# -----------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
# http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
# -----------------------------------------------------------------------------------------------------

import numpy as np # type: ignore
from scipy.signal import coherence, csd # type: ignore
from .WaLSA_speclizer import WaLSA_speclizer # type: ignore
from .WaLSA_wavelet import xwt, wct # type: ignore
from .WaLSA_detrend_apod import WaLSA_detrend_apod # type: ignore
from .WaLSA_wavelet_confidence import WaLSA_wavelet_confidence # type: ignore

# -------------------------------------- Main Function ----------------------------------------
def WaLSA_cross_spectra(signal=None, time=None, method=None, **kwargs):
    """
    Compute the cross-spectra between two time series.

    Parameters:
        data1 (array): The first time series (1D).
        data2 (array): The first time series (1D).
        time (array): The time array of the signals.
        method (str): The method to use for the analysis. Options are 'welch' and 'wavelet'.
    """
    if method == 'welch':
        return getcross_spectrum_Welch(signal, time=time, **kwargs)
    elif method == 'wavelet':
        return getcross_spectrum_Wavelet(signal, time=time, **kwargs)
    elif method == 'fft':
        print("Note: FFT method selected. Cross-spectra calculations will use the Welch method instead, "
              "which segments the signal into multiple parts to reduce noise sensitivity. "
              "You can control frequency resolution vs. noise reduction using the 'nperseg' parameter.")
        return getcross_spectrum_Welch(signal, time=time, **kwargs)
    else:
        raise ValueError(f"Unknown method: {method}")


# -------------------------------------- Welch ----------------------------------------
def getcross_spectrum_Welch(signal, time, **kwargs):
    """
    Calculate cross-spectral relationships of two time series whose amplitudes and powers
    are computed using the WaLSA_speclizer routine.
    The cross-spectrum is complex valued, thus its magnitude is returned as the
    co-spectrum. The phase lags between the two time series are are estimated 
    from the imaginary and real arguments of the complex cross spectrum.
    The coherence is calculated from the normalized square of the amplitude 
    of the complex cross-spectrum

    Parameters:
        data1 (array): The first 1D time series signal.
        data2 (array): The second 1D time series signal.
        time (array): The time array corresponding to the signals.
        nperseg (int, optional): Length of each segment for analysis. Default: 256.
        noverlap (int, optional): Number of points to overlap between segments. Default: 128.
        window (str, optional): Type of window function used in the Welch method. Default: 'hann'.
        siglevel (float, optional): Significance level for confidence intervals. Default: 0.95.
        nperm (int, optional): Number of permutations for significance testing. Default: 1000.
        silent (bool, optional): If True, suppress print statements. Default: False.
        **kwargs: Additional parameters for the analysis method.

    Returns:
        cospectrum, frequencies, phase_angle, coherence, signif_cross, signif_coh, d1_power, d2_power
    """
    # Define default values for the optional parameters
    defaults = {
        'data1': None,        # First data array
        'data1': None,        # Second data array
        'nperseg': 256,      # Number of points per segments to average
        'apod': 0.1,       # Apodization function
        'nodetrendapod': None,  # No detrending ot apodization applied
        'pxdetrend': 2,  # Detrend parameter
        'meandetrend': None,  # Detrend parameter
        'polyfit': None,    # Detrend parameter
        'meantemporal': None,  # Detrend parameter
        'recon': None      # Detrend parameter
    }

    # Update defaults with any user-provided keyword arguments
    params = {**defaults, **kwargs}

    data1 = params['data1']
    data2 = params['data2']

    dummy = signal+2

    tdiff = np.diff(time)
    cadence = np.median(tdiff)

    # Power spectrum for data1
    power_data1, frequencies, _ = WaLSA_speclizer(signal=data1, time=time, method='welch', 
                                               amplitude=True, nosignificance=True, silent=kwargs.pop('silent', True), **kwargs)

    # Power spectrum for data2
    power_data2, _, _ = WaLSA_speclizer(signal=data2, time=time, method='welch', 
                                               amplitude=True, nosignificance=True, silent=kwargs.pop('silent', False), **kwargs)

    # Calculate cross-spectrum
    _, crosspower = csd(data1, data2, fs=1.0/cadence, window='hann', nperseg=params['nperseg'],)

    cospectrum = np.abs(crosspower)

    # Calculate phase lag
    phase_angle = np.angle(crosspower, deg=True)

    # Calculate coherence
    freq_coh, coh = coherence(data1, data2, 1.0/cadence, nperseg=params['nperseg'])

    return frequencies, cospectrum, phase_angle, power_data1, power_data2, freq_coh, coh


# -------------------------------------- Wavelet ----------------------------------------
def getcross_spectrum_Wavelet(signal, time, **kwargs):
    """
    Calculate the cross-power spectrum, coherence spectrum, and phase angles 
    between two signals using Wavelet Transform.

    Parameters:
        data1 (array): The first 1D time series signal.
        data2 (array): The second 1D time series signal.
        time (array): The time array corresponding to the signals.
        siglevel (float): Significance level for the confidence intervals. Default: 0.95.
        nperm (int): Number of permutations for significance testing. Default: 1000.
        mother (str): The mother wavelet function to use. Default: 'morlet'.
        GWS (bool): If True, calculate the Global Wavelet Spectrum. Default: False.
        RGWS (bool): If True, calculate the Refined Global Wavelet Spectrum (time-integrated power, excluding COI and insignificant areas). Default: False.
        dj (float): Scale spacing. Smaller values result in better scale resolution but slower calculations. Default: 0.025.
        s0 (float): Initial (smallest) scale of the wavelet. Default: 2 * dt.
        J (int): Number of scales minus one. Scales range from s0 up to s0 * 2**(J * dj), giving a total of (J + 1) scales. Default: (log2(N * dt / s0)) / dj.
        lag1 (float): Lag-1 autocorrelation. Default: 0.0.
        apod (float): Extent of apodization edges (of a Tukey window). Default: 0.1.
        pxdetrend (int): Subtract linear trend with time per pixel. Options: 1 (simple) or 2 (advanced). Default: 2.
        polyfit (int): Degree of polynomial fit for detrending the data. If set, a polynomial fit (instead of linear) is applied. Default: None.
        meantemporal (bool): If True, apply simple temporal detrending by subtracting the mean signal from the data, skipping fitting procedures. Default: False.
        meandetrend (bool): If True, subtract the linear trend with time for the image means (spatial detrending). Default: False.
        recon (bool): If True, perform Fourier reconstruction of the input time series. This does not preserve amplitudes but is useful for examining frequencies far from the low-frequency range.  Default: False.
        resample_original (bool): If True, and if recon set True, approximate values close to the original are returned for comparison. Default: False.
        nodetrendapod (bool): If True, neither detrending nor apodization is performed. Default: False.
        silent (bool): If True, suppress print statements. Default: False.

    Returns:
    cross_power : np.ndarray
        Cross power spectrum between `data1` and `data2`.

    cross_periods : np.ndarray
        Periods corresponding to the cross-power spectrum.

    cross_sig : np.ndarray
        Significance of the cross-power spectrum.

    cross_coi : np.ndarray
        Cone of influence for the cross-power spectrum.

    coherence : np.ndarray
        Coherence spectrum between `data1` and `data2`.

    coh_periods : np.ndarray
        Periods corresponding to the coherence spectrum.

    coh_sig : np.ndarray
        Significance of the coherence spectrum.

    corr_coi : np.ndarray
        Cone of influence for the coherence spectrum.

    phase_angle : np.ndarray
        2D array containing x- and y-components of the phase direction 
        arrows for each frequency and time point.

    Notes
    -----
    - Cross-power, coherence, and phase angles are calculated using 
      **cross-wavelet transform (XWT)** and **wavelet coherence transform (WCT)**.
    - Arrows for phase direction are computed such that:
        - Arrows pointing downwards indicate anti-phase.
        - Arrows pointing upwards indicate in-phase.
        - Arrows pointing right indicate `data1` leading `data2`.
        - Arrows pointing left indicate `data2` leading `data1`.

    Examples
    --------
    >>> cross_power, cross_periods, cross_sig, cross_coi, coherence, coh_periods, coh_sig, corr_coi, phase_angle, dt = \
    >>> getcross_spectrum_Wavelet(signal, cadence, data1=signal1, data2=signal2)
    """

    # Define default values for optional parameters
    defaults = {
        'data1': None,               # First data array
        'data2': None,                # Second data array
        'mother': 'morlet',          # Type of mother wavelet
        'siglevel': 0.95,  # Significance level for cross-spectral analysis
        'dj': 0.025,                 # Spacing between scales
        's0': -1,  # Initial scale
        'J': -1,  # Number of scales
        'cache': False,               # Cache results to avoid recomputation
        'apod': 0.1,
        'silent': False,
        'pxdetrend': 2,
        'meandetrend': False,
        'polyfit': None,
        'meantemporal': False,
        'recon': False,
        'resample_original': False,
        'nperm': 1000               # Number of permutations for significance testing
    }

    # Update defaults with any user-provided keyword arguments
    params = {**defaults, **kwargs}

    data1 = params['data1']
    data2 = params['data2']
    dummy = signal+2

    tdiff = np.diff(time)
    cadence = np.median(tdiff)

    data1orig = data1.copy()
    data2orig = data2.copy()

    data1 = WaLSA_detrend_apod(
        data1, 
        apod=params['apod'], 
        meandetrend=params['meandetrend'], 
        pxdetrend=params['pxdetrend'], 
        polyfit=params['polyfit'], 
        meantemporal=params['meantemporal'], 
        recon=params['recon'], 
        cadence=cadence, 
        resample_original=params['resample_original'], 
        silent=True
    )

    data2 = WaLSA_detrend_apod(
        data2, 
        apod=params['apod'], 
        meandetrend=params['meandetrend'], 
        pxdetrend=params['pxdetrend'], 
        polyfit=params['polyfit'], 
        meantemporal=params['meantemporal'], 
        recon=params['recon'], 
        cadence=cadence, 
        resample_original=params['resample_original'], 
        silent=True
    )

    # Calculate signal properties
    nt = len(params['data1'])  # Number of time points

    # Determine the initial scale s0 if not provided
    if params['s0'] == -1:
        params['s0'] = 2 * cadence

    # Determine the number of scales J if not provided
    if params['J'] == -1:
        params['J'] = int((np.log(float(nt) * cadence / params['s0']) / np.log(2)) / params['dj'])

    # ----------- CROSS-WAVELET TRANSFORM (XWT) -----------
    W12, cross_coi, freq, signif = xwt(
        data1, data2, cadence, dj=params['dj'], s0=params['s0'], J=params['J'], 
        no_default_signif=True, wavelet=params['mother'], normalize=False
    )
    print('Wavelet cross-power spectrum calculated.')
    # Calculate the cross-power spectrum
    cross_power = np.abs(W12) ** 2
    cross_periods = 1 / freq  # Periods corresponding to the frequency axis

    #----------------------------------------------------------------------
    # Calculate significance levels using Monte Carlo randomization method
    #----------------------------------------------------------------------
    nxx, nyy = cross_power.shape
    cross_perm = np.zeros((nxx, nyy, params['nperm']))
    print('\nCalculating wavelet cross-power significance:')
    total_iterations = params['nperm']  # Total number of (x, y) combinations
    iteration_count = 0  # To keep track of progress
    for ip in range(params['nperm']):
        iteration_count += 1
        progress = (iteration_count / total_iterations) * 100
        print(f"\rProgress: {progress:.2f}%", end="")

        y_perm1 = np.random.permutation(data1orig)  # Permuting the original signal
        apocube1 = WaLSA_detrend_apod(
            y_perm1, 
            apod=params['apod'], 
            meandetrend=params['meandetrend'], 
            pxdetrend=params['pxdetrend'], 
            polyfit=params['polyfit'], 
            meantemporal=params['meantemporal'], 
            recon=params['recon'], 
            cadence=cadence, 
            resample_original=params['resample_original'], 
            silent=True
        )
        y_perm2 = np.random.permutation(data2orig)  # Permuting the original signal
        apocube2 = WaLSA_detrend_apod(
            y_perm2, 
            apod=params['apod'], 
            meandetrend=params['meandetrend'], 
            pxdetrend=params['pxdetrend'], 
            polyfit=params['polyfit'], 
            meantemporal=params['meantemporal'], 
            recon=params['recon'], 
            cadence=cadence, 
            resample_original=params['resample_original'], 
            silent=True
        )
        W12s, _, _, _ = xwt(
            apocube1, apocube2, cadence, dj=params['dj'], s0=params['s0'], J=params['J'], 
            no_default_signif=True, wavelet=params['mother'], normalize=False
        )
        cross_power_sig = np.abs(W12s) ** 2
        cross_perm[:, :, ip] = cross_power_sig

    signifn = WaLSA_wavelet_confidence(cross_perm, siglevel=params['siglevel'])

    cross_sig = cross_power / signifn

    # ----------- WAVELET COHERENCE TRANSFORM (WCT) -----------
    WCT, aWCT, coh_coi, freq, sig = wct(
        data1, data2, cadence, dj=params['dj'], s0=params['s0'], J=params['J'], 
        sig=False, wavelet=params['mother'], normalize=False, cache=params['cache']
    )
    print('Wavelet coherence calculated.')
    # Calculate the coherence spectrum
    coh_periods = 1 / freq  # Periods corresponding to the frequency axis
    coherence = WCT

    #----------------------------------------------------------------------
    # Calculate significance levels using Monte Carlo randomization method
    #----------------------------------------------------------------------
    nxx, nyy = coherence.shape
    coh_perm = np.zeros((nxx, nyy, params['nperm']))
    print('\nCalculating wavelet coherence significance:')
    total_iterations = params['nperm']  # Total number of permutations
    iteration_count = 0  # To keep track of progress
    for ip in range(params['nperm']):
        iteration_count += 1
        progress = (iteration_count / total_iterations) * 100
        print(f"\rProgress: {progress:.2f}%", end="")

        y_perm1 = np.random.permutation(data1orig)  # Permuting the original signal
        apocube1 = WaLSA_detrend_apod(
            y_perm1, 
            apod=params['apod'], 
            meandetrend=params['meandetrend'], 
            pxdetrend=params['pxdetrend'], 
            polyfit=params['polyfit'], 
            meantemporal=params['meantemporal'], 
            recon=params['recon'], 
            cadence=cadence, 
            resample_original=params['resample_original'], 
            silent=True
        )
        y_perm2 = np.random.permutation(data2orig)  # Permuting the original signal
        apocube2 = WaLSA_detrend_apod(
            y_perm2, 
            apod=params['apod'], 
            meandetrend=params['meandetrend'], 
            pxdetrend=params['pxdetrend'], 
            polyfit=params['polyfit'], 
            meantemporal=params['meantemporal'], 
            recon=params['recon'], 
            cadence=cadence, 
            resample_original=params['resample_original'], 
            silent=True
        )
        WCTs, _, _, _, _ = wct(
            apocube1, apocube2, cadence, dj=params['dj'], s0=params['s0'], J=params['J'], 
            sig=False, wavelet=params['mother'], normalize=False, cache=params['cache']
        )
        coh_perm[:, :, ip] = WCTs

    sig_coh = WaLSA_wavelet_confidence(coh_perm, siglevel=params['siglevel'])

    coh_sig = coherence / sig_coh  # Ratio > 1 means coherence is significant

    # --------------- PHASE ANGLES ---------------
    phase_angle = aWCT

    return (
        cross_power, cross_periods, cross_sig, cross_coi, 
        coherence, coh_periods, coh_sig, coh_coi, 
        phase_angle
    )

WaLSA_detrend_apod.py

This module provides functions for detrending and apodizing time series data to mitigate trends and edge effects.

WaLSA_detrend_apod.py
# -----------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
# http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
# -----------------------------------------------------------------------------------------------------

import numpy as np # type: ignore
from scipy.optimize import curve_fit # type: ignore
from .WaLSA_wavelet import cwt, significance # type: ignore

# -------------------------------------- Wavelet ----------------------------------------
def getWavelet(signal, time, **kwargs):
    """
    Perform wavelet analysis using the pycwt package.

    Parameters:
        signal (array): The input signal (1D).
        time (array): The time array corresponding to the signal.
        siglevel (float): Significance level for the confidence intervals. Default: 0.95.
        nperm (int): Number of permutations for significance testing. Default: 1000.
        mother (str): The mother wavelet function to use. Default: 'morlet'.
        GWS (bool): If True, calculate the Global Wavelet Spectrum. Default: False.
        RGWS (bool): If True, calculate the Refined Global Wavelet Spectrum (time-integrated power, excluding COI and insignificant areas). Default: False.
        dj (float): Scale spacing. Smaller values result in better scale resolution but slower calculations. Default: 0.025.
        s0 (float): Initial (smallest) scale of the wavelet. Default: 2 * dt.
        J (int): Number of scales minus one. Scales range from s0 up to s0 * 2**(J * dj), giving a total of (J + 1) scales. Default: (log2(N * dt / s0)) / dj.
        lag1 (float): Lag-1 autocorrelation. Default: 0.0.
        apod (float): Extent of apodization edges (of a Tukey window). Default: 0.1.
        pxdetrend (int): Subtract linear trend with time per pixel. Options: 1 (simple) or 2 (advanced). Default: 2.
        polyfit (int): Degree of polynomial fit for detrending the data. If set, a polynomial fit (instead of linear) is applied. Default: None.
        meantemporal (bool): If True, apply simple temporal detrending by subtracting the mean signal from the data, skipping fitting procedures. Default: False.
        meandetrend (bool): If True, subtract the linear trend with time for the image means (spatial detrending). Default: False.
        recon (bool): If True, perform Fourier reconstruction of the input time series. This does not preserve amplitudes but is useful for examining frequencies far from the low-frequency range.  Default: False.
        resample_original (bool): If True, and if recon set True, approximate values close to the original are returned for comparison. Default: False.
        nodetrendapod (bool): If True, neither detrending nor apodization is performed. Default: False.
        silent (bool): If True, suppress print statements. Default: False.
        **kwargs: Additional parameters for the analysis method.

    Returns:
        power: The wavelet power spectrum.
        periods: Corresponding periods.
        sig_slevel: The significance levels.
        coi: The cone of influence.
        Optionally, if global_power=True:
        global_power: Global wavelet power spectrum.
        global_conf: Confidence levels for the global wavelet spectrum.
        Optionally, if RGWS=True:
        rgws_periods: Periods for the refined global wavelet spectrum.
        rgws_power: Refined global wavelet power spectrum.
    """
    # Define default values for the optional parameters similar to IDL
    defaults = {
        'siglevel': 0.95,
        'mother': 'morlet',  # Morlet wavelet as the mother function
        'dj': 1/32. ,  # Scale spacing
        's0': -1,  # Initial scale
        'J': -1,  # Number of scales
        'lag1': 0.0,  # Lag-1 autocorrelation
        'silent': False,
        'nperm': 1000,   # Number of permutations for significance calculation
    }

    # Update defaults with any user-provided keyword arguments
    params = {**defaults, **kwargs}

    tdiff = np.diff(time)
    cadence = np.median(tdiff)

    n = len(signal)
    dt = cadence

    # Standardize the signal before the wavelet transform
    std_signal = signal.std()
    norm_signal = signal / std_signal

    # Determine the initial scale s0 if not provided
    if params['s0'] == -1:
        params['s0'] = 2 * dt

    # Determine the number of scales J if not provided
    if params['J'] == -1:
        params['J'] = int((np.log(float(n) * dt / params['s0']) / np.log(2)) / params['dj'])

    # Perform wavelet transform
    W, scales, frequencies, coi, _, _ = cwt(
        norm_signal,
        dt,
        dj=params['dj'],
        s0=params['s0'],
        J=params['J'],
        wavelet=params['mother']
    )

    power = np.abs(W) ** 2  # Wavelet power spectrum
    periods = 1 / frequencies  # Convert frequencies to periods

    return power, periods, coi, scales, W

# -----------------------------------------------------------------------------------------------------

# Linear detrending function for curve fitting
def linear(x, a, b):
    return a + b * x

# Custom Tukey window implementation
def custom_tukey(nt, apod=0.1):
    apodrim = int(apod * nt)
    apodt = np.ones(nt)  # Initialize with ones

    # Apply sine-squared taper to the first 'apodrim' points
    taper = (np.sin(np.pi / 2. * np.arange(apodrim) / apodrim)) ** 2

    # Apply taper symmetrically at both ends
    apodt[:apodrim] = taper
    apodt[-apodrim:] = taper[::-1]  # Reverse taper for the last points

    return apodt

# Main detrending and apodization function
# apod=0: The Tukey window becomes a rectangular window (no tapering).
# apod=1: The Tukey window becomes a Hann window (fully tapered with a cosine function).
def WaLSA_detrend_apod(cube, apod=0.1, meandetrend=False, pxdetrend=2, polyfit=None, meantemporal=False,
                       recon=False, cadence=None, resample_original=False, min_resample=None, 
                       max_resample=None, silent=False, dj=32, lo_cutoff=None, hi_cutoff=None, upper=False):

    nt = len(cube)  # Assume input is a 1D signal
    cube = cube - np.mean(cube) # Remove the mean of the input signal
    apocube = np.copy(cube)  # Create a copy of the input signal
    t = np.arange(nt)  # Time array

    # Apply Tukey window (apodization)
    if apod > 0:
        tukey_window = custom_tukey(nt, apod)
        apocube = apocube * tukey_window  # Apodize the signal

    # Mean detrend (optional)
    if meandetrend:
        avg_signal = np.mean(apocube)
        time = np.arange(nt)
        mean_fit_params, _ = curve_fit(linear, time, avg_signal)
        mean_trend = linear(time, *mean_fit_params)
        apocube -= mean_trend

    # Wavelet-based Fourier reconstruction (optional)
    if recon and cadence:
        apocube = WaLSA_wave_recon(apocube, cadence, dj=dj, lo_cutoff=lo_cutoff, hi_cutoff=hi_cutoff, upper=upper)

    # Pixel-based detrending (temporal detrend)
    if pxdetrend > 0:
        mean_val = np.mean(apocube)
        if meantemporal:
            # Simple temporal detrend by subtracting the mean
            apocube -= mean_val
        else:
            # Advanced detrend (linear or polynomial fit)
            if polyfit is not None:
                poly_coeffs = np.polyfit(t, apocube, polyfit)
                trend = np.polyval(poly_coeffs, t)
            else:
                popt, _ = curve_fit(linear, t, apocube, p0=[mean_val, 0])
                trend = linear(t, *popt)
            apocube -= trend

    # Resampling to preserve amplitudes (optional)
    if resample_original:
        if min_resample is None:
            min_resample = np.min(apocube)
        if max_resample is None:
            max_resample = np.max(apocube)
        apocube = np.interp(apocube, (np.min(apocube), np.max(apocube)), (min_resample, max_resample))

    if not silent:
        print("Detrending and apodization complete.")

    return apocube

# Wavelet-based reconstruction function (optional)
def WaLSA_wave_recon(ts, delt, dj=32, lo_cutoff=None, hi_cutoff=None, upper=False):
    """
    Reconstructs the wavelet-filtered time series based on given frequency cutoffs.
    """

    # Define duration based on the time series length
    dur = (len(ts) - 1) * delt

    # Assign default values if lo_cutoff or hi_cutoff is None
    if lo_cutoff is None:
        lo_cutoff = 0.0  # Default to 0 as in IDL

    if hi_cutoff is None:
        hi_cutoff = dur / (3.0 * np.sqrt(2)) 

    mother = 'morlet'
    num_points = len(ts)
    time_array = np.linspace(0, (num_points - 1) * delt, num_points)

    _, period, coi, scales, wave= getWavelet(signal=ts, time=time_array, method='wavelet', siglevel=0.99, apod=0.1, mother=mother)

    # Ensure good_per_idx and bad_per_idx are properly assigned
    if upper:
        good_per_idx = np.where(period > hi_cutoff)[0][0]
        bad_per_idx = len(period)
    else:
        good_per_idx = np.where(period > lo_cutoff)[0][0]
        bad_per_idx = np.where(period > hi_cutoff)[0][0]

    # set the power inside the CoI equal to zero 
    # (i.e., exclude points inside the CoI -- subject to edge effect)
    iampl = np.zeros((len(ts), len(period)), dtype=float)
    for i in range(len(ts)):
        pcol = np.real(wave[:, i])  # Extract real part of wavelet transform for this time index
        ii = np.where(period < coi[i])[0]  # Find indices where period is less than COI at time i
        if ii.size > 0:
            iampl[i, ii] = pcol[ii]  # Assign values where condition is met

    # Initialize reconstructed signal array
    recon_sum = np.zeros(len(ts))
    # Summation over valid period indices
    for i in range(good_per_idx, bad_per_idx):
        recon_sum += iampl[:, i] / np.sqrt(scales[i])

    # Apply normalization factor
    recon_all = dj * np.sqrt(delt) * recon_sum / (0.766 * (np.pi ** -0.25))

    return recon_all

WaLSA_confidence.py

This module implements statistical significance testing for the spectral analysis results using various methods.

WaLSA_confidence.py
# -----------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
# http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
# -----------------------------------------------------------------------------------------------------

import numpy as np # type: ignore

def WaLSA_confidence(ps_perm, siglevel=0.05, nf=None):
    """
    Find the confidence levels (significance levels) for the given power spectrum permutations.

    Parameters:
    -----------
    ps_perm : np.ndarray
        2D array where each row represents a power spectrum, and each column represents a permutation.
    siglevel : float
        Significance level (default 0.05 for 95% confidence).
    nf : int
        Number of frequencies (length of the power spectra). If None, inferred from the shape of ps_perm.

    Returns:
    --------
    signif : np.ndarray
        Significance levels for each frequency.
    """
    if nf is None:
        nf = ps_perm.shape[0]  # Number of frequencies inferred from the shape of ps_perm

    signif = np.zeros((nf))

    # Loop through each frequency
    for iv in range(nf):
        # Extract the permutation values for this frequency
        tmp = np.sort(ps_perm[iv, :])  # Sort the power spectrum permutation values

        # Calculate the number of permutations
        ntmp = len(tmp)

        # Find the significance threshold
        nsig = int(round(siglevel * ntmp))  # Number of permutations to cut off for the significance level

        # Set the confidence level for this frequency
        signif[iv] = tmp[-nsig]  # Select the (ntmp - nsig)-th element (equivalent to IDL's ROUND and indexing)

    return signif

WaLSA_wavelet_confidence.py

This module implements statistical significance testing for the wavelet analysis results.

WaLSA_wavelet_confidence.py
# -----------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
# http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
# -----------------------------------------------------------------------------------------------------

import numpy as np # type: ignore

def WaLSA_wavelet_confidence(ps_perm, siglevel=None):
    """
    Find the confidence levels (significance levels) for the given wavelet power spectrum permutations.

    Parameters:
    -----------
    ps_perm : np.ndarray
        2D array where each row represents a power spectrum, and each column represents a permutation.
    siglevel : float
        e.g., 0.95 for 95% confidence level (significance at 5%).

    Returns:
    --------
    signif : np.ndarray
        Significance levels for each frequency.
    """
    nnff , nntt, nnperm = ps_perm.shape

    signif = np.zeros((nnff, nntt))

    for iv in range(nnff):
        for it in range(nntt):
            signif[iv, it] = np.percentile(ps_perm[iv, it, :], 100 * siglevel) # e.g., 95% percentile (for 95% consideence level)

    return signif

WaLSA_io.py

This module provides functions for input/output operations, such as saving images as PDF (in both RGB and CMYK formats) and image contrast enhancements.

WaLSA_io.py
# -----------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
# http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
# -----------------------------------------------------------------------------------------------------

import subprocess
import shutil
import os
import stat
import numpy as np # type: ignore
from matplotlib.backends.backend_pdf import PdfPages # type: ignore

def WaLSA_save_pdf(fig, pdf_path, color_mode='RGB', dpi=300, bbox_inches=None, pad_inches=0):
    """
    Save a PDF from a Matplotlib figure with an option to convert to CMYK if Ghostscript is available.

    Parameters:
    - fig: Matplotlib figure object to save.
    - pdf_path: Path to save the PDF file.
    - color_mode: 'RGB' (default) to save in RGB, or 'CMYK' to convert to CMYK.
    - dpi: Resolution in dots per inch. Default is 300.
    - bbox_inches: Set to 'tight' to remove extra white space. Default is 'tight'.
    - pad_inches: Padding around the figure. Default is 0.
    """

    # Save the initial PDF in RGB format using fig.savefig to apply bbox_inches, pad_inches, and dpi
    if bbox_inches is None:
        # Save the initial PDF in RGB format
        with PdfPages(pdf_path) as pdf:
            pdf.savefig(fig, transparent=True)

        if color_mode.upper() == 'CMYK':
            # Check if Ghostscript is installed
            if shutil.which("gs") is not None:
                # Set permissions on the initial RGB PDF to ensure access for Ghostscript
                os.chmod(pdf_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IROTH)

                # Define a temporary path for the CMYK file
                temp_cmyk_path = pdf_path.replace('.pdf', '_temp_cmyk.pdf')

                # Run Ghostscript to create the CMYK PDF as a temporary file
                subprocess.run([
                    "gs", "-o", temp_cmyk_path,
                    "-sDEVICE=pdfwrite",
                    "-dProcessColorModel=/DeviceCMYK",
                    "-dColorConversionStrategy=/CMYK",
                    "-dNOPAUSE", "-dBATCH", "-dSAFER",
                    "-dPDFSETTINGS=/prepress", pdf_path
                ], check=True)

                # Replace the original RGB PDF with the CMYK version
                os.remove(pdf_path)  # Delete the RGB version
                os.rename(temp_cmyk_path, pdf_path)  # Rename CMYK as the original file
                print(f"PDF saved in CMYK format as '{pdf_path}'")

            else:
                print("Warning: Ghostscript is not installed, so the PDF remains in RGB format.")
                print("To enable CMYK saving, please install Ghostscript:")
                print("- On macOS, use: brew install ghostscript")
                print("- On Ubuntu, use: sudo apt install ghostscript")
                print("- On Windows, download from: https://www.ghostscript.com/download.html")
        else:
            print(f"PDF saved in RGB format as '{pdf_path}'")

    else:
        # Temporary path for the RGB version of the PDF
        temp_pdf_path = pdf_path.replace('.pdf', '_temp.pdf')

        fig.savefig(
            temp_pdf_path, 
            format='pdf', 
            dpi=dpi, 
            bbox_inches=bbox_inches, 
            pad_inches=pad_inches, 
            transparent=True
        )

        if color_mode.upper() == 'CMYK':
            # Check if Ghostscript is installed
            if shutil.which("gs") is not None:
                # Set permissions on the initial RGB PDF to ensure access for Ghostscript
                os.chmod(temp_pdf_path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IROTH)

                # Run Ghostscript to create the CMYK PDF as a temporary file
                subprocess.run([
                    "gs", "-o", pdf_path,
                    "-sDEVICE=pdfwrite",
                    "-dProcessColorModel=/DeviceCMYK",
                    "-dColorConversionStrategy=/CMYK",
                    "-dNOPAUSE", "-dBATCH", "-dSAFER",
                    "-dPDFSETTINGS=/prepress", temp_pdf_path
                ], check=True)

                # Remove the temporary PDF file
                os.remove(temp_pdf_path)
                print(f"PDF saved in CMYK format as '{pdf_path}'")

            else:
                print("Warning: Ghostscript is not installed, so the PDF remains in RGB format.")
                print("To enable CMYK saving, please install Ghostscript:")
                print("- On macOS, use: brew install ghostscript")
                print("- On Ubuntu, use: sudo apt install ghostscript")
                print("- On Windows, download from: https://www.ghostscript.com/download.html")
        else:
            # Rename the temporary file as the final file if no CMYK conversion is needed
            os.rename(temp_pdf_path, pdf_path)
            print(f"PDF saved in RGB format as '{pdf_path}'")


def WaLSA_histo_opt(image, cutoff=1e-3, top_only=False, bot_only=False):
    """
    Clip image values based on the cutoff percentile to enhance contrast.
    Inspired by IDL's "iris_histo_opt" function (Copyright: P.Suetterlin, V.Hansteen, and M. Carlsson)

    Parameters:
    - image: 2D array (image data).
    - cutoff: Fraction of values to clip at the top and bottom (default is 0.001).
    - top_only: If True, clip only the highest values.
    - bot_only: If True, clip only the lowest values.

    Returns:
    - Clipped image for better contrast.
    """
    # Ignore NaNs in the image
    finite_values = image[np.isfinite(image)]

    # Calculate lower and upper bounds based on cutoff percentiles
    lower_bound = np.percentile(finite_values, cutoff * 100)
    upper_bound = np.percentile(finite_values, (1 - cutoff) * 100)

    # Clip image according to bounds and options
    if top_only:
        return np.clip(image, None, upper_bound)
    elif bot_only:
        return np.clip(image, lower_bound, None)
    else:
        return np.clip(image, lower_bound, upper_bound)

WaLSA_interactive.py

This module implements the interactive interface of WaLSAtools, guiding users through the analysis process.

WaLSA_interactive.py
# -----------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
# http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
# -----------------------------------------------------------------------------------------------------

import ipywidgets as widgets # type: ignore
from IPython.display import display, clear_output, HTML # type: ignore
import os # type: ignore

# Function to detect if it's running in a notebook or terminal
def is_notebook():
    try:
        # Check if IPython is in the environment and if we're using a Jupyter notebook
        from IPython import get_ipython # type: ignore
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True  # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # IPython terminal
        else:
            return False  # Other types
    except (NameError, ImportError):
        # NameError: get_ipython is not defined
        # ImportError: IPython is not installed
        return False  # Standard Python interpreter or shell

# Function to print logo and credits for both environments
def print_logo_and_credits():
    logo_terminal = r"""
        __          __          _          _____            
        \ \        / /         | |        / ____|     /\    
         \ \  /\  / /  ▄▄▄▄▄   | |       | (___      /  \   
          \ \/  \/ /   ▀▀▀▀██  | |        \___ \    / /\ \  
           \  /\  /   ▄██▀▀██  | |____    ____) |  / ____ \ 
            \/  \/    ▀██▄▄██  |______|  |_____/  /_/    \_\
    """
    credits_terminal = """
        © WaLSA Team (www.WaLSA.team)
        -----------------------------------------------------------------------
        WaLSAtools v1.0.0 - Wave analysis tools
        Documentation: www.WaLSA.tools
        GitHub repository: www.github.com/WaLSAteam/WaLSAtools
        -----------------------------------------------------------------------
        If you use WaLSAtools in your research, please cite:   
        Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21 

        Free access to a view-only version: https://WaLSA.tools/nrmp     
        Supplementary Information: https://WaLSA.tools/nrmp-si     
        -----------------------------------------------------------------------
        Choose a category, data type, and analysis method from the list below,
        to get hints on the calling sequence and parameters:
        """

    credits_notebook = """
        <div style="margin-left: 30px; margin-top: 20px; font-size: 1.1em; line-height: 0.8;">
            <p>© WaLSA Team (<a href="https://www.WaLSA.team" target="_blank">www.WaLSA.team</a>)</p>
            <hr style="width: 70%; margin: 0; border: 0.98px solid #888; margin-bottom: 10px;">
            <p><strong>WaLSAtools</strong> v1.0.0 - Wave analysis tools</p>
            <p>Documentation: <a href="https://www.WaLSA.tools" target="_blank">www.WaLSA.tools</a></p>
            <p>GitHub repository: <a href="https://www.github.com/WaLSAteam/WaLSAtools" target="_blank">www.github.com/WaLSAteam/WaLSAtools</a></p>
            <hr style="width: 70%; margin: 0; border: 0.98px solid #888; margin-bottom: 10px;">
            <p>If you use <strong>WaLSAtools</strong> in your research, please cite:</p>
            <p>Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, <em>Nature Reviews Methods Primers</em>, 5, 21</p>
            <p>Free access to a view-only version: <a href="https://WaLSA.tools/nrmp" target="_blank">www.WaLSA.tools</a></p>
            <p>Supplementary Information: <a href="https://WaLSA.tools/nrmp-si" target="_blank">www.WaLSA.tools</a></p>
            <hr style="width: 70%; margin: 0; border: 0.98px solid #888; margin-bottom: 15px;">
            <p>Choose a category, data type, and analysis method from the list below,</p>
            <p>to get hints on the calling sequence and parameters:</p>
        </div>
    """

    if is_notebook():
        try:
            # For scripts
            current_dir = os.path.dirname(os.path.abspath(__file__))
        except NameError:
            # For Jupyter notebooks
            current_dir = os.getcwd()
        img_path = os.path.join(current_dir, '..', 'assets', 'WaLSAtools_black.png')
        # display(HTML(f'<img src="{img_path}" style="margin-left: 40px; margin-top: 20px; width:300px; height: auto;">')) # not shwon in Jupyter notebook, only in MS Code
        import base64
        # Convert the image to Base64
        with open(img_path, "rb") as img_file:
            encoded_img = base64.b64encode(img_file.read()).decode('utf-8')
        # Embed the Base64 image in the HTML
        html_code_logo = f"""
        <div style="margin-left: 30px; margin-top: 20px;">
            <img src="data:image/png;base64,{encoded_img}" style="width: 300px; height: auto;">
        </div>
        """
        display(HTML(html_code_logo))
        display(HTML(credits_notebook))
    else:
        print(logo_terminal)
        print(credits_terminal)

from .parameter_definitions import display_parameters_text, single_series_parameters, cross_correlation_parameters

# Terminal-based interactive function
import textwrap
import shutil
def walsatools_terminal():
    """Main interactive function for terminal version of WaLSAtools."""
    print_logo_and_credits()

    # Get terminal width
    terminal_width = shutil.get_terminal_size().columns
    # Calculate 90% of the width (and ensure it's an integer)
    line_width = int(terminal_width * 0.90)

    # Step 1: Select Category
    while True:
        print("\n    Category:")
        print("    (a) Single time series analysis")
        print("    (b) Cross-correlation between two time series")
        category = input("    --- Select a category (a/b): ").strip().lower()

        if category not in ['a', 'b']:
            print("    Invalid selection. Please enter either 'a' or 'b'.\n")
            continue

        # Step 2: Data Type
        if category == 'a':
            while True:
                print("\n    Data Type:")
                print("    (1) 1D signal")
                print("    (2) 3D datacube")
                method = input("    --- Select a data type (1/2): ").strip()

                if method not in ['1', '2']:
                    print("    Invalid selection. Please enter either '1' or '2'.\n")
                    continue

                # Step 3: Analysis Method
                if method == '1':  # 1D Signal
                    while True:
                        print("\n    Analysis Method:")
                        print("    (1) FFT")
                        print("    (2) Wavelet")
                        print("    (3) Lomb-Scargle")
                        print("    (4) Welch")
                        print("    (5) EMD")
                        analysis_type = input("    --- Select an analysis method (1-5): ").strip()

                        method_map_1d = {
                            '1': 'fft',
                            '2': 'wavelet',
                            '3': 'lombscargle',
                            '4': 'welch',
                            '5': 'emd'
                        }

                        selected_method = method_map_1d.get(analysis_type, 'unknown')

                        if selected_method == 'unknown':
                            print("    Invalid selection. Please select a valid analysis method (1-5).\n")
                            continue

                        # Generate and display the calling sequence for 1D signal
                        print("\n    Calling sequence:\n")
                        return_values = single_series_parameters[selected_method]['return_values']
                        command = f">>> {return_values} = WaLSAtools(signal=INPUT_DATA, time=TIME_ARRAY, method='{selected_method}', **kwargs)"
                        wrapper = textwrap.TextWrapper(width=line_width - 4, initial_indent='    ', subsequent_indent='        ')
                        wrapped_command = wrapper.fill(command)
                        print(wrapped_command)

                        # Display parameter hints
                        display_parameters_text(selected_method, category='a')
                        return  # Exit the function after successful output

                elif method == '2':  # 3D Datacube
                    while True:
                        print("\n    Analysis Method:")
                        print("    (1) k-omega")
                        print("    (2) POD")
                        print("    (3) Dominant Freq / Mean Power Spectrum")
                        analysis_type = input("    --- Select an analysis method (1-3): ").strip()

                        if analysis_type == '3':  # Sub-method required
                            while True:
                                print("\n    Analysis method for Dominant Freq / Mean Power Spectrum:")
                                print("    (1) FFT")
                                print("    (2) Wavelet")
                                print("    (3) Lomb-Scargle")
                                print("    (4) Welch")
                                sub_method = input("    --- Select a method (1-4): ").strip()

                                sub_method_map = {'1': 'fft', '2': 'wavelet', '3': 'lombscargle', '4': 'welch'}
                                selected_method = sub_method_map.get(sub_method, 'unknown')

                                if selected_method == 'unknown':
                                    print("    Invalid selection. Please select a valid sub-method (1-4).\n")
                                    continue

                                # Generate and display the calling sequence for sub-method
                                print("\n    Calling sequence:\n")
                                return_values = 'dominant_frequency, mean_power, frequency, power_map'
                                command = f">>> {return_values} = WaLSAtools(data=INPUT_DATA, time=TIME_ARRAY, averagedpower=True, dominantfreq=True, method='{selected_method}', **kwargs)"
                                wrapper = textwrap.TextWrapper(width=line_width - 4, initial_indent='    ', subsequent_indent='        ')
                                wrapped_command = wrapper.fill(command)
                                print(wrapped_command)

                                # Display parameter hints
                                display_parameters_text(selected_method, category='a')
                                return  # Exit the function after successful output

                        method_map_3d = {
                            '1': 'k-omega',
                            '2': 'pod'
                        }

                        selected_method = method_map_3d.get(analysis_type, 'unknown')

                        if selected_method == 'unknown':
                            print("    Invalid selection. Please select a valid analysis method (1-3).\n")
                            continue

                        # Generate and display the calling sequence for k-omega/POD
                        print("\n    Calling sequence:\n")
                        return_values = single_series_parameters[selected_method]['return_values']
                        command = f">>> {return_values} = WaLSAtools(data1=INPUT_DATA, time=TIME_ARRAY, method='{selected_method}', **kwargs)"
                        wrapper = textwrap.TextWrapper(width=line_width - 4, initial_indent='    ', subsequent_indent='        ')
                        wrapped_command = wrapper.fill(command)
                        print(wrapped_command)

                        # Display parameter hints
                        display_parameters_text(selected_method, category='a')
                        return  # Exit the function after successful output

        elif category == 'b':  # Cross-correlation
            while True:
                print("\n    Data Type:")
                print("    (1) 1D signal")
                method = input("    --- Select a data type (1): ").strip()

                if method != '1':
                    print("    Invalid selection. Please enter '1'.\n")
                    continue

                while True:
                    print("\n    Analysis Method:")
                    print("    (1) FFT")
                    print("    (2) Wavelet")
                    print("    (3) Welch")
                    analysis_type = input("    --- Select an analysis method (1-2): ").strip()

                    cross_correlation_map = {
                        '1': 'fft',
                        '2': 'wavelet',
                        '3': 'welch'
                    }

                    selected_method = cross_correlation_map.get(analysis_type, 'unknown')

                    if selected_method == 'unknown':
                        print("    Invalid selection. Please select a valid analysis method (1-2).\n")
                        continue

                    # Generate and display the calling sequence for cross-correlation
                    print("\n    Calling sequence:\n")
                    return_values = cross_correlation_parameters[selected_method]['return_values']
                    command = f">>> {return_values} = WaLSAtools(data1=INPUT_DATA1, data2=INPUT_DATA2, time=TIME_ARRAY, method='{selected_method}', **kwargs)"
                    wrapper = textwrap.TextWrapper(width=line_width - 4, initial_indent='    ', subsequent_indent='        ')
                    wrapped_command = wrapper.fill(command)
                    print(wrapped_command)

                    # Display parameter hints
                    display_parameters_text(selected_method, category='b')
                    return  # Exit the function after successful output


# Jupyter-based interactive function
import ipywidgets as widgets # type: ignore
from IPython.display import display, clear_output, HTML # type: ignore
from .parameter_definitions import display_parameters_html, single_series_parameters, cross_correlation_parameters # type: ignore

# Global flag to prevent multiple observers
is_observer_attached = False

def walsatools_jupyter():
    """Main interactive function for Jupyter Notebook version of WaLSAtools."""
    global is_observer_attached, category, method, analysis_type, sub_method  # Declare global variables for reuse

    # Detach any existing observers and reset the flag
    try:
        detach_observers()
    except NameError:
        pass  # `detach_observers` hasn't been defined yet

    is_observer_attached = False  # Reset observer flag

    # Clear any previous output
    clear_output(wait=True)

    print_logo_and_credits()

    # Recreate widgets to reset state
    category = widgets.Dropdown(
        options=['Select Category', 'Single time series analysis', 'Cross-correlation between two time series'],
        value='Select Category',
        description='Category:'
    )
    method = widgets.Dropdown(
        options=['Select Data Type'],
        value='Select Data Type',
        description='Data Type:'
    )
    analysis_type = widgets.Dropdown(
        options=['Select Method'],
        value='Select Method',
        description='Method:'
    )
    sub_method = widgets.Dropdown(
        options=['Select Sub-method', 'FFT', 'Wavelet', 'Lomb-Scargle', 'Welch'],
        value='Select Sub-method',
        description='Sub-method:',
        layout=widgets.Layout(display='none')  # Initially hidden
    )

    # Persistent output widget
    output = widgets.Output()

    def clear_output_if_unselected(change=None):
        """Clear the output widget if any dropdown menu is unselected."""
        with output:
            if (
                category.value == 'Select Category'
                or method.value == 'Select Data Type'
                or analysis_type.value == 'Select Method'
                or (
                    analysis_type.value == 'Dominant Freq / Mean Power Spectrum'
                    and sub_method.layout.display == 'block'
                    and sub_method.value == 'Select Sub-method'
                )
            ):
                clear_output(wait=True)
                warn = '<div style="font-size: 1.1em; margin-left: 30px; margin-top:15px; margin-bottom: 15px;">Please select appropriate options from all dropdown menus.</div>'
                display(HTML(warn))

    def update_method_options(change=None):
        """Update available Method and Sub-method options."""
        clear_output_if_unselected()  # Ensure the output clears if any dropdown is unselected.
        sub_method.layout.display = 'none'
        sub_method.value = 'Select Sub-method'
        sub_method.options = ['Select Sub-method']

        if category.value == 'Single time series analysis':
            method.options = ['Select Data Type', '1D signal', '3D datacube']
            if method.value == '1D signal':
                analysis_type.options = ['Select Method', 'FFT', 'Wavelet', 'Lomb-Scargle', 'Welch', 'EMD']
            elif method.value == '3D datacube':
                analysis_type.options = ['Select Method', 'k-omega', 'POD', 'Dominant Freq / Mean Power Spectrum']
            else:
                analysis_type.options = ['Select Method']
        elif category.value == 'Cross-correlation between two time series':
            method.options = ['Select Data Type', '1D signal']
            if method.value == '1D signal':
                analysis_type.options = ['Select Method', 'FFT', 'Wavelet', 'Welch']
            else:
                analysis_type.options = ['Select Method']
        else:
            method.options = ['Select Data Type']
            analysis_type.options = ['Select Method']

    def update_sub_method_visibility(change=None):
        """Show or hide the Sub-method dropdown based on conditions."""
        clear_output_if_unselected()  # Ensure the output clears if any dropdown is unselected.
        if (
            category.value == 'Single time series analysis'
            and method.value == '3D datacube'
            and analysis_type.value == 'Dominant Freq / Mean Power Spectrum'
        ):
            sub_method.options = ['Select Sub-method', 'FFT', 'Wavelet', 'Lomb-Scargle', 'Welch']
            sub_method.layout.display = 'block'
        else:
            sub_method.options = ['Select Sub-method']
            sub_method.value = 'Select Sub-method'
            sub_method.layout.display = 'none'

    def update_calling_sequence(change=None):
        """Update the function calling sequence based on user's selection."""
        clear_output_if_unselected()  # Ensure the output clears if any dropdown is unselected.
        if (
            category.value == 'Select Category'
            or method.value == 'Select Data Type'
            or analysis_type.value == 'Select Method'
            or (
                analysis_type.value == 'Dominant Freq / Mean Power Spectrum'
                and sub_method.layout.display == 'block'
                and sub_method.value == 'Select Sub-method'
            )
        ):
            return  # Do nothing until all required fields are properly selected

        with output:
            clear_output(wait=True)

            # Handle Dominant Frequency / Mean Power Spectrum with Sub-method
            if (
                category.value == 'Single time series analysis'
                and method.value == '3D datacube'
                and analysis_type.value == 'Dominant Freq / Mean Power Spectrum'
            ):
                if sub_method.value == 'Select Sub-method':
                    print("Please select a Sub-method.")
                    return

                sub_method_map = {'FFT': 'fft', 'Wavelet': 'wavelet', 'Lomb-Scargle': 'lombscargle', 'Welch': 'welch'}
                selected_method = sub_method_map.get(sub_method.value, 'unknown')
                return_values = 'dominant_frequency, mean_power, frequency, power_map'
                command = f"{return_values} = WaLSAtools(signal=INPUT_DATA, time=TIME_ARRAY, averagedpower=True, dominantfreq=True, method='{selected_method}', **kwargs)"

            # Handle k-omega and POD
            elif (
                category.value == 'Single time series analysis'
                and method.value == '3D datacube'
                and analysis_type.value in ['k-omega', 'POD']
            ):
                method_map = {'k-omega': 'k-omega', 'POD': 'pod'}
                selected_method = method_map.get(analysis_type.value, 'unknown')
                parameter_definitions = single_series_parameters
                return_values = parameter_definitions.get(selected_method, {}).get('return_values', '')
                command = f"{return_values} = WaLSAtools(signal=INPUT_DATA, time=TIME_ARRAY, method='{selected_method}', **kwargs)"

            # Handle Cross-correlation
            elif category.value == 'Cross-correlation between two time series':
                cross_correlation_map = {'FFT': 'fft', 'Wavelet': 'wavelet', 'Welch': 'welch'}
                selected_method = cross_correlation_map.get(analysis_type.value, 'unknown')
                parameter_definitions = cross_correlation_parameters
                return_values = parameter_definitions.get(selected_method, {}).get('return_values', '')
                command = f"{return_values} = WaLSAtools(data1=INPUT_DATA1, data2=INPUT_DATA2, time=TIME_ARRAY, method='{selected_method}', **kwargs)"

            # Handle Single 1D signal analysis
            elif category.value == 'Single time series analysis' and method.value == '1D signal':
                method_map = {'FFT': 'fft', 'Wavelet': 'wavelet', 'Lomb-Scargle': 'lombscargle', 'Welch': 'welch', 'EMD': 'emd'}
                selected_method = method_map.get(analysis_type.value, 'unknown')
                parameter_definitions = single_series_parameters
                return_values = parameter_definitions.get(selected_method, {}).get('return_values', '')
                command = f"{return_values} = WaLSAtools(signal=INPUT_DATA, time=TIME_ARRAY, method='{selected_method}', **kwargs)"

            else:
                print("Invalid configuration.")
                return

            # Generate and display the command in HTML
            html_code = f"""
            <div style="font-size: 1.2em; margin-left: 30px; margin-top:15px; margin-bottom: 15px;">Calling sequence:</div>
            <div style="display: flex; margin-left: 30px; margin-bottom: 3ch;">
                <span style="color: #222; min-width: 4ch;">>>> </span>
                <pre style="
                    white-space: pre-wrap; 
                    word-wrap: break-word;  
                    color: #01016D; 
                    margin: 0;
                ">{command}</pre>
            </div>
            """
            display(HTML(html_code))
            display_parameters_html(selected_method, category=category.value)

    def attach_observers():
        """Attach observers to the dropdown widgets and ensure no duplicates."""
        global is_observer_attached
        if not is_observer_attached:
            detach_observers()
            category.observe(clear_output_if_unselected, names='value')
            method.observe(clear_output_if_unselected, names='value')
            analysis_type.observe(clear_output_if_unselected, names='value')
            category.observe(update_method_options, names='value')
            method.observe(update_method_options, names='value')
            analysis_type.observe(update_sub_method_visibility, names='value')
            sub_method.observe(update_calling_sequence, names='value')
            analysis_type.observe(update_calling_sequence, names='value')
            is_observer_attached = True

    def detach_observers():
        """Detach all observers to prevent multiple triggers."""
        try:
            category.unobserve(update_method_options, names='value')
            method.unobserve(update_method_options, names='value')
            analysis_type.unobserve(update_sub_method_visibility, names='value')
            sub_method.unobserve(update_calling_sequence, names='value')
            analysis_type.unobserve(update_calling_sequence, names='value')
        except ValueError:
            pass

    attach_observers()
    display(category, method, analysis_type, sub_method, output)


# Unified interactive function for both terminal and Jupyter environments
def interactive():
    if is_notebook():
        walsatools_jupyter()
    else:
        walsatools_terminal()

WaLSA_plot_k_omega.py

This module provides functions for plotting k-ω diagrams and filtered datacubes.

WaLSA_plot_k_omega.py
# -----------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
# http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
# -----------------------------------------------------------------------------------------------------
# The following codes are baed on those originally written by David B. Jess and Samuel D. T. Grant
# -----------------------------------------------------------------------------------------------------

import numpy as np # type: ignore
from matplotlib.colors import ListedColormap # type: ignore
from WaLSAtools import WaLSA_histo_opt # type: ignore
from scipy.interpolate import griddata # type: ignore
import matplotlib.pyplot as plt # type: ignore
from matplotlib.ticker import FixedLocator, FixedFormatter # type: ignore
import matplotlib.patches as patches # type: ignore
import matplotlib.patheffects as path_effects # type: ignore
from matplotlib.colors import Normalize # type: ignore

def WaLSA_plot_k_omega(
    kopower, kopower_xscale, kopower_yscale, 
    xtitle='Wavenumber (pixel⁻¹)', ytitle='Frequency (mHz)',
    xlog=False, ylog=False, xrange=None, yrange=None,
    xtick_interval=0.05, ytick_interval=200,
    xtick_minor_interval=0.01, ytick_minor_interval=50,
    colorbar_label='Log₁₀(Power)', cmap='viridis', cbartab=0.12,
    figsize=(9, 4), smooth=True, dypix = 1200, dxpix = 1600, ax=None,
    k1=None, k2=None, f1=None, f2=None, colorbar_location='right'
):
    """
    Plots the k-omega diagram with corrected orientation and axis alignment.
    Automatically clips the power array and scales based on x and y ranges.

    Example usage:
    WaLSA_plot_k_omega(
        kopower=power,
        kopower_xscale=wavenumber,
        kopower_yscale=frequencies*1000.,
        xlog=False, ylog=False, 
        xrange=(0, 0.3), figsize=(8, 4), cbartab=0.18,
        xtitle='Wavenumber (pixel⁻¹)', ytitle='Frequency (mHz)',
        colorbar_label='Log₁₀(Oscillation Power)',
        f1=0.470*1000, f2=0.530*1000,
        k1=0.047, k2=0.25
    )
    """

    if xrange is None or len(xrange) != 2:
        xrange = [np.min(kopower_xscale), np.max(kopower_xscale)]

    if yrange is None or len(yrange) != 2:
        yrange = [np.min(kopower_yscale), np.max(kopower_yscale)]

    # Handle xrange and yrange adjustments
    if xrange is not None and len(xrange) == 2 and xrange[0] == 0:
        xrange = [np.min(kopower_xscale), xrange[1]]

    if yrange is not None and len(yrange) == 2 and yrange[0] == 0:
        yrange = [np.min(kopower_yscale), yrange[1]]

    # Clip kopower, kopower_xscale, and kopower_yscale based on xrange and yrange
    if xrange:
        x_min, x_max = xrange
        x_indices = np.where((kopower_xscale >= x_min) & (kopower_xscale <= x_max))[0]

        # Check if the lower bound is not included
        if kopower_xscale[x_indices[0]] > x_min and x_indices[0] > 0:
            x_indices = np.insert(x_indices, 0, x_indices[0] - 1)

        # Check if the upper bound is not included
        if kopower_xscale[x_indices[-1]] < x_max and x_indices[-1] + 1 < len(kopower_xscale):
            x_indices = np.append(x_indices, x_indices[-1] + 1)

        kopower = kopower[:, x_indices]
        kopower_xscale = kopower_xscale[x_indices]

    if yrange:
        y_min, y_max = yrange
        y_indices = np.where((kopower_yscale >= y_min) & (kopower_yscale <= y_max))[0]

        # Check if the lower bound is not included
        if kopower_yscale[y_indices[0]] > y_min and y_indices[0] > 0:
            y_indices = np.insert(y_indices, 0, y_indices[0] - 1)

        # Check if the upper bound is not included
        if kopower_yscale[y_indices[-1]] < y_max and y_indices[-1] + 1 < len(kopower_yscale):
            y_indices = np.append(y_indices, y_indices[-1] + 1)

        kopower = kopower[y_indices, :]
        kopower_yscale = kopower_yscale[y_indices]

    # Pixel coordinates in data space
    if xlog:
        nxpix = np.logspace(np.log10(xrange[0]), np.log10(xrange[1]), dxpix)
    else:
        nxpix = np.linspace(xrange[0], xrange[1], dxpix)

    if ylog:
        nypix = np.logspace(np.log10(yrange[0]), np.log10(yrange[1]), dypix)
    else:
        nypix = np.linspace(yrange[0], yrange[1], dypix)

    # Interpolation over data
    interpolator = griddata(
        points=(np.repeat(kopower_yscale, len(kopower_xscale)), np.tile(kopower_xscale, len(kopower_yscale))),
        values=kopower.ravel(),
        xi=(nypix[:, None], nxpix[None, :]),
        method='linear'
    )
    newimage = np.nan_to_num(interpolator)

    # Clip the interpolated image to the exact xrange and yrange
    x_indices = (nxpix >= xrange[0]) & (nxpix <= xrange[1])
    y_indices = (nypix >= yrange[0]) & (nypix <= yrange[1])
    newimage = newimage[np.ix_(y_indices, x_indices)]
    nxpix = nxpix[x_indices]
    nypix = nypix[y_indices]

    # Extent for plotting
    extent = [nxpix[0], nxpix[-1], nypix[0], nypix[-1]]

    # Set up the plot
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)

    # Load custom colormap
    rgb_values = np.loadtxt('Color_Tables/idl_colormap_13.txt') / 255.0
    idl_colormap_13 = ListedColormap(rgb_values)

    # Compute minimum and maximum values of the data
    vmin = np.nanmin(kopower) 
    vmax = np.nanmax(kopower) 

    img = ax.imshow(
        WaLSA_histo_opt(newimage), extent=extent, origin='lower', aspect='auto', cmap=idl_colormap_13, norm=Normalize(vmin=vmin, vmax=vmax)
    )

    # Configure axis labels and scales
    ax.set_xlabel(xtitle)
    ax.set_ylabel(ytitle)
    if xlog:
        ax.set_xscale('log')
    if ylog:
        ax.set_yscale('log')

    # Configure ticks
    ax.xaxis.set_major_locator(plt.MultipleLocator(xtick_interval))
    ax.xaxis.set_minor_locator(plt.MultipleLocator(xtick_minor_interval))
    ax.yaxis.set_major_locator(plt.MultipleLocator(ytick_interval))
    ax.yaxis.set_minor_locator(plt.MultipleLocator(ytick_minor_interval))

    # # Add a secondary x and y axes for perido and spatial size
    ax.tick_params(axis='both', which='both', direction='out', top=True, right=True)
    ax.tick_params(axis='both', which='major', length=7, width=1.1)  # Major ticks
    ax.tick_params(axis='both', which='minor', length=4, width=1.1)  # Minor ticks

    major_xticks = ax.get_xticks()
    major_yticks = ax.get_yticks()

    # Define a function to calculate the period (1000 / frequency)
    def frequency_to_period(frequency):
        return 1000. / frequency if frequency > 0 else np.inf
    # Generate labels for the secondary y-axis based on the primary y-axis ticks
    period_labels = [f"{frequency_to_period(tick):.1f}" if tick > 0 else "" for tick in major_yticks]

    # Set custom labels for the secondary y-axis
    ax_right = ax.secondary_yaxis('right')
    ax_right.set_yticks(major_yticks)
    ax_right.set_yticklabels(period_labels)
    ax_right.set_ylabel('Period (s)', labelpad=12)

    # Define a function to calculate spatial size (2π / wavenumber)
    def wavenumber_to_spatial_size(wavenumber):
        return 2 * np.pi / wavenumber if wavenumber > 0 else np.inf

    # Generate labels for the secondary x-axis based on the primary x-axis ticks
    spatial_size_labels = [f"{wavenumber_to_spatial_size(tick):.1f}" if tick > 0 else "" for tick in major_xticks]

    # Set custom labels for the secondary x-axis
    ax_top = ax.secondary_xaxis('top')
    ax_top.set_xticks(major_xticks)
    ax_top.set_xticklabels(spatial_size_labels)
    ax_top.set_xlabel('Spatial Size (pixel)', labelpad=12)

    if k1 is not None and k2 is not None and f1 is not None and f2 is not None:
        width = k2 - k1
        height = f2 - f1
        rectangle = patches.Rectangle(
            (k1, f1), width, height, zorder=10,
            linewidth=1.5, edgecolor='white', facecolor='none', linestyle='--'
        )
        rectangle.set_path_effects([
            path_effects.Stroke(linewidth=2.5, foreground='black'),
            path_effects.Normal()
        ])
        # Mark the filtered area
        ax.add_patch(rectangle)

    # Add colorbar
    tick_values = [
        vmin,
        vmin + (vmax - vmin) * 0.33,
        vmin + (vmax - vmin) * 0.67,
        vmax
    ]
    # Format the tick labels
    tick_labels = [f"{v:.1f}" for v in tick_values] 

    if colorbar_location == 'top':
        cbar = plt.colorbar(img, ax=ax, orientation='horizontal', pad=cbartab, location='top', aspect=30)
        cbar.set_label(colorbar_label)
        # Override the ticks and tick labels
        cbar.set_ticks(tick_values)       # Set custom tick locations
        cbar.set_ticklabels(tick_labels)  # Set custom tick labels
        cbar.ax.xaxis.set_major_locator(FixedLocator(tick_values))  # Fix tick locations
        cbar.ax.xaxis.set_major_formatter(FixedFormatter(tick_labels))  # Fix tick labels
        # Suppress auto ticks completely
        cbar.ax.xaxis.set_minor_locator(FixedLocator([]))  # Ensure no minor ticks appear
    else:
        cbar = plt.colorbar(img, ax=ax, orientation='vertical', pad=cbartab, aspect=22)
        cbar.set_label(colorbar_label)
        # Override the ticks and tick labels
        cbar.set_ticks(tick_values)       # Set custom tick locations
        cbar.set_ticklabels(tick_labels)  # Set custom tick labels
        cbar.ax.yaxis.set_major_locator(FixedLocator(tick_values))  # Fix tick locations
        cbar.ax.yaxis.set_major_formatter(FixedFormatter(tick_labels))  # Fix tick labels
        # Suppress auto ticks completely
        cbar.ax.yaxis.set_minor_locator(FixedLocator([]))  # Ensure no minor ticks appear

    return ax

WaLSA_plot_wavelet_spectrum.py

This module provides functions for plotting wavelet power spectra and related visualizations.

WaLSA_plot_wavelet_spectrum.py
# -----------------------------------------------------------------------------------------------------
# WaLSAtools - Wave analysis tools
# Copyright (C) 2025 WaLSA Team - Shahin Jafarzadeh et al.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
# http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 
# Note: If you use WaLSAtools for research, please consider citing:
# Jafarzadeh, S., Jess, D. B., Stangalini, M. et al. 2025, Nature Reviews Methods Primers, 5, 21
# -----------------------------------------------------------------------------------------------------


import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
from mpl_toolkits.axes_grid1 import make_axes_locatable

def WaLSA_plot_wavelet_spectrum(t, power, periods, sig_slevel, coi, dt, normalize_power=False,
                                title='Wavelet Power Spectrum', ylabel='Period [s]', xlabel='Time [s]', 
                                colorbar_label='Power (%)', ax=None, colormap='custom', removespace=False):
    """Plots the wavelet power spectrum of a given signal.

    Parameters:
    - t (array-like): Time values for the signal.
    - power (2D array-like): Wavelet power spectrum.
    - periods (array-like): Period values corresponding to the power.
    - sig_slevel (2D array-like): Significance levels of the power spectrum.
    - coi (array-like): Cone of influence, showing where edge effects might be present.
    - dt (float): Time interval between successive time values.
    - normalize_power (bool): If True, normalize power to 0-100%. Default is False.
    - title (str): Title of the plot. Default is 'Wavelet Power Spectrum'.
    - ylabel (str): Y-axis label. Default is 'Period [s]'.
    - xlabel (str): X-axis label. Default is 'Time [s]'.
    - colorbar_label (str): Label for the color bar. Default is 'Power (%)'.
    - ax (matplotlib axis, optional): Axis to plot on. Default is None, which creates a new figure and axis.
    - colormap (str or colormap, optional): Colormap to be used. Default is 'custom'.
    - removespace (bool, optional): If True, limits the maximum y-range to the peak of the cone of influence.
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(12, 6))

    # Custom colormap with white at the lower end
    if colormap == 'custom':
        cmap = LinearSegmentedColormap.from_list('custom_white', ['white', 'blue', 'green', 'yellow', 'red'], N=256)
    else:
        cmap = plt.get_cmap(colormap)

    # Normalize power if requested
    if normalize_power:
        power = 100 * power / np.nanmax(power)

    # Define a larger number of levels to create a continuous color bar
    levels = np.linspace(np.nanmin(power), np.nanmax(power), 100)

    # Plot the wavelet power spectrum
    CS = ax.contourf(t, periods, power, levels=levels, cmap=cmap, extend='neither')  # Removed 'extend' mode for straight ends

    # 95% significance contour
    ax.contour(t, periods, sig_slevel, levels=[1], colors='k', linewidths=[1.0])

    if removespace:
        max_period = np.max(coi)
    else:
        max_period = np.max(periods)

    # Cone-of-influence
    ax.plot(t, coi, '-k', lw=2)

    ax.fill(np.concatenate([t, t[-1:] + dt, t[-1:] + dt, t[:1] - dt, t[:1] - dt]),
            np.concatenate([coi, [1e-9], [max_period], [max_period], [1e-9]]),
            'k', alpha=0.3, hatch='x')

    # Log scale for periods
    ax.set_ylim([np.min(periods), max_period])
    ax.set_yscale('log', base=10)
    ax.yaxis.set_major_formatter(ticker.ScalarFormatter())
    ax.ticklabel_format(axis='y', style='plain')
    ax.invert_yaxis()

    # Set axis limits and labels
    ax.set_xlim([t.min(), t.max()])
    ax.set_ylabel(ylabel, fontsize=14)
    ax.set_xlabel(xlabel, fontsize=14)
    ax.tick_params(axis='both', which='major', labelsize=12)
    ax.set_title(title, fontsize=16)

    # Add a secondary y-axis for frequency in Hz
    ax_freq = ax.twinx()
    # Set limits for the frequency axis based on the `max_period` used for the period axis
    min_frequency = 1 / max_period
    max_frequency = 1 / np.min(periods)
    ax_freq.set_yscale('log', base=10)
    ax_freq.set_ylim([max_frequency, min_frequency])  # Adjust frequency range properly
    ax_freq.yaxis.set_major_formatter(ticker.ScalarFormatter())
    ax_freq.ticklabel_format(axis='y', style='plain')
    ax_freq.invert_yaxis()
    ax_freq.set_ylabel('Frequency (Hz)', fontsize=14)
    ax_freq.tick_params(axis='both', which='major', labelsize=12)

    # Add color bar on top with minimal distance
    divider = make_axes_locatable(ax)
    cax = divider.append_axes('top', size='5%', pad=0.01)  # Use 'pad=0.01' for a small fixed distance
    cbar = plt.colorbar(CS, cax=cax, orientation='horizontal')

    # Move color bar label to the top of the bar
    cbar.set_label(colorbar_label, fontsize=12, labelpad=5)
    cbar.ax.tick_params(labelsize=10, direction='out', top=True, labeltop=True, bottom=False, labelbottom=False)
    cbar.ax.xaxis.set_label_position('top')

    # Adjust layout if a new figure was created
    if ax is None:
        plt.tight_layout()
        plt.show()