Nx for computations with complex numbers

I thought that I could use Nx for computations with complex numbers.
We know that a complex number z can be viewed as a [ [x, -y], [y,x]] so complex calculations boils down to matrix manipulation.

I wanted to compute the terms defined by the recurrence relation:
z(n+1) = z(n)^2 +c

for some number c. Its just a multiplication and addition defined recursively.

When I compare the implementation using Nx to one that uses the complex library, I get a huge difference in performance as seen below.
I don’t understand because complex does not use any special backend. I naively thought Nx would be great fit also because of the “special” while for recursion as shown in the doc with factorial .

    {:nx, "~> 0.9.1"},
    {:complex, "~> 0.5.0"},
    {:benchee, "~> 1.3"},
    {:exla, "~> 0.9.1"}
  config: [nx: [default_backend: EXLA.Backend]]

    "nx_cx" => fn -> CNx.iterate(300, CNx.new(0.2, 0.2)) end,
    "complex" => fn -> Cx.iterate(300, Complex.new(0.2, 0.2)) end
  time: 2


complex       24.52 K
nx_cx        0.0281 K - 872.74x slower +35.55 ms

The code is:

defmodule CNx do
   import Nx.Defn
   def new(x,y) do
     Nx.stack([Nx.tensor(x, type: :f64), Nx.tensor(-y, type: :f64), Nx.tensor(y, type: :f64), Nx.tensor(x, type: :f64)])
     |> Nx.reshape({2,2})

   defn p(z,c), do: Nx.dot(z,z) |> Nx.add(c)

   defn iterate(n,c) do
      while {z = c, c, n}, Nx.greater(n,1) do
        {p(z,c), c, n-1}

compared to:

defmodule Cx do
  def p(z,c), do: Complex.multiply(z,z) |> Complex.add(c)

  def iterate(1,c), do: c
  def iterate(n,c), do: p(iterate(n-1,c),c)

So thanks to Paulo Valente, the solution to this strange result is that not only you need to set the default backend as EXLA but also set:

Nx.Defn.global_default_options(compiler: EXLA, client: :host)

when you are using defn , numerical functions.

Then, Nx accepts natively the type c64 and c128 .

So you can use the library Complex along with Nx. No need to define your own complex as a matrix.

