Hi, I am trying to upgrade to the 0.2
versions of nx
, axon
and exla
from 0.1.0
and get the following error when starting training my model:
** (Axon.CompileError) exception found when compiling layer Axon.Layers.embedding/3 named embedding_0:
** (ArgumentError) expected a %Nx.Tensor{} or a number, got: {#Nx.Tensor<
f32[4][40][2]
Nx.Defn.Expr
parameter a:0 s64[4][40][2]
b = as_type a f32[4][40][2]
>, #Nx.Tensor<
f32[4][40][1]
Nx.Defn.Expr
parameter a:1 s64[4][40][1]
b = as_type a f32[4][40][1]
>}
(nx 0.4.0) lib/nx.ex:1859: Nx.to_tensor/1
(nx 0.4.0) lib/nx.ex:2185: Nx.as_type/2
(axon 0.3.0) lib/axon/layers.ex:2043: Axon.Layers."__defn:embedding__"/3
The layer was defined at:
(axon 0.3.0) lib/axon.ex:276: Axon.layer/3
(grapex 0.1.0) lib/grapex/models/transe.ex:8: Grapex.Model.Transe.model/1
(grapex 0.1.0) lib/grapex/models/trainers/margin_based.ex:258: Grapex.Model.Trainers.MarginBasedTrainer.train/1
main.exs:132: (file)
(elixir 1.14.1) src/elixir_compiler.erl:65: :elixir_compiler.dispatch/4
(elixir 1.14.1) src/elixir_compiler.erl:50: :elixir_compiler.compile/3
Compiling of the model was initiated at:
(axon 0.3.0) lib/axon/loop.ex:315: anonymous fn/5 in Axon.Loop.train_step/4
(nx 0.4.0) lib/nx/defn/compiler.ex:138: Nx.Defn.Compiler.runtime_fun/3
(exla 0.4.0) lib/exla/defn.ex:368: anonymous fn/2 in EXLA.Defn.compile/7
(exla 0.4.0) lib/exla/defn/locked_cache.ex:36: EXLA.Defn.LockedCache.run/2
(stdlib 3.17) timer.erl:166: :timer.tc/1
(exla 0.4.0) lib/exla/defn.ex:366: EXLA.Defn.compile/7
That’s how I create the model:
entity_embeddings_ = Axon.input({nil, batch_size, 2}, "entity-embedding")
|> Axon.embedding(Grapex.Meager.n_entities, hidden_size)
relation_embeddings_ = Axon.input({nil, batch_size, 1}, "relation-embedding")
|> Axon.embedding(Grapex.Meager.n_relations, hidden_size)
Axon.concatenate([entity_embeddings_, relation_embeddings_], axis: 2, name: "transe")
What am I doing wrong?