Getting Batches To Work With Axon

,

I have been struggling with this problem for the past hour now, so hopefully someone here can help me. I am trying to pass in a batch to Axon, but I can’t seem to get it working correctly. I get the following error:

** (CompileError) /home/djaouen/.cache/mix/installs/elixir-1.15.4-erts-14.0.2/27fa3a4853b40b31af06054a5826b5ad/deps/axon/lib/axon/loop.ex:468: the do-block in while must return tensors with the same shape, type, and names as the initial arguments.

Body matches template:
<snip>

But for the life of me, I can’t figure out why. I think it has something to do with my batches not being the correct shape? However, I really don’t know if that’s the problem or not. You can find the Livebook (and the associated stocks.zip file) in my repository here: GitHub - danieljaouen/stocks

I know this is probably a lot of debugging for someone to take on, so if you are busy, please don’t waste your time, but if you have a spare hour or two to help me out with this, I would greatly appreciate it. Thanks! :slight_smile:

Could it be that my data goes from :f64 to :f32 at some point in the model? How can I force it to use :f64 at every stage of the pipeline?

+-----------------------------------------------------------------------------------------------------------------+
|                                                      Model                                                      |
+=====================================+===============+==============+=====================+======================+
| Layer                               | Input Shape   | Output Shape | Options             | Parameters           |
+=====================================+===============+==============+=====================+======================+
| predictors ( input )                | []            | {32, 30, 3}  | shape: {nil, 30, 3} |                      |
|                                     |               |              | optional: false     |                      |
+-------------------------------------+---------------+--------------+---------------------+----------------------+
| flatten_0 ( flatten["predictors"] ) | [{32, 30, 3}] | {32, 90}     |                     |                      |
+-------------------------------------+---------------+--------------+---------------------+----------------------+
| dense_0 ( dense["flatten_0"] )      | [{32, 90}]    | {32, 128}    |                     | kernel: f32[90][128] |
|                                     |               |              |                     | bias: f32[128]       |
+-------------------------------------+---------------+--------------+---------------------+----------------------+
| relu_0 ( relu["dense_0"] )          | [{32, 128}]   | {32, 128}    |                     |                      |
+-------------------------------------+---------------+--------------+---------------------+----------------------+
| dense_1 ( dense["relu_0"] )         | [{32, 128}]   | {32, 1}      |                     | kernel: f32[128][1]  |
|                                     |               |              |                     | bias: f32[1]         |
+-------------------------------------+---------------+--------------+---------------------+----------------------+
Total Parameters: 11777
Total Parameters Memory: 47108 bytes

Here is the full error, in case that helps. As you can see, the data starts out as :f64, but becomes :f32 at the dense_0 layer in the pipeline (or at least, I think? Not really sure as there is also a dense_0 that shows :f64):

** (CompileError) /home/djaouen/.cache/mix/installs/elixir-1.15.4-erts-14.0.2/27fa3a4853b40b31af06054a5826b5ad/deps/axon/lib/axon/loop.ex:468: the do-block in while must return tensors with the same shape, type, and names as the initial arguments.

Body matches template:

{%{"dense_0" => %{"bias" => #Nx.Tensor<
       f64[128]
     >, "kernel" => #Nx.Tensor<
       f64[90][128]
     >}, "dense_1" => %{"bias" => #Nx.Tensor<
       f64[1]
     >, "kernel" => #Nx.Tensor<
       f64[128][1]
     >}}, %{"dense_0" => %{"bias" => #Nx.Tensor<
       f32[128]
     >, "kernel" => #Nx.Tensor<
       f32[90][128]
     >}, "dense_1" => %{"bias" => #Nx.Tensor<
       f32[1]
     >, "kernel" => #Nx.Tensor<
       f32[128][1]
     >}}, %{}, {%{scale: #Nx.Tensor<
      f32
    >}, %{count: #Nx.Tensor<
      s64
    >, nu: %{"dense_0" => %{"bias" => #Nx.Tensor<
          f64[128]
        >, "kernel" => #Nx.Tensor<
          f64[90][128]
        >}, "dense_1" => %{"bias" => #Nx.Tensor<
          f64[1]
        >, "kernel" => #Nx.Tensor<
          f64[128][1]
        >}}, mu: %{"dense_0" => %{"bias" => #Nx.Tensor<
          f64[128]
        >, "kernel" => #Nx.Tensor<
          f64[90][128]
        >}, "dense_1" => %{"bias" => #Nx.Tensor<
          f64[1]
        >, "kernel" => #Nx.Tensor<
          f64[128][1]
        >}}}}, %{"dense_0" => %{"bias" => #Nx.Tensor<
       f64[128]
     >, "kernel" => #Nx.Tensor<
       f64[90][128]
     >}, "dense_1" => %{"bias" => #Nx.Tensor<
       f64[1]
     >, "kernel" => #Nx.Tensor<
       f64[128][1]
     >}}, #Nx.Tensor<
   s64
 >, #Nx.Tensor<
   s64
 >}

