I thought that I could use Nx
for computations with complex numbers.
We know that a complex number z can be viewed as a [ [x, -y], [y,x]] so complex calculations boils down to matrix manipulation.
I wanted to compute the terms defined by the recurrence relation:
z(n+1) = z(n)^2 +c
for some number c. Its just a multiplication and addition defined recursively.
When I compare the implementation using Nx
to one that uses the complex library, I get a huge difference in performance as seen below.
I don’t understand because complex
does not use any special backend. I naively thought Nx
would be great fit also because of the “special” while
for recursion as shown in the doc with factorial
.
Mix.install(
[
{:nx, "~> 0.9.1"},
{:complex, "~> 0.5.0"},
{:benchee, "~> 1.3"},
{:exla, "~> 0.9.1"}
],
config: [nx: [default_backend: EXLA.Backend]]
)
Benchee.run(
%{
"nx_cx" => fn -> CNx.iterate(300, CNx.new(0.2, 0.2)) end,
"complex" => fn -> Cx.iterate(300, Complex.new(0.2, 0.2)) end
},
time: 2
)
gives:
Comparison:
complex 24.52 K
nx_cx 0.0281 K - 872.74x slower +35.55 ms
The code is:
defmodule CNx do
import Nx.Defn
def new(x,y) do
Nx.stack([Nx.tensor(x, type: :f64), Nx.tensor(-y, type: :f64), Nx.tensor(y, type: :f64), Nx.tensor(x, type: :f64)])
|> Nx.reshape({2,2})
end
defn p(z,c), do: Nx.dot(z,z) |> Nx.add(c)
defn iterate(n,c) do
while {z = c, c, n}, Nx.greater(n,1) do
{p(z,c), c, n-1}
end
|>elem(0)
end
end
compared to:
defmodule Cx do
def p(z,c), do: Complex.multiply(z,z) |> Complex.add(c)
def iterate(1,c), do: c
def iterate(n,c), do: p(iterate(n-1,c),c)
end