Making Algorithm Generic Over outer Dimension

,

I built an algorithm to calculate the intersection over union (IOU) for two rectangles (x,y,w,h) in Nx:

intersection_over_union = fn expected, actual ->
  # calculate bottom right corner of intersection
  #   Tensor<u8[1][2] [x, y]>
  i_br = Nx.min(
    # for both expected and actual:
    #   reshape: [x,y,w,h] -> [[x, y], [w, h]]
    #   sum vertical: x + w, y + h
    Nx.reshape(expected, {2, 2}) |> Nx.sum(axes: [0]),
    Nx.reshape(actual, {2, 2}) |> Nx.sum(axes: [0])
  )

  # calculate top left corner of intersection
  #   Tensor<u8[1][2] [x, y]>
  i_tl =
    Nx.max(expected, actual)
    |> Nx.take(Nx.tensor([0,1]))

  # area of intersection
  isec = Nx.subtract(i_br, i_tl)
    |> Nx.max(0)
    |> Nx.product()


  # calculate total area of both rectrangles
  union = Nx.add(
    Nx.take(expected, Nx.tensor([2, 3])) |> Nx.product(),
    Nx.take(actual, Nx.tensor([2, 3])) |> Nx.product())
  # area above has the intersection included twice, so we need to subtract it once
  |> Nx.subtract(isec)

  Nx.divide(isec, union)
end

It works with Tensors of length 4 (eg Nx.Tensor<f32[4] [2.0, 3.0, 1.0, 2.0]>). How can I expand the algorithm to work with any batch size?
For example

#Nx.Tensor<
  u8[2][4]
  [
    [3, 3, 2, 2],
    [6, 1, 1, 1]
  ]
>

as one of the input tensors?

1 Like

Look for the :axes or :axis option for the functions you’re using. max, product, take and so on.
It might be easier to have your input be 3-dimensional with shape {batch_dimension, 2, 4}, where {2, 4} is the shape of the 2 rectangles you’re comparing.

2 Likes

Thanks, the hint with the axes helped a lot :slight_smile:
It was actually a lot easier than expected

Here is the final result:

intersection_over_union = fn expected, actual ->

  # take outermost dimension for reshaping below
  [outer_dim | _] = Nx.shape(expected) |> Tuple.to_list()

  # calculate bottom right corner of intersection
  #   Tensor<u8[1][2] [x, y]>
  i_br = Nx.min(
    # for both expected and actual:
    #   reshape: [x,y,w,h] -> [[x, y], [w, h]]
    #   sum vertical: x + w, y + h
    Nx.reshape(expected, {outer_dim, 2, 2}) |> Nx.sum(axes: [1]),
    Nx.reshape(actual, {outer_dim, 2, 2}) |> Nx.sum(axes: [1])
  )

  # calculate top left corner of intersection
  #   Tensor<u8[1][2] [x, y]>
  i_tl = Nx.max(expected, actual)
    |> Nx.take(Nx.tensor([0,1]), axis: 1)

  # area of intersection
  isec = Nx.subtract(i_br, i_tl)
    |> Nx.max(0)
    |> Nx.product(axes: [1])

  # calculate total area of both rectangles
  union = Nx.add(
    Nx.take(expected, Nx.tensor([2, 3]), axis: 1) |> Nx.product(axes: [1]),
    Nx.take(actual, Nx.tensor([2, 3]), axis: 1) |> Nx.product(axes: [1]))
  # area above has the intersection included twice, so we need to subtract it once
  |> Nx.subtract(isec)

  Nx.divide(isec, union)
end

Feel free to criticize and optimize