How to use Swin and Donut models with Bumblebee?

I am trying to add Swin model to Bumblebee as a first step for adding Donut which uses Swin transformer as an encoder. Mostly I use analogy with existing ViT model (because I couldn’t find any manual how to add model to Bumblebee).

When I try to load the model with Bumblebee.load_model I get an error

08:32:20.344 [debug] Axon finished graph traversal in 0.1ms
Bug Bug ..!!** (Axon.CompileError) exception found when compiling layer Axon.Layers.add/2 named add_3:
                                                                                                                                                                                                            
    ** (ArgumentError) cannot broadcast tensor of dimensions {1, 9216, 256} to {1, 9216, 128}
        (nx 0.7.1) lib/nx/shape.ex:345: Nx.Shape.binary_broadcast/4
        (nx 0.7.1) lib/nx.ex:5407: Nx.devectorized_element_wise_bin_op/4
        (elixir 1.16.0) lib/enum.ex:2528: Enum."-reduce/3-lists^foldl/2-0-"/3
                                                                                                                  
The layer was defined at:
                                                                                                                                                                                                            
    (axon 0.6.1) lib/axon.ex:344: Axon.layer/3
    (bumblebee 0.5.3) lib/bumblebee/layers/transformer.ex:513: Bumblebee.Layers.Transformer.block_impl/4
    (bumblebee 0.5.3) lib/bumblebee/layers/transformer.ex:479: Bumblebee.Layers.Transformer.block/2
    (bumblebee 0.5.3) lib/bumblebee/vision/swin.ex:265: Bumblebee.Vision.Swin.layer/4
    (bumblebee 0.5.3) lib/bumblebee/vision/swin.ex:238: anonymous fn/4 in Bumblebee.Vision.Swin.encoder/3
    (elixir 1.16.0) lib/enum.ex:4391: Enum.reduce_range/5

Obviously I am missing a layer that will make shapes compatible but I cannot find where in the model structure this error happens. Namely if I add IO.inspect parts in the code I get following:

layer_0: {#Axon<
   inputs: %{"pixel_values" => {1, 384, 384, 3}}
   outputs: "encoder.block_4.output_norm"
   nodes: 51
 >,
 #Axon<
   inputs: %{"pixel_values" => {1, 384, 384, 3}}
   outputs: "dropout_1"
   nodes: 32
 >,
 #Axon<
   inputs: %{}
   outputs: "none_0"
   nodes: 1
 >,
 #Axon<
   inputs: %{"pixel_values" => {1, 384, 384, 3}}
   outputs: "custom_0"
   nodes: 25
 >,
 #Axon<
   inputs: %{}
   outputs: "none_0"
   nodes: 1
 >}

which is the structure of the layer but obviously it does not have any information about output dimension of each layer.

I tried to attach Axon hook but without success. My guess is that it works after the model has been loaded and due to the above error I cannot use it.

Is there a way I can get output dimension of each step so I can identify where error happened?

Hey @bosko. For the Axon hook you probably used Nx.Defn.Kernel.print_value, which prints the actual value after everything is compiled and actually runs. You can inspect the shape at compilation time using print_expr:

hidden_state
|> Axon.nx(fn x ->
  Nx.Defn.Kernel.print_expr(Nx.shape(x))
  x
end)

In the corresponding hf/transformers Python code you can use print(x.shape) to compare.

That’s exactly what I was looking for. Thanks!