and initial argument has template:

{%{"dense_0" => %{"bias" => #Nx.Tensor<
       f64[128]
     >, "kernel" => #Nx.Tensor<
       f64[90][128]
     >}, "dense_1" => %{"bias" => #Nx.Tensor<
       f64[1]
     >, "kernel" => #Nx.Tensor<
       f64[128][1]
     >}}, %{"dense_0" => %{"bias" => #Nx.Tensor<
       f32[128]
     >, "kernel" => #Nx.Tensor<
       f32[90][128]
     >}, "dense_1" => %{"bias" => #Nx.Tensor<
       f32[1]
     >, "kernel" => #Nx.Tensor<
       f32[128][1]
     >}}, %{}, {%{scale: #Nx.Tensor<
      f32
    >}, %{count: #Nx.Tensor<
      s64
    >, nu: %{"dense_0" => %{"bias" => #Nx.Tensor<
          f32[128]
        >, "kernel" => #Nx.Tensor<
          f32[90][128]
        >}, "dense_1" => %{"bias" => #Nx.Tensor<
          f32[1]
        >, "kernel" => #Nx.Tensor<
          f32[128][1]
        >}}, mu: %{"dense_0" => %{"bias" => #Nx.Tensor<
          f32[128]
        >, "kernel" => #Nx.Tensor<
          f32[90][128]
        >}, "dense_1" => %{"bias" => #Nx.Tensor<
          f32[1]
        >, "kernel" => #Nx.Tensor<
          f32[128][1]
        >}}}}, %{"dense_0" => %{"bias" => #Nx.Tensor<
       f32[128]
     >, "kernel" => #Nx.Tensor<
       f32[90][128]
     >}, "dense_1" => %{"bias" => #Nx.Tensor<
       f32[1]
     >, "kernel" => #Nx.Tensor<
       f32[128][1]
     >}}, #Nx.Tensor<
   s64
 >, #Nx.Tensor<
   s64
 >}

    (nx 0.5.3) lib/nx/defn/expr.ex:483: Nx.Defn.Expr.compatible_while!/4
    (nx 0.5.3) lib/nx/defn/expr.ex:354: Nx.Defn.Expr.defn_while/6
    (axon 0.5.1) lib/axon/loop.ex:468: Axon.Loop."__defn:accumulate_gradients__"/8
    (axon 0.5.1) lib/axon/loop.ex:420: anonymous fn/6 in Axon.Loop.train_step/4
    (axon 0.5.1) lib/axon/loop.ex:1925: anonymous fn/6 in Axon.Loop.build_batch_fn/2
    (nx 0.5.3) lib/nx/defn/compiler.ex:158: Nx.Defn.Compiler.runtime_fun/3
    (exla 0.5.3) lib/exla/defn.ex:385: anonymous fn/4 in EXLA.Defn.compile/8
    (exla 0.5.3) lib/exla/defn/locked_cache.ex:36: EXLA.Defn.LockedCache.run/2
1 Like

Could this be related to this bug: Passing f32 data into LSTM with Axon.Loop trainer+run causes while shape mismatch error · Issue #490 · elixir-nx/axon · GitHub ? @seanmor5