Source code for hybra.isac_mel

from typing import Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from hybra.utils import audfilters
from hybra.utils import plot_response as plot_response_
from hybra.utils import ISACgram as ISACgram_

[docs] class ISACMelSpectrogram(nn.Module):
[docs] def __init__(self, kernel_size:Union[int,None]=None, num_channels:int=40, fc_max:Union[float,int,None]=None, stride:Union[int,None]=None, fs:int=16000, L:int=16000, bw_multiplier:float=1, scale:str='erb', is_encoder_learnable=False, is_averaging_kernel_learnable=False, is_log=False): super().__init__() [kernels, d, fc, fc_min, fc_max, kernel_min, kernel_size, Ls] = audfilters( kernel_size=kernel_size,num_channels=num_channels, fc_max=fc_max, fs=fs,L=L,bw_multiplier=bw_multiplier,scale=scale ) if stride is not None: d = stride Ls = int(torch.ceil(torch.tensor(L / d)) * d) print(f"The output length is set to {Ls}.") self.kernels = kernels self.stride = d self.kernel_size = kernel_size self.kernel_min = kernel_min self.fs = fs self.fc = fc self.fc_min = fc_min self.fc_max = fc_max self.num_channels = num_channels self.Ls = Ls self.time_avg = self.kernel_size // self.stride self.time_avg_stride = self.time_avg // 2 kernels_real = kernels.real.to(torch.float32) kernels_imag = kernels.imag.to(torch.float32) self.is_log = is_log if is_encoder_learnable: self.register_parameter('kernels_real', nn.Parameter(kernels_real, requires_grad=True)) self.register_parameter('kernels_imag', nn.Parameter(kernels_imag, requires_grad=True)) else: self.register_buffer('kernels_real', kernels_real) self.register_buffer('kernels_imag', kernels_imag) if is_averaging_kernel_learnable: self.register_parameter('averaging_kernel', nn.Parameter(torch.ones([self.num_channels,1,self.time_avg]), requires_grad=True)) else: self.register_buffer('averaging_kernel', torch.ones([self.num_channels,1,self.time_avg]))
[docs] def forward(self, x:torch.Tensor) -> torch.Tensor: x = F.conv1d( F.pad(x, (self.kernel_size//2, self.kernel_size//2), mode='circular'), self.kernels_real.unsqueeze(1), stride=self.stride, )**2 + F.conv1d( F.pad(x, (self.kernel_size//2,self.kernel_size//2), mode='circular'), self.kernels_imag.unsqueeze(1), stride=self.stride, )**2 output = F.conv1d( x, self.averaging_kernel.to(x.device), groups=self.num_channels, stride=self.time_avg_stride ) if self.is_log: output = torch.log10(output) return output
[docs] def ISACgram(self, x): with torch.no_grad(): coefficients = self.forward(x) ISACgram_(coefficients, self.fc, self.Ls, self.fs)
[docs] def plot_response(self): plot_response_(g=(self.kernels_real + 1j*self.kernels_imag).detach().numpy(), fs=self.fs, scale=True, fc_min=self.fc_min, fc_max=self.fc_max, kernel_min=self.kernel_min)