from typing import Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
####################################################################################################
##################### Cool routines to study decimated filterbanks #################################
####################################################################################################
[docs]
def frame_bounds(
w: torch.Tensor, d: int, Ls: Union[int, None] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute frame bounds of a filterbank using polyphase representation.
Frame bounds characterize the numerical stability and invertibility of the
filterbank transform. Tight frames (A ≈ B) provide optimal stability.
Args:
w (torch.Tensor): Impulse responses of shape (num_channels, length)
d (int): Decimation (stride) factor
Ls (int, optional): Signal length. If None, computed automatically. Default: None
Returns:
Tuple[torch.Tensor, torch.Tensor]: Lower and upper frame bounds (A, B)
Note:
For d=1, reduces to computing min/max of power spectral density.
For d>1, uses polyphase analysis to compute worst-case eigenvalues.
Example:
>>> w = torch.randn(40, 128)
>>> A, B = frame_bounds(w, d=4)
>>> condition_number = B / A
"""
if Ls is None:
Ls = int(torch.ceil(torch.tensor(w.shape[-1] * 2 / d)) * d)
w_full = torch.cat([w, torch.conj(w)], dim=0)
w_hat = torch.fft.fft(w_full, Ls, dim=-1).T
if d == 1:
psd = torch.sum(w_hat.abs() ** 2, dim=-1)
A = torch.min(psd)
B = torch.max(psd)
return A, B
else:
N = w_hat.shape[0]
M = w_hat.shape[1]
assert N % d == 0, "Oh no! Decimation factor must divide signal length!"
if w_hat.device.type == "mps":
temp_device = torch.device("cpu")
else:
temp_device = w_hat.device
w_hat_cpu = w_hat.to(temp_device)
A = torch.tensor([torch.inf]).to(temp_device)
B = torch.tensor([0]).to(temp_device)
Ha = torch.zeros((d, M)).to(temp_device)
Hb = torch.zeros((d, M)).to(temp_device)
for j in range(N // d):
idx_a = (j - torch.arange(d) * (N // d)) % N
idx_b = (torch.arange(d) * (N // d) - j) % N
Ha = w_hat_cpu[idx_a, :]
Hb = torch.conj(w_hat_cpu[idx_b, :])
lam = torch.linalg.eigvalsh(Ha @ Ha.H + Hb @ Hb.H).real
A = torch.min(A, torch.min(lam))
B = torch.max(B, torch.max(lam))
return (A / d).to(w_hat.device), (B / d).to(w_hat.device)
[docs]
def condition_number(
w: torch.Tensor, d: int, Ls: Union[int, None] = None
) -> torch.Tensor:
"""Compute condition number of a filterbank frame operator.
The condition number κ = B/A quantifies numerical stability, where A and B
are the lower and upper frame bounds. Lower values indicate better stability.
Args:
w (torch.Tensor): Impulse responses of shape (num_channels, signal_length)
d (int): Decimation factor (stride)
Ls (int, optional): Signal length. If None, computed automatically. Default: None
Returns:
torch.Tensor: Condition number κ = B/A
Note:
κ = 1 indicates a tight frame (optimal stability).
κ >> 1 suggests potential numerical instability.
Example:
>>> w = torch.randn(40, 128)
>>> kappa = condition_number(w, d=4)
>>> print(f"Condition number: {kappa.item():.2f}")
"""
A, B = frame_bounds(w, d, Ls)
A = torch.max(
A, torch.tensor(1e-6, dtype=A.dtype, device=A.device)
) # Avoid division by zero
return B / A
[docs]
def frequency_correlation(
w: torch.Tensor, d: int, Ls: Union[int, None] = None, diag_only: bool = False
) -> torch.Tensor:
"""
Computes the frequency correlation functions (vectorized version).
Parameters:
w: (J, K) - Impulse responses
d: Decimation factor
Ls: FFT length (default: nearest multiple of d ≥ 2K-1)
diag_only: If True, only return diagonal (i.e., PSD)
Returns:
G: (d, Ls) complex tensor with frequency correlations
"""
K = w.shape[-1]
if Ls is None:
Ls = int(torch.ceil(torch.tensor((2 * K - 1) / d)) * d)
w_full = torch.cat([w, torch.conj(w)], dim=0)
w_hat = torch.fft.fft(w_full, Ls, dim=-1) # shape: [J, Ls]
N = Ls
assert N % d == 0, "Decimation factor must divide FFT length"
# Diagonal: sum_j |w_hat_j|^2
diag = torch.sum(w_hat.abs() ** 2, dim=0) # shape: [Ls]
if diag_only:
return torch.real(diag)
G = [diag] # G[0] = diagonal
for j in range(1, d):
rolled = torch.roll(w_hat, shifts=j * (N // d), dims=-1)
val = torch.sum(w_hat * torch.conj(rolled), dim=0)
G.append(val)
G = torch.stack(G, dim=0) # shape: [d, Ls]
return G
[docs]
def alias(
w: torch.Tensor, d: int, Ls: Union[int, None] = None, diag_only: bool = False
) -> torch.Tensor:
"""
Computes the norm of the aliasing terms.
Parameters:
w: Impulse responses of the filterbank as 2-D Tensor torch.tensor[num_channels, sig_length]
d: Decimation factor, must divide filter length!
Output:
A: Energy of the aliasing terms
"""
G = frequency_correlation(w=w, d=d, Ls=Ls, diag_only=diag_only)
if diag_only:
return torch.max(G).div(torch.min(G))
else:
# return torch.max(torch.real(G[0,:])).div(torch.min(torch.real(G[0,:]))) + torch.sum(torch.norm(G[1::,:], p=2, dim=-1), dim=-1) - 1
return torch.norm(
torch.real(G[0, :]) - torch.ones_like(G[0, :]), p=2
) + torch.sum(torch.norm(G[1::, :], p=2, dim=-1), dim=-1)
[docs]
def can_tight(w: torch.Tensor, d: int, Ls: int) -> torch.Tensor:
"""
Computes the canonical tight filterbank of w (time domain) using the polyphase representation.
Parameters:
w: Impulse responses of the filterbank as 2-d Tensor torch.tensor[num_channels, signal_length]
d: Decimation factor, must divide signal_length!
Returns:
W: Canonical tight filterbank of W (torch.tensor[num_channels, signal_length])
"""
w_hat = torch.fft.fft(w.T, Ls, dim=0)
if d == 1:
lp = torch.sum(w_hat.abs() ** 2, dim=1).reshape(-1, 1)
w_hat_tight = w_hat * (lp ** (-0.5))
return torch.fft.ifft(w_hat_tight.T, dim=1)
else:
N = w_hat.shape[0]
J = w_hat.shape[1]
assert N % d == 0, "Oh no! Decimation factor must divide signal length!"
w_hat_tight = torch.zeros(J, N, dtype=torch.complex64)
for j in range(N // d):
idx = (j - torch.arange(d) * (N // d)) % N
H = w_hat[idx, :]
U, _, V = torch.linalg.svd(H, full_matrices=False)
H = U @ V
w_hat_tight[:, idx] = H.T.to(torch.complex64)
return torch.fft.ifft(torch.fft.ifft(w_hat_tight.T, dim=1) * d**0.5, dim=0).T
[docs]
def fir_tightener3000(
w: torch.Tensor, supp: int, d: int, eps: float = 1.01, Ls: Union[int, None] = None
):
"""
Iterative tightening procedure with fixed support for a given filterbank w
Parameters:
w: Impulse responses of the filterbank as 2-D Tensor torch.tensor[num_channels, signal_length].
supp: Desired support of the resulting filterbank
d: Decimation factor, must divide filter length!
eps: Desired precision for the condition number
Ls: System length (if not already given by w). If set, the resulting filterbank is padded with zeros to length Ls.
Returns:
Filterbank with condition number *eps* and support length *supp*. If length=supp then the resulting filterbank is the canonical tight filterbank of w.
"""
print("Hold on, the kernels are tightening")
if Ls is not None:
w = torch.cat([w, torch.zeros(w.shape[0], Ls - w.shape[1])], dim=1)
w_tight = w.clone()
kappa = condition_number(w, d).item()
while kappa > eps:
w_tight = can_tight(w_tight, d)
w_tight[:, supp:] = 0
kappa = condition_number(w_tight, d).item()
if Ls is None:
return w_tight
else:
return w_tight[:, :supp]
[docs]
def upsample(x: torch.Tensor, d: int) -> torch.Tensor:
N = x.shape[-1] * d
x_up = F.pad(torch.zeros_like(x), (0, N - x.shape[-1]))
x_up[:, :, ::d] = x
return x_up
[docs]
def circ_conv(x: torch.Tensor, kernels: torch.Tensor, d: int = 1) -> torch.Tensor:
"""Circular convolution with optional downsampling.
Performs efficient circular convolution using FFT, followed by downsampling.
The kernels are automatically centered for proper phase alignment.
Args:
x (torch.Tensor): Input signal of shape (..., signal_length)
kernels (torch.Tensor): Filter kernels of shape (num_channels, 1, kernel_length)
or (num_channels, kernel_length)
d (int): Downsampling factor (stride). Default: 1
Returns:
torch.Tensor: Convolved and downsampled output of shape (..., num_channels, output_length)
Note:
Uses circular convolution which assumes periodic boundary conditions.
Kernels are automatically zero-padded and centered.
Example:
>>> x = torch.randn(1, 1000)
>>> kernels = torch.randn(40, 128)
>>> y = circ_conv(x, kernels, d=4)
"""
L = x.shape[-1]
x = x.to(kernels.dtype)
kernels_long = F.pad(kernels, (0, L - kernels.shape[-1]), mode="constant", value=0)
kernels_centered = torch.roll(kernels_long, shifts=-kernels.shape[-1] // 2, dims=-1)
x_fft = torch.fft.fft(x, n=L, dim=-1)
k_fft = torch.fft.fft(kernels_centered, n=L, dim=-1)
y_fft = x_fft * k_fft
y = torch.fft.ifft(y_fft)
return y[:, :, ::d]
[docs]
def circ_conv_transpose(
y: torch.Tensor, kernels: torch.Tensor, d: int = 1
) -> torch.Tensor:
"""Transpose (adjoint) of circular convolution with upsampling.
Implements the adjoint operation of circ_conv for signal reconstruction.
Used in synthesis/decoder operations of filterbanks.
Args:
y (torch.Tensor): Input coefficients of shape (..., num_channels, num_frames)
kernels (torch.Tensor): Filter kernels of shape (num_channels, 1, kernel_length)
or (num_channels, kernel_length)
d (int): Upsampling factor (stride). Default: 1
Returns:
torch.Tensor: Reconstructed signal of shape (..., 1, signal_length)
Note:
This is the mathematical adjoint, not the true inverse. For perfect reconstruction,
appropriate dual frame filters should be used.
Example:
>>> coeffs = torch.randn(1, 40, 250)
>>> kernels = torch.randn(40, 128)
>>> x_recon = circ_conv_transpose(coeffs, kernels, d=4)
"""
L = y.shape[-1] * d
y_up = upsample(y, d)
kernels_long = F.pad(kernels, (0, L - kernels.shape[-1]), mode="constant", value=0)
kernels_centered = torch.roll(kernels_long, shifts=-kernels.shape[-1] // 2, dims=-1)
kernels_synth = torch.flip(torch.conj(kernels_centered), dims=(1,))
y_fft = torch.fft.fft(y_up, n=L, dim=-1)
k_fft = torch.fft.fft(kernels_synth, n=L, dim=-1)
x_fft = y_fft * k_fft
x = torch.fft.ifft(x_fft, dim=-1)
x = torch.sum(x, dim=-2, keepdim=True)
return torch.roll(x, 1, -1)
####################################################################################################
################### Routines for constructing auditory filterbanks #################################
####################################################################################################
[docs]
def freqtoaud(
freq: Union[float, int, torch.Tensor],
scale: str = "erb",
fs: Union[int, None] = None,
) -> torch.Tensor:
"""Convert frequencies from Hz to auditory scale units.
Transforms linear frequency values to perceptually-motivated auditory scales
that better reflect human frequency discrimination.
Args:
freq (Union[float, int, torch.Tensor]): Frequency value(s) in Hz
scale (str): Auditory scale type. One of {'erb', 'mel', 'bark', 'log10', 'elelog'}. Default: 'erb'
fs (int, optional): Sampling frequency (required for 'elelog' scale). Default: None
Returns:
torch.Tensor: Corresponding auditory scale units
Raises:
ValueError: If unsupported scale is specified or fs is missing for 'elelog'
Note:
- ERB: Equivalent Rectangular Bandwidth (Glasberg & Moore)
- MEL: Mel scale (perceptually uniform pitch)
- Bark: Bark scale (critical band rate)
- elelog: Logarithmic scale adapted for elephant hearing
Example:
>>> freq_hz = torch.tensor([100, 1000, 8000])
>>> mel_units = freqtoaud(freq_hz, scale='mel')
"""
scale = scale.lower()
if isinstance(freq, (int, float)):
freq = torch.tensor(freq)
if scale == "erb":
# Glasberg and Moore's ERB scale
return 9.2645 * torch.sign(freq) * torch.log(1 + torch.abs(freq) * 0.00437)
elif scale == "mel":
# MEL scale
return (
1000
/ torch.log(torch.tensor(17 / 7))
* torch.sign(freq)
* torch.log(1 + torch.abs(freq) / 700)
)
elif scale == "bark":
# Bark scale from Traunmuller (1990)
return torch.sign(freq) * ((26.81 / (1 + 1960 / torch.abs(freq))) - 0.53)
elif scale == "log10":
# Logarithmic scale
return torch.log10(torch.maximum(torch.ones(1), freq))
elif scale == "elelog":
if fs is None:
raise ValueError(
"Sampling frequency fs must be provided for 'elelog' scale."
)
fmin = 1
fmax = fs // 2
k = 0.88
A = fmin / (1 - k)
alpha = np.log10(fmax / A + k)
return np.log((freq / A + k) / alpha) # - np.log((fmin / A + k) / alpha)
else:
raise ValueError(
f"Unsupported scale: '{scale}'. Available options are: 'mel', 'erb', 'bark', 'log10', 'elelog'."
)
[docs]
def audtofreq(
aud: Union[float, int, torch.Tensor],
scale: str = "erb",
fs: Union[int, None] = None,
) -> torch.Tensor:
"""Convert auditory scale units back to frequencies in Hz.
Args:
aud (Union[float, int, torch.Tensor]): Auditory scale values
scale (str): Auditory scale type. One of {'erb', 'mel', 'bark', 'log10', 'elelog'}. Default: 'erb'
fs (int, optional): Sampling frequency (required for 'elelog' scale). Default: None
Returns:
torch.Tensor: Corresponding frequencies in Hz
Example:
>>> mel_units = torch.tensor([100, 1000, 2000])
>>> freq_hz = audtofreq(mel_units, scale='mel')
"""
if scale == "erb":
return (1 / 0.00437) * (torch.exp(aud / 9.2645) - 1)
elif scale == "mel":
return (
700
* torch.sign(aud)
* (torch.exp(torch.abs(aud) * torch.log(torch.tensor(17 / 7)) / 1000) - 1)
)
elif scale == "bark":
return torch.sign(aud) * 1960 / (26.81 / (torch.abs(aud) + 0.53) - 1)
elif scale == "log10":
return 10**aud
elif scale == "elelog":
if fs is None:
raise ValueError(
"Sampling frequency fs must be provided for 'elelog' scale."
)
fmin = 1
fmax = fs // 2
k = 0.88
A = fmin / (1 - k)
alpha = np.log10(fmax / A + k)
return A * (np.exp(aud) * alpha - k)
else:
raise ValueError(
f"Unsupported scale: '{scale}'. Available options are: 'mel', 'erb', 'bark', 'log10', 'elelog'."
)
[docs]
def audspace(
fmin: Union[float, int, torch.Tensor],
fmax: Union[float, int, torch.Tensor],
num_channels: int,
scale: str = "erb",
):
"""
Computes a vector of values equidistantly spaced on the selected auditory scale.
Parameters:
fmin (float): Minimum frequency in Hz.
fmax (float): Maximum frequency in Hz.
num_channels (int): Number of points in the output vector.
audscale (str): Auditory scale (default is 'erb').
Returns:
tuple:
y (ndarray): Array of frequencies equidistantly scaled on the auditory scale.
"""
if num_channels <= 0:
raise ValueError("n must be a positive integer scalar.")
if fmin > fmax:
raise ValueError("fmin must be less than or equal to fmax.")
# Convert [fmin, fmax] to auditory scale
if scale == "log10" or scale == "elelog":
fmin = torch.maximum(torch.tensor(fmin), torch.ones(1))
audlimits = freqtoaud(torch.tensor([fmin, fmax]), scale)
# Generate frequencies spaced evenly on the auditory scale
aud_space = torch.linspace(audlimits[0], audlimits[1], num_channels)
y = audtofreq(aud_space, scale)
# Ensure exact endpoints
y[0] = fmin
y[-1] = fmax
return y
[docs]
def freqtoaud_mod(
freq: Union[float, int, torch.Tensor],
fc_low: Union[float, int, torch.Tensor],
fc_high: Union[float, int, torch.Tensor],
scale="erb",
fs=None,
):
"""
Modified auditory scale function with linear region below fc_crit.
Parameters:
freq (ndarray): Frequency values in Hz.
fc_low (float): Lower transition frequency in Hz.
fc_high (float): Upper transition frequency in Hz.
Returns:
ndarray:
Values on the modified auditory scale.
"""
aud_crit_low = freqtoaud(fc_low, scale, fs)
aud_crit_high = freqtoaud(fc_high, scale, fs)
slope_low = (freqtoaud(fc_low * 1.01, scale, fs) - aud_crit_low) / (fc_low * 0.01)
slope_high = (freqtoaud(fc_high * 1.01, scale, fs) - aud_crit_high) / (
fc_high * 0.01
)
linear_low = freq < fc_low
linear_high = freq > fc_high
auditory = [not x for x in (linear_low + linear_high)]
aud = torch.zeros_like(freq, dtype=torch.float32)
aud[linear_low] = slope_low * (freq[linear_low] - fc_low) + aud_crit_low
aud[auditory] = freqtoaud(freq[auditory], scale, fs)
aud[linear_high] = slope_high * (freq[linear_high] - fc_high) + aud_crit_high
return aud
[docs]
def audtofreq_mod(
aud: Union[float, int, torch.Tensor],
fc_low: Union[float, int, torch.Tensor],
fc_high: Union[float, int, torch.Tensor],
scale="erb",
fs=None,
):
"""
Inverse of freqtoaud_mod to map auditory scale back to frequency.
Parameters:
aud (ndarray): Auditory scale values.
fc_low (float): Lower transition frequency in Hz.
fc_high (float): Upper transition frequency in Hz.
Returns:
ndarray:
Frequency values in Hz
"""
aud_crit_low = freqtoaud(fc_low, scale, fs)
aud_crit_high = freqtoaud(fc_high, scale, fs)
slope_low = (freqtoaud(fc_low * 1.01, scale, fs) - aud_crit_low) / (fc_low * 0.01)
slope_high = (freqtoaud(fc_high * 1.01, scale, fs) - aud_crit_high) / (
fc_high * 0.01
)
linear_low = aud < aud_crit_low
linear_high = aud > aud_crit_high
auditory_part = [not x for x in (linear_low + linear_high)]
freq = torch.zeros_like(aud, dtype=torch.float32)
freq[linear_low] = (aud[linear_low] - aud_crit_low) / slope_low + fc_low
freq[auditory_part] = audtofreq(aud[auditory_part], scale, fs)
freq[linear_high] = (aud[linear_high] - aud_crit_high) / slope_high + fc_high
return freq
[docs]
def audspace_mod(
fc_low: Union[float, int, torch.Tensor],
fc_high: Union[float, int, torch.Tensor],
fs: int,
num_channels: int,
scale: str = "erb",
):
"""Generate M frequency samples that are equidistant in the modified auditory scale.
Parameters:
fc_crit (float): Critical frequency in Hz.
fs (int): Sampling rate in Hz.
M (int): Number of filters/channels.
Returns:
ndarray:
Frequency values in Hz and in the auditory scale.
"""
if fc_low > fc_high:
raise ValueError("fc_low must be less than fc_high.")
elif fc_low == fc_high:
# equidistant samples form 0 to fs/2
fc = torch.linspace(0, fs // 2, num_channels)
return fc, freqtoaud_mod(fc, fs // 2, fs // 2, scale, fs)
elif fc_low < fc_high:
# Convert [0, fs//2] to modified auditory scale
aud_min = freqtoaud_mod(torch.tensor([0]), fc_low, fc_high, scale, fs)[0]
aud_max = freqtoaud_mod(torch.tensor([fs // 2]), fc_low, fc_high, scale, fs)[0]
# Generate frequencies spaced evenly on the modified auditory scale
fc_aud = torch.linspace(aud_min, aud_max, num_channels)
# Convert back to frequency scale
fc = audtofreq_mod(fc_aud, fc_low, fc_high, scale, fs)
# Ensure exact endpoints
fc[0] = 0
fc[-1] = fs // 2
return fc, fc_aud
else:
raise ValueError("There is something wrong with fc_low and fc_high.")
[docs]
def fctobw(fc: Union[float, int, torch.Tensor], scale="erb"):
"""
Computes the critical bandwidth of a filter at a given center frequency.
Parameters:
fc (float or ndarray): Center frequency in Hz. Must be non-negative.
audscale (str): Auditory scale. Supported values are:
- 'erb': Equivalent Rectangular Bandwidth (default)
- 'bark': Bark scale
- 'mel': Mel scale
- 'log10': Logarithmic scale
Returns:
ndarray or float:
Critical bandwidth at each center frequency.
"""
if isinstance(fc, (list, tuple, int, float)):
fc = torch.tensor(fc)
if not (isinstance(fc, (float, int, torch.Tensor)) and torch.all(fc >= 0)):
raise ValueError("fc must be a non-negative scalar or array.")
# Compute bandwidth based on the auditory scale
if scale == "erb":
bw = 24.7 + fc / 9.265
elif scale == "bark":
bw = 25 + 75 * (1 + 1.4e-6 * fc**2) ** 0.69
elif scale == "mel":
bw = torch.log10(torch.tensor(17 / 7)) * (700 + fc) / 1000
elif scale == "log10":
bw = fc
else:
raise ValueError(f"Unsupported auditory scale: {scale}")
return bw
[docs]
def bwtofc(bw: Union[float, int, torch.Tensor], scale="erb"):
"""
Computes the center frequency corresponding to a given critical bandwidth.
Parameters:
bw (float or ndarray): Critical bandwidth. Must be non-negative.
scale (str): Auditory scale. Supported values are:
- 'erb': Equivalent Rectangular Bandwidth
- 'bark': Bark scale
- 'mel': Mel scale
- 'log10': Logarithmic scale
Returns:
ndarray or float:
Center frequency corresponding to the given bandwidth.
"""
if isinstance(bw, (list, tuple)):
bw = torch.tensor(bw)
if not (isinstance(bw, (float, int, torch.Tensor)) and torch.all(bw >= 0)):
raise ValueError("bw must be a non-negative scalar or array.")
# Compute center frequency based on the auditory scale
if scale == "erb":
fc = (bw - 24.7) * 9.265
elif scale == "bark":
fc = torch.sqrt(((bw - 25) / 75) ** (1 / 0.69) / 1.4e-6)
elif scale == "mel":
fc = 1000 * (bw / torch.log10(torch.tensor(17 / 7))) - 700
elif scale == "log10":
fc = bw
else:
raise ValueError(f"Unsupported auditory scale: {scale}")
return fc
[docs]
def firwin(kernel_size: int, padto: int = None):
"""
FIR window generation in Python.
Parameters:
kernel_size (int): Length of the window.
padto (int): Length to which it should be padded.
name (str): Name of the window.
Returns:
g (ndarray): FIR window.
"""
g = torch.hann_window(kernel_size, periodic=False)
g /= torch.sum(torch.abs(g))
if padto is None or padto == kernel_size:
return g
elif padto > kernel_size:
g_padded = torch.concatenate([g, torch.zeros(padto - len(g))])
g_centered = torch.roll(g_padded, int((padto - len(g)) // 2))
return g_centered
else:
raise ValueError("padto must be larger than kernel_size.")
[docs]
def modulate(g: torch.Tensor, fc: Union[float, int, torch.Tensor], fs: int):
"""Modulate a filters.
Args:
g (list of torch.Tensor): Filters.
fc (list): Center frequencies.
fs (int): Sampling rate.
Returns:
g_mod (list of torch.Tensor): Modulated filters.
"""
Lg = len(g)
g_mod = g * torch.exp(2 * torch.pi * 1j * fc * torch.arange(Lg) / fs)
return g_mod
####################################################################################################
########################################### ISAC ###################################################
####################################################################################################
[docs]
def audfilters(
kernel_size: Union[int, None] = None,
num_channels: int = 96,
fc_max: Union[float, int, None] = None,
fs: int = None,
L: int = None,
supp_mult: float = 1,
scale: str = "mel",
) -> Tuple[
torch.Tensor,
int,
torch.Tensor,
Union[int, float],
Union[int, float],
int,
int,
int,
torch.Tensor,
]:
"""Generate auditory-inspired FIR filterbank kernels.
Creates a bank of bandpass filters with center frequencies distributed according
to perceptual auditory scales (mel, erb, bark, etc.). Filters are designed with
variable bandwidths matching critical bands of human auditory perception.
Args:
kernel_size (int, optional): Maximum filter kernel size. If None, computed automatically. Default: None
num_channels (int): Number of frequency channels. Default: 96
fc_max (float, optional): Maximum center frequency in Hz. If None, uses fs//2. Default: None
fs (int): Sampling frequency in Hz. Default: None (required)
L (int): Signal length in samples. If None, uses fs. Default: None
supp_mult (float): Support multiplier for kernel sizing. Default: 1.0
scale (str): Auditory scale. One of {'mel', 'erb', 'bark', 'log10', 'elelog'}. Default: 'mel'
Returns:
Tuple containing:
- kernels (torch.Tensor): Filter kernels of shape (num_channels, kernel_size)
- d (int): Recommended stride for 25% overlap
- fc (torch.Tensor): Center frequencies in Hz
- fc_min (Union[int, float]): Minimum center frequency
- fc_max (Union[int, float]): Maximum center frequency
- kernel_min (int): Minimum kernel size used
- kernel_size (int): Maximum kernel size used
- Ls (int): Adjusted signal length
- tsupp (torch.Tensor): Time support for each filter
Raises:
ValueError: If parameters are invalid (negative values, unsupported scale, etc.)
Note:
The filterbank construction follows auditory modeling principles where:
- Low frequencies use longer filters (better frequency resolution)
- High frequencies use shorter filters (better time resolution)
- Bandwidth scales according to critical band theory
Example:
>>> kernels, stride, fc, _, _, _, _, Ls, _ = audfilters(
... kernel_size=128, num_channels=40, fs=16000, scale='mel'
... )
>>> print(f"Generated {kernels.shape[0]} filters with stride {stride}")
"""
# check if all inputs are valid
if kernel_size is not None and kernel_size <= 0:
raise ValueError("kernel_size must be a positive integer.")
if num_channels <= 0:
raise ValueError("num_channels must be a positive integer.")
# check if fs is a positive integer
if fs is None:
raise ValueError("sampling rate must be set.")
if not isinstance(fs, int) or fs <= 0:
raise ValueError("fs must be a positive integer.")
if L is None:
L = fs
if not isinstance(L, int) or L <= 0:
raise ValueError("L must be a positive integer.")
if supp_mult < 0:
raise ValueError("supp_mult must be a non-negative float.")
if scale not in ["mel", "erb", "bark", "log10", "elelog"]:
raise ValueError(
"scale must be one of 'mel', 'erb', 'bark', 'log10', or 'elelog'."
)
if fc_max is not None and (fc_max <= 0 or fc_max >= fs // 2):
raise ValueError("fc_max must be a positive integer less than fs/2.")
####################################################################################################
# Bandwidth conversion
####################################################################################################
probeLs = 10000
probeLg = 1000
g_probe = firwin(probeLg, probeLs)
# peak normalize
gf_probe = torch.real(
torch.fft.fft(g_probe) / torch.max(torch.abs(torch.fft.fft(g_probe)))
)
bw_probe = torch.norm(gf_probe) ** 2 * probeLg / probeLs / 2
# preset bandwidth factors to get a good condition number
if scale == "erb":
bw_factor = 0.608
elif scale == "mel":
bw_factor = 111.33
elif scale == "log10":
bw_factor = 0.2
elif scale == "bark":
bw_factor = 1.0
elif scale == "elelog":
bw_factor = 1
bw_conversion = bw_probe / bw_factor # * num_channels / 40
####################################################################################################
# Center frequencies
####################################################################################################
# checking the maximum kernel size
if scale == "elelog":
cycles = 10
kernel_max = fs // 10 * cycles # capture frequencies of 10Hz for 10 cycles
if kernel_size is None:
kernel_size = kernel_max
fc_min = 10
if fc_max is None:
fc_max = fs // 2
kernel_min = int(fs / fc_max * cycles)
else:
fsupp_min = fctobw(0, scale)
kernel_max = int(
torch.minimum(torch.round(bw_conversion / fsupp_min * fs), torch.tensor(L))
)
if kernel_size is None:
kernel_size = kernel_max
if kernel_size > kernel_max:
bw_factor = bw_probe / kernel_size / fsupp_min * fs
bw_conversion = bw_probe / bw_factor # * num_channels / 40
fsupp_min = fctobw(0, scale)
kernel_max = int(
torch.minimum(
torch.round(bw_conversion / fsupp_min * fs), torch.tensor(L)
)
)
# get the bandwidth for the maximum kernel size and the associated center frequency
fsupp_low = bw_conversion / kernel_size * fs
fc_min = bwtofc(fsupp_low, scale)
if fc_max is None:
fc_max = fs // 2
# get the bandwidth for the maximum center frequency and the associated kernel size
fsupp_high = fctobw(fc_max, scale)
kernel_min = int(torch.round(bw_conversion / fsupp_high * fs))
if fc_min >= fc_max:
fc_max = fc_min
kernel_min = kernel_size
Warning(
f"fc_max was increased to {fc_min} to enable the kernel size of {kernel_size}."
)
# get center frequencies
[fc, _] = audspace_mod(fc_min, fc_max, fs, num_channels, scale)
num_low = torch.where(fc < fc_min)[0].shape[0]
num_high = torch.where(fc > fc_max)[0].shape[0]
num_aud = num_channels - num_low - num_high
####################################################################################################
# Frequency and time supports
####################################################################################################
# get time supports
tsupp_low = (torch.ones(num_low) * kernel_size).int()
tsupp_high = torch.ones(num_high) * kernel_min
if scale == "elelog":
tsupp_aud = (
torch.minimum(
torch.tensor(kernel_size),
torch.round(fs / fc[num_low : num_low + num_aud] * cycles),
)
).int()
tsupp = torch.concatenate([tsupp_low, tsupp_aud, tsupp_high]).int()
else:
if num_low + num_high == num_channels:
fsupp = fctobw(fc_max, scale)
tsupp = tsupp_low
else:
fsupp = fctobw(fc[num_low : num_low + num_aud], scale)
tsupp_aud = torch.round(bw_conversion / fsupp * fs)
tsupp = torch.concatenate([tsupp_low, tsupp_aud, tsupp_high]).int()
if supp_mult < 1:
tsupp = torch.max(
torch.round(tsupp * supp_mult), torch.ones_like(tsupp) * 8
).int()
else:
tsupp = torch.min(
torch.round(tsupp * supp_mult), torch.ones_like(tsupp) * L
).int()
kernel_min = tsupp.min()
kernel_size = tsupp.max()
# Decimation factor (stride) to get a nice frame and according signal length (lcm of d and Ls)
# d = torch.floor(torch.min(fs / fsupp))
d = torch.maximum(kernel_min // 4, torch.tensor(1))
Ls = int(torch.ceil(L / d) * d)
####################################################################################################
# Generate filters
####################################################################################################
g = torch.zeros((num_channels, kernel_size), dtype=torch.cfloat)
g[0, :] = torch.sqrt(d) * firwin(kernel_size) / torch.sqrt(torch.tensor(2))
g[-1, :] = (
torch.sqrt(d)
* modulate(firwin(tsupp[-1], kernel_size), fs // 2, fs)
/ torch.sqrt(torch.tensor(2))
)
for m in range(1, num_channels - 1):
g[m, :] = torch.sqrt(d) * modulate(firwin(tsupp[m], kernel_size), fc[m], fs)
# _, B = frame_bounds(g, d, Ls)
# g = g / B**0.5
return g, int(d), fc, fc_min, fc_max, kernel_min, kernel_size, Ls, tsupp
####################################################################################################
####################################################################################################
####################################################################################################
[docs]
def response(g: np.ndarray, fs: int) -> np.ndarray:
"""Compute frequency responses of filter kernels.
Args:
g (np.ndarray): Filter kernels of shape (num_channels, kernel_size)
fs (int): Sampling frequency for frequency axis scaling
Returns:
np.ndarray: Magnitude-squared frequency responses of shape (2*num_channels, fs//2)
Note:
Computes responses for both analysis and conjugate filters.
"""
g_full = np.concatenate([g, np.conj(g)], axis=0)
G = np.abs(np.fft.fft(g_full, fs, axis=1)[:, : fs // 2]) ** 2
return G
[docs]
def plot_response(
g: np.ndarray,
fs: int,
scale: str = "mel",
plot_scale: bool = False,
fc_min: Union[float, None] = None,
fc_max: Union[float, None] = None,
kernel_min: Union[int, None] = None,
decoder: bool = False,
) -> None:
"""Plot frequency responses and auditory scale visualization of filters.
Creates comprehensive visualization showing individual filter responses,
total power spectral density, and optional auditory scale mapping.
Args:
g (np.ndarray): Filter kernels of shape (num_channels, kernel_size)
fs (int): Sampling frequency in Hz for frequency axis scaling
scale (str): Auditory scale name for scale plotting. Default: 'mel'
plot_scale (bool): Whether to plot the auditory scale mapping. Default: False
fc_min (float, optional): Lower transition frequency for scale visualization. Default: None
fc_max (float, optional): Upper transition frequency for scale visualization. Default: None
kernel_min (int, optional): Minimum kernel size for annotations. Default: None
decoder (bool): Whether filters are for synthesis (affects plot titles). Default: False
Note:
This function displays plots and does not return values.
Creates 2-3 subplots depending on plot_scale parameter.
Example:
>>> filters = np.random.randn(40, 128)
>>> plot_response(filters, fs=16000, scale='mel', plot_scale=True)
"""
num_channels = g.shape[0]
g_hat = response(g, fs)
g_hat_pos = g_hat[:num_channels, :]
g_hat_pos[np.isnan(g_hat_pos)] = 0
psd = np.sum(g_hat, axis=0)
psd[np.isnan(psd)] = 0
if plot_scale:
plt.figure(figsize=(8, 2))
freq_samples, _ = audspace_mod(fc_min, fc_max, fs, num_channels, scale)
freqs = torch.linspace(0, fs // 2, fs // 2)
auds = freqtoaud_mod(freqs, fc_min, fc_max, scale, fs).numpy()
auds_orig = freqtoaud(freqs, scale, fs).numpy()
plt.scatter(
freq_samples.numpy(),
freqtoaud_mod(freq_samples, fc_min, fc_max, scale, fs).numpy(),
color="black",
label="Center frequencies",
linewidths=0.04,
)
plt.plot(freqs, auds, color="black", label=f"ISAC {scale}-scale")
plt.plot(
freqs,
auds_orig,
color="black",
linestyle="--",
alpha=0.5,
label=f"Original {scale}-scale",
)
if fc_min is not None:
plt.axvline(fc_min, color="black", alpha=0.25)
plt.fill_betweenx(
y=[auds[0] - 1, auds[-1] * 1.1],
x1=0,
x2=fc_min,
color="gray",
alpha=0.25,
)
plt.fill_betweenx(
y=[auds[0] - 1, auds[-1] * 1.1],
x1=fc_min,
x2=fs // 2,
color="gray",
alpha=0.1,
)
if fc_max is not None:
plt.axvline(fc_max, color="black", alpha=0.25)
plt.fill_betweenx(
y=[auds[0] - 1, auds[-1] * 1.1],
x1=0,
x2=fc_max,
color="gray",
alpha=0.25,
)
plt.fill_betweenx(
y=[auds[0] - 1, auds[-1] * 1.1],
x1=fc_max,
x2=fs // 2,
color="gray",
alpha=0.1,
)
plt.xlim([0, fs // 2])
plt.ylim([auds[0] - 1, auds[-1] * 1.1])
plt.xlabel("Frequency (Hz)")
# text_x = fc_min / 2
# text_y = auds[-1]
# plt.text(text_x, text_y, 'linear', color='black', ha='center', va='center', fontsize=12, alpha=0.75)
# plt.text(text_x + fc_min - 1, text_y, 'ERB', color='black', ha='center', va='center', fontsize=12, alpha=0.75)
# plt.title(f"ISAC {scale}-scale")
plt.ylabel("Auditory Units")
plt.legend(loc="lower right")
plt.tight_layout()
plt.show()
fig, ax = plt.subplots(2, 1, figsize=(6, 3), sharex=True)
fr_id = 0
psd_id = 1
f_range = np.linspace(0, fs // 2, fs // 2)
ax[fr_id].set_xlim([0, fs // 2])
ax[fr_id].set_ylim([0, np.max(g_hat_pos) * 1.1])
ax[fr_id].plot(f_range, g_hat_pos.T)
if decoder:
ax[fr_id].set_title("PSDs of the synthesis filters")
if not decoder:
ax[fr_id].set_title("PSDs of the analysis filters")
# ax[fr_id].set_xlabel('Frequency [Hz]')
ax[fr_id].set_ylabel("Magnitude")
ax[psd_id].plot(f_range, psd)
ax[psd_id].set_xlim([0, fs // 2])
ax[psd_id].set_ylim([0, np.max(psd) * 1.1])
ax[psd_id].set_title("Total PSD")
ax[psd_id].set_xlabel("Frequency [Hz]")
ax[psd_id].set_ylabel("Magnitude")
if fc_min is not None:
ax[fr_id].fill_betweenx(
y=[0, np.max(g_hat) * 1.1], x1=0, x2=fc_min, color="gray", alpha=0.25
)
ax[fr_id].fill_betweenx(
y=[0, np.max(g_hat) * 1.1], x1=fc_min, x2=fs // 2, color="gray", alpha=0.1
)
ax[psd_id].fill_betweenx(
y=[0, np.max(psd) * 1.1], x1=0, x2=fc_min, color="gray", alpha=0.25
)
ax[psd_id].fill_betweenx(
y=[0, np.max(psd) * 1.1], x1=fc_min, x2=fs // 2, color="gray", alpha=0.1
)
if fc_max is not None:
ax[fr_id].fill_betweenx(
y=[0, np.max(g_hat) * 1.1], x1=0, x2=fc_max, color="gray", alpha=0.25
)
ax[fr_id].fill_betweenx(
y=[0, np.max(g_hat) * 1.1], x1=fc_max, x2=fs // 2, color="gray", alpha=0.1
)
ax[psd_id].fill_betweenx(
y=[0, np.max(psd) * 1.1], x1=0, x2=fc_max, color="gray", alpha=0.25
)
ax[psd_id].fill_betweenx(
y=[0, np.max(psd) * 1.1], x1=fc_max, x2=fs // 2, color="gray", alpha=0.1
)
plt.tight_layout()
plt.show()
[docs]
def ISACgram(
c: torch.Tensor,
fc: Union[torch.Tensor, None] = None,
L: Union[int, None] = None,
fs: Union[int, None] = None,
fmax: Union[float, None] = None,
log_scale: bool = False,
vmin: Union[float, None] = None,
cmap: str = "inferno",
) -> None:
"""Plot time-frequency representation of filterbank coefficients.
Creates a spectrogram-like visualization with frequency on y-axis and time on x-axis.
Supports logarithmic scaling and frequency range limitation for better visualization.
Args:
c (torch.Tensor): Filterbank coefficients of shape (batch_size, num_channels, num_frames)
fc (torch.Tensor, optional): Center frequencies in Hz for y-axis labeling. Default: None
L (int, optional): Original signal length for time axis scaling. Default: None
fs (int, optional): Sampling frequency for time axis scaling. Default: None
fmax (float, optional): Maximum frequency to display in Hz. Default: None
log_scale (bool): Whether to apply log10 scaling to coefficients. Default: False
vmin (float, optional): Minimum value for dynamic range clipping. Default: None
cmap (str): Matplotlib colormap name. Default: 'inferno'
Note:
This function displays a plot and does not return values.
Only processes the first batch element if batch_size > 1.
Example:
>>> coeffs = torch.randn(1, 40, 250)
>>> fc = torch.linspace(100, 8000, 40)
>>> ISACgram(coeffs, fc=fc, L=16000, fs=16000, log_scale=True)
"""
fig, ax = plt.subplots(figsize=(10, 4))
c = c[0].detach().cpu().numpy()
if log_scale:
c = np.log10(np.abs(c) + 1e-10)
if fc is not None and fmax is not None:
c = c[: np.argmax(fc > fmax), :]
if vmin is not None:
mesh = ax.pcolor(c, cmap=cmap, vmin=np.min(c) * vmin)
else:
mesh = ax.pcolor(c, cmap=cmap)
# Add colorbar
fig.colorbar(mesh, ax=ax)
# Y-axis: frequencies
if fc is not None:
locs = np.linspace(0, c.shape[0] - 1, min(len(fc), 10)).astype(int)
ax.set_yticks(locs)
ax.set_yticklabels([int(np.round(fc[i])) for i in locs])
# X-axis: time
num_time_labels = 10
xticks = np.linspace(0, c.shape[1] - 1, num_time_labels)
ax.set_xticks(xticks)
ax.set_xticklabels(
[np.round(x, 1) for x in np.linspace(0, L // fs, num_time_labels)]
)
ax.set_ylabel("Frequency [Hz]")
ax.set_xlabel("Time [s]")
else:
ax.set_ylabel("Frequency index")
ax.set_xlabel("Time samples")
plt.tight_layout()
plt.show()