Accessing current tensor value indexes in defn


I’ve implemented a very basic “ray tracing” algorithm in Elixir, and given that it involves float arithmetic over large 2-dimensional data I thought that I’d try to port it to Nx to give it a performance boost. However being new to this kind of things I’m not sure that it’s actually feasible.

The algorithm is the following: I have a 2d height map, which is a two dimensional list with float values (for example [[2.0, 7.2], [1.3, 2.4]). I also have a ray tracing direction vector {dx, dy, dz}. For every point in my height map, I’ll trace a ray starting from that point, moving along this vector until I’m out of the map, or the ray bumps into an obstacle (meaning it’s height z is below the value of the height map at this point).

The elixir implementation looks like this:

for {row, i} <- Enum.with_index(height_map) do
  for {z, j} <- Enum.with_index(row) do
    is_in_the_shade?(to_map(height_map), i, j, z, sun_x, sun_y, sun_z)

def is_in_the_shade?(height_map, x, y, z, sun_x, sun_y, sun_z) do
  case Map.fetch(height_map, {trunc(x), trunc(y)}) do
    :error -> # we've reached the end of the height_map without meeting any obstacle
    {:ok, height} when height > z -> # we've met an obstacle
    {:ok, height } -> # continue ray tracing along sun_v
      is_in_the_shade?(height_map, x + sun_x, y + sun_y, z + sun_z, sun_x, sun_y, sun_z)

I’ve tried to port it to Nx like this:

shade(hm, sun_v: {sun_x, sun_y, sun_z}, height_map: hm)

defn shade({x, y, z}, opts \\ []) do
  {sun_x, sun_y, sun_z} = keyword!(opts, [:sun_v])
  height_map = keyword!(opts, [:height_map])
  while {res = 0, height_map, x, y, z, sun_x, sun_y, sun_z},
    res == 0 and x <= 0 and x > 256 and y <= 0 and y > 256 do
    if height_map[x][y] > z do
      { 1, height_map, x, y, z, sun_x, sun_y, sun_z }
      { 0, height_map, x + sun_x, y + sun_y, z + sun_z }

The problem is that I can’t figure out how to get access to x and y here, I’ve tried to define my tensor as a list of tuples, like [{0, 0, 2.0}, {0, 1, 3.2}, {1, 0, 3.1}, {1, 1, 3.0}] but that doesn’t seem possible (I’m getting the error invalid value given to Nx.tensor/1, got: {0, 0, 38.75}, which makes sense tensors values need to be ints or floats).

Is there a way to make this work with Nx, or should I be looking at a C port to improve performance?