Why isn't running mnist example on gpu isn't faster than cpu?

When running axon/examples/vision/mnist example [after updating plugins and setting EXLA backend] the 5 training epochs run in ~30 seconds on two different computers. One computer with cpu only and the other with an nvidia RTX 3060 gpu. nvidia-smi monitor shows 40% gpu use during training. I thought the gpu would be faster? Do I have something configured wrong? Thanks!

WSL2 Ubuntu 22.04
OTP 24 / Elixir 1.14.3
Compute Capability 8.6
cuda111

Are you sure CUDA is being used?

You have to make sure that the correct XLA_TARGET is set for your CUDA version

Thanks for the quick reply. No, I’m not sure CUDA is being used.

CUDA Driver version: 12.0
CUDA Runtime Version 11.2

I have XLA_TARGET set to cuda111 and EXLA_TARGET set to cuda and added

Nx.Defn.default_options(compiler: EXLA, client: :cuda)
Nx.global_default_backend({EXLA.Backend, client: :cuda})

This last default_backend line seems to make no difference.

EXLA.Client.get_supported_platforms returns
%{cuda: 1, host: 20, interpreter: 1}

This is the output when running mnist.exs:

[info] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
[info] XLA service 0x7f7c68297ba0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
[info] StreamExecutor device (0): NVIDIA GeForce RTX 3060, Compute Capability 8.6
[info] Using BFC allocator.
[info] XLA backend allocating 10627212902 bytes on device 0 for BFCAllocator.
[info] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.

Is there something else to check?

You are indeed using CUDA.

My guess is that the cost of transferring the images onto the GPU is higher than the performance benefits of using the GPU.

Put differently, since your MNIST model is simple, the CPU is processing the input at the same speed as it takes to coordinate data transfer and processing on the GPU.


You can try increasing your batch size or try playing with a more complex examples (eg horses_vs_humans or cifar10).