I’ve been working on adding ControlNet support to bumblebee.
This picture shows how this works. Basically, a ControlNet is only the encoding part of a UNet, and its residuals get passed as additional input to the regular UNet of Stable Diffusion.
The point I struggle with is how to pass the residuals to the UNet.
The ControlNet outputs a tuple of residuals of different shape, something like:
{
#Nx.Tensor<f32[1][64][64][320]>,
#Nx.Tensor<f32[1][64][64][320]>,
#Nx.Tensor<f32[1][64][64][320]>,
#Nx.Tensor<f32[1][32][32][320]>,
#Nx.Tensor<f32[1][32][32][640]>,
#Nx.Tensor<f32[1][32][32][640]>,
#Nx.Tensor<f32[1][16][16][640]>,
#Nx.Tensor<f32[1][16][16][1280]>,
#Nx.Tensor<f32[1][16][16][1280]>,
#Nx.Tensor<f32[1][8][8][1280]>,
#Nx.Tensor<f32[1][8][8][1280]>,
#Nx.Tensor<f32[1][8][8][1280]>,
}
The exact tuple size and tensor shapes depend on the configuration.
Now, I want to pass the tuple as input to the UNet, but with Axon.input
I can only specify the input shape to be like a tensor.
For now, I solved this by calculating the shapes of the tensors of the tuple from the configuration and then creating a list with a single input per tensor:
defp inputs_with_controlnet(spec) do
sample_shape = {nil, spec.sample_size, spec.sample_size, spec.in_channels}
{mid_spatial, out_shapes} = mid_spatial_and_residual_shapes(spec)
down_residuals =
for {shape, i} <- Enum.with_index(out_shapes) do
Axon.input("controlnet_down_residual_#{i}", shape: shape)
end
mid_dim = List.last(spec.hidden_sizes)
mid_residual_shape = {nil, mid_spatial, mid_spatial, mid_dim}
Bumblebee.Utils.Model.inputs_to_map(
[
Axon.input("sample", shape: sample_shape),
Axon.input("timestep", shape: {}),
Axon.input("encoder_hidden_state", shape: {nil, nil, spec.cross_attention_size}),
Axon.input("controlnet_mid_residual", shape: mid_residual_shape)
] ++ down_residuals
)
end
However, I feel like I’m missing something that would make this part way more “idiomatic”.
Is there a better way?