Neural Network
Neural Network
#!D conv block
import torch.nn as nn
import torch.nn.functional as F
class OneDConvBlock(nn.Module):
def __init__(self,B,H,L,P,Sc,dil):
super().__init__()
'''B - number of inut channels
L - Length of filters (in samples)
H - number of channels in the convolutional blocks
P - kernal size in convolution blocks
Sc - Number of channels in the skip connection which gets summed up to use as a output in TCN
dil - dilation factor 1,2,4,8,.. '''
pad = math.floor(dil*(P-1)/2)
self.pointConv1 = nn.Conv1d(B,H,P,stride=1,padding=pad,dilation=dil,padding_mode='zeros') #when padding = dilation L wont change
self.normalization = nn.BatchNorm1d(L)
self.Dconv = nn.Conv1d(H,H,P,stride=1,padding=pad,dilation=dil,padding_mode='zeros')
self.pointConv2 = nn.Conv1d(H,B,P,stride=1,padding=pad,dilation=dil,padding_mode='zeros')
self.pointConv3 = nn.Conv1d(H,Sc,P,stride=1,padding=pad,dilation=dil,padding_mode='zeros')
def forwardnet(self, x):
inp = x
x = F.rrelu(self.pointConv1(x))
x = self.normalization(x)
x = F.rrelu(self.Dconv(x))
x = self.normalization(x)
y = self.pointConv3(x) #skip connection [Sc,L]
x = self.pointConv2(x)+inp #output goes in to the next 1-D conv block [B,L]
return x,y
x=torch.randn(10,882)
net = OneDConvBlock(B=10,H=40,L=882,P=3,Sc=20,dil=2)
out,skip = net.forwardnet(x)
print(net)
print(torch.transpose((torch.transpose(x,0,1)),0,1).shape)
print(out.shape)
print(skip.shape)
#A single column in the separator
class separator_block(nn.Module): #8 1Dconv blocks in one column H.shape=[1,8] , Kernal.shape = [1,8] (kernal vector should consist only odd values)
def __init__(self,reduced_time_size,H,kernal): #we can input H vector consist of varying sizes for H and a vector for varying kernal size
super().__init__()
self.TCN0 = OneDConvBlock(batch_size, H[0], reduced_time_size ,kernal[0] , batch_size, 1)
self.TCN1 = OneDConvBlock(batch_size, H[1], reduced_time_size ,kernal[1] , batch_size, 2)
self.TCN2 = OneDConvBlock(batch_size, H[2], reduced_time_size ,kernal[2] , batch_size, 4)
self.TCN3 = OneDConvBlock(batch_size, H[3], reduced_time_size ,kernal[3] , batch_size, 8)
self.TCN4 = OneDConvBlock(batch_size, H[4], reduced_time_size ,kernal[4] , batch_size, 16)
self.TCN5 = OneDConvBlock(batch_size, H[5], reduced_time_size ,kernal[5] , batch_size, 32)
self.TCN6 = OneDConvBlock(batch_size, H[6], reduced_time_size ,kernal[6] , batch_size, 64)
self.TCN7 = OneDConvBlock(batch_size, H[7], reduced_time_size ,kernal[7] , batch_size, 128)
def forwardsep(self,x):
out,skip = self.TCN0.forwardnet(x)
sum = skip
out,skip = self.TCN1.forwardnet(out)
sum = sum + skip
out,skip = self.TCN2.forwardnet(out)
sum = sum + skip
out,skip = self.TCN3.forwardnet(out)
sum = sum + skip
out,skip = self.TCN4.forwardnet(out)
sum = sum + skip
out,skip = self.TCN5.forwardnet(out)
sum = sum + skip
out,skip = self.TCN6.forwardnet(out)
sum = sum + skip
out,skip = self.TCN7.forwardnet(out)
sum = sum + skip
return out,sum
#creating neural network
class convTasnet(nn.Module):
def __init__(self): #M - number of conv1D blocks in a column , N - number of columns
nn.Module.__init__(self)
reduced_time_size = math.floor(time_window_size/8)
self.encoder = OneDConvBlock(time_window_size,reduced_time_size,batch_size,kernal_size,reduced_time_size,1)
self.layerNorm = nn.LayerNorm(reduced_time_size)
self.one_X_oneconv = nn.Conv1d(batch_size , batch_size, 7 ,stride=1, padding=3 , dilation=1 ,padding_mode='zeros')
#Separator
self.column1 = separator_block(reduced_time_size,[10,5,10,15,10,5,8,10],[3,5,3,5,7,3,9,13])
self.column2 = separator_block(reduced_time_size,[7,5,10,20,18,24,12,5],[3,5,3,5,7,3,9,13])
self.column3 = separator_block(reduced_time_size,[10,5,10,15,10,5,8,10],[3,5,3,5,7,3,9,13])
self.column4 = separator_block(reduced_time_size,[10,5,10,15,10,5,8,10],[3,5,3,5,7,3,9,13])
self.column5 = separator_block(reduced_time_size,[10,5,10,15,10,5,8,10],[3,5,3,5,7,3,9,13])
self.column6 = separator_block(reduced_time_size,[10,5,10,15,10,5,8,10],[3,5,3,5,7,3,9,13])
self.one_X_oneconv2 = nn.Conv1d(batch_size , batch_size, 7 ,stride=1, padding=3 , dilation=1 ,padding_mode='zeros')
self.decoder = OneDConvBlock(reduced_time_size,reduced_time_size*2,batch_size,kernal_size,time_window_size,1)
def forward(self, x): #x = [batch size,time window size] -----> we get the input in this form
y,x = self.encoder.forwardnet(torch.transpose(x,0,1)) #x = [reduced time window size,batch size] --->this is the parameter we pass inside to the neural network
mixture = torch.transpose(x,0,1) #x = [batch size,time window size]------>multiply this with the masks derived elemant wise
x = self.one_X_oneconv(x) #x = [batch size,reduced time window size] ---->all the settings (stride,dilation,padding) are selected such that the size wont change
x,sum1 = self.column1.forwardsep(x) #x = [batch size,reduced time window size]-----> x is fed into the next neural network , sum = [batch size,reduced time window size]
x,sum2 = self.column2.forwardsep(x)
x,sum3 = self.column3.forwardsep(x)
x,sum4 = self.column4.forwardsep(x)
x,sum5 = self.column5.forwardsep(x)
x,sum6 = self.column6.forwardsep(x)
sum = F.rrelu(sum1+sum2+sum3+sum4+sum5+sum6)
mask = torch.sigmoid(self.one_X_oneconv(sum)) #sum = [batch size,reduced time window size]
decod = torch.mul(mask,mixture) #to_decode = [batch size,reduced time window size]
out,sep = self.decoder.forwardnet(torch.transpose(decod,0,1))
return torch.transpose(sep,0,1)
#x=torch.randn(batch_size,time_window_size)
net = convTasnet()
#out = net(x)
#print(net)
#print(x.shape)
#print(torch.transpose(x,0,1).shape)
#print(out.shape)
#print(skip.shape)
#loss function and optimizer
import torch.optim as optim
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.00001)
for epoch in range(5): # 3 full passes over the data
for data in range(len(train_set[0])): # `data` is a batch of data
x = train_set[0][data] # X is the batch of features, y is the batch of targets.
y = train_set[1][data]
net.zero_grad() # sets gradients to 0 before loss calc. You will do this likely every # step.
output = net(x) # pass in the reshaped batch (recall they are 28x28 atm)
loss = loss_function(output, y ) # calc and grab the loss value
loss.backward() # apply this loss backwards thru the network's parameters
optimizer.step() # attempt to optimize weights to account for loss/gradients
print(loss)