NX - TorchX - Unable to get tensor param in NIF

Hi!

I am struggeling a bit with NX and TorchX and got a bit stuck!

I have created a model and state and I am trying to run this function:

The function load! reads a model and state that is saved to disk using term_to_binary and when inspecting the model and state they look OK.

I am feeding the function path to a PNG file.

def predict(path) do
    mat = Evision.imread(path, flags: Evision.Constant.cv_IMREAD_GRAYSCALE())
    mat = Evision.resize(mat, {28, 28})

    data =
      Evision.Mat.to_nx(mat)
      |> Nx.reshape({1, 28, 28})
      |> List.wrap()
      |> Nx.stack()
      |> Nx.backend_transfer(Torchx.Backend)

    {model, state} = load!()

    model
    |> Axon.predict(state, data, debug: true)
    |> Nx.argmax()
    |> Nx.to_number()
  end

I am currently getting this error message as the result from the predict function:

** (RuntimeError) Torchx: Unable to get tensor param in NIF
    (torchx 0.9.2) lib/torchx.ex:447: Torchx.unwrap!/1
    (torchx 0.9.2) lib/torchx.ex:450: Torchx.unwrap_tensor!/2
    (torchx 0.9.2) lib/torchx/backend.ex:992: Torchx.Backend.dot/7
    (nx 0.9.2) lib/nx/defn/evaluator.ex:441: Nx.Defn.Evaluator.eval_apply/4
    (nx 0.9.2) lib/nx/defn/evaluator.ex:256: Nx.Defn.Evaluator.eval/3
    (elixir 1.17.1) lib/enum.ex:1829: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
    (nx 0.9.2) lib/nx/defn/evaluator.ex:419: Nx.Defn.Evaluator.eval_apply/4
    #cell:kkr4tfn454cgrk4x:3: (file)

I have tried to debug this and search the internet for similar errors with no luck so I hoped to find some help here! :smiley:

The code is based on this: Digits

Just changed a bit to work in Livebook.

I have a production use for the techniques and just need to get the POC working before moving on :slight_smile:

The problem was how I saved and restored the model and state apparently.
I tried running the code on variables directly without save/load and it worked.

Now I have to figure out how to serialize the model and the state in a way that works I guess.

Axon has a serialize function before, but not anymore…
What is the replacement? :slight_smile:

Figured it out by looking at the deprecation message from Axon.serialize.

You should keep the model as code and store the trained state ( parameters ) using Nx.serialize()/deserialize()

2 Likes