Help with "Semantic Search with Phoenix, Axon, and Elastic" getting an `Nx.Defn.compile/3` RuntimError

,

Hello, everyone. I will apologise in advance as I am a complete beginner when it comes to the Nx ecosystem.

I was following through Sean Moriarity’s article Semantic Search with Phoenix, Axon, and Elastic pretty much step by step until I reached half way through the article where he defines the Wine.Model module:

defmodule Wine.Model do
  @max_sequence_length 120

  def load() do
    {model, params} =
      AxonOnnx.import("priv/models/model.onnx", batch: 1, sequence: max_sequence_length())

    {:ok, tokenizer} = Tokenizers.Tokenizer.from_pretrained("bert-base-uncased")

    {_, predict_fn} = Axon.compile(model, compiler: EXLA)

    predict_fn =
      EXLA.compile(
        fn params, inputs ->
          {_, pooled} = predict_fn.(params, inputs)
          Nx.squeeze(pooled)
        end,
        [params, inputs()]
      )

    :persistent_term.put({__MODULE__, :model}, {predict_fn, params})
    # Load the tokenizer as well
    :persistent_term.put({__MODULE__, :tokenizer}, tokenizer)

    :ok
  end

  def max_sequence_length(), do: @max_sequence_length

  defp inputs() do
    %{
      "input_ids" => Nx.template({1, 120}, {:s, 64}),
      "token_type_ids" => Nx.template({1, 120}, {:s, 64}),
      "attention_mask" => Nx.template({1, 120}, {:s, 64})
    }
  end
end

The article is around 6 months old, and with how fast the Nx ecosystem has been evolving, I figured I will have to do some updates to the code if any of the APIs had changed. Luckily only one seems to have changed: Axon.compile/4. This is my code:

defmodule Wine.Model do
  @max_sequence_length 120

  def load() do
    {model, params} =
      AxonOnnx.import("priv/models/model.onnx", batch: 1, sequence: max_sequence_length())

    {:ok, tokenizer} = Tokenizers.Tokenizer.from_pretrained("bert-base-uncased")

    {_, predict_fn} = Axon.compile(model, inputs(), %{}, compiler: EXLA)

    predict_fn =
      EXLA.compile(
        fn params, inputs ->
          {_, pooled} = predict_fn.(params, inputs)
          Nx.squeeze(pooled)
        end,
        [params, inputs()]
      )

    :persistent_term.put({__MODULE__, :model}, {predict_fn, params})
    # Load the tokenizer as well
    :persistent_term.put({__MODULE__, :tokenizer}, tokenizer)

    :ok
  end

  def max_sequence_length(), do: @max_sequence_length

  defp inputs() do
    %{
      "input_ids" => Nx.template({1, 120}, {:s, 64}),
      "token_type_ids" => Nx.template({1, 120}, {:s, 64}),
      "attention_mask" => Nx.template({1, 120}, {:s, 64})
    }
  end
end

It seemed to pass the earlier compilation errors for Axon.compile/4 However, I ran into this error:

Erlang/OTP 25 [erts-13.2] [source] [64-bit] [smp:16:16] [ds:16:16:10] [async-threads:1] [jit:ns]

Compiling 1 file (.ex)
[info] TfrtCpuClient created.
[notice] Application wine exited: exited in: Wine.Application.start(:normal, [])
    ** (EXIT) an exception was raised:
        ** (RuntimeError) cannot invoke compiled function when there is a JIT compilation happening
            (nx 0.5.2) lib/nx/defn.ex:309: anonymous fn/4 in Nx.Defn.compile/3
            (wine 0.1.0) lib/wine/model.ex:15: anonymous fn/3 in Wine.Model.load/0
            (nx 0.5.2) lib/nx/defn/compiler.ex:158: Nx.Defn.Compiler.runtime_fun/3
            (exla 0.5.2) lib/exla/defn.ex:385: anonymous fn/4 in EXLA.Defn.compile/8
            (exla 0.5.2) lib/exla/defn/locked_cache.ex:36: EXLA.Defn.LockedCache.run/2
            (stdlib 4.3) timer.erl:235: :timer.tc/1
            (exla 0.5.2) lib/exla/defn.ex:383: EXLA.Defn.compile/8
            (exla 0.5.2) lib/exla/defn.ex:270: EXLA.Defn.__compile__/4
            (nx 0.5.2) lib/nx/defn.ex:305: Nx.Defn.compile/3
            (wine 0.1.0) lib/wine/model.ex:13: Wine.Model.load/0
            (wine 0.1.0) lib/wine/application.ex:10: Wine.Application.start/2
            (kernel 8.5.4) application_master.erl:293: :application_master.start_it_old/4
