Our Convolutional Neural Network (CNN) was built on Matlab using the Deep Learning Toolbox. It consists of 18 layers, including 4 convolutional layers, and was trained on 583 scalograms from each sleep state. The model was trained using the Adaptive Movement Estimation (Adam) optimization algorithm for 30 training epochs. The training progress is shown below.
The model has an accuracy of 97.33%. The confusion matrix generated when testing the model is shown below. Since a high level of accuracy was achieved, no bias is apparent. However, the model was trained using data from only 1 rat over the course of 1 day, and has not yet been tested on another dataset. It is possible that limitations have arisen due to this fact, and future work would involve investigating the effectiveness of the model on data from other rats and ideally other mammals.
To observe our data in the frequency domain, we first notice that traditional Fourier Transform techniques applied in our course had limited capabilities in providing any meaningful information on our data. Generally, Fourier Transforms perform poorly in frequency analysis of non-stationary signals, like LFPs. Two techniques are most commonly used for the analysis of non-stationary signals include the Short-Time Fourier Transform (STFT) and the wavelet transform. STFTs are an easier technique for analysis of signals, so we first looked into this option. However, we anticipated poor results due to the fixed resolution in time and frequency of the STFT, especially with the dynamic nature of LFP signals. Our prediction of poor resolution of the STFT in analysis of the LFP was confirmed in attempting analysis. We generated the following spectrogram from an N-REM stage of sleep on the first channel of our dataset.
The wavelet transform technique instead involves dividing data into different frequency components and studying each component with a resolution matched to its scale. There are two main analysis methods for this technique: the discrete wavelet transform and the continuous wavelet transform. Although computationally efficient, the discrete wavelet transform can miss bursts, oscillations, shifts, or other transient behavior. For LFP analysis, our group decided the continuous wavelet transform would produce the best results. The equation for the continuous wavelet transform is included below.
where ψ is the mother wavelet, a is the scale parameter, b is the translation parameter, and x is the input signal
Next, we selected a wavelet to use in our wavelet transform. LFP data is traditionally analyzed via different types of complex Gaussian wavelets, which provide equivalent resolution in time and frequency due to their equal representation in both domains. Our team decided to process data via the Morlet wavelet, which was most commonly used in similar studies involving LFP analysis. The Morlet wavelet is a sinusoidal function modulated by a Gaussian window, and is liked because of its relative simplicity as a wavelet. The following plot shows an example Morlet wavelet generated by our team in MATLab.
The Morlet wavelet that our team used had a center frequency of 5 and scaling factor of 1. Using the same dataset as before, we generated the following scalogram, demonstrating remarkably better results than that shown from the STFT.
Our team additionally looked into improving upon this transform through implementation of even more accurate wavelets. We discovered the Morse wavelet, which had similar characteristics in application to that of the Morlet wavelet, but with greater flexibility in parameters, allowing for the signal to be matched by the wavelet more closely. The equation for this wavelet is included below.
Additionally, the Morse wavelet can be implemented via the same methods as the Morlet wavelet in MATLab. The following results were achieved applying the Morse wavelets to the same grouping of data.
Although not very noticeable, there is slightly higher resolution in the scalogram with the Morse wavelet than the Morlet. Our group decided to proceed with the Morse wavelet to process our signals.
Shown above are examples of scalograms that were used to train our network. The scalograms show how present a given frequency is at a given time, and thus there are certain patterns associated with each sleep state that the model was taught to recognize. More information on these differences can be found on the Data page.
The CNN used a layered network architecture to process and analyze visual inputs, including RGB images and scalograms of LFP data for each time period. Introducing batch normalization layers between convolutional layers and ReLU layers significantly accelerates the training process of the convolutional neural network and reduces its sensitivity to initial settings. And the ReLU activation function is integral to injecting non-linearity into the network by setting any input value below zero to zero. Following these layers, max pooling is utilized to further refine the feature extraction. This operation simplifies the output by dividing the input into sections and selecting the maximum value from each section. The network integrates fully connected layers after the pooling stages. These layers perform critical synthesis by multiplying the input with a weight matrix and adding a bias vector. Each neuron connects to all neurons in the previous layer in order to recognize larger patterns for classification. Following the last fully connected layer, a softmax layer is implemented. It applies the softmax function to convert the network's outputs into a probability distribution, where each class's probability reflects the likelihood of the input belonging to that class.
There were several approaches we took to improve the accuracy of our network. After generating scalograms for 1 channel of LFP data, we only had 31 images associated with REM sleep, meaning we could only train the model with a maximum of 31 images from each sleep stage. As expected, this was not very effective and gave us an initial accuracy of 52.94%. Thus, we started generating scalograms from several other electrode channels, for a final set of 1,234 Awake scalograms, 1,140 Non-REM scalograms, and 583 REM scalograms. When reading this data into our CNN, we under-sampled the data from the Awake and Non-REM subfolders so that the model was trained on an equal number of images from each state. We used 70% of these images for training and 30% for testing. Since training our model took a long time (~30min to 1hr), we wanted to keep the complexity low. For this reason we started with a 14 layer network that included 3 convolutional layers, but in the interest of improving accuracy, this was increased to 18 layers with the addition of a convolutional layer and the subsequent normalization, activation, and pooling layers. The last major change we made was switching optimization algorithms. The model was initially trained using the Stochastic Gradient Descent with Momentum (SGDM) algorithm. Upon doing research however, the Adaptive Movement Estimation (Adam) algorithm seemed like it could be more effective for analyzing scalograms, and once implemented, it produced much better results.
Shown below is a visual demonstration of how these changes improved our network.
Network trained with 200 images per state (62.52% accuracy)
Complexity increased from 14 to 18 layers (83.70% accuracy)
Network trained with 583 images per state (88.95% accuracy)
Optimization algorithm switched from SDGM to Adam (97.33% accuracy)
We learned about brain waves and their relationships to sleep states, reading nwb data, techniques for analyzing non-stationary signals in the frequency domain, types of wavelets and their implications in frequency domain analysis, image representation relating time and frequency, how to construct, train, and optimize convolutional neural networks, and how CNNs work through the function of each of their layers.
From the course, we utilized wavelet transforms. This was described for the process of image compression in the lecture in class. However, we applied this technique to the concept of improving signal resolution of the LFP signal. The advice given in this lecture to use the continuous wavelet transform influenced the way our group approached performing the transform.
Another tool we used from the course was convolution. We learned how convolution works, as well as how it can be specifically applied in filtering to identify particular features of an image. In our CNN, our convolutional layers convolve filters across our scalograms to identify their unique features, and having the background knowledge we developed in class about this process was very helpful in knowing exactly what our network was doing.
One example of a way our group utilized information outside of the course material was the selection of the specific wavelet. Our group analyzed various wavelet types that could be applied in our wavelet transform technique from class. We researched and selected several wavelets that would match our goal of increasing resolution of the LFP signal in the frequency domain beyond what could be provided from the Fourier transform. This involved examination of complex Gaussian wavelets in particular, including the Morse and Morlet wavelet described above. For more examples of techniques our group applied that were outside of the scope of the class, see the Methods section above.