How to add an index with Nx.iota to a given tensor

Hello, first post here.

I want to create an n*n tensor and add an index to the first value of the cells.
What I have so far:

def create_world(num) do
    cell_info = Nx.tensor([0, 0, 50])
    world = Nx.broadcast(cell_info, {num, num, 3})

    world
  end

I tried to add Nx.iota inside Nx.broadcast but that doesn’t work.

The cells should read like [0, 0, 50], [1, 0, 50], [2, 0, 50], etc.

Try looking at Nx.put_slice (or Nx.concatenate depending on how your data is built).

1 Like

Thanks! I solved it by using Nx.put_slice():

  def create_world(num) do
    cell_info = Nx.tensor([0, 0, 50], names: [:cell_info], type: {:u, 32})
    world = Nx.broadcast(cell_info, {num, num, 3})
    world = Nx.tensor(world, names: [:x, :y, :cell_info], type: {:u, 32})
    indices = Nx.iota({num, num}, names: [:x, :y], type: {:u, 32})
    indices = Nx.new_axis(indices, 2, :cell_info)
    world = Nx.put_slice(world, [0, 0, 0], indices)
    
    world
  end