Flatten and unflatten Axon parameters

,

To train Axon models with other algorithms other than gradient descent e.g genetic algorithms one needs to flatten all parameters of the model during optimization and then unflatten the parameters during inference. For example the following model and its parameters…

Model

model = 
  Axon.input({nil, 2})
  |> Axon.dense(2, activation: :relu)
  |> Axon.dense(1, activation: :sigmoid)

Parameters

%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[2]
      EXLA.Backend<host:0, 0.1979063076.4252631046.85519>
      [0.0, 0.0]
    >,
    "kernel" => #Nx.Tensor<
      f32[2][2]
      EXLA.Backend<host:0, 0.1979063076.4252631046.85520>
      [
        [0.2250952422618866, -0.2300528585910797],
        [0.8318504691123962, 1.00990629196167]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      EXLA.Backend<host:0, 0.1979063076.4252631046.85521>
      [0.0]
    >,
    "kernel" => #Nx.Tensor<
      f32[2][1]
      EXLA.Backend<host:0, 0.1979063076.4252631046.85522>
      [
        [0.5544151663780212],
        [1.0918326377868652]
      ]
    >
  }
}

Would be transformed to:

[0.0, 0.0, 0.2250952422618866, -0.2300528585910797, 0.8318504691123962,
 1.00990629196167, 0.0, 0.5544151663780212, 1.0918326377868652]

The process to flatten parameters can be done by a recursive reduce through the parameters map. I am unable to come up with a way to unflatten the parameters. I will greatly appreciate ideas on how one can do that as well as suggestions and alternatives for the flattening and unflattening procedure.

I think I figured out how to do this. This may not be the most efficient way of doing it but it works.

The Code

defmodule AxonParams do
  def flatten(params) do
    params
    |> to_flat_map()
    |> Map.values()
    |> Enum.reduce(fn x, acc ->
      Nx.concatenate([Nx.flatten(acc), Nx.flatten(x)])
    end)
  end

  def unflatten(params, template) do
    template
    |> Enum.reduce({%{}, -1, params}, fn element, acc ->
      {key, {shape, names}} = element
      {curr_params, curr_idx, params} = acc

      start_idx = curr_idx + 1
      end_idx = start_idx + Tuple.product(shape) - 1
      param = Nx.reshape(params[start_idx..end_idx], shape, names: names)
      curr_params = Map.put(curr_params, key, param)

      {curr_params, end_idx, params}
    end)
    |> then(&elem(&1, 0))
    |> Enum.reduce(%{}, fn {k, v}, acc ->
      put_in(acc, Enum.map(String.split(k, "."), &Access.key(&1, %{})), v)
    end)
  end

  def extract_template(params) do
    params
    |> to_flat_map()
    |> Enum.map(fn {k, v} -> {k, {Nx.shape(v), Nx.names(v)}} end)
    |> Enum.into(%{})
  end

  def to_flat_map(params) when is_map(params) and not is_struct(params) do
    for {k, v} <- params, sub_key = to_string(k), sub_map <- to_flat_map(v), into: %{} do
      case sub_map do
        {key, val} -> {sub_key <> "." <> key, val}
        val -> {sub_key, val}
      end
    end
  end

  def to_flat_map(params), do: [params]
end

Testing

require Axon

model =
  Axon.input({nil, 2})
  |> Axon.dense(2, activation: :relu)
  |> Axon.dense(1, activation: :sigmoid)

params = Axon.init(model, compiler: EXLA)
IO.inspect(params, label: "Initialization params")

flat_params = AxonParams.flatten(params)
template = AxonParams.extract_template(params)
IO.inspect(flat_params, label: "Flattened params")
IO.inspect(template, label: "Params template")

unflattened_params = AxonParams.unflatten(flat_params, template)
IO.inspect(unflattened_params, label: "Unflattened params")

Output

Initialization params: %{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[2]
      EXLA.Backend<host:0, 0.3675893082.1876557832.9714>
      [0.0, 0.0]
    >,
    "kernel" => #Nx.Tensor<
      f32[2][2]
      EXLA.Backend<host:0, 0.3675893082.1876557832.9715>
      [
        [0.2250952422618866, -0.2300528585910797],
        [0.8318504691123962, 1.00990629196167]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      EXLA.Backend<host:0, 0.3675893082.1876557832.9716>
      [0.0]
    >,
    "kernel" => #Nx.Tensor<
      f32[2][1]
      EXLA.Backend<host:0, 0.3675893082.1876557832.9717>
      [
        [0.5544151663780212],
        [1.0918326377868652]
      ]
    >
  }
}
Flattened params: #Nx.Tensor<
  f32[9]
  EXLA.Backend<host:0, 0.3675893082.1876557832.9722>
  [0.0, 0.0, 0.2250952422618866, -0.2300528585910797, 0.8318504691123962, 1.00990629196167, 0.0, 0.5544151663780212, 1.0918326377868652]
>
Params template: %{
  "dense_0.bias" => {{2}, [nil]},
  "dense_0.kernel" => {{2, 2}, [nil, nil]},
  "dense_1.bias" => {{1}, [nil]},
  "dense_1.kernel" => {{2, 1}, [nil, nil]}
}
Unflattened params: %{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[2]
      EXLA.Backend<host:0, 0.3675893082.1876557832.9723>
      [0.0, 0.0]
    >,
    "kernel" => #Nx.Tensor<
      f32[2][2]
      EXLA.Backend<host:0, 0.3675893082.1876557832.9725>
      [
        [0.2250952422618866, -0.2300528585910797],
        [0.8318504691123962, 1.00990629196167]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      EXLA.Backend<host:0, 0.3675893082.1876557832.9726>
      [0.0]
    >,
    "kernel" => #Nx.Tensor<
      f32[2][1]
      EXLA.Backend<host:0, 0.3675893082.1876557832.9728>
      [
        [0.5544151663780212],
        [1.0918326377868652]
      ]
    >
  }
}