Tuple of tensors as Axon.input

I’ve been working on adding ControlNet support to bumblebee.

This picture shows how this works. Basically, a ControlNet is only the encoding part of a UNet, and its residuals get passed as additional input to the regular UNet of Stable Diffusion.

The point I struggle with is how to pass the residuals to the UNet.
The ControlNet outputs a tuple of residuals of different shape, something like:

{
  #Nx.Tensor<f32[1][64][64][320]>,
  #Nx.Tensor<f32[1][64][64][320]>,
  #Nx.Tensor<f32[1][64][64][320]>,
  #Nx.Tensor<f32[1][32][32][320]>,
  #Nx.Tensor<f32[1][32][32][640]>,
  #Nx.Tensor<f32[1][32][32][640]>,
  #Nx.Tensor<f32[1][16][16][640]>,
  #Nx.Tensor<f32[1][16][16][1280]>,
  #Nx.Tensor<f32[1][16][16][1280]>,
  #Nx.Tensor<f32[1][8][8][1280]>,
  #Nx.Tensor<f32[1][8][8][1280]>,
  #Nx.Tensor<f32[1][8][8][1280]>,
}

The exact tuple size and tensor shapes depend on the configuration.
Now, I want to pass the tuple as input to the UNet, but with Axon.input I can only specify the input shape to be like a tensor.

For now, I solved this by calculating the shapes of the tensors of the tuple from the configuration and then creating a list with a single input per tensor:

  defp inputs_with_controlnet(spec) do
    sample_shape = {nil, spec.sample_size, spec.sample_size, spec.in_channels}

    {mid_spatial, out_shapes} = mid_spatial_and_residual_shapes(spec)

    down_residuals =
      for {shape, i} <- Enum.with_index(out_shapes) do
        Axon.input("controlnet_down_residual_#{i}", shape: shape)
      end

    mid_dim = List.last(spec.hidden_sizes)

    mid_residual_shape = {nil, mid_spatial, mid_spatial, mid_dim}

    Bumblebee.Utils.Model.inputs_to_map(
      [
        Axon.input("sample", shape: sample_shape),
        Axon.input("timestep", shape: {}),
        Axon.input("encoder_hidden_state", shape: {nil, nil, spec.cross_attention_size}),
        Axon.input("controlnet_mid_residual", shape: mid_residual_shape)
      ] ++ down_residuals
    )
  end

However, I feel like I’m missing something that would make this part way more “idiomatic”.
Is there a better way?

3 Likes

@joelpaulkoch the input can be a tuple (more generally, any Nx.Container). Note that the :shape option is just an additional validation, but it’s not required : )

1 Like

Thanks, so this just works. A point of confusion for me was how to “unpack” the tuple input as it comes in as an Axon struct. I found a solution in the unwrap_tuple function in lib/bumblebee/layers.ex:

    controlnet_down_residuals =
      for i <- 0..(num_down_residuals - 1) do
        Axon.nx(inputs["controlnet_down_residuals"], &elem(&1, i))
      end

It makes sense that you can do that. I feel like sometimes I’m not thinking in terms of Axon…

Ah yeah, there are two contexts, the outer Axon graph and the inner Nx. If you have an Axon.container node and pass it as input to Axon.layer, the Nx function gets the actual container content (like a tuple):

x = Axon.input("x")
container = Axon.container({x, x})

Axon.layer(
  fn {x, x}, _opts ->
    x + x
  end
  [container]
)

Going the other way is a bit more awkward. If the current Axon node returns a map, you can “pick” a specific field with |> Axon.nx(& &1.key), and now it returns the specific key. However, there is no way to automatically unwrap the inner map to a map of Axon nodes, you need to do by hand with %{key1: Axon.nx(x, & &1.key1), key2: Axon.nx(x, & &1.key2)}. With tuples it’s analogous, except you use elem/2.

The reason we can’t unwrap the “inner” tuple is that we don’t actually know the shape, we only know when compiling. Consider this contrived example:

x = Axon.input("x")

model =
  Axon.layer(
    fn x, _opts ->
      case Nx.shape(x) do
        {_} -> %{key1: x}
        {_, _} -> %{key2: x}
      end
    end
    [x]
  )

There is no way we could do Magic.unwrap(model), because we have no idea how the map looks. It is only known once we compile the model with a specific input template.

3 Likes