Does Nx has an equivalent for JAX tree_map?

Does Nx has an equivalent for JAX tree_map jax.tree_util.tree_map — JAX documentation?

Thank you!

2 Likes

Couldn’t find any so I hacked a quick one that fits my use case. Any feedback or alternatives is more than welcome as I am pretty new to Elixir.

defmodule TreeMap do
  import Nx.Defn
  import Nx, only: [is_tensor: 1]

  defn fun({a, b, c}) do
    Nx.multiply(c, Nx.add(a, b))
  end

  def process_list_of_arbitrary_nested_maps(list_of_arbitrary_nested_maps) do
    list_of_arbitrary_nested_maps
    |> Enum.map(&Map.to_list/1)
    |> Enum.map(&Enum.sort/1)
    |> Enum.zip()
    |> Enum.map(&Tuple.to_list/1)
    |> Enum.map(fn elem ->
      elem |> Enum.map_reduce("", fn x, _ -> {x |> elem(1), x |> elem(0)} end)
    end)
  end

  def traverse(list_of_arbitrary_nested_maps, fun) do
    condition =
      list_of_arbitrary_nested_maps
      |> Enum.map(&Map.values/1)
      |> Enum.map(fn x -> x |> Enum.all?(&is_tensor/1) end)
      |> Enum.all?()

    cond do
      condition ->
        list_of_arbitrary_nested_maps
        |> process_list_of_arbitrary_nested_maps()
        |> Enum.map(fn {l, k} -> %{"#{k}" => fun.(l |> List.to_tuple())} end)
        |> Enum.reduce(&Map.merge/2)

      true ->
        list_of_arbitrary_nested_maps
        |> process_list_of_arbitrary_nested_maps()
        |> Enum.map(fn {l, k} ->
          %{
            "#{k}" => traverse(l, &fun/1)
          }
        end)
        |> Enum.reduce(&Map.merge/2)
    end
  end
end
1 Like

Maybe Nx.Defn.Composite.traverse?

1 Like

Thank you. I checked polaris/lib/polaris/shared.ex at ad85df596966548b7c38a89e2032263d0a0b4527 · elixir-nx/polaris · GitHub and I can see that it can be used for 2 arguments but not sure how to generalize to more than 2 arguments :thinking:

1 Like

What do you mean more than 2 args?

1 Like

Maybe you want to use Nx.Container.traverse together with Nx.Defn.Composite.flatten_list in a way that your accumulator yields the corresponding “zipped” values for each container.

1 Like

Yet another possibility is to flatten each container and zip them together

2 Likes

I mean more than 2 Nx.Containers

1 Like

Thanks !

This immensely simplified my implementation :smiley:

defmodule TreeMap do
  import Nx.Defn

  defn fun({a, b, c}) do
    Nx.multiply(c, Nx.add(a, b))
  end

  def traverse(list_of_arbitrary_nested_maps, fun) when is_list(list_of_arbitrary_nested_maps) do
    list_of_arbitrary_nested_maps
    |> Enum.map(&List.wrap/1)
    |> Enum.map(&Nx.Defn.Composite.flatten_list/1)
    |> Enum.zip()
    |> Enum.map(fn t -> fun.(t) end)
    |> Kernel.then(fn v ->
      {v, _} =
        Nx.Defn.Composite.traverse(list_of_arbitrary_nested_maps |> Enum.at(0), 0, fn _, acc ->
          {v |> Enum.at(acc), acc + 1}
        end)

      v
    end)
  end
end
1 Like

Here’s my suggestion:

  def traverse([template_container | _] = list_of_arbitrary_nested_maps, fun) when is_list(list_of_arbitrary_nested_maps) do
    zipped_containers = 
      list_of_arbitrary_nested_maps
      |> Enum.map(&Nx.Defn.Composite.flatten_list[&1])
      |> Enum.zip_with(fun)
    
    {v, []} =
        Nx.Defn.Composite.traverse(template_container, zipped_containers, fn _, [h | t] ->
          {h, t}
        end)

   v
  end
2 Likes