Strange behaviour with nested while loops in Nx

So I was trying out some stuff, and ran into this. It’s not an issue per-se it’s just unexpected.

Consider this function for computing the gradient:

{dj_dw, dj_db, _, _, _, _} =
      while {dj_dw, dj_db, x, y, w, b}, i <- range_m, unroll: true do
        f_wb_i = Nx.sigmoid(Nx.dot(x[i], w) + b)
        err_i = f_wb_i - y[i]
  
        {dj_dw, _x, _err_i} =
          while {dj_dw, x, err_i}, j <- range_n, unroll: true do
            dj_dw_j = dj_dw[j] + err_i * x[[i, j]]

            dj_dw = Nx.indexed_put(dj_dw, Nx.stack([j]), dj_dw_j)
            
            {dj_dw, x, err_i}
          end

        {dj_dw, dj_db + err_i, x, y, w, b}
      end

This above example works, but technically I don’t think it should. The reason being is in the nested while loop I didn’t pass in the i variable. But because I set unroll: true it works.

If I however did the same thing without unroll: true it would correctly tell me that I didn’t pass in the parameter I’m using inside the nested while loop. I think that’s the expected behavior and makes sense to me.

I understand that unroll literally flattens out the operation and so i is computed before unrolling of the nested while loop and able to be referenced inside the second loop, which is then unrolled before execution.

Wanted to ask if this is the expected behavior?

The reason why it works is because when unrolling, i is not really a variable/expression but a constant. And constants can be passed directly.

1 Like

Got it. Thank you. That’s what my intuition was telling me. I just didn’t expect it. But it makes sense.