Digit detection with Nx

Hi everyone.

José’s receently shared a talk about Nx, where he showed how to build a simple neural network to identify digits using the MNIST dataset. I know very little about machine learning but I was excited about this and wanted to extend the example he shared in the talk so that it is able to identify images drawn by the user



In this demo you can see I return the top 3 predictions made by the neural network along with the probability of that prediction, as you can see it fails sometimes.

What I have done so far

I allow the user to draw an image on a browser canvas element which I then capture.

I process this captured image using ImageMagick, make it monotone(black and white) and reisze it to 28x28 then extract pixel information from it and finally do some additional stuff to convert it into the MNIST dataset format.

The complete code I have written to convert to MNIST format is here.

Now, the results:

While the trained neural network has 95%+ accuracy on the MNIST test dateset but in practice with the images I give it, it fails many times mostly with digits like 5, 9, etc.

Problems that I think might be happening:

  • Maybe I need to train it for more epoches, Nope already tried training for 200 epoches reducing loss to less than 0.05, but still no luck.

  • I see that at many other places they have used 2 hidden layers for the MNIST dateset, so should I try to add another hidden layer, if so then does the training process change. In the example Jose used simple gradient decent, can I continue to use that if I add additional layers and will that even make any difference given it already performs well on the MNIST test dateset?

  • Maybe I need to train it with images captured from the browser canvas, if that is the case the I will have to somehow prepare a dateset with lots of images made from the browser canvas.

I am not sure what is going wrong, and would love some help or advice. You can find the complete code here, it requires Nx setup with EXLA.

Thanks for reading, have a good day :grinning_face_with_smiling_eyes:


I won’t claim to be an expert on deep learning but I think one piece of advice I can offer is that the MNIST model you’ve trained is trained on images where the range of pixel values is between 0 and 1. If you see the following line (line 93):

the images are normalized to be of values in between 0 and 1 instead of 0 and 255 after the division operation. However, the images that you feed into the model from your web app are between values 0 and 255. I recommend sticking to one convention for both your model and your web app images, and make it so that the images from your web app get normalized to have a max pixel value of 1. This would involve making a change at line 29 here:

Another change I’d consider is that the MNIST images are likely not binary; that is, the images have pixel values that range from 0 to 1, whereas your images has pixel values that are just 0 or 1. You can probably get better performance by training the MNIST model on such binary images. To do so, you can add a rounding operation to round the MNIST images to have pixel values either 0 or 1 after normalizing them by dividing by 255 and then training the model on that image set. Again, you just have to introduce a Nx.round() operation after the Nx.divide(255) here:

Your reasoning in your third bullet point is in the right direction I think: you want to make the images the model trains on and the images the model will work on as similar as possible (if not identical, with respect to data-type, shape, dimensions, domain/range, etc.), and I think implementing the above two steps will help with that, without having to create a training set of data from your web app.

In general, I’ve learned that before I augment my model, I should try to focus on the cleanliness and simplicity of my data and data-structures first, which is why I’d avoid adding layers, or increasing the width, or changing the topology of neural network to see an increase in performance (which would likely be marginal if I did change the number of layers, width of the network, or the overall topology). I’d also hold out on increasing the training time (epochs, batch sizes, etc.) because these will also increase our performance almost guaranteed, but only marginally.


Awesome, thanks a lot for the help. You are absolutely right with

the images the model trains on and the images the model will work on as similar as possible

Thanks for the observation with Nx.divide(255) I did not pay attention to this detail.
I will definitely try this and also Nx.round() to train the model on binary pixel values. I will post the results here. :smile:


Hi, sorry for the late reply.

I tried the two improvements that you have mentioned and this has definitely helped. I see the predictions are much better than before however, it’s still not perfect, for a few numbers like 6,7,9 it mostly predicts incorrectly.

I think the way to solve this might be to train it with images captured from the browser canvas, but for that, I will need to create a dataset. Nevertheless, thanks for the help.

1 Like

95% accuracy is about the limit for a fully connected layer. If you want better accuracy, you will have to have a different architecture; you could consider doing a convnet or if you feel like a real challenge do something like a resnet, which will give you 99.9% accuracy.

1 Like

Hi @arpan thank you for posting this. I have recently gotten interested in NX and Axon so I set out to create an interactive demo which is very similar to your demo and I ran into some of the same issues you did.

I captured a video showing my observations (I also see frequent incorrect predictions for numbers like 7, 9 and 6). Video is here for anyone interested.

My entire codebase is open source (I’m trying to host on my blog) and I have a PR open with just the Nx/Axon/LiveView changes related to this interactive demo Mnist by mmmries · Pull Request #14 · mmmries/blog · GitHub

I have made my training dataset similar to the interactive images by using Nx.round that you mentioned above, but still seeing that I get much less than 99% accuracy from the interactive mode. Did you ever find any other causes for this? I’m considering capturing additional training images, and I also started looking into using an GAN (generative adversarial network) to automatically generate lots more examples with slight variations. But I think there’s a good chance I’ve made a simple mistake along the line, so I’m hoping to get some advice before I invest too much time into those other areas.

One thing I noticed in my manual testing is that the accuracy seems to work best when I write my digits a bit smaller and keep them in the center of the overall box. Then I went and looked at the Mnist website which says

The original black and white (bilevel) images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.

So I’m going to try to write my demo application to accept the user input and then find the bounding box of the white pixels and only pull out that part of the image to be resized and place it in the center of a 28x28 image. That should give me a similar setup to how the mnist group prepared the training data.