Nx.indexed_put inside nested while loop


I’m trying to implement gradient descent for logistic regression, I ran into a warning I’m not sure how to resolve it.

Here is my code:

defn compute_gradient_logistic(x, y, w, b) do
    {m, n} = Nx.shape(x)

    range_m = 0..(m - 1)
    range_n = 0..(n - 1)

    dj_dw = Nx.broadcast(Nx.tensor(0.0, type: :f64), {n})    
    dj_db = Nx.tensor(0.0, type: :f64)

    {dj_dw, dj_db, _, _, _, _} =
      while {dj_dw, dj_db, x, y, w, b}, i <- range_m, unroll: true do
        f_wb_i = Nx.sigmoid(Nx.dot(x[i], w) + b)
        err_i = f_wb_i - y[i]
        {dj_dw, _x, _err_i} =
          while {dj_dw, x, err_i}, j <- range_n, unroll: true do
            dj_dw_j = dj_dw[j] + err_i * x[[i, j]]
            # update dj_dw element
            dj_dw = Nx.indexed_put(dj_dw, Nx.tensor([j]), dj_dw_j)
            {dj_dw, x, err_i}

        {dj_dw, dj_db + err_i, x, y, w, b}

    dj_dw = dj_dw / m
    dj_db = dj_db / m

    {dj_db, dj_dw}

When using indexed_put I have to pass a tensor as an index. In this case the index is [j]

While this code works I do get a warning.

warning: Nx.tensor/2 inside defn expects the first argument to be a literal (such as a list)

You must avoid code such as:


As that will JIT compile a different function for each different key.
Those values must be literals or be converted to tensors by explicitly calling Nx.tensor/2 outside of a defn

  /root/ml-specialized/logistic-regression-cost.livemd#cell:ebln7jut4x75cdw6:39: LogisticRegression.compute_gradient_descent/4

I’ve also tried dj_dw[j] = dj_dw_j but got an error.

This code actually works, and returns the correct result, however I’m wondering what the recommended way is so it can compile without warnings.

If I pass [j] directly as an index it gives me an error

** (Protocol.UndefinedError) protocol Nx.LazyContainer not implemented for [0] of type List, lists are not valid tensors (and therefore not supported as defn inputs). However, you can convert them to tensors using Nx.tensor/1. This protocol is implemented for the following type(s): Any, Atom, Complex, Float, Integer, List, Map, Nx.Batch, Nx.Tensor, Tuple
    (nx 0.7.2) lib/nx/lazy_container.ex:99: Nx.LazyContainer.List.traverse/3
    (nx 0.7.2) lib/nx.ex:2067: Nx.to_tensor/1
    (elixir 1.16.3) lib/enum.ex:1700: Enum."-map/2-lists^map/1-1-"/2
    (elixir 1.16.3) lib/enum.ex:1700: Enum."-map/2-lists^map/1-1-"/2
    (nx 0.7.2) lib/nx.ex:5269: Nx.do_reshape_vectors/2
    (nx 0.7.2) lib/nx.ex:5107: Nx.broadcast_vectors/2
    (nx 0.7.2) lib/nx.ex:7836: Nx.indexed_op/5
    #cell:t3f2sanmlygino45:4: (file)

I also tried Nx.stack([j]) and that works without any warnings. I’m wondering still if this is the right approach.