Axon embedding layer compilation error

Hi, I am trying to upgrade to the 0.2 versions of nx, axon and exla from 0.1.0 and get the following error when starting training my model:

** (Axon.CompileError) exception found when compiling layer Axon.Layers.embedding/3 named embedding_0:

    ** (ArgumentError) expected a %Nx.Tensor{} or a number, got: {#Nx.Tensor<
       f32[4][40][2]
       
       Nx.Defn.Expr
       parameter a:0   s64[4][40][2]
       b = as_type a   f32[4][40][2]
     >, #Nx.Tensor<
       f32[4][40][1]
       
       Nx.Defn.Expr
       parameter a:1   s64[4][40][1]
       b = as_type a   f32[4][40][1]
     >}
        (nx 0.4.0) lib/nx.ex:1859: Nx.to_tensor/1
        (nx 0.4.0) lib/nx.ex:2185: Nx.as_type/2
        (axon 0.3.0) lib/axon/layers.ex:2043: Axon.Layers."__defn:embedding__"/3
    
The layer was defined at:

    (axon 0.3.0) lib/axon.ex:276: Axon.layer/3
    (grapex 0.1.0) lib/grapex/models/transe.ex:8: Grapex.Model.Transe.model/1
    (grapex 0.1.0) lib/grapex/models/trainers/margin_based.ex:258: Grapex.Model.Trainers.MarginBasedTrainer.train/1
    main.exs:132: (file)
    (elixir 1.14.1) src/elixir_compiler.erl:65: :elixir_compiler.dispatch/4
    (elixir 1.14.1) src/elixir_compiler.erl:50: :elixir_compiler.compile/3

Compiling of the model was initiated at:

    (axon 0.3.0) lib/axon/loop.ex:315: anonymous fn/5 in Axon.Loop.train_step/4
    (nx 0.4.0) lib/nx/defn/compiler.ex:138: Nx.Defn.Compiler.runtime_fun/3
    (exla 0.4.0) lib/exla/defn.ex:368: anonymous fn/2 in EXLA.Defn.compile/7
    (exla 0.4.0) lib/exla/defn/locked_cache.ex:36: EXLA.Defn.LockedCache.run/2
    (stdlib 3.17) timer.erl:166: :timer.tc/1
    (exla 0.4.0) lib/exla/defn.ex:366: EXLA.Defn.compile/7

That’s how I create the model:

entity_embeddings_ = Axon.input({nil, batch_size, 2}, "entity-embedding")
                     |> Axon.embedding(Grapex.Meager.n_entities, hidden_size)

relation_embeddings_ = Axon.input({nil, batch_size, 1}, "relation-embedding")
                     |> Axon.embedding(Grapex.Meager.n_relations, hidden_size)

Axon.concatenate([entity_embeddings_, relation_embeddings_], axis: 2, name: "transe")

What am I doing wrong?

The problems was solved after I changed the format of data which I pass to the model from

{{batch.entities, batch.relations}, batch.true_labels}

to

{%{"entities" => batch.entities, "relations" => batch.relations}, batch.true_labels}
1 Like