Axon error when trying to use max_pool: "indices must be an integer tensor, got type: {:f, 32}"

Hey, I’m trying to make a model like this:

Axon.input("features", shape: {nil, 784})
    |> Axon.reshape({:auto, 28, 28, 1})
    |> Axon.max_pool(kernel_size: 2)
    |> Axon.conv(32, kernel_size: 3)
    |> Axon.max_pool(kernel_size: 2)
    |> Axon.conv(3, kernel_size: 3)
    |> Axon.flatten()
    |> Axon.dense(128)
    |> Axon.relu()
    |> Axon.dense(10)
    |> Axon.softmax(name: "labels")

However, when trying to train it the following error appears:

** (ArgumentError) indices must be an integer tensor, got type: {:f, 32}
    (nx 0.7.2) lib/nx.ex:7839: Nx.indexed_op/5
    (torchx 0.7.2) lib/torchx/backend.ex:1580: Torchx.Backend.window_scatter_function/7
    (nx 0.7.2) lib/nx/defn/evaluator.ex:441: Nx.Defn.Evaluator.eval_apply/4
    (nx 0.7.2) lib/nx/defn/evaluator.ex:256: Nx.Defn.Evaluator.eval/3
    (elixir 1.16.2) lib/enum.ex:1826: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3       
    (nx 0.7.2) lib/nx/defn/evaluator.ex:419: Nx.Defn.Evaluator.eval_apply/4
    (nx 0.7.2) lib/nx/defn/evaluator.ex:256: Nx.Defn.Evaluator.eval/3
    iex:27: (file)

This only happens when I have the two max_pool layers there, otherwise it works correctly. How do I fix this?