Help figuring out "Reshape operation has mismatched element counts: from=0 (f32[32,0,8,64]) to=16384 (f32[32,512])."

Hey folks,

I’m quite new to the ML world, as well as to Nx and Axon (although not necessarily to Elixir). I’ve been trying to follow some articles, as well as Sean Moriarty’s beta book, but I often get stuck with no real idea of what’s happening, so it’d be great if someone could help.

While running the code from How to Create Neural Network in Elixir Using Nx and Axon | Curiosum (with some adaptations since it’s a 1-year-old article), I get to the following error when trying to run Axon.Loop.run:

Reshape operation has mismatched element counts: from=0 (f32[32,0,8,64]) to=16384 (f32[32,512])

For reference, this is what I have so far:


{images, labels} = Scidata.CIFAR10.download()

{img_data, img_type, img_shape} = images
{label_data, label_type, label_shape} = labels

batch_size = 32
images = 
  img_data
  |> Nx.from_binary(img_type)
  |> Nx.reshape(img_shape)
  |> Nx.divide(255.0)

labels =
  label_data
  |> Nx.from_binary(label_type)
  |> Nx.reshape(label_shape)
  |> Nx.new_axis(-1)
  |> Nx.equal(Nx.iota({1, 10}))

(...)

train_data = 
  train_images
  |> Nx.to_batched(batch_size)
  |> Stream.zip(Nx.to_batched(train_labels, batch_size))

(...)

model =
Axon.input("images", shape: {nil, 3, 32, 32})
  |> Axon.conv(32, kernel_size: {3, 3}, activation: :relu, padding: :same)
  |> Axon.max_pool(kernel_size: {2, 2}, strides: [2, 2])
  |> Axon.conv(64, kernel_size: {3, 3}, activation: :relu, padding: :same)
  |> Axon.max_pool(kernel_size: {2, 2}, strides: [2, 2])
  |> Axon.flatten()
  |> Axon.dense(64, activation: :relu)
  |> Axon.dense(10, activation: :softmax)

trained_model_state = 
  model
  |> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
  |> Axon.Loop.metric(:accuracy)
  |> Axon.Loop.run(train_data, %{}, epochs: 10, compiler: EXLA)

Based on axon/examples/vision/cifar10.exs at main · elixir-nx/axon · GitHub I think you need to change the above to actually

|> Nx.reshape({elem(img_shape, 0), 32, 32, 3})

because unfortunately the order in which the channels dimension is coming from Scidata is not working as is.

It works, thanks!

Cool.

Opened pull request to update Axon example with a few minor tweaks that might help in understanding where those numbers are coming from or mean: refactor(examples): attempt to improve legibility of CIFAR-10 by grzuy · Pull Request #534 · elixir-nx/axon · GitHub.

Hope it helps too.

Note the example in Axon has been refactored to refactor(examples): consistent CIFAR-10 shapes by grzuy · Pull Request #534 · elixir-nx/axon · GitHub in case you might benefit also from updating on your side :slight_smile: