import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import sys
from torch.autograd import Variable
import torch.optim as optim
import math
import os
import numpy as np
import torch
import torch.nn as nn
import skimage.io
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from random import sample
from sklearn.metrics import normalized_mutual_info_score
def flip(x, dim):
xsize = x.size()
dim = x.dim() + dim if dim < 0 else dim
x = x.contiguous()
x = x.view(-1, *xsize[dim:])
x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1,
-1, -1), ('cpu','cuda')[x.is_cuda])().long(), :]
return x.view(xsize)
def sinc(band,t_right):
y_right= torch.sin(2*math.pi*band*t_right)/(2*math.pi*band*t_right)
y_left= flip(y_right,0)
y=torch.cat([y_left,Variable(torch.ones(1)).cuda(),y_right])
return y
class SincConv_fast(nn.Module):
"""Sinc-based convolution
Parameters
----------
in_channels : `int`
Number of input channels. Must be 1.
out_channels : `int`
Number of filters.
kernel_size : `int`
Filter length.
sample_rate : `int`, optional
Sample rate. Defaults to 16000.
Usage
-----
See `torch.nn.Conv1d`
Reference
---------
Mirco Ravanelli, Yoshua Bengio,
"Speaker Recognition from raw waveform with SincNet".
https://arxiv.org/abs/1808.00158
"""
@staticmethod
def to_mel(hz):
return 2595 * np.log10(1 + hz / 700)
@staticmethod
def to_hz(mel):
return 700 * (10 ** (mel / 2595) - 1)
def __init__(self, out_channels, kernel_size, sample_rate=16000, in_channels=1,
stride=1, padding=0, dilation=1, bias=False, groups=1, min_low_hz=50, min_band_hz=50):
super(SincConv_fast,self).__init__()
if in_channels != 1:
#msg = (f'SincConv only support one input channel '
# f'(here, in_channels = {in_channels:d}).')
msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels)
raise ValueError(msg)
self.out_channels = out_channels
self.kernel_size = kernel_size
# Forcing the filters to be odd (i.e, perfectly symmetrics)
if kernel_size%2==0:
self.kernel_size=self.kernel_size+1
self.stride = stride
self.padding = padding
self.dilation = dilation
if bias:
raise ValueError('SincConv does not support bias.')
if groups > 1:
raise ValueError('SincConv does not support groups.')
self.sample_rate = sample_rate
self.min_low_hz = min_low_hz
self.min_band_hz = min_band_hz
# initialize filterbanks such that they are equally spaced in Mel scale
low_hz = 30
high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz)
mel = np.linspace(self.to_mel(low_hz),
self.to_mel(high_hz),
self.out_channels + 1)
hz = self.to_hz(mel)
# filter lower frequency (out_channels, 1)
self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1))
# filter frequency band (out_channels, 1)
self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1))
# Hamming window
#self.window_ = torch.hamming_window(self.kernel_size)
n_lin=torch.linspace(0, (self.kernel_size/2)-1, steps=int((self.kernel_size/2))) # computing only half of the window
self.window_=0.54-0.46*torch.cos(2*math.pi*n_lin/self.kernel_size);
# (1, kernel_size/2)
n = (self.kernel_size - 1) / 2.0
self.n_ = 2*math.pi*torch.arange(-n, 0).view(1, -1) / self.sample_rate # Due to symmetry, I only need half of the time axes
def forward(self, waveforms):
"""
Parameters
----------
waveforms : `torch.Tensor` (batch_size, 1, n_samples)
Batch of waveforms.
Returns
-------
features : `torch.Tensor` (batch_size, out_channels, n_samples_out)
Batch of sinc filters activations.
"""
self.n_ = self.n_.to(waveforms.device)
self.window_ = self.window_.to(waveforms.device)
low = self.min_low_hz + torch.abs(self.low_hz_)
high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_),self.min_low_hz,self.sample_rate/2)
band=(high-low)[:,0]
f_times_t_low = torch.matmul(low, self.n_)
f_times_t_high = torch.matmul(high, self.n_)
band_pass_left=((torch.sin(f_times_t_high)-torch.sin(f_times_t_low))/(self.n_/2))*self.window_ # Equivalent of Eq.4 of the reference paper (SPEAKER RECOGNITION FROM RAW WAVEFORM WITH SINCNET). I just have expanded the sinc and simplified the terms. This way I avoid several useless computations.
band_pass_center = 2*band.view(-1,1)
band_pass_right= torch.flip(band_pass_left,dims=[1])
band_pass=torch.cat([band_pass_left,band_pass_center,band_pass_right],dim=1)
band_pass = band_pass / (2*band[:,None])
self.filters = (band_pass).view(
self.out_channels, 1, self.kernel_size)
return F.conv1d(waveforms, self.filters, stride=self.stride,
padding=self.padding, dilation=self.dilation,
bias=None, groups=1)
class sinc_conv(nn.Module):
def __init__(self, N_filt,Filt_dim,fs):
super(sinc_conv,self).__init__()
# Mel Initialization of the filterbanks
low_freq_mel = 80
high_freq_mel = (2595 * np.log10(1 + (fs / 2) / 700)) # Convert Hz to Mel
mel_points = np.linspace(low_freq_mel, high_freq_mel, N_filt) # Equally spaced in Mel scale
f_cos = (700 * (10**(mel_points / 2595) - 1)) # Convert Mel to Hz
b1=np.roll(f_cos,1)
b2=np.roll(f_cos,-1)
b1[0]=30
b2[-1]=(fs/2)-100
self.freq_scale=fs*1.0
self.filt_b1 = nn.Parameter(torch.from_numpy(b1/self.freq_scale))
self.filt_band = nn.Parameter(torch.from_numpy((b2-b1)/self.freq_scale))
self.N_filt=N_filt
self.Filt_dim=Filt_dim
self.fs=fs
def forward(self, x):
filters=Variable(torch.zeros((self.N_filt,self.Filt_dim))).cuda()
N=self.Filt_dim
t_right=Variable(torch.linspace(1, (N-1)/2, steps=int((N-1)/2))/self.fs).cuda()
min_freq=50.0;
min_band=50.0;
filt_beg_freq=torch.abs(self.filt_b1)+min_freq/self.freq_scale
filt_end_freq=filt_beg_freq+(torch.abs(self.filt_band)+min_band/self.freq_scale)
n=torch.linspace(0, N, steps=N)
# Filter window (hamming)
window=0.54-0.46*torch.cos(2*math.pi*n/N);
window=Variable(window.float().cuda())
for i in range(self.N_filt):
low_pass1 = 2*filt_beg_freq[i].float()*sinc(filt_beg_freq[i].float()*self.freq_scale,t_right)
low_pass2 = 2*filt_end_freq[i].float()*sinc(filt_end_freq[i].float()*self.freq_scale,t_right)
band_pass=(low_pass2-low_pass1)
band_pass=band_pass/torch.max(band_pass)
filters[i,:]=band_pass.cuda()*window
out=F.conv1d(x, filters.view(self.N_filt,1,self.Filt_dim))
return out
def act_fun(act_type):
if act_type=="relu":
return nn.ReLU()
if act_type=="tanh":
return nn.Tanh()
if act_type=="sigmoid":
return nn.Sigmoid()
if act_type=="leaky_relu":
return nn.LeakyReLU(0.2)
if act_type=="elu":
return nn.ELU()
if act_type=="softmax":
return nn.LogSoftmax(dim=1)
if act_type=="linear":
return nn.LeakyReLU(1) # initializzed like this, but not used in forward!
class LayerNorm(nn.Module):
def __init__(self, features, eps=1e-6):
super(LayerNorm,self).__init__()
self.gamma = nn.Parameter(torch.ones(features))
self.beta = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.gamma * (x - mean) / (std + self.eps) + self.beta
class MLP(nn.Module):
def __init__(self, options):
super(MLP, self).__init__()
self.input_dim=int(options['input_dim'])
self.fc_lay=options['fc_lay']
self.fc_drop=options['fc_drop']
self.fc_use_batchnorm=options['fc_use_batchnorm']
self.fc_use_laynorm=options['fc_use_laynorm']
self.fc_use_laynorm_inp=options['fc_use_laynorm_inp']
self.fc_use_batchnorm_inp=options['fc_use_batchnorm_inp']
self.fc_act=options['fc_act']
self.wx = nn.ModuleList([])
self.bn = nn.ModuleList([])
self.ln = nn.ModuleList([])
self.act = nn.ModuleList([])
self.drop = nn.ModuleList([])
# input layer normalization
if self.fc_use_laynorm_inp:
self.ln0=LayerNorm(self.input_dim)
# input batch normalization
if self.fc_use_batchnorm_inp:
self.bn0=nn.BatchNorm1d([self.input_dim],momentum=0.05)
self.N_fc_lay=len(self.fc_lay)
current_input=self.input_dim
# Initialization of hidden layers
for i in range(self.N_fc_lay):
# dropout
self.drop.append(nn.Dropout(p=self.fc_drop[i]))
# activation
self.act.append(act_fun(self.fc_act[i]))
add_bias=True
# layer norm initialization
self.ln.append(LayerNorm(self.fc_lay[i]))
self.bn.append(nn.BatchNorm1d(self.fc_lay[i],momentum=0.05))
if self.fc_use_laynorm[i] or self.fc_use_batchnorm[i]:
add_bias=False
# Linear operations
self.wx.append(nn.Linear(current_input, self.fc_lay[i],bias=add_bias))
# weight initialization
self.wx[i].weight = torch.nn.Parameter(torch.Tensor(self.fc_lay[i],current_input).uniform_(-np.sqrt(0.01/(current_input+self.fc_lay[i])),np.sqrt(0.01/(current_input+self.fc_lay[i]))))
self.wx[i].bias = torch.nn.Parameter(torch.zeros(self.fc_lay[i]))
current_input=self.fc_lay[i]
def forward(self, x):
# Applying Layer/Batch Norm
if bool(self.fc_use_laynorm_inp):
x=self.ln0((x))
if bool(self.fc_use_batchnorm_inp):
x=self.bn0((x))
for i in range(self.N_fc_lay):
if self.fc_act[i]!='linear':
if self.fc_use_laynorm[i]:
x = self.drop[i](self.act[i](self.ln[i](self.wx[i](x))))
if self.fc_use_batchnorm[i]:
x = self.drop[i](self.act[i](self.bn[i](self.wx[i](x))))
if self.fc_use_batchnorm[i]==False and self.fc_use_laynorm[i]==False:
x = self.drop[i](self.act[i](self.wx[i](x)))
else:
if self.fc_use_laynorm[i]:
x = self.drop[i](self.ln[i](self.wx[i](x)))
if self.fc_use_batchnorm[i]:
x = self.drop[i](self.bn[i](self.wx[i](x)))
if self.fc_use_batchnorm[i]==False and self.fc_use_laynorm[i]==False:
x = self.drop[i](self.wx[i](x))
return x
class SincNet(nn.Module):
def __init__(self,options):
super(SincNet,self).__init__()
self.cnn_N_filt=options['cnn_N_filt']
self.cnn_len_filt=options['cnn_len_filt']
self.cnn_max_pool_len=options['cnn_max_pool_len']
self.cnn_act=options['cnn_act']
self.cnn_drop=options['cnn_drop']
self.cnn_use_laynorm=options['cnn_use_laynorm']
self.cnn_use_batchnorm=options['cnn_use_batchnorm']
self.cnn_use_laynorm_inp=options['cnn_use_laynorm_inp']
self.cnn_use_batchnorm_inp=options['cnn_use_batchnorm_inp']
self.input_dim=int(options['input_dim'])
self.fs=options['fs']
self.N_cnn_lay=len(options['cnn_N_filt'])
self.conv = nn.ModuleList([])
self.bn = nn.ModuleList([])
self.ln = nn.ModuleList([])
self.act = nn.ModuleList([])
self.drop = nn.ModuleList([])
if self.cnn_use_laynorm_inp:
self.ln0=LayerNorm(self.input_dim)
if self.cnn_use_batchnorm_inp:
self.bn0=nn.BatchNorm1d([self.input_dim],momentum=0.05)
current_input=self.input_dim
for i in range(self.N_cnn_lay):
N_filt=int(self.cnn_N_filt[i])
len_filt=int(self.cnn_len_filt[i])
# dropout
self.drop.append(nn.Dropout(p=self.cnn_drop[i]))
# activation
self.act.append(act_fun(self.cnn_act[i]))
# layer norm initialization
self.ln.append(LayerNorm([N_filt,int((current_input-self.cnn_len_filt[i]+1)/self.cnn_max_pool_len[i])]))
self.bn.append(nn.BatchNorm1d(N_filt,int((current_input-self.cnn_len_filt[i]+1)/self.cnn_max_pool_len[i]),momentum=0.05))
if i==0:
self.conv.append(SincConv_fast(self.cnn_N_filt[0],self.cnn_len_filt[0],self.fs))
else:
self.conv.append(nn.Conv1d(self.cnn_N_filt[i-1], self.cnn_N_filt[i], self.cnn_len_filt[i]))
current_input=int((current_input-self.cnn_len_filt[i]+1)/self.cnn_max_pool_len[i])
self.out_dim=current_input*N_filt
def forward(self, x):
batch=x.shape[0]
seq_len=x.shape[1]
if bool(self.cnn_use_laynorm_inp):
x=self.ln0((x))
if bool(self.cnn_use_batchnorm_inp):
x=self.bn0((x))
x=x.view(batch,1,seq_len)
for i in range(self.N_cnn_lay):
if self.cnn_use_laynorm[i]:
if i==0:
x = self.drop[i](self.act[i](self.ln[i](F.max_pool1d(torch.abs(self.conv[i](x)), self.cnn_max_pool_len[i]))))
else:
x = self.drop[i](self.act[i](self.ln[i](F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i]))))
if self.cnn_use_batchnorm[i]:
x = self.drop[i](self.act[i](self.bn[i](F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i]))))
if self.cnn_use_batchnorm[i]==False and self.cnn_use_laynorm[i]==False:
x = self.drop[i](self.act[i](F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i])))
x = x.view(batch,-1)
return x
class MutualInformation(nn.Module):
def __init__(self, sigma=0.1, num_bins=256, normalize=True):
super(MutualInformation, self).__init__()
self.sigma = sigma
self.num_bins = num_bins
self.normalize = normalize
self.epsilon = 1e-10
self.bins = nn.Parameter(torch.linspace(0, 255, num_bins, device=device).float(), requires_grad=False)
def marginalPdf(self, values):
residuals = values - self.bins.unsqueeze(0).unsqueeze(0)
kernel_values = torch.exp(-0.5*(residuals / self.sigma).pow(2))
pdf = torch.mean(kernel_values, dim=1)
normalization = torch.sum(pdf, dim=1).unsqueeze(1) + self.epsilon
pdf = pdf / normalization
return pdf, kernel_values
def jointPdf(self, kernel_values1, kernel_values2):
joint_kernel_values = torch.matmul(kernel_values1.transpose(1, 2), kernel_values2)
normalization = torch.sum(joint_kernel_values, dim=(1,2)).view(-1, 1, 1) + self.epsilon
pdf = joint_kernel_values / normalization
return pdf
def getMutualInformation(self, input1, input2):
'''
input1: B, C, H, W
input2: B, C, H, W
return: scalar
'''
# Torch tensors for images between (0, 1)
input1 = input1*255
input2 = input2*255
B, C, H, W = input1.shape
assert((input1.shape == input2.shape))
x1 = input1.view(B, H*W, C)
x2 = input2.view(B, H*W, C)
pdf_x1, kernel_values1 = self.marginalPdf(x1)
pdf_x2, kernel_values2 = self.marginalPdf(x2)
pdf_x1x2 = self.jointPdf(kernel_values1, kernel_values2)
H_x1 = -torch.sum(pdf_x1*torch.log2(pdf_x1 + self.epsilon), dim=1)
H_x2 = -torch.sum(pdf_x2*torch.log2(pdf_x2 + self.epsilon), dim=1)
H_x1x2 = -torch.sum(pdf_x1x2*torch.log2(pdf_x1x2 + self.epsilon), dim=(1,2))
mutual_information = H_x1 + H_x2 - H_x1x2
if self.normalize:
mutual_information = 2*mutual_information/(H_x1+H_x2)
return mutual_information
def forward(self, input1, input2):
'''
input1: B, C, H, W
input2: B, C, H, W
return: scalar
'''
return self.getMutualInformation(input1, input2)
class NetworkModel(nn.Module):
def __init__(self, batch_size,chunk_size, N_chunks,CNN_arch):
super(NetworkModel,self).__init__()
self.sincnet1 = SincNet(CNN_arch)
self.sincnet2 = SincNet(CNN_arch)
self.sincnet3 = SincNet(CNN_arch)
self.chunk_size = chunk_size
self.N_chunks = N_chunks
# self.Xploss = nn.BCELoss()
# self.Xnloss = nn.BCELoss()
self.batch_size = batch_size
def forward(self, C1,C2,Crnd): #C1,C2,Crnd formed in the shape of [1,number of samples], C1,C2 from same utterence, and Crnd are different utterences
z1 = torch.unsqueeze(self.sincnet1(C1),0)
z2 = torch.unsqueeze(self.sincnet2(C2),0)
zrnd = torch.unsqueeze(self.sincnet3(Crnd),0)
Xp = torch.cat((z1,z2),0).view(2,self.batch_size,-1,self.chunk_size)
Xprand = torch.randint(Xp.size(dim=2),(self.N_chunks,))
Xp = Xp[:,:,Xprand].view(2,self.batch_size,-1)
Xn = torch.cat((z1,zrnd),0).view(2,self.batch_size,-1,self.chunk_size)
Xnrand = torch.randint(Xn.size(dim=2),(self.N_chunks,))
Xn = Xn[:,:,Xnrand].view(2,self.batch_size,-1)#,self.N_chunks*self.chunk_size)
m = nn.Sigmoid()
Xp = m(Xp)
Xn = m(Xn)
# g_Xp = self.Xploss(m(Xp[0]),m(Xp[1]))
# g_Xn = self.Xnloss(m(Xn[0]),m(Xn[1]))
return Xp,Xn
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# model = SincConv_fast (out_channels=1, kernel_size=3, sample_rate=16000, in_channels=1,stride=1, padding=1, dilation=1, bias=False, groups=1, min_low_hz=50, min_band_hz=50)
wlen = 8000
fs = 8000
cnn_N_filt = 80,60,60
cnn_len_filt = 251,5,5
cnn_max_pool_len = 100,5,5
cnn_use_laynorm_inp=True
cnn_use_batchnorm_inp=False
cnn_use_laynorm= True,True,True
cnn_use_batchnorm=False,False,False
cnn_act=['relu','relu','relu']
cnn_drop=0.0,0.0,0.0
# Feature extractor CNN
CNN_arch = {'input_dim': wlen,
'fs': fs,
'cnn_N_filt': cnn_N_filt,
'cnn_len_filt': cnn_len_filt,
'cnn_max_pool_len':cnn_max_pool_len,
'cnn_use_laynorm_inp': cnn_use_laynorm_inp,
'cnn_use_batchnorm_inp': cnn_use_batchnorm_inp,
'cnn_use_laynorm':cnn_use_laynorm,
'cnn_use_batchnorm':cnn_use_batchnorm,
'cnn_act': cnn_act,
'cnn_drop':cnn_drop,
}
# CNN_net=SincNet(CNN_arch)
batch_size=128
chunk_size=8
N_chunks=8
n_epochs = 10
number_of_batches = 2
#load the dataset in this shape
C1 = torch.randn(number_of_batches,batch_size,wlen) #same audio
C2 = torch.randn(number_of_batches,batch_size,wlen) #same audio
Crnd = torch.randn(number_of_batches,batch_size,wlen) #different audio
model = NetworkModel(batch_size,chunk_size, N_chunks,CNN_arch)
optimizer = optim.Adam(model.parameters(), lr=0.001, maximize = True)
Xploss = nn.BCELoss(reduction='none')
Xnloss = nn.BCELoss(reduction='none')
for epoch in range(n_epochs): # 3 full passes over the data
for data in range(number_of_batches): # `data` is a batch of data
model.zero_grad() # sets gradients to 0 before loss calc. You will do this likely every step.
Xp,Xn = model(C1[data],C2[data],Crnd[data] )
g_Xp = Xploss(Xp[0],Xp[1])
g_Xn = Xnloss(Xn[0],Xn[1])
loss = torch.mean(torch.log(g_Xp))+torch.mean(torch.log(1-g_Xn))
loss.backward( retain_graph=True) # apply this loss backwards thru the network's parameters
optimizer.step() # attempt to optimize weights to account for loss/gradients
print(epoch,loss)