Incompatible defn and exla/torchx backends

Hi everybody, Im using Nx not for AI but do some maths. I need to optimize some functions and Im using Polaris.

The problem is that I don’t find much information on how to use Polaris. Searching in the tests I have come up with this loop:

Polaris error

Mix.install([
  {:vega_lite, "~> 0.1.8"},
  {:kino_vega_lite, "~> 0.1.8"},
  {:nx, "~> 0.7"},
  {:polaris, "~> 0.1"},
  {:exla, "~> 0.7"}
])

Nx.global_default_backend(EXLA.Backend)

Section

defmodule ExlaTest do
  import Nx.Defn

  defn system(alpha, beta, f) do
    2 * alpha ** beta * f + beta * f - Nx.Constants.pi()
  end

  defn error(alpha, beta, f, e) do
    s = system(alpha, beta, f)
    Nx.sum((s - e) ** 2)
  end

  def optimize(f, e) do
    learning_rate = 1.0e-1
    {init_fn, update_fn} = Polaris.Optimizers.adam(learning_rate: learning_rate)

    loss = fn x ->
      error(
        x["alpha"],
        x["beta"],
        f,
        e
      )
    end

    compiled_k = compile_k(update_fn, loss)

    create_x() |> prepare_optimization(init_fn, compiled_k) |> optimize_x0()
  end

  def compile_k(update_fn, loss) do
    step_fn =
      fn state ->
        {params, opt_state} = state
        gradients = Nx.Defn.grad(params, loss)
        {updates, new_state} = update_fn.(gradients, opt_state, params)
        {Polaris.Updates.apply_updates(updates, params), new_state}
      end

    Nx.Defn.jit(step_fn, compiler: EXLA)
  end

  def prepare_optimization(x0, init_fn, compiled_k) do
    init_state = init_fn.(x0)
    {{x0, init_state}, compiled_k}
  end

  def optimize_x0({state, compiled_k}) do
    num_steps = 30000

    {result, _} =
      for _ <- 1..num_steps, reduce: state do
        state ->
          apply(compiled_k, [state])
      end

    result
  end

  def create_x() do
    %{
      "alpha" => Nx.tensor(5.0, type: {:f, 64}),
      "beta" => Nx.tensor(4.0, type: {:f, 64})
    }
  end
end

Execution:

sol = ExlaTest.optimize(Nx.tensor([1, 2, 3]), Nx.tensor([3, 4, 5]))
{sol, ExlaTest.error(sol["alpha"], sol["beta"], Nx.tensor([1, 2, 3]), Nx.tensor([3, 4, 5]))}

The error function is a defn, and works fine if I use the default backend, but as soon as I put Exla or xTorch I get this horrible message:

** (RuntimeError) cannot invoke Nx function because it relies on two incompatible tensor implementations: Torchx.Backend and Nx.Defn.Expr. This may mean you are passing a tensor to defn/jit as an optional argument or as closure in an anonymous function. For efficiency, it is preferred to always pass tensors as required arguments instead. Alternatively, you could call Nx.backend_copy/1 on the tensor, however this will copy its value and inline it inside the defn expression
    (nx 0.7.0) lib/nx/shared.ex:475: Nx.Shared.pick_struct/2
    (nx 0.7.0) lib/nx.ex:5409: Nx.devectorized_element_wise_bin_op/4

What I’m doing wrong, any clues on how to use Polaris in the right way??

If a remove the Nx.global_default_backend(EXLA.Backend) line it works.

Is that stacktrace the complete error?

The first problem I see, which could result in your error, is that you’re mixing compilers. When you call ExlaTest.error at the last line, you’re using, implicitly, the Nx.Defn.Evaluator as your compiler, while your jit calls are using EXLA. Try setting Nx.Defn.default_options(compiler: EXLA) as a first test.

This is because the Nx functions you call outside of defn will be implicitly jitted by EXLA.Backend.

Although I think it’s odd to see that error with Torchx.

Please try setting the EXLA compiler as a first try, and if that breaks, send the full stacktrace

Hi, Thanks so much for the reply.

I have put the line of code you told me, also have removed the explicit reference to exla in the JIT and the result is the same.

With torchx i cannot compile, but if a remove the jit and set torch as the backend the result is the same
This is the stack trace:

  • (RuntimeError) cannot invoke Nx function because it relies on two incompatible tensor implementations: Nx.Defn.Expr and EXLA.Backend. This may mean you are passing a tensor to defn/jit as an optional argument or as closure in an anonymous function. For efficiency, it is preferred to always pass tensors as required arguments instead. Alternatively, you could call Nx.backend_copy/1 on the tensor, however this will copy its value and inline it inside the defn expression
    (nx 0.7.0) lib/nx/shared.ex:475: Nx.Shared.pick_struct/2
    (nx 0.7.0) lib/nx.ex:5409: Nx.devectorized_element_wise_bin_op/4
    /home/gabrielete/Documentos/exla_test.livemd#cell:rvavdyhix6u6yu2zldvpagj2sna47mr2:5: ExlaTest.“defn:system”/3
    /home/gabrielete/Documentos/exla_test.livemd#cell:rvavdyhix6u6yu2zldvpagj2sna47mr2:9: ExlaTest.“defn:error”/4
    (nx 0.7.0) lib/nx/defn/grad.ex:23: Nx.Defn.Grad.transform/3
    (nx 0.7.0) lib/nx/defn.ex:639: anonymous fn/2 in Nx.Defn.grad/2
    (nx 0.7.0) lib/nx/defn/compiler.ex:173: Nx.Defn.Compiler.runtime_fun/3
    /home/gabrielete/Documentos/exla_test.livemd#cell:bodcjbjsdscae7shfotthxzw64i7efy5:1: (file)

