Vectorising over the arguments of a function that has a while statement

Hey,

I have a code that the follows the same semantics shown below

defmodule Test do
  import Nx.Defn

  defnp test(x, y) do
    while {x, y}, Nx.greater_equal(x, 0) do
      z = x * y
      x = x - 1
      {x, z}
    end
  end

  defn call_test(x, y) do
    test(Nx.vectorize(x, :first), Nx.vectorize(y, :first)) |> Nx.devectorize(keep_names: false)
  end
end

Test.call_test(
  Nx.broadcast(1, {100000}),
  Nx.broadcast(1, {100000, 1})
)

It seems to be at least an order of magnitude slower on EXLA backend vs Binary one. Why is that I must be missing something? Is there a better way to map a function that has a while over vectorised arguments?

Do you have EXLA configured as the defn compiler too?

Thanks ! It was not even complied :joy:

1 Like