Digit detection with Nx

Hi everyone.

José’s receently shared a talk about Nx, where he showed how to build a simple neural network to identify digits using the MNIST dataset. I know very little about machine learning but I was excited about this and wanted to extend the example he shared in the talk so that it is able to identify images drawn by the user



In this demo you can see I return the top 3 predictions made by the neural network along with the probability of that prediction, as you can see it fails sometimes.

What I have done so far

I allow the user to draw an image on a browser canvas element which I then capture.

I process this captured image using ImageMagick, make it monotone(black and white) and reisze it to 28x28 then extract pixel information from it and finally do some additional stuff to convert it into the MNIST dataset format.

The complete code I have written to convert to MNIST format is here.

Now, the results:

While the trained neural network has 95%+ accuracy on the MNIST test dateset but in practice with the images I give it, it fails many times mostly with digits like 5, 9, etc.

Problems that I think might be happening:

  • Maybe I need to train it for more epoches, Nope already tried training for 200 epoches reducing loss to less than 0.05, but still no luck.

  • I see that at many other places they have used 2 hidden layers for the MNIST dateset, so should I try to add another hidden layer, if so then does the training process change. In the example Jose used simple gradient decent, can I continue to use that if I add additional layers and will that even make any difference given it already performs well on the MNIST test dateset?

  • Maybe I need to train it with images captured from the browser canvas, if that is the case the I will have to somehow prepare a dateset with lots of images made from the browser canvas.

I am not sure what is going wrong, and would love some help or advice. You can find the complete code here, it requires Nx setup with EXLA.

Thanks for reading, have a good day :grinning_face_with_smiling_eyes:


I won’t claim to be an expert on deep learning but I think one piece of advice I can offer is that the MNIST model you’ve trained is trained on images where the range of pixel values is between 0 and 1. If you see the following line (line 93):

the images are normalized to be of values in between 0 and 1 instead of 0 and 255 after the division operation. However, the images that you feed into the model from your web app are between values 0 and 255. I recommend sticking to one convention for both your model and your web app images, and make it so that the images from your web app get normalized to have a max pixel value of 1. This would involve making a change at line 29 here:

Another change I’d consider is that the MNIST images are likely not binary; that is, the images have pixel values that range from 0 to 1, whereas your images has pixel values that are just 0 or 1. You can probably get better performance by training the MNIST model on such binary images. To do so, you can add a rounding operation to round the MNIST images to have pixel values either 0 or 1 after normalizing them by dividing by 255 and then training the model on that image set. Again, you just have to introduce a Nx.round() operation after the Nx.divide(255) here:

Your reasoning in your third bullet point is in the right direction I think: you want to make the images the model trains on and the images the model will work on as similar as possible (if not identical, with respect to data-type, shape, dimensions, domain/range, etc.), and I think implementing the above two steps will help with that, without having to create a training set of data from your web app.

In general, I’ve learned that before I augment my model, I should try to focus on the cleanliness and simplicity of my data and data-structures first, which is why I’d avoid adding layers, or increasing the width, or changing the topology of neural network to see an increase in performance (which would likely be marginal if I did change the number of layers, width of the network, or the overall topology). I’d also hold out on increasing the training time (epochs, batch sizes, etc.) because these will also increase our performance almost guaranteed, but only marginally.


Awesome, thanks a lot for the help. You are absolutely right with

the images the model trains on and the images the model will work on as similar as possible

Thanks for the observation with Nx.divide(255) I did not pay attention to this detail.
I will definitely try this and also Nx.round() to train the model on binary pixel values. I will post the results here. :smile:


Hi, sorry for the late reply.

I tried the two improvements that you have mentioned and this has definitely helped. I see the predictions are much better than before however, it’s still not perfect, for a few numbers like 6,7,9 it mostly predicts incorrectly.

I think the way to solve this might be to train it with images captured from the browser canvas, but for that, I will need to create a dataset. Nevertheless, thanks for the help.

1 Like

95% accuracy is about the limit for a fully connected layer. If you want better accuracy, you will have to have a different architecture; you could consider doing a convnet or if you feel like a real challenge do something like a resnet, which will give you 99.9% accuracy.

1 Like