Could it be that my data goes from :f64
to :f32
at some point in the model? How can I force it to use :f64
at every stage of the pipeline?
+-----------------------------------------------------------------------------------------------------------------+
| Model |
+=====================================+===============+==============+=====================+======================+
| Layer | Input Shape | Output Shape | Options | Parameters |
+=====================================+===============+==============+=====================+======================+
| predictors ( input ) | [] | {32, 30, 3} | shape: {nil, 30, 3} | |
| | | | optional: false | |
+-------------------------------------+---------------+--------------+---------------------+----------------------+
| flatten_0 ( flatten["predictors"] ) | [{32, 30, 3}] | {32, 90} | | |
+-------------------------------------+---------------+--------------+---------------------+----------------------+
| dense_0 ( dense["flatten_0"] ) | [{32, 90}] | {32, 128} | | kernel: f32[90][128] |
| | | | | bias: f32[128] |
+-------------------------------------+---------------+--------------+---------------------+----------------------+
| relu_0 ( relu["dense_0"] ) | [{32, 128}] | {32, 128} | | |
+-------------------------------------+---------------+--------------+---------------------+----------------------+
| dense_1 ( dense["relu_0"] ) | [{32, 128}] | {32, 1} | | kernel: f32[128][1] |
| | | | | bias: f32[1] |
+-------------------------------------+---------------+--------------+---------------------+----------------------+
Total Parameters: 11777
Total Parameters Memory: 47108 bytes
Here is the full error, in case that helps. As you can see, the data starts out as :f64
, but becomes :f32
at the dense_0
layer in the pipeline (or at least, I think? Not really sure as there is also a dense_0
that shows :f64
):
** (CompileError) /home/djaouen/.cache/mix/installs/elixir-1.15.4-erts-14.0.2/27fa3a4853b40b31af06054a5826b5ad/deps/axon/lib/axon/loop.ex:468: the do-block in while must return tensors with the same shape, type, and names as the initial arguments.
Body matches template:
{%{"dense_0" => %{"bias" => #Nx.Tensor<
f64[128]
>, "kernel" => #Nx.Tensor<
f64[90][128]
>}, "dense_1" => %{"bias" => #Nx.Tensor<
f64[1]
>, "kernel" => #Nx.Tensor<
f64[128][1]
>}}, %{"dense_0" => %{"bias" => #Nx.Tensor<
f32[128]
>, "kernel" => #Nx.Tensor<
f32[90][128]
>}, "dense_1" => %{"bias" => #Nx.Tensor<
f32[1]
>, "kernel" => #Nx.Tensor<
f32[128][1]
>}}, %{}, {%{scale: #Nx.Tensor<
f32
>}, %{count: #Nx.Tensor<
s64
>, nu: %{"dense_0" => %{"bias" => #Nx.Tensor<
f64[128]
>, "kernel" => #Nx.Tensor<
f64[90][128]
>}, "dense_1" => %{"bias" => #Nx.Tensor<
f64[1]
>, "kernel" => #Nx.Tensor<
f64[128][1]
>}}, mu: %{"dense_0" => %{"bias" => #Nx.Tensor<
f64[128]
>, "kernel" => #Nx.Tensor<
f64[90][128]
>}, "dense_1" => %{"bias" => #Nx.Tensor<
f64[1]
>, "kernel" => #Nx.Tensor<
f64[128][1]
>}}}}, %{"dense_0" => %{"bias" => #Nx.Tensor<
f64[128]
>, "kernel" => #Nx.Tensor<
f64[90][128]
>}, "dense_1" => %{"bias" => #Nx.Tensor<
f64[1]
>, "kernel" => #Nx.Tensor<
f64[128][1]
>}}, #Nx.Tensor<
s64
>, #Nx.Tensor<
s64
>}
and initial argument has template:
{%{"dense_0" => %{"bias" => #Nx.Tensor<
f64[128]
>, "kernel" => #Nx.Tensor<
f64[90][128]
>}, "dense_1" => %{"bias" => #Nx.Tensor<
f64[1]
>, "kernel" => #Nx.Tensor<
f64[128][1]
>}}, %{"dense_0" => %{"bias" => #Nx.Tensor<
f32[128]
>, "kernel" => #Nx.Tensor<
f32[90][128]
>}, "dense_1" => %{"bias" => #Nx.Tensor<
f32[1]
>, "kernel" => #Nx.Tensor<
f32[128][1]
>}}, %{}, {%{scale: #Nx.Tensor<
f32
>}, %{count: #Nx.Tensor<
s64
>, nu: %{"dense_0" => %{"bias" => #Nx.Tensor<
f32[128]
>, "kernel" => #Nx.Tensor<
f32[90][128]
>}, "dense_1" => %{"bias" => #Nx.Tensor<
f32[1]
>, "kernel" => #Nx.Tensor<
f32[128][1]
>}}, mu: %{"dense_0" => %{"bias" => #Nx.Tensor<
f32[128]
>, "kernel" => #Nx.Tensor<
f32[90][128]
>}, "dense_1" => %{"bias" => #Nx.Tensor<
f32[1]
>, "kernel" => #Nx.Tensor<
f32[128][1]
>}}}}, %{"dense_0" => %{"bias" => #Nx.Tensor<
f32[128]
>, "kernel" => #Nx.Tensor<
f32[90][128]
>}, "dense_1" => %{"bias" => #Nx.Tensor<
f32[1]
>, "kernel" => #Nx.Tensor<
f32[128][1]
>}}, #Nx.Tensor<
s64
>, #Nx.Tensor<
s64
>}
(nx 0.5.3) lib/nx/defn/expr.ex:483: Nx.Defn.Expr.compatible_while!/4
(nx 0.5.3) lib/nx/defn/expr.ex:354: Nx.Defn.Expr.defn_while/6
(axon 0.5.1) lib/axon/loop.ex:468: Axon.Loop."__defn:accumulate_gradients__"/8
(axon 0.5.1) lib/axon/loop.ex:420: anonymous fn/6 in Axon.Loop.train_step/4
(axon 0.5.1) lib/axon/loop.ex:1925: anonymous fn/6 in Axon.Loop.build_batch_fn/2
(nx 0.5.3) lib/nx/defn/compiler.ex:158: Nx.Defn.Compiler.runtime_fun/3
(exla 0.5.3) lib/exla/defn.ex:385: anonymous fn/4 in EXLA.Defn.compile/8
(exla 0.5.3) lib/exla/defn/locked_cache.ex:36: EXLA.Defn.LockedCache.run/2