Does Nx has an equivalent for JAX tree_map jax.tree_util.tree_map — JAX documentation?
Thank you!
Does Nx has an equivalent for JAX tree_map jax.tree_util.tree_map — JAX documentation?
Thank you!
Couldn’t find any so I hacked a quick one that fits my use case. Any feedback or alternatives is more than welcome as I am pretty new to Elixir.
defmodule TreeMap do
import Nx.Defn
import Nx, only: [is_tensor: 1]
defn fun({a, b, c}) do
Nx.multiply(c, Nx.add(a, b))
end
def process_list_of_arbitrary_nested_maps(list_of_arbitrary_nested_maps) do
list_of_arbitrary_nested_maps
|> Enum.map(&Map.to_list/1)
|> Enum.map(&Enum.sort/1)
|> Enum.zip()
|> Enum.map(&Tuple.to_list/1)
|> Enum.map(fn elem ->
elem |> Enum.map_reduce("", fn x, _ -> {x |> elem(1), x |> elem(0)} end)
end)
end
def traverse(list_of_arbitrary_nested_maps, fun) do
condition =
list_of_arbitrary_nested_maps
|> Enum.map(&Map.values/1)
|> Enum.map(fn x -> x |> Enum.all?(&is_tensor/1) end)
|> Enum.all?()
cond do
condition ->
list_of_arbitrary_nested_maps
|> process_list_of_arbitrary_nested_maps()
|> Enum.map(fn {l, k} -> %{"#{k}" => fun.(l |> List.to_tuple())} end)
|> Enum.reduce(&Map.merge/2)
true ->
list_of_arbitrary_nested_maps
|> process_list_of_arbitrary_nested_maps()
|> Enum.map(fn {l, k} ->
%{
"#{k}" => traverse(l, &fun/1)
}
end)
|> Enum.reduce(&Map.merge/2)
end
end
end
Maybe Nx.Defn.Composite.traverse?
Thank you. I checked polaris/lib/polaris/shared.ex at ad85df596966548b7c38a89e2032263d0a0b4527 · elixir-nx/polaris · GitHub and I can see that it can be used for 2 arguments but not sure how to generalize to more than 2 arguments
What do you mean more than 2 args?
Maybe you want to use Nx.Container.traverse together with Nx.Defn.Composite.flatten_list in a way that your accumulator yields the corresponding “zipped” values for each container.
Yet another possibility is to flatten each container and zip them together
I mean more than 2 Nx.Containers
Thanks !
This immensely simplified my implementation
defmodule TreeMap do
import Nx.Defn
defn fun({a, b, c}) do
Nx.multiply(c, Nx.add(a, b))
end
def traverse(list_of_arbitrary_nested_maps, fun) when is_list(list_of_arbitrary_nested_maps) do
list_of_arbitrary_nested_maps
|> Enum.map(&List.wrap/1)
|> Enum.map(&Nx.Defn.Composite.flatten_list/1)
|> Enum.zip()
|> Enum.map(fn t -> fun.(t) end)
|> Kernel.then(fn v ->
{v, _} =
Nx.Defn.Composite.traverse(list_of_arbitrary_nested_maps |> Enum.at(0), 0, fn _, acc ->
{v |> Enum.at(acc), acc + 1}
end)
v
end)
end
end
Here’s my suggestion:
def traverse([template_container | _] = list_of_arbitrary_nested_maps, fun) when is_list(list_of_arbitrary_nested_maps) do
zipped_containers =
list_of_arbitrary_nested_maps
|> Enum.map(&Nx.Defn.Composite.flatten_list[&1])
|> Enum.zip_with(fun)
{v, []} =
Nx.Defn.Composite.traverse(template_container, zipped_containers, fn _, [h | t] ->
{h, t}
end)
v
end