Learn how to create applications that use Machine Learning. This project uses a touch LCD screen where the user can draw a digit from 0 to 9, and the model will be able to tell what digit it is.
https://github.com/benhaub/EdgeInference/commit/187b6f8a21b117f97f4a56ed3629316039ecc041
This was the first time I wrote python code to train a model. I started off with this series of exercises to get familiar with what the code looks like. I wasn't able to develop a sense of which activation functions, optimizer or loss functions fit where and with what scenario but I familiarized myself at least with the concepts. I spent a good deal of time trying to unconvolute the idea and purpose behind convolutions.
I then downloaded a handwritten 0 to 9 data set from mnist. Kaggle was useful as it also provided a community led source of code examples so i had to follow along with this code since I still wasn't really sure how to create my model based off the google codelabs I had done.
Once I trained the model, I saved a quantized tflite inference model and converted the whole thing to a std::array so that I could embed it into the binary of my project.
Next up was integrating the LCD screen. I had already done some work on this previously.
One mistake I made the first time when working with the LCD was considering the maximum GPIO current I could draw from my devkit. I wasn't actually powering the board from GPIOs, but rather the 3v3 and 5V pins from the booster pack which has the current supplied by a TPS73733 and a TPS2052 respectively. They limit current at 0.9A which is enough for the back light.
I had a lot of trouble getting the LCD screen to turn on. I looked at many github repositories trying to figure out what the correct initialization procedure is. Many of the repositories all had their own distinct way of initializing the screen so it wasn't a lot of help and I couldn't risk integrating these repositories because the code didn't look very good and they all had quite large barriers to entry (all required work to port) so I decided to create my own implementation. I eventually ended up discovering that the screen would not turn on because of an invalid VSYNC setting. I also discovered after hooking up my oscilloscope to view the SPI signals that my chip select was turning off a few bits early. I had to add in a call to the SPI driver to wait until the TX buffer was empty before returning.
I then referenced off of bridgetek's github repository to get some of the the commands and display lists working. Bridgetek had by far the best repository for referencing off of.
One important command was copying the screen. I used the Hex Editor tool from visual studio to display bytes copied from a buffer. I could then copy these bytes as hexidecimal in ascii encoding. Using xxd -r -ps asciiHexBytes rawHexBytes. From this I created a raw pixel file and uploaded it to this online tool to verify that the screen had been copied correctly. This was a vital workflow throughout the project as it allowed me to validate what was being sent to the model for inference.
I believe also that my physical setup with SPI was causing a lot of errors in transmission. I tried reducing the clock frequency but it didn't help. I think just the physical setup and the wires I was using was degrading the quality. I had to implement retry and read back mechanisms into the LCD driver in order to make using the LCD screen a good user/developer experience.
The first time I ran it I did not have all the proper Operations registered. I had to use netron to see what I needed in my model as well as the tflm_benchmark to verify the arena size was correct and to further verify which operations needed to be registered.
tflite was still hanging while trying to AllocateTensors. After some debugging I found out that something was being corrupted because I had wrapped the tflite::MicroInterpreter in a std::optional as a member variable and was calling std::optional::emplace on it later. My suspicion was that something in the MicroInterpreter was getting corrupted while C++ performed some std::move's or std::copy's in the background. I wasn't satisfied with removing std::optional as a I felt this should work so I decided to investigate the tflite code for bugs. I gave up when I realized the code is quite difficult to read as it use a lot of C-style raw pointers to hold references to many different things and it was all too much to keep track of. My guess is that the classes are not moveable because the pointers are being used as views to other objects and if they change after construction then they are left dangling. The tflite maintainers should be paying more attention to which constructors and operators are being left un-deleted when they are not supported.
I had to deal with consistency in training vs. what the devices was expected to get. My model was originally trained on 28x28 hand written digits but these were far too small to write on the LCD screen so I had to play around with resizing in my model. I also had to cut the sketch area in half since you did not need the whole screen to comfortably draw a number and copying 320x240 bytes of screen data was too much for the 256kiB of SRAM that I had available. Then I had to convert the pixels to greyscale since the MNIST hand written digit dataset I worked with was in 8-bit greyscale, but the LCD screen was saving the bitmaps as ARGB4. Lastly, the MNIST data was trained with very thick, white numbers, and the LCD screen was drawing them too thin for my model to recognize them. I wasn't able to figure out a way to adjust the "brushstroke" size, so I just went over my numbers a few times with my finger to thicken them up. Once I did that the model started guessing somewhat well at what numbers I was drawing.
Raw Pixel image of a "2" that was drawn on the screen and whose bytes were copied to a file.
What drawing numbers on the LCD looks like
Either the model is poorly trained, or the mnist data set and the numbers on the LCD screen are too different for high accuracy inference.
I drew each number 3 times, and pasted the inference results to gemini and had it write this table and summary:Â
Success Rate: 10/30 (33.3%)
Top Performing Digits: 7 (66% success), 5 (66% success), and 6 (66% success).
Most Frequent Confusion: Drawing 0, 1, 4, or 8 often results in the model predicting 5, 7, 6, or 2.
I tried to add in a convolutional layer to see if I could improve the performance. I didn't work. I then decided to try to not mess with the data set in python by upscaling the mnist digits to match the size of my screen. It was causing the size of the neural network to be quite a bit larger since more inputs were needed for more pixels, and I felt that increasing the size of the pictures was causing some bluriness. I instead decided to do the preprocessing on target instead. I implemented nearest neighbor resizing to reduce the image to match the mnist size. It didn't work and it looked a little choppy in my image viewer, so I decided to try bilinear interpolation. I think the image had a slight improvement in clarity but my inference was still inaccurate.
At this point I was at the limit of what I little I knew about ML inference and training. It was tempting to feed my code into AI and see if it could tell me what was wrong, but I would be doing it blindly. I decided to go back to tutorials. I found another google course to do and also discovered this video with links to this site. So this project is paused until I feel like I know what I'm doing enough to try and fix this project.
When inference is inaccurate, the culprits are typically:
Input data drift. The data no longer matches what the model was trained on since over time the context changed. This requires the model to be retrained with updated data.
Overfitting. During training, the model memorized the data instead of learning it. This could be caused by poorly curated datasets or too large of a batch size used during training.
Inadequate pre-processing. The data did not undergo the proper pre-processing before being given to the model for inference.
The most likely thing going wrong for me is pre-processing. I found and fixed:
Issues with copying the screen, and doing unnecessary pixel conversion which lead to poorer quality resized images.
Implemented algorithms for finding image seeds (an important part of the image) based on the given criteria and sharpening connected pixels to enhance the digit. This allowed me to no longer have to draw over my numbers many times as I was now able to sharpen the faint thin lines.
Experimented with other resampling techniques for resizing. Box resampling produced by far the best results. Of course there may be a bug in my bilinear and nearest neighbour algorithms, but I figured the results weren't great because those functions skip over a lot of pixels where as box resampling considers every pixel.
Bilinear Interpolation
Nearest Neighbour
Box
Connected Pixel Sharpening Algorithm
The middle digit is AI's output from doing binarization on the pixels.
I provided the some sample images of sharpened digits to AI asking if it could "thin" the image out in order to make the numbers a bit more smooth. It started doing some preprocessing with what it referred to as a Binarization using Otsu's Method to determine a threshold for the image and then set the pixel full intensity or minimum intensity. I had never heard of Binarization before so I implemented it and it improved the quality of my preprocessing by a considerable amount.
Before when I tried convolutions the inference accuracy did not improve, but that was before my pre-processing had improved. With no convolutions done, I think the model is relying too much on my numbers looking exactly like the data set. What I want is for the model to notice the correct features instead. With the improvements in pre-processing gained from using box resampling and binarization, I felt I would really benefit from adding some convolutions to the neural network.
I grabbed a set of 5 9's from my LCD screen since 9 performed the worst and started using an inference script I made with AI to test out the accuracy.
I fed the output to gemini. It noted something about how the training loss and validation loss diverged at epoch 6 where the validation loss began to rise while the training loss continued to decrease. I have made a note to look into this concept in more detail.
Epoch 1/10469/469 [==============================] - 12s 25ms/step - loss: 0.2874 - accuracy: 0.9146 - val_loss: 0.1076 - val_accuracy: 0.9683Epoch 2/10469/469 [==============================] - 12s 25ms/step - loss: 0.0810 - accuracy: 0.9766 - val_loss: 0.0728 - val_accuracy: 0.9777Epoch 3/10469/469 [==============================] - 12s 26ms/step - loss: 0.0515 - accuracy: 0.9850 - val_loss: 0.0660 - val_accuracy: 0.9803Epoch 4/10469/469 [==============================] - 11s 24ms/step - loss: 0.0359 - accuracy: 0.9890 - val_loss: 0.0570 - val_accuracy: 0.9833Epoch 5/10469/469 [==============================] - 12s 25ms/step - loss: 0.0255 - accuracy: 0.9923 - val_loss: 0.0621 - val_accuracy: 0.9811Epoch 6/10469/469 [==============================] - 12s 26ms/step - loss: 0.0203 - accuracy: 0.9934 - val_loss: 0.0543 - val_accuracy: 0.9841Epoch 7/10469/469 [==============================] - 13s 27ms/step - loss: 0.0141 - accuracy: 0.9955 - val_loss: 0.0584 - val_accuracy: 0.9826Epoch 8/10469/469 [==============================] - 12s 26ms/step - loss: 0.0129 - accuracy: 0.9956 - val_loss: 0.0709 - val_accuracy: 0.9827Epoch 9/10469/469 [==============================] - 12s 26ms/step - loss: 0.0092 - accuracy: 0.9971 - val_loss: 0.0767 - val_accuracy: 0.9816Epoch 10/10469/469 [==============================] - 12s 26ms/step - loss: 0.0102 - accuracy: 0.9965 - val_loss: 0.0768 - val_accuracy: 0.9826It suggested adding dropout layers and some manipulation layers to rotate the images, but I knew from reading my textbook that over fitting can result from too large of batch sizes also or too many iterations, so I reduced my iterations from 10 to 5. I also made the network smaller from 128 -> 64 -> 32 -> 10 to 64 -> 32 -> 16 -> 10.
The new results were:
Epoch 1/5469/469 [==============================] - 13s 27ms/step - loss: 0.2900 - accuracy: 0.9133 - val_loss: 0.1071 - val_accuracy: 0.9683Epoch 2/5469/469 [==============================] - 11s 24ms/step - loss: 0.0735 - accuracy: 0.9780 - val_loss: 0.0794 - val_accuracy: 0.9752Epoch 3/5469/469 [==============================] - 12s 25ms/step - loss: 0.0464 - accuracy: 0.9857 - val_loss: 0.0562 - val_accuracy: 0.9840Epoch 4/5469/469 [==============================] - 12s 27ms/step - loss: 0.0320 - accuracy: 0.9902 - val_loss: 0.0641 - val_accuracy: 0.9799Epoch 5/5469/469 [==============================] - 11s 24ms/step - loss: 0.0239 - accuracy: 0.9926 - val_loss: 0.0630 - val_accuracy: 0.9822Which is quite a bit better.
The model size increased quite a bit from doing this. It went up from 8kiB to 50kiB. I added some max pooling layers which dropped the size back down to 34kiB. In order to get rid of the white specs (which I think is corruption happening from the SPI transfer). I used a vertical strip filter. AI was super super helpful for debugging this. It was able to take my code, run it on the images I provided and show me what the results would look like.
The digits are quite broken up. This does sometimes cause the model to confidently infer 7 from 9 since the tail and the bottom loop of the 9 do have similar characteristics to a 7 especially if the top of the loop is disconnected from the bottom.
I thought back to the 3 that I had drawn and resampled with the box algorithm and noticed that it was the most complete that the digit was in terms of being connected during the preprocessing process. Here is a 9.
I decided a needed a way to make sure that digits are fully connected and bold. I implemented 3 more pre-processing stages into the preparation for inference:
Dilation - 2x2 kernel that expands the white dots out into a 2x2 grid.
Pixel Gap Filling - Filling in the gaps so that the digit is fully connected. The gap size is a max of 1, so if two white dots are separated by more than that they won't be connected.
Largest Island extraction - Find all islands. Remove all but the largest one. Helps remove the remaining specs that got by the vertical strip filter.
I also dropped the frequency of the SPI clock. It was at 20MHz at the beginning of the project. I dropped it to 10 to try and fix LCD display issues I was having. Then in order to clean up the images I thought I'd try reducing it again to 2MHz, then again to 250kHz, and again to 50kHz which is where it's at right now. Each time I've reduced the clock rate, I haven't witnessed any noticeable behavior changes and I still have a lot of noise in my images instead of a pure black background with a white digit on it.
Up 10% from last time. The good thing about the MNIST project is that it has been attempted many many times by other people and I'm certainly not the only one to have had low accuracy with mnist. When others struggled with accuracy, the consensus on the cause was a mismatch in training data and the input data for inference. I am now attempting to manipulate the MNIST data set to look more like the images that I pull from the LCD screen by making them more blocky. I had help with AI to develop and algorithm for making the digits more blocky. It suggested binarizing, decreasing the size by half, then increasing the size again.
If we take into account the confidence in each guess then the perfect score would be 50 (50 correct guessed at 100%) and the worst score would be -50 (50 incorrect guesses at 100%). The scores for analyses 3 and 4 were 2.32 and -4.14 respectively.
I'm thinking that maybe my pre-processing is too unpredictable. Sometimes it might produce a good result, sometimes it do more harm than good. I'd like to try an manipulate the images so that I have to do less/more predictable pre-processing on the device and be able to have a more consistent input.
While playing around with the original sized 120x120 image using the same technique as before where I give gemini all my code and tell it to adjust parameters and show me the results, I had an idea to have it do an island filter instead of a vertical strip filter. The results were nice. Probably the best I've had yet. I think this filter works best on the original sized image because the differences between areas where I draw and noise obtained during the image are more stark.
So the pre-processing pipeline is now Binarize (120x120) -> IslandFilter (120x120) -> Downsize (28x28) -> Binarize (28x28) -> Dilate (28x28). This matches the "blockified" dataset quite nicely. So I'll pass this off to my model input for inference.
Much better results. The model is now usable, but still has room for improvement.
The raw accuracy is the amount of correct guesses out of the total guesses. The Confidence takes into account how sure the model was of it's guess. If it was wrong, but also had low confidence, that's actually a good thing. Being right and 100% confident is perfect and contributes to a higher score.
While there are improvements I could make
LCD screen performance. Often freezes, glitches, missing clicks, doesn't load the screen properly. etc. User experience is not good.
Accuracy is still not perfect. It still struggles on some digits and the output of the pre-processing could be inspected to see what needs to change
Response time is high.
I feel that I have grasped the essence of the problem and where I need focus my efforts for improvement. I feel that I have also cleared up some misconceptions of ML inference that I had when I started this project.
I thought that models were a lot smarter than they actually are. I thought that with a large enough data set it would be able to guess any number regardless of its characteristics as long as the shape matched.
I miscalculated the importance of the quality of the data. I didn't think the data would have to match exactly the source. Domain adaption was very important.
I miscalculated the importance of pre-processing. I thought the model would be able to figure out the noise and extraneous features and just recognize the important parts.
I overall didn't understand the degree to which the input data needed to match the training data, especially for small models.
Of the AI triangle, I feel like data is the most important. Not only just collecting it, but making sure that it's good. The infrastructure and model algorithms follow once the data has been curated.
I feel as though I learned the things I set out to learn when I started this project and I will be able to take that what I've learned here and apply it to my next ML project.
Post processing - What to do when the inference is either incorrect or does not have enough confidence.
Data collection - I did not have to collect, create, or curate the dataset.
Variations in data entry size. The 120x120 drawing space does not stop someone from not drawing a digit that doesn't maximally take up the entire space. The model does not perform well when the size of the digit doesn't fill the space. This would have to be detected somehow and the digit enlarged.
Reinforcement learning techniques where incorrect input is used to inform a better decision later
As I continue to read my textbook, I will revisit this project to try and make improvements and address incorrect assumptions.