Source code for hybra.utils

from typing import Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F

####################################################################################################
##################### Cool routines to study decimated filterbanks #################################
####################################################################################################


[docs] def frame_bounds( w: torch.Tensor, d: int, Ls: Union[int, None] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute frame bounds of a filterbank using polyphase representation. Frame bounds characterize the numerical stability and invertibility of the filterbank transform. Tight frames (A ≈ B) provide optimal stability. Args: w (torch.Tensor): Impulse responses of shape (num_channels, length) d (int): Decimation (stride) factor Ls (int, optional): Signal length. If None, computed automatically. Default: None Returns: Tuple[torch.Tensor, torch.Tensor]: Lower and upper frame bounds (A, B) Note: For d=1, reduces to computing min/max of power spectral density. For d>1, uses polyphase analysis to compute worst-case eigenvalues. Example: >>> w = torch.randn(40, 128) >>> A, B = frame_bounds(w, d=4) >>> condition_number = B / A """ if Ls is None: Ls = int(torch.ceil(torch.tensor(w.shape[-1] * 2 / d)) * d) w_full = torch.cat([w, torch.conj(w)], dim=0) w_hat = torch.fft.fft(w_full, Ls, dim=-1).T if d == 1: psd = torch.sum(w_hat.abs() ** 2, dim=-1) A = torch.min(psd) B = torch.max(psd) return A, B else: N = w_hat.shape[0] M = w_hat.shape[1] assert N % d == 0, "Oh no! Decimation factor must divide signal length!" if w_hat.device.type == "mps": temp_device = torch.device("cpu") else: temp_device = w_hat.device w_hat_cpu = w_hat.to(temp_device) A = torch.tensor([torch.inf]).to(temp_device) B = torch.tensor([0]).to(temp_device) Ha = torch.zeros((d, M)).to(temp_device) Hb = torch.zeros((d, M)).to(temp_device) for j in range(N // d): idx_a = (j - torch.arange(d) * (N // d)) % N idx_b = (torch.arange(d) * (N // d) - j) % N Ha = w_hat_cpu[idx_a, :] Hb = torch.conj(w_hat_cpu[idx_b, :]) lam = torch.linalg.eigvalsh(Ha @ Ha.H + Hb @ Hb.H).real A = torch.min(A, torch.min(lam)) B = torch.max(B, torch.max(lam)) return (A / d).to(w_hat.device), (B / d).to(w_hat.device)
[docs] def condition_number( w: torch.Tensor, d: int, Ls: Union[int, None] = None ) -> torch.Tensor: """Compute condition number of a filterbank frame operator. The condition number κ = B/A quantifies numerical stability, where A and B are the lower and upper frame bounds. Lower values indicate better stability. Args: w (torch.Tensor): Impulse responses of shape (num_channels, signal_length) d (int): Decimation factor (stride) Ls (int, optional): Signal length. If None, computed automatically. Default: None Returns: torch.Tensor: Condition number κ = B/A Note: κ = 1 indicates a tight frame (optimal stability). κ >> 1 suggests potential numerical instability. Example: >>> w = torch.randn(40, 128) >>> kappa = condition_number(w, d=4) >>> print(f"Condition number: {kappa.item():.2f}") """ A, B = frame_bounds(w, d, Ls) A = torch.max( A, torch.tensor(1e-6, dtype=A.dtype, device=A.device) ) # Avoid division by zero return B / A
[docs] def frequency_correlation( w: torch.Tensor, d: int, Ls: Union[int, None] = None, diag_only: bool = False ) -> torch.Tensor: """ Computes the frequency correlation functions (vectorized version). Parameters: w: (J, K) - Impulse responses d: Decimation factor Ls: FFT length (default: nearest multiple of d ≥ 2K-1) diag_only: If True, only return diagonal (i.e., PSD) Returns: G: (d, Ls) complex tensor with frequency correlations """ K = w.shape[-1] if Ls is None: Ls = int(torch.ceil(torch.tensor((2 * K - 1) / d)) * d) w_full = torch.cat([w, torch.conj(w)], dim=0) w_hat = torch.fft.fft(w_full, Ls, dim=-1) # shape: [J, Ls] N = Ls assert N % d == 0, "Decimation factor must divide FFT length" # Diagonal: sum_j |w_hat_j|^2 diag = torch.sum(w_hat.abs() ** 2, dim=0) # shape: [Ls] if diag_only: return torch.real(diag) G = [diag] # G[0] = diagonal for j in range(1, d): rolled = torch.roll(w_hat, shifts=j * (N // d), dims=-1) val = torch.sum(w_hat * torch.conj(rolled), dim=0) G.append(val) G = torch.stack(G, dim=0) # shape: [d, Ls] return G
[docs] def alias( w: torch.Tensor, d: int, Ls: Union[int, None] = None, diag_only: bool = False ) -> torch.Tensor: """ Computes the norm of the aliasing terms. Parameters: w: Impulse responses of the filterbank as 2-D Tensor torch.tensor[num_channels, sig_length] d: Decimation factor, must divide filter length! Output: A: Energy of the aliasing terms """ G = frequency_correlation(w=w, d=d, Ls=Ls, diag_only=diag_only) if diag_only: return torch.max(G).div(torch.min(G)) else: # return torch.max(torch.real(G[0,:])).div(torch.min(torch.real(G[0,:]))) + torch.sum(torch.norm(G[1::,:], p=2, dim=-1), dim=-1) - 1 return torch.norm( torch.real(G[0, :]) - torch.ones_like(G[0, :]), p=2 ) + torch.sum(torch.norm(G[1::, :], p=2, dim=-1), dim=-1)
[docs] def can_tight(w: torch.Tensor, d: int, Ls: int) -> torch.Tensor: """ Computes the canonical tight filterbank of w (time domain) using the polyphase representation. Parameters: w: Impulse responses of the filterbank as 2-d Tensor torch.tensor[num_channels, signal_length] d: Decimation factor, must divide signal_length! Returns: W: Canonical tight filterbank of W (torch.tensor[num_channels, signal_length]) """ w_hat = torch.fft.fft(w.T, Ls, dim=0) if d == 1: lp = torch.sum(w_hat.abs() ** 2, dim=1).reshape(-1, 1) w_hat_tight = w_hat * (lp ** (-0.5)) return torch.fft.ifft(w_hat_tight.T, dim=1) else: N = w_hat.shape[0] J = w_hat.shape[1] assert N % d == 0, "Oh no! Decimation factor must divide signal length!" w_hat_tight = torch.zeros(J, N, dtype=torch.complex64) for j in range(N // d): idx = (j - torch.arange(d) * (N // d)) % N H = w_hat[idx, :] U, _, V = torch.linalg.svd(H, full_matrices=False) H = U @ V w_hat_tight[:, idx] = H.T.to(torch.complex64) return torch.fft.ifft(torch.fft.ifft(w_hat_tight.T, dim=1) * d**0.5, dim=0).T
[docs] def fir_tightener3000( w: torch.Tensor, supp: int, d: int, eps: float = 1.01, Ls: Union[int, None] = None ): """ Iterative tightening procedure with fixed support for a given filterbank w Parameters: w: Impulse responses of the filterbank as 2-D Tensor torch.tensor[num_channels, signal_length]. supp: Desired support of the resulting filterbank d: Decimation factor, must divide filter length! eps: Desired precision for the condition number Ls: System length (if not already given by w). If set, the resulting filterbank is padded with zeros to length Ls. Returns: Filterbank with condition number *eps* and support length *supp*. If length=supp then the resulting filterbank is the canonical tight filterbank of w. """ print("Hold on, the kernels are tightening") if Ls is not None: w = torch.cat([w, torch.zeros(w.shape[0], Ls - w.shape[1])], dim=1) w_tight = w.clone() kappa = condition_number(w, d).item() while kappa > eps: w_tight = can_tight(w_tight, d) w_tight[:, supp:] = 0 kappa = condition_number(w_tight, d).item() if Ls is None: return w_tight else: return w_tight[:, :supp]
[docs] def upsample(x: torch.Tensor, d: int) -> torch.Tensor: N = x.shape[-1] * d x_up = F.pad(torch.zeros_like(x), (0, N - x.shape[-1])) x_up[:, :, ::d] = x return x_up
[docs] def circ_conv(x: torch.Tensor, kernels: torch.Tensor, d: int = 1) -> torch.Tensor: """Circular convolution with optional downsampling. Performs efficient circular convolution using FFT, followed by downsampling. The kernels are automatically centered for proper phase alignment. Args: x (torch.Tensor): Input signal of shape (..., signal_length) kernels (torch.Tensor): Filter kernels of shape (num_channels, 1, kernel_length) or (num_channels, kernel_length) d (int): Downsampling factor (stride). Default: 1 Returns: torch.Tensor: Convolved and downsampled output of shape (..., num_channels, output_length) Note: Uses circular convolution which assumes periodic boundary conditions. Kernels are automatically zero-padded and centered. Example: >>> x = torch.randn(1, 1000) >>> kernels = torch.randn(40, 128) >>> y = circ_conv(x, kernels, d=4) """ L = x.shape[-1] x = x.to(kernels.dtype) kernels_long = F.pad(kernels, (0, L - kernels.shape[-1]), mode="constant", value=0) kernels_centered = torch.roll(kernels_long, shifts=-kernels.shape[-1] // 2, dims=-1) x_fft = torch.fft.fft(x, n=L, dim=-1) k_fft = torch.fft.fft(kernels_centered, n=L, dim=-1) y_fft = x_fft * k_fft y = torch.fft.ifft(y_fft) return y[:, :, ::d]
[docs] def circ_conv_transpose( y: torch.Tensor, kernels: torch.Tensor, d: int = 1 ) -> torch.Tensor: """Transpose (adjoint) of circular convolution with upsampling. Implements the adjoint operation of circ_conv for signal reconstruction. Used in synthesis/decoder operations of filterbanks. Args: y (torch.Tensor): Input coefficients of shape (..., num_channels, num_frames) kernels (torch.Tensor): Filter kernels of shape (num_channels, 1, kernel_length) or (num_channels, kernel_length) d (int): Upsampling factor (stride). Default: 1 Returns: torch.Tensor: Reconstructed signal of shape (..., 1, signal_length) Note: This is the mathematical adjoint, not the true inverse. For perfect reconstruction, appropriate dual frame filters should be used. Example: >>> coeffs = torch.randn(1, 40, 250) >>> kernels = torch.randn(40, 128) >>> x_recon = circ_conv_transpose(coeffs, kernels, d=4) """ L = y.shape[-1] * d y_up = upsample(y, d) kernels_long = F.pad(kernels, (0, L - kernels.shape[-1]), mode="constant", value=0) kernels_centered = torch.roll(kernels_long, shifts=-kernels.shape[-1] // 2, dims=-1) kernels_synth = torch.flip(torch.conj(kernels_centered), dims=(1,)) y_fft = torch.fft.fft(y_up, n=L, dim=-1) k_fft = torch.fft.fft(kernels_synth, n=L, dim=-1) x_fft = y_fft * k_fft x = torch.fft.ifft(x_fft, dim=-1) x = torch.sum(x, dim=-2, keepdim=True) return torch.roll(x, 1, -1)
#################################################################################################### ################### Routines for constructing auditory filterbanks ################################# ####################################################################################################
[docs] def freqtoaud( freq: Union[float, int, torch.Tensor], scale: str = "erb", fs: Union[int, None] = None, ) -> torch.Tensor: """Convert frequencies from Hz to auditory scale units. Transforms linear frequency values to perceptually-motivated auditory scales that better reflect human frequency discrimination. Args: freq (Union[float, int, torch.Tensor]): Frequency value(s) in Hz scale (str): Auditory scale type. One of {'erb', 'mel', 'log10', 'elelog'}. Default: 'erb' fs (int, optional): Sampling frequency (required for 'elelog' scale). Default: None Returns: torch.Tensor: Corresponding auditory scale units Raises: ValueError: If unsupported scale is specified or fs is missing for 'elelog' Note: - ERB: Equivalent Rectangular Bandwidth (Glasberg & Moore) - MEL: Mel scale (perceptually uniform pitch) - Bark: Bark scale (critical band rate) - elelog: Logarithmic scale adapted for elephant hearing Example: >>> freq_hz = torch.tensor([100, 1000, 8000]) >>> mel_units = freqtoaud(freq_hz, scale='mel') """ scale = scale.lower() if isinstance(freq, (int, float)): freq = torch.tensor(freq) if scale == "erb": # Glasberg and Moore's ERB scale return 9.2645 * torch.sign(freq) * torch.log(1 + torch.abs(freq) * 0.00437) elif scale == "mel": # MEL scale return ( 1000 / torch.log(torch.tensor(17 / 7)) * torch.sign(freq) * torch.log(1 + torch.abs(freq) / 700) ) # elif scale == "bark": # # Bark scale from Traunmuller (1990) # return torch.sign(freq) * ((26.81 / (1 + 1960 / torch.abs(freq))) - 0.53) elif scale == "log10": # Logarithmic scale return torch.log10(torch.maximum(torch.ones(1), freq)) elif scale == "elelog": if fs is None: raise ValueError( "Sampling frequency fs must be provided for 'elelog' scale." ) fmin = 1 fmax = fs // 2 k = 0.88 A = fmin / (1 - k) alpha = np.log10(fmax / A + k) return np.log((freq / A + k) / alpha) # - np.log((fmin / A + k) / alpha) else: raise ValueError( f"Unsupported scale: '{scale}'. Available options are: 'mel', 'erb', 'log10', 'elelog'." )
[docs] def audtofreq( aud: Union[float, int, torch.Tensor], scale: str = "erb", fs: Union[int, None] = None, ) -> torch.Tensor: """Convert auditory scale units back to frequencies in Hz. Args: aud (Union[float, int, torch.Tensor]): Auditory scale values scale (str): Auditory scale type. One of {'erb', 'mel', 'log10', 'elelog'}. Default: 'erb' fs (int, optional): Sampling frequency (required for 'elelog' scale). Default: None Returns: torch.Tensor: Corresponding frequencies in Hz Example: >>> mel_units = torch.tensor([100, 1000, 2000]) >>> freq_hz = audtofreq(mel_units, scale='mel') """ if scale == "erb": return (1 / 0.00437) * (torch.exp(aud / 9.2645) - 1) elif scale == "mel": return ( 700 * torch.sign(aud) * (torch.exp(torch.abs(aud) * torch.log(torch.tensor(17 / 7)) / 1000) - 1) ) # elif scale == "bark": # return torch.sign(aud) * 1960 / (26.81 / (torch.abs(aud) + 0.53) - 1) elif scale == "log10": return 10**aud elif scale == "elelog": if fs is None: raise ValueError( "Sampling frequency fs must be provided for 'elelog' scale." ) fmin = 1 fmax = fs // 2 k = 0.88 A = fmin / (1 - k) alpha = np.log10(fmax / A + k) return A * (np.exp(aud) * alpha - k) else: raise ValueError( f"Unsupported scale: '{scale}'. Available options are: 'mel', 'erb', 'log10', 'elelog'." )
[docs] def audspace( fmin: Union[float, int, torch.Tensor], fmax: Union[float, int, torch.Tensor], num_channels: int, scale: str = "erb", ): """ Computes a vector of values equidistantly spaced on the selected auditory scale. Parameters: fmin (float): Minimum frequency in Hz. fmax (float): Maximum frequency in Hz. num_channels (int): Number of points in the output vector. audscale (str): Auditory scale (default is 'erb'). Returns: tuple: y (ndarray): Array of frequencies equidistantly scaled on the auditory scale. """ if num_channels <= 0: raise ValueError("n must be a positive integer scalar.") if fmin > fmax: raise ValueError("fmin must be less than or equal to fmax.") # Convert [fmin, fmax] to auditory scale if scale == "log10" or scale == "elelog": fmin = torch.maximum(torch.tensor(fmin), torch.ones(1)) audlimits = freqtoaud(torch.tensor([fmin, fmax]), scale) # Generate frequencies spaced evenly on the auditory scale aud_space = torch.linspace(audlimits[0], audlimits[1], num_channels) y = audtofreq(aud_space, scale) # Ensure exact endpoints y[0] = fmin y[-1] = fmax return y
[docs] def freqtoaud_mod( freq: Union[float, int, torch.Tensor], fc_low: Union[float, int, torch.Tensor], fc_high: Union[float, int, torch.Tensor], scale="erb", fs=None, ): """ Modified auditory scale function with linear region below fc_crit. Parameters: freq (ndarray): Frequency values in Hz. fc_low (float): Lower transition frequency in Hz. fc_high (float): Upper transition frequency in Hz. Returns: ndarray: Values on the modified auditory scale. """ aud_crit_low = freqtoaud(fc_low, scale, fs) aud_crit_high = freqtoaud(fc_high, scale, fs) slope_low = (freqtoaud(fc_low * 1.01, scale, fs) - aud_crit_low) / (fc_low * 0.01) slope_high = (freqtoaud(fc_high * 1.01, scale, fs) - aud_crit_high) / ( fc_high * 0.01 ) linear_low = freq < fc_low linear_high = freq > fc_high auditory = [not x for x in (linear_low + linear_high)] aud = torch.zeros_like(freq, dtype=torch.float32) aud[linear_low] = slope_low * (freq[linear_low] - fc_low) + aud_crit_low aud[auditory] = freqtoaud(freq[auditory], scale, fs) aud[linear_high] = slope_high * (freq[linear_high] - fc_high) + aud_crit_high return aud
[docs] def audtofreq_mod( aud: Union[float, int, torch.Tensor], fc_low: Union[float, int, torch.Tensor], fc_high: Union[float, int, torch.Tensor], scale="erb", fs=None, ): """ Inverse of freqtoaud_mod to map auditory scale back to frequency. Parameters: aud (ndarray): Auditory scale values. fc_low (float): Lower transition frequency in Hz. fc_high (float): Upper transition frequency in Hz. Returns: ndarray: Frequency values in Hz """ aud_crit_low = freqtoaud(fc_low, scale, fs) aud_crit_high = freqtoaud(fc_high, scale, fs) slope_low = (freqtoaud(fc_low * 1.01, scale, fs) - aud_crit_low) / (fc_low * 0.01) slope_high = (freqtoaud(fc_high * 1.01, scale, fs) - aud_crit_high) / ( fc_high * 0.01 ) linear_low = aud < aud_crit_low linear_high = aud > aud_crit_high auditory_part = [not x for x in (linear_low + linear_high)] freq = torch.zeros_like(aud, dtype=torch.float32) freq[linear_low] = (aud[linear_low] - aud_crit_low) / slope_low + fc_low freq[auditory_part] = audtofreq(aud[auditory_part], scale, fs) freq[linear_high] = (aud[linear_high] - aud_crit_high) / slope_high + fc_high return freq
[docs] def audspace_mod( fc_low: Union[float, int, torch.Tensor], fc_high: Union[float, int, torch.Tensor], fs: int, num_channels: int, scale: str = "erb", ): """Generate M frequency samples that are equidistant in the modified auditory scale. Parameters: fc_crit (float): Critical frequency in Hz. fs (int): Sampling rate in Hz. M (int): Number of filters/channels. Returns: ndarray: Frequency values in Hz and in the auditory scale. """ if fc_low > fc_high: raise ValueError("fc_low must be less than fc_high.") elif fc_low == fc_high: # equidistant samples form 0 to fs/2 fc = torch.linspace(0, fs // 2, num_channels) return fc, freqtoaud_mod(fc, fs // 2, fs // 2, scale, fs) elif fc_low < fc_high: # Convert [0, fs//2] to modified auditory scale aud_min = freqtoaud_mod(torch.tensor([0]), fc_low, fc_high, scale, fs)[0] aud_max = freqtoaud_mod(torch.tensor([fs // 2]), fc_low, fc_high, scale, fs)[0] # Generate frequencies spaced evenly on the modified auditory scale fc_aud = torch.linspace(aud_min, aud_max, num_channels) # Convert back to frequency scale fc = audtofreq_mod(fc_aud, fc_low, fc_high, scale, fs) # Ensure exact endpoints fc[0] = 0 fc[-1] = fs // 2 return fc, fc_aud else: raise ValueError("There is something wrong with fc_low and fc_high.")
[docs] def fctobw(fc: Union[float, int, torch.Tensor], scale="mel"): """ Computes the critical bandwidth of a filter at a given center frequency. Parameters: fc (float or ndarray): Center frequency in Hz. Must be non-negative. audscale (str): Auditory scale. Supported values are: - 'mel': Mel scale (default) - 'erb': Equivalent Rectangular Bandwidth - 'log10': Logarithmic scale Returns: ndarray or float: Critical bandwidth at each center frequency. """ if isinstance(fc, (list, tuple, int, float)): fc = torch.tensor(fc) if not (isinstance(fc, (float, int, torch.Tensor)) and torch.all(fc >= 0)): raise ValueError("fc must be a non-negative scalar or array.") # Compute bandwidth based on the auditory scale if scale == "erb": bw = 24.7 + fc / 9.265 # elif scale == "bark": # bw = 25 + 75 * (1 + 1.4e-6 * fc**2) ** 0.69 elif scale == "mel": bw = torch.log10(torch.tensor(17 / 7)) * (700 + fc) / 1000 elif scale == "log10": bw = fc else: raise ValueError(f"Unsupported auditory scale: {scale}") return bw
[docs] def bwtofc(bw: Union[float, int, torch.Tensor], scale="mel"): """ Computes the center frequency corresponding to a given critical bandwidth. Parameters: bw (float or ndarray): Critical bandwidth. Must be non-negative. scale (str): Auditory scale. Supported values are: - 'mel': Mel scale - 'erb': Equivalent Rectangular Bandwidth - 'log10': Logarithmic scale Returns: ndarray or float: Center frequency corresponding to the given bandwidth. """ if isinstance(bw, (list, tuple)): bw = torch.tensor(bw) if not (isinstance(bw, (float, int, torch.Tensor)) and torch.all(bw >= 0)): raise ValueError("bw must be a non-negative scalar or array.") # Compute center frequency based on the auditory scale if scale == "erb": fc = (bw - 24.7) * 9.265 # elif scale == "bark": # fc = torch.sqrt(((bw - 25) / 75) ** (1 / 0.69) / 1.4e-6) elif scale == "mel": fc = 1000 * (bw / torch.log10(torch.tensor(17 / 7))) - 700 elif scale == "log10": fc = bw else: raise ValueError(f"Unsupported auditory scale: {scale}") return fc
[docs] def firwin(kernel_size: int, padto: int = None): """ FIR window generation in Python. Parameters: kernel_size (int): Length of the window. padto (int): Length to which it should be padded. name (str): Name of the window. Returns: g (ndarray): FIR window. """ g = torch.hann_window(kernel_size, periodic=False) g /= torch.sum(torch.abs(g)) if padto is None or padto == kernel_size: return g elif padto > kernel_size: g_padded = torch.concatenate([g, torch.zeros(padto - len(g))]) g_centered = torch.roll(g_padded, int((padto - len(g)) // 2)) return g_centered else: raise ValueError("padto must be larger than kernel_size.")
[docs] def modulate(g: torch.Tensor, fc: Union[float, int, torch.Tensor], fs: int): """Modulate a filters. Args: g (list of torch.Tensor): Filters. fc (list): Center frequencies. fs (int): Sampling rate. Returns: g_mod (list of torch.Tensor): Modulated filters. """ Lg = len(g) g_mod = g * torch.exp(2 * torch.pi * 1j * fc * torch.arange(Lg) / fs) return g_mod
#################################################################################################### ########################################### ISAC ################################################### ####################################################################################################
[docs] def audfilters( fs: int, kernel_size: Union[int, None] = None, num_channels: int = 96, fc_max: Union[float, int, None] = None, L: Union[int, None] = None, supp_mult: float = 1, scale: str = "mel", ) -> Tuple[ torch.Tensor, int, torch.Tensor, Union[int, float], Union[int, float], int, int, int, torch.Tensor, ]: """Generate auditory-inspired FIR filterbank kernels. Creates a bank of bandpass filters with center frequencies distributed according to perceptual auditory scales (mel, erb, etc.). Filters are designed with variable bandwidths matching critical bands of human auditory perception. Args: fs (int): Sampling frequency in Hz. (required) kernel_size (int, optional): Maximum filter kernel size. If None, computed automatically. Default: None num_channels (int): Number of frequency channels. Default: 96 fc_max (float, optional): Maximum center frequency in Hz. If None, uses fs//2. Default: None L (int): Signal length in samples. If None, uses fs. Default: None supp_mult (float): Support multiplier for kernel sizing. Default: 1.0 scale (str): Auditory scale. One of {'mel', 'erb', 'log10', 'elelog'}. Default: 'mel' Returns: Tuple containing: - kernels (torch.Tensor): Filter kernels of shape (num_channels, kernel_size) - d (int): Recommended stride for 50% overlap - fc (torch.Tensor): Center frequencies in Hz - fc_min (Union[int, float]): Minimum center frequency - fc_max (Union[int, float]): Maximum center frequency - kernel_min (int): Minimum kernel size used - kernel_size (int): Maximum kernel size used - Ls (int): Adjusted signal length - tsupp (torch.Tensor): Time support for each filter Raises: ValueError: If parameters are invalid (negative values, unsupported scale, etc.) Note: The filterbank construction follows auditory modeling principles where: - Low frequencies use longer filters (better frequency resolution) - High frequencies use shorter filters (better time resolution) - Bandwidth scales according to critical band theory Example: >>> kernels, stride, fc, _, _, _, _, Ls, _ = audfilters( ... kernel_size=128, num_channels=40, fs=16000, scale='mel' ... ) >>> print(f"Generated {kernels.shape[0]} filters with stride {stride}") """ # check if all inputs are valid if kernel_size is not None and kernel_size <= 0: raise ValueError("kernel_size must be a positive integer.") if num_channels <= 0: raise ValueError("num_channels must be a positive integer.") # check if fs is a positive integer if fs is None: raise ValueError("sampling rate must be set.") if not isinstance(fs, int) or fs <= 0: raise ValueError("fs must be a positive integer.") if L is None: L = fs if not isinstance(L, int) or L <= 0: raise ValueError("L must be a positive integer.") if supp_mult < 0: raise ValueError("supp_mult must be a non-negative float.") if scale not in ["mel", "erb", "log10", "elelog"]: raise ValueError("scale must be one of 'mel', 'erb', 'log10', or 'elelog'.") if fc_max is not None and (fc_max <= 0 or fc_max >= fs // 2): raise ValueError("fc_max must be a positive integer less than fs/2.") #################################################################################################### # Bandwidth conversion #################################################################################################### probeLs = 10000 probeLg = 1000 g_probe = firwin(probeLg, probeLs) # peak normalize gf_probe = torch.real( torch.fft.fft(g_probe) / torch.max(torch.abs(torch.fft.fft(g_probe))) ) bw_probe = torch.norm(gf_probe) ** 2 * probeLg / probeLs / 2 # preset bandwidth factors to get a good condition number if scale == "erb": bw_factor = 0.608 elif scale == "mel": bw_factor = 111.33 elif scale == "log10": bw_factor = 0.2 # elif scale == "bark": # bw_factor = 0.5 elif scale == "elelog": bw_factor = 1 bw_conversion = bw_probe / bw_factor # * num_channels / 40 #################################################################################################### # Center frequencies #################################################################################################### # checking the maximum kernel size if scale == "elelog": cycles = 10 kernel_max = fs // 10 * cycles # capture frequencies of 10Hz for 10 cycles if kernel_size is None: kernel_size = kernel_max fc_min = 10 if fc_max is None: fc_max = fs // 2 kernel_min = int(fs / fc_max * cycles) else: fsupp_min = fctobw(0, scale) # if not specified, set the kernel size equal to the sampling frequency fs if kernel_size is None: kernel_size = int( torch.minimum( torch.round(bw_conversion / fsupp_min * fs), torch.tensor(fs) ) ) # get the bandwidth for the kernel size and the associated center frequency fsupp_low = bw_conversion / kernel_size * fs fc_min = bwtofc(fsupp_low, scale) if fc_max is None: fc_max = fs // 2 # get the bandwidth for the maximum center frequency and the associated kernel size fsupp_high = fctobw(fc_max, scale) kernel_min = int(torch.round(bw_conversion / fsupp_high * fs)) if fc_min >= fc_max: fc_max = fc_min kernel_min = kernel_size Warning( f"fc_max was increased to {fc_min} to enable the kernel size of {kernel_size}." ) # get center frequencies [fc, _] = audspace_mod(fc_min, fc_max, fs, num_channels, scale) num_low = torch.where(fc < fc_min)[0].shape[0] num_high = torch.where(fc > fc_max)[0].shape[0] num_aud = num_channels - num_low - num_high #################################################################################################### # Frequency and time supports #################################################################################################### # get time supports tsupp_low = (torch.ones(num_low) * kernel_size).int() tsupp_high = (torch.ones(num_high) * kernel_min).int() if scale == "elelog": tsupp_aud = ( torch.minimum( torch.tensor(kernel_size), torch.round(fs / fc[num_low : num_low + num_aud] * cycles), ) ).int() tsupp = torch.concatenate([tsupp_low, tsupp_aud, tsupp_high]).int() else: if num_low + num_high == num_channels: fsupp = fctobw(fc_max, scale) tsupp = tsupp_low else: fsupp = fctobw(fc[num_low : num_low + num_aud], scale) tsupp_aud = torch.round(bw_conversion / fsupp * fs) tsupp = torch.concatenate([tsupp_low, tsupp_aud, tsupp_high]).int() if supp_mult < 1: tsupp = torch.max( torch.round(tsupp * supp_mult), torch.ones_like(tsupp) * 8 ).int() else: tsupp = torch.min( torch.round(tsupp * supp_mult), torch.ones_like(tsupp) * L ).int() kernel_min = tsupp.min() kernel_size = tsupp.max() # Decimation factor (stride) for 50% overlap d = torch.maximum(kernel_min // 2, torch.tensor(1)) Ls = int(torch.ceil(L / d) * d) #################################################################################################### # Generate filters #################################################################################################### g = torch.zeros((num_channels, kernel_size), dtype=torch.cfloat) g[0, :] = torch.sqrt(d) * firwin(kernel_size) / torch.sqrt(torch.tensor(2)) g[-1, :] = ( torch.sqrt(d) * modulate(firwin(tsupp[-1], kernel_size), fs // 2, fs) / torch.sqrt(torch.tensor(2)) ) for m in range(1, num_channels - 1): g[m, :] = torch.sqrt(d) * modulate(firwin(tsupp[m], kernel_size), fc[m], fs) # _, B = frame_bounds(g, d, Ls) # g = g / B**0.5 return g, int(d), fc, fc_min, fc_max, kernel_min, kernel_size, Ls, tsupp
#################################################################################################### #################################################################################################### ####################################################################################################
[docs] def response(g: np.ndarray, fs: int) -> np.ndarray: """Compute frequency responses of filter kernels. Args: g (np.ndarray): Filter kernels of shape (num_channels, kernel_size) fs (int): Sampling frequency for frequency axis scaling Returns: np.ndarray: Magnitude-squared frequency responses of shape (2*num_channels, fs//2) Note: Computes responses for both analysis and conjugate filters. """ g_full = np.concatenate([g, np.conj(g)], axis=0) G = np.abs(np.fft.fft(g_full, fs, axis=1)[:, : fs // 2]) ** 2 return G
[docs] def plot_response( g: np.ndarray, fs: int, scale: str = "mel", plot_scale: bool = False, fc_min: Union[float, None] = None, fc_max: Union[float, None] = None, decoder: bool = False, ) -> None: """Plot frequency responses and auditory scale visualization of filters. Creates comprehensive visualization showing individual filter responses, total power spectral density, and optional auditory scale mapping. Args: g (np.ndarray): Filter kernels of shape (num_channels, kernel_size) fs (int): Sampling frequency in Hz for frequency axis scaling scale (str): Auditory scale name for scale plotting. Default: 'mel' plot_scale (bool): Whether to plot the auditory scale mapping. Default: False fc_min (float, optional): Lower transition frequency for scale visualization. Default: None fc_max (float, optional): Upper transition frequency for scale visualization. Default: None decoder (bool): Whether filters are for synthesis (affects plot titles). Default: False Note: This function displays plots and does not return values. Creates 2-3 subplots depending on plot_scale parameter. Example: >>> filters = np.random.randn(40, 128) >>> plot_response(filters, fs=16000, scale='mel', plot_scale=True) """ num_channels = g.shape[0] g_hat = response(g, fs) g_hat_pos = g_hat[:num_channels, :] g_hat_pos[np.isnan(g_hat_pos)] = 0 psd = np.sum(g_hat, axis=0) psd[np.isnan(psd)] = 0 if plot_scale: plt.figure(figsize=(8, 2)) freq_samples, _ = audspace_mod(fc_min, fc_max, fs, num_channels, scale) freqs = torch.linspace(0, fs // 2, fs // 2) auds = freqtoaud_mod(freqs, fc_min, fc_max, scale, fs).numpy() auds_orig = freqtoaud(freqs, scale, fs).numpy() plt.scatter( freq_samples.numpy(), freqtoaud_mod(freq_samples, fc_min, fc_max, scale, fs).numpy(), color="black", label="Center frequencies", linewidths=0.04, ) plt.plot(freqs, auds, color="black", label=f"ISAC {scale}-scale") plt.plot( freqs, auds_orig, color="black", linestyle="--", alpha=0.5, label=f"Original {scale}-scale", ) if fc_min is not None: plt.axvline(fc_min, color="black", alpha=0.25) plt.fill_betweenx( y=[auds[0] - 1, auds[-1] * 1.1], x1=0, x2=fc_min, color="gray", alpha=0.25, ) plt.fill_betweenx( y=[auds[0] - 1, auds[-1] * 1.1], x1=fc_min, x2=fs // 2, color="gray", alpha=0.1, ) if fc_max is not None: plt.axvline(fc_max, color="black", alpha=0.25) plt.fill_betweenx( y=[auds[0] - 1, auds[-1] * 1.1], x1=0, x2=fc_max, color="gray", alpha=0.25, ) plt.fill_betweenx( y=[auds[0] - 1, auds[-1] * 1.1], x1=fc_max, x2=fs // 2, color="gray", alpha=0.1, ) plt.xlim([0, fs // 2]) plt.ylim([auds[0] - 1, auds[-1] * 1.1]) plt.xlabel("Frequency (Hz)") # text_x = fc_min / 2 # text_y = auds[-1] # plt.text(text_x, text_y, 'linear', color='black', ha='center', va='center', fontsize=12, alpha=0.75) # plt.text(text_x + fc_min - 1, text_y, 'ERB', color='black', ha='center', va='center', fontsize=12, alpha=0.75) # plt.title(f"ISAC {scale}-scale") plt.ylabel("Auditory Units") plt.legend(loc="lower right") plt.tight_layout() plt.show() fig, ax = plt.subplots(2, 1, figsize=(6, 3), sharex=True) fr_id = 0 psd_id = 1 f_range = np.linspace(0, fs // 2, fs // 2) ax[fr_id].set_xlim([0, fs // 2]) ax[fr_id].set_ylim([0, np.max(g_hat_pos) * 1.1]) ax[fr_id].plot(f_range, g_hat_pos.T) if decoder: ax[fr_id].set_title("PSDs of the synthesis filters") if not decoder: ax[fr_id].set_title("PSDs of the analysis filters") # ax[fr_id].set_xlabel('Frequency [Hz]') ax[fr_id].set_ylabel("Magnitude") ax[psd_id].plot(f_range, psd) ax[psd_id].set_xlim([0, fs // 2]) ax[psd_id].set_ylim([0, np.max(psd) * 1.1]) ax[psd_id].set_title("Total PSD") ax[psd_id].set_xlabel("Frequency [Hz]") ax[psd_id].set_ylabel("Magnitude") if fc_min is not None: ax[fr_id].fill_betweenx( y=[0, np.max(g_hat) * 1.1], x1=0, x2=fc_min, color="gray", alpha=0.25 ) ax[fr_id].fill_betweenx( y=[0, np.max(g_hat) * 1.1], x1=fc_min, x2=fs // 2, color="gray", alpha=0.1 ) ax[psd_id].fill_betweenx( y=[0, np.max(psd) * 1.1], x1=0, x2=fc_min, color="gray", alpha=0.25 ) ax[psd_id].fill_betweenx( y=[0, np.max(psd) * 1.1], x1=fc_min, x2=fs // 2, color="gray", alpha=0.1 ) if fc_max is not None: ax[fr_id].fill_betweenx( y=[0, np.max(g_hat) * 1.1], x1=0, x2=fc_max, color="gray", alpha=0.25 ) ax[fr_id].fill_betweenx( y=[0, np.max(g_hat) * 1.1], x1=fc_max, x2=fs // 2, color="gray", alpha=0.1 ) ax[psd_id].fill_betweenx( y=[0, np.max(psd) * 1.1], x1=0, x2=fc_max, color="gray", alpha=0.25 ) ax[psd_id].fill_betweenx( y=[0, np.max(psd) * 1.1], x1=fc_max, x2=fs // 2, color="gray", alpha=0.1 ) plt.tight_layout() plt.show()
[docs] def ISACgram( c: torch.Tensor, fc: Union[torch.Tensor, None] = None, L: Union[int, None] = None, fs: Union[int, None] = None, fmax: Union[float, None] = None, log_scale: bool = False, vmin: Union[float, None] = None, cmap: str = "inferno", ) -> None: """Plot time-frequency representation of filterbank coefficients. Creates a spectrogram-like visualization with frequency on y-axis and time on x-axis. Supports logarithmic scaling and frequency range limitation for better visualization. Args: c (torch.Tensor): Filterbank coefficients of shape (batch_size, num_channels, num_frames) fc (torch.Tensor, optional): Center frequencies in Hz for y-axis labeling. Default: None L (int, optional): Original signal length for time axis scaling. Default: None fs (int, optional): Sampling frequency for time axis scaling. Default: None fmax (float, optional): Maximum frequency to display in Hz. Default: None log_scale (bool): Whether to apply log10 scaling to coefficients. Default: False vmin (float, optional): Minimum value for dynamic range clipping. Default: None cmap (str): Matplotlib colormap name. Default: 'inferno' Note: This function displays a plot and does not return values. Only processes the first batch element if batch_size > 1. Example: >>> coeffs = torch.randn(1, 40, 250) >>> fc = torch.linspace(100, 8000, 40) >>> ISACgram(coeffs, fc=fc, L=16000, fs=16000, log_scale=True) """ plt.figure(figsize=(10, 6)) ax = plt.gca() c = c[0].detach().cpu().numpy() if log_scale: c = np.log10(np.abs(c) + 1e-10) if fc is not None and fmax is not None: c = c[: np.argmax(fc > fmax), :] if vmin is not None: mesh = ax.pcolor(c, cmap=cmap, vmin=np.min(c) * vmin) else: mesh = ax.pcolor(c, cmap=cmap) # Add colorbar plt.colorbar(mesh, ax=ax) # Axis labeling if fc is not None: locs = np.linspace(0, c.shape[0] - 1, min(len(fc), 10)).astype(int) ax.set_yticks(locs) ax.set_yticklabels([int(np.round(fc[i])) for i in locs]) # X-axis: time num_time_labels = 10 xticks = np.linspace(0, c.shape[1] - 1, num_time_labels) ax.set_xticks(xticks) ax.set_xticklabels( [np.round(x, 1) for x in np.linspace(0, L // fs, num_time_labels)] ) ax.set_ylabel("Frequency [Hz]") ax.set_xlabel("Time [s]") else: ax.set_ylabel("Frequency index") ax.set_xlabel("Time samples") plt.tight_layout() # plt.savefig('/Users/dani/Library/Mobile Documents/com~apple~CloudDocs/Documents/PhD/ELECOM/IBAC/rumble_avg.png', dpi=600) plt.show()