Help with updating every single element in tensor in a loop

Beginner question here:
I want to process all elements in a given tensor in a loop (or one go). As far as i understand, in Elixir you cannot access variables from outer scopes. So the following code does nothing:

  def update_all_cells(world) do
    shape = Nx.shape(world)
    rows = elem(shape, 0)
    columns = elem(shape, 1)
    # we need a temp_world because we cannot update the world
    # as we need it to calculate the next state
    temp_world = fn (world, rows, columns) ->
      temp_world = Nx.tensor(world, names: [:x, :y, :cell_info], type: {:s, 64})
      Enum.each(0..rows-1, fn x ->
        Enum.each(0..columns-1, fn y ->
          updated_cell = update_cell(world, x, y)
          temp_world = update_world(world, x, y, updated_cell)
        end)
      end)
      temp_world
    end
    temp_world.(world, rows, columns)
  end

I understand that I need to use something like Enum.map, however that only works for enumerables and so does not work for Nx.tensors. I saw that there is Nx.map…
I can call something like Nx.map(world[cell_info: 0..-1//1], fn x -> some_function(x) end) however I still need to define my variables and update them somehow. I am kind of lost here :smiley:
Can somebody help me?

I have never used or even looked at Nx, but I see in the docs that map is not preferred and while is, that seems to allow you to pass in and update variables.

@Hermanverschooten thanks for your answer. I have looked into Nx.while

I am still missing something, somehow I cannot wrap my head around this… how can you pass variables from the outer scope (defn) into inner scopes like Nx.while and back again?

I have the following code:

defn update_all_cells(world) do
  shape = Nx.shape(world)
  rows = elem(shape, 0)
  columns = elem(shape, 1)
  # we need a temp_world because we cannot update the world
  # as we need it to calculate the next state
  temp_world = Nx.tensor(world, names: [:x, :y, :cell_info], type: {:s, 64})
  result =
    while {world, temp_world, x = 0}, x < rows do
      while {world, temp_world, y = 0}, y < columns do
        {world, update_world(world, x, y, update_cell(world, x, y)), y + 1}
      end
      {world, temp_world, x + 1}
    end
  temp_world = elem(result, 1)
end

It throws this error:

** (RuntimeError) cannot build defn because expressions come from different contexts: {:while, #Reference<0.3059512554.628097030.13356>} and {:while, #Reference<0.3059512554.628097030.13361>}.

This typically happens on "while" and inside anonymous functions when you try to access an external variable. All variables you intend to use inside "while" or anonymous functions in defn must be explicitly given as arguments.
For example, this is not valid:

    defn increment_by_y_while_less_than_10(y) do
      while x = 0, Nx.less(x, 10) do
        x + y
      end
    end

In the example above, we want to increment "x" by "y" while it is less than 10. However, the code won't compile because "y" is used inside "while" but not explicitly defined as part of "while". You must fix it like so:

    defn increment_by_y_while_less_than_10(y) do
      while {x = 0, y}, Nx.less(x, 10) do
        {x + y, y}
      end
    end


    (nx 0.6.2) lib/nx/defn/expr.ex:1431: Nx.Defn.Expr.merge_context!/2
    (nx 0.6.2) lib/nx/defn/expr.ex:1353: anonymous fn/2 in Nx.Defn.Expr.to_exprs/1
    (elixir 1.15.7) lib/enum.ex:1819: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
    (elixir 1.15.7) lib/enum.ex:1819: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
    (nx 0.6.2) lib/nx/defn/expr.ex:1140: Nx.Defn.Expr.slice/5
    (nx 0.6.2) lib/nx.ex:13611: Nx.slice/4
    (nx 0.6.2) lib/nx/tensor.ex:92: Nx.Tensor.fetch_axes/2
    iex:24: (file)