Hello, first post here.
I want to create an n*n tensor and add an index to the first value of the cells.
What I have so far:
def create_world(num) do
cell_info = Nx.tensor([0, 0, 50])
world = Nx.broadcast(cell_info, {num, num, 3})
world
end
I tried to add Nx.iota inside Nx.broadcast but that doesn’t work.
The cells should read like [0, 0, 50], [1, 0, 50], [2, 0, 50]
, etc.
Try looking at Nx.put_slice (or Nx.concatenate depending on how your data is built).
1 Like
Thanks! I solved it by using Nx.put_slice()
:
def create_world(num) do
cell_info = Nx.tensor([0, 0, 50], names: [:cell_info], type: {:u, 32})
world = Nx.broadcast(cell_info, {num, num, 3})
world = Nx.tensor(world, names: [:x, :y, :cell_info], type: {:u, 32})
indices = Nx.iota({num, num}, names: [:x, :y], type: {:u, 32})
indices = Nx.new_axis(indices, 2, :cell_info)
world = Nx.put_slice(world, [0, 0, 0], indices)
world
end