So I was trying out some stuff, and ran into this. It’s not an issue per-se it’s just unexpected.
Consider this function for computing the gradient:
{dj_dw, dj_db, _, _, _, _} =
while {dj_dw, dj_db, x, y, w, b}, i <- range_m, unroll: true do
f_wb_i = Nx.sigmoid(Nx.dot(x[i], w) + b)
err_i = f_wb_i - y[i]
{dj_dw, _x, _err_i} =
while {dj_dw, x, err_i}, j <- range_n, unroll: true do
dj_dw_j = dj_dw[j] + err_i * x[[i, j]]
dj_dw = Nx.indexed_put(dj_dw, Nx.stack([j]), dj_dw_j)
{dj_dw, x, err_i}
end
{dj_dw, dj_db + err_i, x, y, w, b}
end
This above example works, but technically I don’t think it should. The reason being is in the nested while loop I didn’t pass in the i
variable. But because I set unroll: true
it works.
If I however did the same thing without unroll: true
it would correctly tell me that I didn’t pass in the parameter I’m using inside the nested while loop. I think that’s the expected behavior and makes sense to me.
I understand that unroll literally flattens out the operation and so i
is computed before unrolling of the nested while loop and able to be referenced inside the second loop, which is then unrolled before execution.
Wanted to ask if this is the expected behavior?