I think I figured out how to do this. This may not be the most efficient way of doing it but it works.
The Code
defmodule AxonParams do
def flatten(params) do
params
|> to_flat_map()
|> Map.values()
|> Enum.reduce(fn x, acc ->
Nx.concatenate([Nx.flatten(acc), Nx.flatten(x)])
end)
end
def unflatten(params, template) do
template
|> Enum.reduce({%{}, -1, params}, fn element, acc ->
{key, {shape, names}} = element
{curr_params, curr_idx, params} = acc
start_idx = curr_idx + 1
end_idx = start_idx + Tuple.product(shape) - 1
param = Nx.reshape(params[start_idx..end_idx], shape, names: names)
curr_params = Map.put(curr_params, key, param)
{curr_params, end_idx, params}
end)
|> then(&elem(&1, 0))
|> Enum.reduce(%{}, fn {k, v}, acc ->
put_in(acc, Enum.map(String.split(k, "."), &Access.key(&1, %{})), v)
end)
end
def extract_template(params) do
params
|> to_flat_map()
|> Enum.map(fn {k, v} -> {k, {Nx.shape(v), Nx.names(v)}} end)
|> Enum.into(%{})
end
def to_flat_map(params) when is_map(params) and not is_struct(params) do
for {k, v} <- params, sub_key = to_string(k), sub_map <- to_flat_map(v), into: %{} do
case sub_map do
{key, val} -> {sub_key <> "." <> key, val}
val -> {sub_key, val}
end
end
end
def to_flat_map(params), do: [params]
end
Testing
require Axon
model =
Axon.input({nil, 2})
|> Axon.dense(2, activation: :relu)
|> Axon.dense(1, activation: :sigmoid)
params = Axon.init(model, compiler: EXLA)
IO.inspect(params, label: "Initialization params")
flat_params = AxonParams.flatten(params)
template = AxonParams.extract_template(params)
IO.inspect(flat_params, label: "Flattened params")
IO.inspect(template, label: "Params template")
unflattened_params = AxonParams.unflatten(flat_params, template)
IO.inspect(unflattened_params, label: "Unflattened params")
Output
Initialization params: %{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[2]
EXLA.Backend<host:0, 0.3675893082.1876557832.9714>
[0.0, 0.0]
>,
"kernel" => #Nx.Tensor<
f32[2][2]
EXLA.Backend<host:0, 0.3675893082.1876557832.9715>
[
[0.2250952422618866, -0.2300528585910797],
[0.8318504691123962, 1.00990629196167]
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[1]
EXLA.Backend<host:0, 0.3675893082.1876557832.9716>
[0.0]
>,
"kernel" => #Nx.Tensor<
f32[2][1]
EXLA.Backend<host:0, 0.3675893082.1876557832.9717>
[
[0.5544151663780212],
[1.0918326377868652]
]
>
}
}
Flattened params: #Nx.Tensor<
f32[9]
EXLA.Backend<host:0, 0.3675893082.1876557832.9722>
[0.0, 0.0, 0.2250952422618866, -0.2300528585910797, 0.8318504691123962, 1.00990629196167, 0.0, 0.5544151663780212, 1.0918326377868652]
>
Params template: %{
"dense_0.bias" => {{2}, [nil]},
"dense_0.kernel" => {{2, 2}, [nil, nil]},
"dense_1.bias" => {{1}, [nil]},
"dense_1.kernel" => {{2, 1}, [nil, nil]}
}
Unflattened params: %{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[2]
EXLA.Backend<host:0, 0.3675893082.1876557832.9723>
[0.0, 0.0]
>,
"kernel" => #Nx.Tensor<
f32[2][2]
EXLA.Backend<host:0, 0.3675893082.1876557832.9725>
[
[0.2250952422618866, -0.2300528585910797],
[0.8318504691123962, 1.00990629196167]
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[1]
EXLA.Backend<host:0, 0.3675893082.1876557832.9726>
[0.0]
>,
"kernel" => #Nx.Tensor<
f32[2][1]
EXLA.Backend<host:0, 0.3675893082.1876557832.9728>
[
[0.5544151663780212],
[1.0918326377868652]
]
>
}
}