Nx While loop training slow down when passing through frozen embedings

Howdy all! I’m a former web dev that had used Elixir at work and loved it. I decided to try and learn some machine learning basics and started building my own micro models that I could train and test locally on my PC. I’m using livebook on a ubuntu partition and after a little finagling with the bazel version, I was able to get my models to compile and train on my RX7600 with exla and rocm.

Generally things run smoothly but I’ve run into an issue that can cause my training time to 10x and I can’t figure out why. I’ve been experimenting with various state space model architectures. In some versions I’m using a pre-trained ‘frozen’ embedding matrix to map my input tokens at the beginning of my forward pass, in other versions the embedding matrix is part of the learned parameters. The massive training slowdown only occurs when I use a pre-trained embedding matrix that isn’t part of the learned parameters. I keep everything else the same, hidden dimensions, batch size, sequence length, vocab size, training data, ect don’t change, but when I swap the learned embedding matrix to a pre-trained one(again same size/shape) my training suddenly takes ten times as long.

I’ve tried with both a pre-trained frozen embedding matrix and just a randomly initialized version but still frozen(doesn’t get passed to the optimizer or through value_and_grad and initialized the exact same way I initialized the learned version) and both seem to cause the slowdown. I’m loading/initializing the frozen weights outside my training loop, just like I do with the learned parameters, then transferring them to the gpu before passing them into my training loop. I’ve tried combining the frozen weights with my trained parameters into a larger tuple before passing that to my training loop as well as passing the frozen weights through value_and_grad and my optimizer but just not running the optimizer function for the frozen weights and passing the old versions back with optimized learned parameters. I’ve also tried using stop_grad in various places, mostly on the embedding look up at the start of my forward pass but also on the loaded/initialized frozen embedding matrix as well as in the training loop before passing them to my forward pass function. Nothing seems to return my training speed to what I see when I used a trained embedding matrix.

My basic training flow: Load in my saved weights from a checkpoint .bin or initialize fresh weights, including the embedding matrix, then transfer them to the gpu. Pass those references to an Enum.reduce that loops over my pre-batched training data sequences via Steam.zip for a fixed amount of steps. For each iteration I call my compute_grad function passing my current parameters(including the learned/frozen embedding matrix) and the batched sequences. That function uses Nx’s while loop to iterate over the sequence length accumulating a loss value for the step. After that loop finishes I pass the accumulated loss to value_and_grad with the parameters which returns the raw gradients and the loss tensor back to my Enum.reduce. I pass the raw gradients with the current params to my optimizer function, which returns my updated parameters that I pass into the next iteration of the reduce.

The only thing I can think of that I haven’t tried would be loading in a fresh copy of the frozen matrices within my compute_grad function but I feel like that would just make things worse by forcing the gpu to make a new copy for every step in the training loop. I have a feeling it has to do with the back propagation math getting bogged down by those frozen matrices and I’m just not using stop_grad right but that’s just a guess.

I didn’t want to just copy my whole .livemd file but I’m happy to share any code snippets if they can help provide more context. I can have a ~3M parameter model with a learned embedding matrix train ten times as fast as a 300k parameter model using a frozen embedding matrix. Clearly I’m missing something, any insight or advice would be much appreciated. Thank you!

1 Like

Which version of EXLA are you using?
What do you have configured as the compiler and the backend?

If you could share the livemd as a gist it would be helpful to analyze it.

I’m using EXLA version 0.11.0, and have rocm set as backend and EXLA set as the compiler. Here’s a gist I made from the .livemd file: https://gist.github.com/Brady-Conn/810f8a9325526b6ecae2b86fbaf7e393.js

The link is for a .js file

Sorry, I hadn’t used gist before and didn’t notice the ‘share’ option defaulted to a js script to add the gist to an html page. Here’s the actual sharable link: Elixir/Nx SSM · GitHub

I think I figured it out, silly typo. I accidentally changed the axis I was using to measure the length of the Nx while loop from 0 to 1. Changing back to 0 brought back my training speed.

Perfect!

However, for future reference, I think the root cause has to do with the while loop carrying data that doesn’t change as part of the loop state.
This is mapped as an improvement to EXLA in the issues tracker, I think.

While that’s not there, you could either take that part of training out of defn, or if the length of the while loop is known, you could use the unroll option. Or use vectorization to do a parallel version of it.

Thanks! Yeah I’ve been meaning to dig into vectorization. The while loop implementation was what I felt most comfortable with initially but I should probably start branching out. Thanks again!