Following this tutorial for building a semantic search tool, I’m encountering an issue that I’m having trouble tracing properly. I have the problematic module posted here. In iex -S mix phx.server
I run :ok = Pento.Model.load()
, I get:
iex(2)> :ok = Pento.Model.load()
** (ArgumentError) invalid input given to model, expected input expected input to be a tensor or a map corresponding to correct input names
(axon 0.3.1) lib/axon/compiler.ex:558: Axon.Compiler.get_input/3
(axon 0.3.1) lib/axon/compiler.ex:285: anonymous fn/7 in Axon.Compiler.recur_model_funs/5
(axon 0.3.1) lib/axon/compiler.ex:220: Axon.Compiler.call_init_cache/7
(axon 0.3.1) lib/axon/compiler.ex:781: anonymous fn/6 in Axon.Compiler.layer_init_fun/11
(elixir 1.14.2) lib/enum.ex:1780: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
(axon 0.3.1) lib/axon/compiler.ex:778: Axon.Compiler.layer_init_fun/11
(axon 0.3.1) lib/axon/compiler.ex:220: Axon.Compiler.call_init_cache/7
(axon 0.3.1) lib/axon/compiler.ex:370: anonymous fn/6 in Axon.Compiler.recur_model_funs/5
(elixir 1.14.2) lib/enum.ex:1780: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
(nx 0.4.1) lib/nx/container.ex:79: Nx.Container.Tuple.traverse/3
(axon 0.3.1) lib/axon/compiler.ex:367: anonymous fn/6 in Axon.Compiler.recur_model_funs/5
(axon 0.3.1) lib/axon/compiler.ex:220: Axon.Compiler.call_init_cache/7
(axon 0.3.1) lib/axon/compiler.ex:781: anonymous fn/6 in Axon.Compiler.layer_init_fun/11
(elixir 1.14.2) lib/enum.ex:1780: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
(axon 0.3.1) lib/axon/compiler.ex:778: Axon.Compiler.layer_init_fun/11
(axon 0.3.1) lib/axon/compiler.ex:220: Axon.Compiler.call_init_cache/7
(axon 0.3.1) lib/axon/compiler.ex:370: anonymous fn/6 in Axon.Compiler.recur_model_funs/5
(elixir 1.14.2) lib/enum.ex:1780: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
(nx 0.4.1) lib/nx/container.ex:79: Nx.Container.Tuple.traverse/3
(axon 0.3.1) lib/axon/compiler.ex:367: anonymous fn/6 in Axon.Compiler.recur_model_funs/5
The module in question
defmodule Pento.Model do
@max_length 120
def load() do
{model, params} =
AxonOnnx.import("priv/models/model.onnx", batch: 1, sequence: @max_length())
{:ok, tokenizer} = Tokenizers.Tokenizer.from_pretrained("bert-base-uncased")
{_, predict_fn} = Axon.compile(model, compiler: EXLA)
predict_fn =
EXLA.compile(
fn params, inps ->
{_, pooled} = predict_fn.(params, inps)
Nx.squeeze(pooled)
end,
[params, inputs()]
)
:persistent_term.put({__MODULE__, :model}, {predict_fn, params})
# load the tokenizer as well
:persistent_term.put({__MODULE__, :tokenizer}, tokenizer)
:ok
end
def max_length(), do: @max_length
defp inputs() do
%{
"input_ids" => Nx.template({1, 120}, {:s, 64}),
"token_type_ids" => Nx.template({1, 120}, {:s, 64}),
"attention_mask" => Nx.template({1, 120}, {:s, 64})
}
end
end