Source code for hybra.isac

from typing import Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from hybra.utils import audfilters, condition_number
from hybra.utils import plot_response as plot_response_
from hybra.utils import ISACgram as ISACgram_
from hybra._fit_dual import fit, tight

[docs] class ISAC(nn.Module):
[docs] def __init__(self, kernel_size:Union[int,None]=128, num_channels:int=40, fc_max:Union[float,int,None]=None, stride:int=None, fs:int=16000, L:int=16000, supp_mult:float=1, scale:str='erb', tighten=False, is_encoder_learnable=False, use_decoder=False, is_decoder_learnable=False,): super().__init__() [kernels, d, fc, fc_min, fc_max, kernel_min, kernel_size, Ls] = audfilters( kernel_size=kernel_size, num_channels=num_channels, fc_max=fc_max, fs=fs, L=L, supp_mult=supp_mult, scale=scale ) print(f"Max kernel size: {kernel_size}") if stride is not None: if stride > d: print(f"Using stride {stride} instead of the optimal {d} may affect the condition number 🌪️.") d = stride Ls = int(torch.ceil(torch.tensor(L / d)) * d) print(f"Output length: {Ls}") else: print(f"Optimal stride: {d}\nOutput length: {Ls}") self.kernels = kernels self.stride = d self.fc = fc self.fc_min = fc_min self.fc_max = fc_max self.kernel_min = kernel_min self.kernel_size = kernel_size self.Ls = Ls self.fs = fs self.scale = scale k_real = kernels.real.to(torch.float32) k_imag = kernels.imag.to(torch.float32) if tighten: max_iter = 1000 fit_eps = 1.01 k_real, k_imag, _ = tight(k_real+1j*k_imag, d, Ls, fs, fit_eps, max_iter) if is_encoder_learnable: self.register_parameter('kernels_real', nn.Parameter(k_real, requires_grad=True)) self.register_parameter('kernels_imag', nn.Parameter(k_imag, requires_grad=True)) else: self.register_buffer('kernels_real', k_real) self.register_buffer('kernels_imag', k_imag) self.use_decoder = use_decoder if use_decoder: max_iter = 1000 # TODO: should we do something like that? decoder_fit_eps = 1e-6 decoder_kernels_real, decoder_kernels_imag, _, _ = fit(k_real+1j*k_imag, d, Ls, fs, decoder_fit_eps, max_iter) if is_decoder_learnable: self.register_parameter('decoder_kernels_real', nn.Parameter(decoder_kernels_real, requires_grad=True)) self.register_parameter('decoder_kernels_imag', nn.Parameter(decoder_kernels_imag, requires_grad=True)) else: self.register_buffer('decoder_kernels_real', decoder_kernels_real) self.register_buffer('decoder_kernels_imag', decoder_kernels_imag)
[docs] def forward(self, x): x = F.pad(x.unsqueeze(1), (self.kernel_size//2, self.kernel_size//2), mode='circular') out_real = F.conv1d(x, self.kernels_real.to(x.device).unsqueeze(1), stride=self.stride) out_imag = F.conv1d(x, self.kernels_imag.to(x.device).unsqueeze(1), stride=self.stride) return out_real + 1j * out_imag
[docs] def decoder(self, x_real:torch.Tensor, x_imag:torch.Tensor) -> torch.Tensor: """Filterbank synthesis. Parameters: ----------- x (torch.Tensor) - input tensor of shape (batch_size, num_channels, signal_length//hop_length) Returns: -------- x (torch.Tensor) - output tensor of shape (batch_size, signal_length) """ L_in = x_real.shape[-1] L_out = self.Ls kernel_size = self.kernel_size padding = kernel_size // 2 # L_out = (L_in -1) * stride - 2 * padding + dialation * (kernel_size - 1) + output_padding + 1 ; dialation = 1 output_padding = L_out - (L_in - 1) * self.stride + 2 * padding - kernel_size x = ( F.conv_transpose1d( x_real, self.decoder_kernels_real.to(x_real.device).unsqueeze(1), stride=self.stride, padding=padding, output_padding=output_padding ) + F.conv_transpose1d( x_imag, self.decoder_kernels_imag.to(x_imag.device).unsqueeze(1), stride=self.stride, padding=padding, output_padding=output_padding ) ) return x.squeeze(1)
[docs] def plot_response(self): plot_response_(g=(self.kernels_real + 1j*self.kernels_imag).cpu().detach().numpy(), fs=self.fs, scale=self.scale, plot_scale=True, fc_min=self.fc_min, fc_max=self.fc_max, kernel_min=self.kernel_min)
[docs] def plot_decoder_response(self): if self.use_decoder: plot_response_(g=(self.decoder_kernels_real+1j*self.decoder_kernels_imag).detach().cpu().numpy(), fs=self.fs, scale=self.scale, decoder=True) else: raise NotImplementedError("No decoder configured")
[docs] def ISACgram(self, x): with torch.no_grad(): coefficients = self.forward(x) ISACgram_(coefficients, self.fc, self.Ls, self.fs)
@property def condition_number(self): kernels = (self.kernels_real + 1j*self.kernels_imag).squeeze() #kernels = F.pad(kernels, (0, self.Ls - kernels.shape[-1]), mode='constant', value=0) return condition_number(kernels, int(self.stride), self.Ls) @property def condition_number_decoder(self): kernels = (self.decoder_kernels_real + 1j*self.decoder_kernels_imag).squeeze() #kernels = F.pad(kernels, (0, self.Ls - kernels.shape[-1]), mode='constant', value=0) return condition_number(kernels, int(self.stride), self.Ls)