Hi everybody, Im using Nx not for AI but do some maths. I need to optimize some functions and Im using Polaris.
The problem is that I don’t find much information on how to use Polaris. Searching in the tests I have come up with this loop:
Polaris error
Mix.install([
{:vega_lite, "~> 0.1.8"},
{:kino_vega_lite, "~> 0.1.8"},
{:nx, "~> 0.7"},
{:polaris, "~> 0.1"},
{:exla, "~> 0.7"}
])
Nx.global_default_backend(EXLA.Backend)
Section
defmodule ExlaTest do
import Nx.Defn
defn system(alpha, beta, f) do
2 * alpha ** beta * f + beta * f - Nx.Constants.pi()
end
defn error(alpha, beta, f, e) do
s = system(alpha, beta, f)
Nx.sum((s - e) ** 2)
end
def optimize(f, e) do
learning_rate = 1.0e-1
{init_fn, update_fn} = Polaris.Optimizers.adam(learning_rate: learning_rate)
loss = fn x ->
error(
x["alpha"],
x["beta"],
f,
e
)
end
compiled_k = compile_k(update_fn, loss)
create_x() |> prepare_optimization(init_fn, compiled_k) |> optimize_x0()
end
def compile_k(update_fn, loss) do
step_fn =
fn state ->
{params, opt_state} = state
gradients = Nx.Defn.grad(params, loss)
{updates, new_state} = update_fn.(gradients, opt_state, params)
{Polaris.Updates.apply_updates(updates, params), new_state}
end
Nx.Defn.jit(step_fn, compiler: EXLA)
end
def prepare_optimization(x0, init_fn, compiled_k) do
init_state = init_fn.(x0)
{{x0, init_state}, compiled_k}
end
def optimize_x0({state, compiled_k}) do
num_steps = 30000
{result, _} =
for _ <- 1..num_steps, reduce: state do
state ->
apply(compiled_k, [state])
end
result
end
def create_x() do
%{
"alpha" => Nx.tensor(5.0, type: {:f, 64}),
"beta" => Nx.tensor(4.0, type: {:f, 64})
}
end
end
Execution:
sol = ExlaTest.optimize(Nx.tensor([1, 2, 3]), Nx.tensor([3, 4, 5]))
{sol, ExlaTest.error(sol["alpha"], sol["beta"], Nx.tensor([1, 2, 3]), Nx.tensor([3, 4, 5]))}
The error function is a defn, and works fine if I use the default backend, but as soon as I put Exla or xTorch I get this horrible message:
** (RuntimeError) cannot invoke Nx function because it relies on two incompatible tensor implementations: Torchx.Backend and Nx.Defn.Expr. This may mean you are passing a tensor to defn/jit as an optional argument or as closure in an anonymous function. For efficiency, it is preferred to always pass tensors as required arguments instead. Alternatively, you could call Nx.backend_copy/1 on the tensor, however this will copy its value and inline it inside the defn expression
(nx 0.7.0) lib/nx/shared.ex:475: Nx.Shared.pick_struct/2
(nx 0.7.0) lib/nx.ex:5409: Nx.devectorized_element_wise_bin_op/4
What I’m doing wrong, any clues on how to use Polaris in the right way??
If a remove the Nx.global_default_backend(EXLA.Backend) line it works.