Howdy all! I’m a former web dev that had used Elixir at work and loved it. I decided to try and learn some machine learning basics and started building my own micro models that I could train and test locally on my PC. I’m using livebook on a ubuntu partition and after a little finagling with the bazel version, I was able to get my models to compile and train on my RX7600 with exla and rocm.
Generally things run smoothly but I’ve run into an issue that can cause my training time to 10x and I can’t figure out why. I’ve been experimenting with various state space model architectures. In some versions I’m using a pre-trained ‘frozen’ embedding matrix to map my input tokens at the beginning of my forward pass, in other versions the embedding matrix is part of the learned parameters. The massive training slowdown only occurs when I use a pre-trained embedding matrix that isn’t part of the learned parameters. I keep everything else the same, hidden dimensions, batch size, sequence length, vocab size, training data, ect don’t change, but when I swap the learned embedding matrix to a pre-trained one(again same size/shape) my training suddenly takes ten times as long.
I’ve tried with both a pre-trained frozen embedding matrix and just a randomly initialized version but still frozen(doesn’t get passed to the optimizer or through value_and_grad and initialized the exact same way I initialized the learned version) and both seem to cause the slowdown. I’m loading/initializing the frozen weights outside my training loop, just like I do with the learned parameters, then transferring them to the gpu before passing them into my training loop. I’ve tried combining the frozen weights with my trained parameters into a larger tuple before passing that to my training loop as well as passing the frozen weights through value_and_grad and my optimizer but just not running the optimizer function for the frozen weights and passing the old versions back with optimized learned parameters. I’ve also tried using stop_grad in various places, mostly on the embedding look up at the start of my forward pass but also on the loaded/initialized frozen embedding matrix as well as in the training loop before passing them to my forward pass function. Nothing seems to return my training speed to what I see when I used a trained embedding matrix.
My basic training flow: Load in my saved weights from a checkpoint .bin or initialize fresh weights, including the embedding matrix, then transfer them to the gpu. Pass those references to an Enum.reduce that loops over my pre-batched training data sequences via Steam.zip for a fixed amount of steps. For each iteration I call my compute_grad function passing my current parameters(including the learned/frozen embedding matrix) and the batched sequences. That function uses Nx’s while loop to iterate over the sequence length accumulating a loss value for the step. After that loop finishes I pass the accumulated loss to value_and_grad with the parameters which returns the raw gradients and the loss tensor back to my Enum.reduce. I pass the raw gradients with the current params to my optimizer function, which returns my updated parameters that I pass into the next iteration of the reduce.
The only thing I can think of that I haven’t tried would be loading in a fresh copy of the frozen matrices within my compute_grad function but I feel like that would just make things worse by forcing the gpu to make a new copy for every step in the training loop. I have a feeling it has to do with the back propagation math getting bogged down by those frozen matrices and I’m just not using stop_grad right but that’s just a guess.
I didn’t want to just copy my whole .livemd file but I’m happy to share any code snippets if they can help provide more context. I can have a ~3M parameter model with a learned embedding matrix train ten times as fast as a 300k parameter model using a frozen embedding matrix. Clearly I’m missing something, any insight or advice would be much appreciated. Thank you!






















