Problem while trying to call a function inside the same function using Nx

Hey there, I’m getting into some problems that I don’t know how to solve. This is my code, which is supposed to be a gamma distribution:

defnp gamma_nx(x) do

    g = 7

    p = Nx.tensor([
        0.99999999999980993,
        676.5203681218851,
        -1259.1392167224028,
        771.32342877765313,
        -176.61502916214059,
        12.507343278686905,
        -0.13857109526572012,
        9.9843695780195716e-6,
        1.5056327351493116e-7
    ])

    if x < 0.5 do
      pi() / (Nx.sin(pi() * x)) * gamma_nx(1 - x))
    else
      z = x - 1
      xs = while acc = 0.0, i <- Nx.linspace(1, 8, n: 8, type: :u8), unroll: true do
        acc + p[i] / (z + i)
      end
      x = p[0] + xs
      t = z + g + 0.5
      Nx.sqrt(2 * pi()) * Nx.pow(t, z + 0.5) * Nx.exp(- 1 * t) * x
    end
  end

It works as expected until it gets to gamma_nx; it does not return anything, not even an error message. Can someone help me?

The way that defns work is that they don’t execute the code as is, they execute a “symbolic” version of the code, which builds a graph and compiles to the GPU. While we are building the graph, we don’t actually know the result of x < 0.5, so we always explore both branches which leads to an infinite loop.

A shorter way to put is that, there is no tail recursion in defn, use while instead. I will add this to the docs. :slight_smile:

4 Likes

Thanks @josevalim!