Hey, I’m trying to make a model like this:
Axon.input("features", shape: {nil, 784})
|> Axon.reshape({:auto, 28, 28, 1})
|> Axon.max_pool(kernel_size: 2)
|> Axon.conv(32, kernel_size: 3)
|> Axon.max_pool(kernel_size: 2)
|> Axon.conv(3, kernel_size: 3)
|> Axon.flatten()
|> Axon.dense(128)
|> Axon.relu()
|> Axon.dense(10)
|> Axon.softmax(name: "labels")
However, when trying to train it the following error appears:
** (ArgumentError) indices must be an integer tensor, got type: {:f, 32}
(nx 0.7.2) lib/nx.ex:7839: Nx.indexed_op/5
(torchx 0.7.2) lib/torchx/backend.ex:1580: Torchx.Backend.window_scatter_function/7
(nx 0.7.2) lib/nx/defn/evaluator.ex:441: Nx.Defn.Evaluator.eval_apply/4
(nx 0.7.2) lib/nx/defn/evaluator.ex:256: Nx.Defn.Evaluator.eval/3
(elixir 1.16.2) lib/enum.ex:1826: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
(nx 0.7.2) lib/nx/defn/evaluator.ex:419: Nx.Defn.Evaluator.eval_apply/4
(nx 0.7.2) lib/nx/defn/evaluator.ex:256: Nx.Defn.Evaluator.eval/3
iex:27: (file)
This only happens when I have the two max_pool layers there, otherwise it works correctly. How do I fix this?