[notice] Application rustler exited: :stopped
[notice] Application tokenizers exited: :stopped
[notice] Application rustler_precompiled exited: :stopped
[notice] Application exla exited: :stopped
[notice] Application axon_onnx exited: :stopped
[notice] Application protox exited: :stopped
[notice] Application decimal exited: :stopped
[notice] Application axon exited: :stopped
[notice] Application nx exited: :stopped
[notice] Application complex exited: :stopped
[notice] Application httpoison exited: :stopped
[notice] Application hackney exited: :stopped
[notice] Application metrics exited: :stopped
[notice] Application ssl_verify_fun exited: :stopped
[notice] Application parse_trans exited: :stopped
[notice] Application syntax_tools exited: :stopped
[notice] Application certifi exited: :stopped
[notice] Application mimerl exited: :stopped
[notice] Application idna exited: :stopped
[notice] Application unicode_util_compat exited: :stopped
[notice] Application plug_cowboy exited: :stopped
[notice] Application cowboy_telemetry exited: :stopped
[notice] Application cowboy exited: :stopped
[notice] Application ranch exited: :stopped
[notice] Application cowlib exited: :stopped
[notice] Application gettext exited: :stopped
[notice] Application expo exited: :stopped
[notice] Application telemetry_poller exited: :stopped
[notice] Application finch exited: :stopped
[notice] Application nimble_options exited: :stopped
[notice] Application nimble_pool exited: :stopped
[notice] Application mint exited: :stopped
[notice] Application hpax exited: :stopped
[notice] Application swoosh exited: :stopped
[notice] Application jason exited: :stopped
[notice] Application xmerl exited: :stopped
[notice] Application tailwind exited: :stopped
[notice] Application esbuild exited: :stopped
[notice] Application phoenix_live_dashboard exited: :stopped
[notice] Application telemetry_metrics exited: :stopped
[notice] Application floki exited: :stopped
[notice] Application phoenix_live_view exited: :stopped
[notice] Application phoenix_live_reload exited: :stopped
[notice] Application file_system exited: :stopped
[notice] Application phoenix_html exited: :stopped
[notice] Application phoenix exited: :stopped
[notice] Application castore exited: :stopped
[notice] Application websock_adapter exited: :stopped
[notice] Application websock exited: :stopped
[notice] Application phoenix_template exited: :stopped
[notice] Application phoenix_pubsub exited: :stopped
[notice] Application plug exited: :stopped
[notice] Application telemetry exited: :stopped
[notice] Application plug_crypto exited: :stopped
[notice] Application mime exited: :stopped
[notice] Application eex exited: :stopped
[notice] Application runtime_tools exited: :stopped
** (Mix) Could not start application wine: exited in: Wine.Application.start(:normal, [])
    ** (EXIT) an exception was raised:
        ** (RuntimeError) cannot invoke compiled function when there is a JIT compilation happening
            (nx 0.5.2) lib/nx/defn.ex:309: anonymous fn/4 in Nx.Defn.compile/3
            (wine 0.1.0) lib/wine/model.ex:15: anonymous fn/3 in Wine.Model.load/0
            (nx 0.5.2) lib/nx/defn/compiler.ex:158: Nx.Defn.Compiler.runtime_fun/3
            (exla 0.5.2) lib/exla/defn.ex:385: anonymous fn/4 in EXLA.Defn.compile/8
            (exla 0.5.2) lib/exla/defn/locked_cache.ex:36: EXLA.Defn.LockedCache.run/2
            (stdlib 4.3) timer.erl:235: :timer.tc/1
            (exla 0.5.2) lib/exla/defn.ex:383: EXLA.Defn.compile/8
            (exla 0.5.2) lib/exla/defn.ex:270: EXLA.Defn.__compile__/4
            (nx 0.5.2) lib/nx/defn.ex:305: Nx.Defn.compile/3
            (wine 0.1.0) lib/wine/model.ex:13: Wine.Model.load/0
            (wine 0.1.0) lib/wine/application.ex:10: Wine.Application.start/2
            (kernel 8.5.4) application_master.erl:293: :application_master.start_it_old/4

line 15 is this: {_, pooled} = predict_fn.(params, inputs) inside the EXLA.compile.

At first I thought the note on the EXLA documentation might have something to do with it:

Note that the EXLA.Backend is asynchronous: operations on its tensors may return immediately, before the tensor data is available. The backend will then block only when trying to read the data or when passing it to another operation.

I tried tinkering with debug and cache options, but to no avail. My second thought was to run the load/0 function under a Task Supervisor from a GenServer, retrying until the model was compiled, but I can’t seem to refactor correctly as it keeps compiling the model over and over and never the predict function.

I’m not sure how to proceed. Any help is much appreciated :pray:

1 Like

Try using Axon.build instead, and let EXLA be the one that compiles your extended predict_fn

With Axon.build you’ll also need to use the returned init_fn to obtain the params for the model

2 Likes

Gotcha
I’ll try it out in the morning and post the results.
Thanks a mil.

Yup, everything works and I can validate the results. The initial model compilation is even faster. I’m not sure I understand why though :sweat_smile: . I’ll be spending some time to understand this.

Thank you so much :pray:

Awesome! Looking at the definition for Axon.compile you’ll understand better that what your doing is basically inserting a few lines of code between the inner call for Axon.build (which yields non-compiled functions) and the actual Nx.Defn.compile call.

Before you were trying to feed an already-compiled function to the compiler yet again. But because it was already compiled, it impeded the compilation for your outer function.

1 Like