How to create a batched input from a DataFrame when model expects multiple inputs?

My model expects multiple named inputs, that means that I need to pass the data as a map with each input name like this:

inputs = %{
  "hours_per_week" => Nx.tensor([1, 1]),
  "capital_loss" => Nx.tensor([1, 1]),
  ...
}

predict_fn.(params, inputs)

If I have an Explorer.DataFrame I can achieve the same using DataFrame.to_columns(df). The problem with this is that it is an expensive operation if the dataframe is big.

Another issue is that I want to batch the inputs to send to the trainer loop. Normally I think I would be able to use Nx.to_batched to do the job, but that function seems to only work with a single input, not multiple.

So, how can I create a batched input for my trainer loop from my data frames that is also lazy so I don’t need to have the full data in memory all the time?

Here is what I was able to make that works, but it will generate all the batches in memory which will give me a OOM depending on the dataset size:

create_batches = fn features, target, batches ->
  features_batches =
    features
    |> DF.to_columns()
    |> Enum.map(fn {key, values} ->
      values
      |> Nx.tensor()
      |> Nx.to_batched(batches)
      |> Enum.to_list()
      |> Enum.map(fn tensors -> {key, tensors} end)
    end)
    |> Enum.zip()
    |> Enum.map(fn batch -> batch |> Tuple.to_list() |> Map.new() end)

  target_batches =
    target
    |> DF.to_series(atom_keys: true)
    |> Map.fetch!(:income_bracket)
    |> Series.to_tensor()
    |> Nx.new_axis(-1)
    |> Nx.to_batched(batches)
    |> Enum.to_list()

  Enum.zip(features_batches, target_batches)
end