I am trying to add Swin model to Bumblebee as a first step for adding Donut which uses Swin transformer as an encoder. Mostly I use analogy with existing ViT model (because I couldn’t find any manual how to add model to Bumblebee).
When I try to load the model with Bumblebee.load_model
I get an error
08:32:20.344 [debug] Axon finished graph traversal in 0.1ms
Bug Bug ..!!** (Axon.CompileError) exception found when compiling layer Axon.Layers.add/2 named add_3:
** (ArgumentError) cannot broadcast tensor of dimensions {1, 9216, 256} to {1, 9216, 128}
(nx 0.7.1) lib/nx/shape.ex:345: Nx.Shape.binary_broadcast/4
(nx 0.7.1) lib/nx.ex:5407: Nx.devectorized_element_wise_bin_op/4
(elixir 1.16.0) lib/enum.ex:2528: Enum."-reduce/3-lists^foldl/2-0-"/3
The layer was defined at:
(axon 0.6.1) lib/axon.ex:344: Axon.layer/3
(bumblebee 0.5.3) lib/bumblebee/layers/transformer.ex:513: Bumblebee.Layers.Transformer.block_impl/4
(bumblebee 0.5.3) lib/bumblebee/layers/transformer.ex:479: Bumblebee.Layers.Transformer.block/2
(bumblebee 0.5.3) lib/bumblebee/vision/swin.ex:265: Bumblebee.Vision.Swin.layer/4
(bumblebee 0.5.3) lib/bumblebee/vision/swin.ex:238: anonymous fn/4 in Bumblebee.Vision.Swin.encoder/3
(elixir 1.16.0) lib/enum.ex:4391: Enum.reduce_range/5
Obviously I am missing a layer that will make shapes compatible but I cannot find where in the model structure this error happens. Namely if I add IO.inspect
parts in the code I get following:
layer_0: {#Axon<
inputs: %{"pixel_values" => {1, 384, 384, 3}}
outputs: "encoder.block_4.output_norm"
nodes: 51
>,
#Axon<
inputs: %{"pixel_values" => {1, 384, 384, 3}}
outputs: "dropout_1"
nodes: 32
>,
#Axon<
inputs: %{}
outputs: "none_0"
nodes: 1
>,
#Axon<
inputs: %{"pixel_values" => {1, 384, 384, 3}}
outputs: "custom_0"
nodes: 25
>,
#Axon<
inputs: %{}
outputs: "none_0"
nodes: 1
>}
which is the structure of the layer but obviously it does not have any information about output dimension of each layer.
I tried to attach Axon hook but without success. My guess is that it works after the model has been loaded and due to the above error I cannot use it.
Is there a way I can get output dimension of each step so I can identify where error happened?