Axon.predict throwing error: expected a %Nx.Tensor{} or a number, got: nil

I’m experimenting with Axon and running into an issue. I can’t seem to get predict to succeed. I’ve trained a model but when I call predict it throws the following error:

trained_params =
  model
  |> Axon.Training.step(:mean_squared_error, Axon.Optimizers.adamw(0.005))
  |> Axon.Training.train(inputs, targets, epochs: 10, compiler: EXLA)

tensor = Nx.tensor([[1,1,1,1,0.5]])
Axon.predict(model, trained_params, tensor, compiler: EXLA)
** (ArgumentError) expected a %Nx.Tensor{} or a number, got: nil
    (nx 0.1.0-dev) lib/nx.ex:1228: Nx.to_tensor/1
    (nx 0.1.0-dev) lib/nx.ex:1494: Nx.as_type/2
    (axon 0.1.0-dev) lib/axon/compiler.ex:355: Axon.Compiler.recur_predict_fun/6
    (axon 0.1.0-dev) lib/axon/compiler.ex:352: Axon.Compiler.recur_predict_fun/6
    (axon 0.1.0-dev) lib/axon/compiler.ex:316: Axon.Compiler.recur_predict_fun/6
    (axon 0.1.0-dev) lib/axon/compiler.ex:149: anonymous fn/5 in Axon.Compiler.compile_predict/2
    (nx 0.1.0-dev) lib/nx/defn/compiler.ex:311: Nx.Defn.Compiler.runtime_fun/4
    (nx 0.1.0-dev) lib/nx/defn/evaluator.ex:27: Nx.Defn.Evaluator.__jit__/4

I’m trying to develop an intuitive sense of how I can ultimately take parameters passed to a phoenix endpoint and use them to predict. I’ve looked over the documentation examples, notebooks and articles but I was unable to find an example using simple inputs to help develop an intuition.

I’m thinking the problem has something to do with how I’m structuring the input to predict or perhaps some sort of mismatch with the model/params and the predict input, but I’m not sure quite where to start tracking this down. Any help would be greatly appreciated!

2 Likes

Can you provide your model? I think one of the layers is missing a configuration parameter.

2 Likes

@josevalim here’s the model:

model =
  Axon.input({nil, 5})
  |> Axon.flatten()
  |> Axon.dense(32)
  |> Axon.dense(2, activation: :softmax)
2 Likes

Hey, the first issue is that the output of train is actually a final training state and not the trained params. So Axon is looking for layer names to find the model params when compiling your predict function, but it can’t find them. If you inspect the final training state you’ll see the params field which should be passed to predict :slight_smile:

3 Likes

That was it! Thank you @seanmor5 :bowing_man:

1 Like