Axon.predict/3 throwing error: ** (Axon.CompileError) exception found when compiling layer Axon.Layers.dense/4 named dense_0:

,

Whenever I try to run predict with module, Axon throws the following error:

** (Axon.CompileError) exception found when compiling layer Axon.Layers.dense/4 named dense_0:
     
         ** (ArgumentError) Axon.Layers.dense: expected input shape to have at least rank 2, got rank 1
             (axon 0.3.0) lib/axon/shared.ex:114: anonymous fn/3 in Axon.Shared."__defn:assert_min_rank!__"/4
             (nx 0.4.0) lib/nx/defn/compiler.ex:179: Nx.Defn.Compiler.__remote__/4
             (axon 0.3.0) lib/axon/layers.ex:125: Axon.Layers."__defn:dense_impl__"/4
         
     The layer was defined at:
     
         (axon 0.3.0) lib/axon.ex:276: Axon.layer/3
         (axon 0.3.0) lib/axon.ex:655: Axon.dense/3
         (axon_onnx 0.3.0) lib/axon_onnx/shared.ex:220: AxonOnnx.Shared.dense_with_bias/5
         (axon_onnx 0.3.0) lib/axon_onnx/deserialize.ex:757: AxonOnnx.Deserialize.recur_nodes/2
         (elixir 1.14.1) lib/enum.ex:2468: Enum."-reduce/3-lists^foldl/2-0-"/3
         (axon_onnx 0.3.0) lib/axon_onnx/deserialize.ex:44: AxonOnnx.Deserialize.graph_to_axon/2
         (axon_onnx 0.3.0) lib/axon_onnx/deserialize.ex:27: AxonOnnx.Deserialize.to_axon/2
         test/app/models/mnist_test.exs:101: App.Models.MNISTTest."test predict/3 correctly predicts"/1
         (ex_unit 1.14.1) lib/ex_unit/runner.ex:512: ExUnit.Runner.exec_test/1
         (stdlib 3.17.2.1) timer.erl:166: :timer.tc/1
         (ex_unit 1.14.1) lib/ex_unit/runner.ex:463: anonymous fn/4 in ExUnit.Runner.spawn_test_monitor/4
     
     Compiling of the model was initiated at:
     
     code: predict.(params, MNIST.transform_image(image))
     stacktrace:
       (nx 0.4.0) lib/nx/defn/compiler.ex:138: Nx.Defn.Compiler.runtime_fun/3
       (exla 0.4.0) lib/exla/defn.ex:368: anonymous fn/2 in EXLA.Defn.compile/7
       (exla 0.4.0) lib/exla/defn/locked_cache.ex:36: EXLA.Defn.LockedCache.run/2
       (stdlib 3.17.2.1) timer.erl:166: :timer.tc/1
       (exla 0.4.0) lib/exla/defn.ex:366: EXLA.Defn.compile/7
       (exla 0.4.0) lib/exla/defn.ex:262: EXLA.Defn.__compile__/4
       (exla 0.4.0) lib/exla/defn.ex:248: EXLA.Defn.__jit__/5
       (nx 0.4.0) lib/nx/defn.ex:442: Nx.Defn.do_jit_apply/3
       test/app/models/mnist_test.exs:124: (test)

Despite spending a lot of time debugging, I’m unable to find why.

Here’s my model:

defmodule App.Models.MNIST do
  def build(input_shape) do
    Axon.input("input", shape: input_shape)
    |> Axon.dense(128, activation: :relu)
    |> Axon.dropout()
    |> Axon.dense(10, activation: :softmax)
  end

  def train(model, train_images, train_labels, opts) do
    model
    |> Axon.Loop.trainer(:categorical_cross_entropy, Axon.Optimizers.adamw(0.005), :identity,
      log: Application.fetch_env!(:app, :axon_log)
    )
    |> Axon.Loop.metric(:accuracy, "Accuracy")
    |> Axon.Loop.run(Stream.zip(train_images, train_labels), %{}, opts)
  end

  def test(model, model_state, test_images, test_labels, opts) do
    model
    |> Axon.Loop.evaluator()
    |> Axon.Loop.metric(:accuracy, "Accuracy")
    |> Axon.Loop.run(Stream.zip(test_images, test_labels), model_state, opts)
  end
end

I’ve been able to sucessfully train and test my model. Though I am saving my model as an onnx file here, the same error occurs when I train and predict.

Likely this is a bug with AxonOnnx and not Axon. Do you mind opening an issue there?

So you are able to train and then save to ONNX and after you convert back to Axon later it fails?

I’m able to save to ONNX. I’m still getting the error when I directly use a regular Axon model, without saving to ONNX and importing, so I don’t think it’s a AxonONNX bug.

@seanmor5, you can reproduce this bug by using the MNIST example in the Axon examples with the latest Axon, Nx, and EXLA releases, and then defining a predict function.

For now, saving a model loaded from an ONNX file as an Axon model file is problematic. Sean mentioned the challenge in his ElixirConf presentation. Just load and use the model from the ONNX file.

@meanderingstream, I am able to reproduce this problem without using AxonOnnx at all. This is not a bug in AxonOnnx.

@seanmor5 Should I make a bug report on the github repo?

Yes please