If you run this code in livebook you can replicate the error:

Polaris error

Mix.install([
  {:vega_lite, "~> 0.1.8"},
  {:kino_vega_lite, "~> 0.1.8"},
  {:nx, "~> 0.7"},
  {:polaris, "~> 0.1"},
  {:exla, "~> 0.7"}
])

Nx.global_default_backend(EXLA.Backend)
Nx.Defn.default_options(compiler: EXLA)

Section

defmodule ExlaTest do
  import Nx.Defn

  defn system(alpha, beta, f) do
    2 * alpha ** beta * f + beta * f - Nx.Constants.pi()
  end

  defn error(alpha, beta, f, e) do
    s = system(alpha, beta, f)
    Nx.sum((s - e) ** 2)
  end

  def optimize(f, e) do
    learning_rate = 1.0e-1
    {init_fn, update_fn} = Polaris.Optimizers.adam(learning_rate: learning_rate)

    loss = fn x ->
      error(
        x["alpha"],
        x["beta"],
        f,
        e
      )
    end

    compiled_k = compile_k(update_fn, loss)

    create_x() |> prepare_optimization(init_fn, compiled_k) |> optimize_x0()
  end

  def compile_k(update_fn, loss) do
    # step_fn = 
    fn state ->
      {params, opt_state} = state
      gradients = Nx.Defn.grad(params, loss)
      {updates, new_state} = update_fn.(gradients, opt_state, params)
      {Polaris.Updates.apply_updates(updates, params), new_state}
    end

    # Nx.Defn.jit(step_fn)
  end

  def prepare_optimization(x0, init_fn, compiled_k) do
    init_state = init_fn.(x0)
    {{x0, init_state}, compiled_k}
  end

  def optimize_x0({state, compiled_k}) do
    num_steps = 30000

    {result, _} =
      for _ <- 1..num_steps, reduce: state do
        state ->
          apply(compiled_k, [state])
      end

    result
  end

  def create_x() do
    %{
      "alpha" => Nx.tensor(5.0, type: {:f, 64}),
      "beta" => Nx.tensor(4.0, type: {:f, 64})
    }
  end
end
sol = ExlaTest.optimize(Nx.tensor([1, 2, 3]), Nx.tensor([3, 4, 5]))
{sol, ExlaTest.error(sol["alpha"], sol["beta"], Nx.tensor([1, 2, 3]), Nx.tensor([3, 4, 5]))}

Ah, I found the issue!

In your definition of loss, you’re passing the parameters f and e as a closure. This means that the function, which is defined outside of defn, will always contain 2 tensors of whatever backend you’re using.

The proper way to go is to pass f and e down to compile_k and include them in the params container for the grad calculation.

Aside that, I believe all of your functions can be defined through defn if you use defn’s while instead of for

3 Likes

Thanks so much! You have help me a lot, now it looks like pretty obvious what was happening but I really needed to connect the dots. The leap in perform is huge. Thanks a lot!!

Polaris error

Mix.install([
  {:vega_lite, "~> 0.1.8"},
  {:kino_vega_lite, "~> 0.1.8"},
  {:nx, "~> 0.7"},
  {:polaris, "~> 0.1"},
  {:exla, "~> 0.7"}
])

Nx.global_default_backend(EXLA.Backend)
Nx.Defn.default_options(compiler: EXLA)

Section

defmodule ExlaTest do
  import Nx.Defn

  defn system(alpha, beta, f) do
    2 * alpha ** beta * f + beta * f - Nx.Constants.pi()
  end

  defn error(alpha, beta, f, e) do
    s = system(alpha, beta, f)
    Nx.sum((s - e) ** 2)
  end

  def optimize(f, e) do
    learning_rate = 1.0e-1
    {init_fn, update_fn} = Polaris.Optimizers.adam(learning_rate: learning_rate)

    loss = fn {x, f, e} ->
      error(
        x["alpha"],
        x["beta"],
        f,
        e
      )
    end

    k = fn {state, f, e} ->
      {params, opt_state} = state
      {gradients, _f, _e} = Nx.Defn.grad({params, f, e}, loss)
      {updates, new_state} = update_fn.(gradients, opt_state, params)
      {Polaris.Updates.apply_updates(updates, params), new_state}
    end

    create_x() |> prepare_optimization(init_fn) |> optimize_x0(k, f, e)
  end

  def prepare_optimization(x0, init_fn) do
    init_state = init_fn.(x0)
    {x0, init_state}
  end

  defn optimize_x0(state, compiled_k, f, e) do
    st = compiled_k.({state, f, e})

    {_i, {result, _}, _f, _e} =
      while {i = 0, st, f, e}, i < 30000 do
        {i + 1, compiled_k.({st, f, e}), f, e}
      end

    result
  end

  def create_x() do
    %{
      "alpha" => Nx.tensor(5.0, type: {:f, 64}),
      "beta" => Nx.tensor(4.0, type: {:f, 64})
    }
  end
end
sol = ExlaTest.optimize(Nx.tensor([1, 2, 3]), Nx.tensor([3, 4, 5]))
{sol, ExlaTest.error(sol["alpha"], sol["beta"], Nx.tensor([1, 2, 3]), Nx.tensor([3, 4, 5]))}
2 Likes