I think that your question reveals a common misconception, which is useful to clear, so I will expand a bit on my answer.
wouldn’t it be possible to use pure Nx for this vectorized hard part?
First, the main benefit of Nx is no necessarily vectorization, but automatic differentiation. When you’re implementing Hamiltonian Montecarlo using NUTS (No U-Turn Sampler), you need to work with the first order (multidimensional) gradient of the function that defines the likelihood. A gradient is multidimensional drivative made up of N components, where N is the number of parameter of the model In very simple terms, you can think of NUTS as a fancy way of minimizing an arbitrary likelihood function R x N to R. This means that any “framework” you use to implement NUTS needs to be able to evaluate gradients of arbitrary functions. This requires some form of automatic differentiation. Nx provides automatic differentiation, which would make it useful here. However, even when you’re running NUTS you need the whole NUTS algorithm and not only the function and gradient. Maybe I could implement NUTS on top of Nx (I don’t know because I have no knowledge on the insides of the NUTS algorithm), but even if that’s possible, it’s not something I’m eager to do.
Regarding vectorization: although vectorization may speed up the kinds of numerical computing you need for a probabilistic model, what you actually need is a fast strogly typed language such as C++ or Rust. If the programming language you’re using is fast, you don’t actually need to vectorize anything, although vectorization may make some operations somewhat faster.
If that is the case, why do we talk so much about vectorization in Elixir (or Python) then? Because Elixir is terribly slow, way too slow for real-world numerical computing. This is mostly because the BEAM (the virtual machine Elixir runs on) needs to check the type of each argument to every function each time it does anything. If you’re using a fast low-level language such as C, Rust or C++, the runtime doesn’t need to do it and everything is orders of magnitude faster.
In the Elixir world, vectorization is just a way to replace slow Elixir code by fast C code, an should be thought in that way. This also means that the common advice you hear such as “don’t use array indexing, use multidimensional array operations” is only true because such operations would be slow in Elixir and when you use the standard array operations everything is handled by fast C code. When you’re already programming in C or Rust, using array indexing is often the fastest way to do whatever you want to do!
Why doesn’t vectorization help much in many real world probablistic models? Vectorization shines when you’re doing things such as multiplying large matrices, summing large arrays or performing vector cross products. But most of the time your models won«t be doing it, unless you’re fitting some really basic models such as generalized linear models or something equivalent to that. As soon as you want to use if statements or array indexing operations (which are pretty much required to handle missing data), vectorization won’t lead to a meaningful speed up (you can read about it in the Stan docs)
In summary:
- If your model is implemented in Rust, vectorization isn’t such a big deal for real-world models
- Nx isn’t a magic way to speed up your code, although it can help with vectorization
- Probabilistic programminbg with NUTS is not easy, and you probably don’t want to implement the NUTS algorithms yourself
The only way I’m considering doing it in Rust is because someone else already implemented the NUTS algortihm in Rust and as a user you only need to write the model. The model is simply a function which returns something like {log_likelihood_value, gradient = [g1, g2, g3, ..., gn]} where the log_likelihood_value is the value of the log-likelihood function and gradient is an n-dimensional vector representing the gradient at that point (n-dimensional) point.
How do I get the gradient value? I use a Rust automatic differentiation library (again, written by someone else, although I have customized it a bit) that automatically evaluates the gradient from almost normal Rust code.
The elixir side only has to generate the rust code for the likelihood function.
The main idea here is that vectorization is not that important, and the main reason people think it is important is because they use dynamic languages such as Elixir in which everything vectorized is fast and everything non-vectorized is slow. I suggest everyone reading this think of vectorization as “running fast code written in C” instead of some undefined procedure which makes your Elixir code faster