Troubleshooting Axon: expected input to be a tensor or a map corresponding to correct input names

Following this tutorial for building a semantic search tool, I’m encountering an issue that I’m having trouble tracing properly. I have the problematic module posted here. In iex -S mix phx.server I run :ok = Pento.Model.load(), I get:

iex(2)> :ok = Pento.Model.load()                                                                                                           
** (ArgumentError) invalid input given to model, expected input expected input to be a tensor or a map corresponding to correct input names
    (axon 0.3.1) lib/axon/compiler.ex:558: Axon.Compiler.get_input/3                                                                       
    (axon 0.3.1) lib/axon/compiler.ex:285: anonymous fn/7 in Axon.Compiler.recur_model_funs/5
    (axon 0.3.1) lib/axon/compiler.ex:220: Axon.Compiler.call_init_cache/7
    (axon 0.3.1) lib/axon/compiler.ex:781: anonymous fn/6 in Axon.Compiler.layer_init_fun/11
    (elixir 1.14.2) lib/enum.ex:1780: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
    (axon 0.3.1) lib/axon/compiler.ex:778: Axon.Compiler.layer_init_fun/11
    (axon 0.3.1) lib/axon/compiler.ex:220: Axon.Compiler.call_init_cache/7
    (axon 0.3.1) lib/axon/compiler.ex:370: anonymous fn/6 in Axon.Compiler.recur_model_funs/5
    (elixir 1.14.2) lib/enum.ex:1780: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
    (nx 0.4.1) lib/nx/container.ex:79: Nx.Container.Tuple.traverse/3
    (axon 0.3.1) lib/axon/compiler.ex:367: anonymous fn/6 in Axon.Compiler.recur_model_funs/5
    (axon 0.3.1) lib/axon/compiler.ex:220: Axon.Compiler.call_init_cache/7
    (axon 0.3.1) lib/axon/compiler.ex:781: anonymous fn/6 in Axon.Compiler.layer_init_fun/11
    (elixir 1.14.2) lib/enum.ex:1780: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
    (axon 0.3.1) lib/axon/compiler.ex:778: Axon.Compiler.layer_init_fun/11
    (axon 0.3.1) lib/axon/compiler.ex:220: Axon.Compiler.call_init_cache/7
    (axon 0.3.1) lib/axon/compiler.ex:370: anonymous fn/6 in Axon.Compiler.recur_model_funs/5
    (elixir 1.14.2) lib/enum.ex:1780: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
    (nx 0.4.1) lib/nx/container.ex:79: Nx.Container.Tuple.traverse/3
    (axon 0.3.1) lib/axon/compiler.ex:367: anonymous fn/6 in Axon.Compiler.recur_model_funs/5

The module in question

defmodule Pento.Model do
  @max_length 120

  def load() do
    {model, params} =
      AxonOnnx.import("priv/models/model.onnx", batch: 1, sequence: @max_length())

    {:ok, tokenizer} = Tokenizers.Tokenizer.from_pretrained("bert-base-uncased")
    {_, predict_fn} = Axon.compile(model, compiler: EXLA)

    predict_fn =
      EXLA.compile(
        fn params, inps ->
          {_, pooled} = predict_fn.(params, inps)
          Nx.squeeze(pooled)
        end,
        [params, inputs()]
      )

      :persistent_term.put({__MODULE__, :model}, {predict_fn, params})
      # load the tokenizer as well
      :persistent_term.put({__MODULE__, :tokenizer}, tokenizer)

      :ok
  end

  def max_length(), do: @max_length

  defp inputs() do
    %{
      "input_ids" => Nx.template({1, 120}, {:s, 64}),
      "token_type_ids" => Nx.template({1, 120}, {:s, 64}),
      "attention_mask" => Nx.template({1, 120}, {:s, 64})
    }
  end
end

When the article was written the compile API was different, now it expects you to pass the input templates. So it should be:

{params, predict_fn} = Axon.compile(model, inputs, params, compiler: EXLA)

2 Likes

Also, you can get rid of the EXLA compilation as axon compile now does that for you. I have plans to update the article with all of the things we’ve added in recent months!

3 Likes

Thank you this fixed the problem