Memory explosion when accessing a large map while using Nx+EXLA

Hey there,

I’m currently benchmarking some sentence transformers (see Nx vs. Python performance for sentence-transformer encoding) and stumbled upon something weird:

Basically I’m encoding a sentence into a vector using Bumblebee (Nx+EXLA) and then calculate a cosine similarity with a list of pre-calculated vectors (~150k). I’m benchmarking the whole process to see what performance I can achieve:

[%{embedding: query}] = Nx.Serving.batched_run(SentenceTransformer, ["this is a test input"])

sim =
  for chunk <- vectors do
    Bumblebee.Utils.Nx.cosine_similarity(query, chunk)
  end
  |> Nx.concatenate()

{similarity, labels} = Nx.top_k(sim, k: 10)

indexes = Nx.to_flat_list(labels)
scores = Nx.to_flat_list(similarity)

for {idx, score} <- Enum.zip(indexes, scores) do
  # accessing the sentence_map here leads to the issue
  %{sentence: sentence_map[idx], score: score}
end

This works fine until I try to map the resulting top k indexes back to the sentences using a 150k key map of index → sentence entries. When doing this, EXLA suddenly starts to consume a LOT of memory:

I think that the memory usage is in EXLA as is not visible in the observer:

Running on CUDA seems to confirm this, it even runs out of memory completely:

This is my script to reproduce:

# System.put_env("XLA_TARGET", "cuda118")

Mix.install([
  {:bumblebee, github: "elixir-nx/bumblebee", ref: "23de64b1b88ed3aad266025c207f255312b80ba6"},
  {:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
  {:exla, github: "elixir-nx/nx", sparse: "exla", override: true},
  {:axon, "~> 0.5.1"},
  {:kino, "~> 0.9"}
])

Nx.global_default_backend(EXLA.Backend)
# Nx.Defn.global_default_options(compiler: EXLA, client: :cuda)
Nx.Defn.global_default_options(compiler: EXLA, client: :host)

model_name = "sentence-transformers/all-MiniLM-L6-v2"
{:ok, model_info} = Bumblebee.load_model({:hf, model_name})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, model_name})

serving =
  Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer,
    compile: [batch_size: 64, sequence_length: 128],
    defn_options: [compiler: EXLA],
    output_attribute: :hidden_state,
    output_pool: :mean_pooling
  )

Kino.start_child({Nx.Serving, serving: serving, name: SentenceTransformer, batch_size: 64, batch_timeout: 50})

defmodule ConcurrentBench do
  def run(fun, concurrency \\ System.schedulers_online(), timeout \\ 10_000) do
    # use an erlang counter to count the number of function invocations
    counter = :counters.new(1, [:write_concurrency])

    # returns time in microseconds
    {taken, _} =
      :timer.tc(fn ->
        tasks =
          for _i <- 1..concurrency do
            Task.async(fn ->
              Stream.repeatedly(fn ->
                fun.()
                # only count after the function ran successfully
                :counters.add(counter, 1, 1)
              end)
              |> Stream.run()
            end)
          end

        results = Task.yield_many(tasks, timeout)

        # kill all processes
        Enum.map(results, fn {task, res} ->
          res || Task.shutdown(task, :brutal_kill)
        end)
      end)

    runs = :counters.get(counter, 1)
    ips = runs / (taken / 1_000_000)

    %{runs: runs, ips: ips}
  end
end

n = 150000

sentence_map = Map.new(1..n, fn i -> {i, i} end)
[%{embedding: vector}] = Nx.Serving.batched_run(SentenceTransformer, ["This is a test sentence"])

IO.puts("encoded sample input")

 
vectors =
  Stream.duplicate(vector, n)
  |> Stream.chunk_every(10000)
  |> Stream.map(fn chunk -> Nx.stack(chunk) end)
  |> Enum.to_list()

IO.puts("created dummy comparison vectors")
IO.puts("Running concurrent bench now")

ConcurrentBench.run(
  fn ->
    [%{embedding: query}] = Nx.Serving.batched_run(SentenceTransformer, ["this is a test input"])

    sim =
      for chunk <- vectors do
        Bumblebee.Utils.Nx.cosine_similarity(query, chunk)
      end
      |> Nx.concatenate()

    {similarity, labels} = Nx.top_k(sim, k: 10)

    indexes = Nx.to_flat_list(labels)
    scores = Nx.to_flat_list(similarity)

    # uncomment this
    # for {idx, score} <- Enum.zip(indexes, scores) do
    #   %{sentence: sentence_map[idx], score: score}
    # end

    nil
  end,
  16, 60_000
) |> IO.inspect()

See the “uncomment this” section.

I guess this is a bug? At least accessing the map should not lead to EXLA allocating memory, right?

Thanks for any suggestions!

Access will end up allocating memory in this case because you’re outside of defn-world, so EXLA can’t optmize things out properly. This means you’ll rely on just Nx, which, to maintain a functional interface, will end up doing a copy on read.

The access is happening on a map. Nothing in the commented out code looks slow to me, unless labels and similarity are really large but I don’t think it is the case. They are 10 elements only, right? Can we be looking at the wrong culprit?

Yes, these are only 10 results from top_k.

Last week we looked a bit into this and there probably is something related to the benchmarking infrastructure itself that’s provoking data copying between processes, and we think the difference that commented for actually makes is make the processes last longer, thereby changing how the garbage collection can work.

The initial hypothesis was that the loop would make the processes take longer to finish, therefore holding onto the memory copied to them, but I didn’t measure this. Instead, it behaved the same as a long Process.sleep where the memory usage increased more slowly, but to the same level as without the for/Process.sleep call.

I tried a few things like setting the vectors variable to a :persistent_term to read from there, but without success.

@josevalim perhaps there’s a chance each process is getting a JIT cache miss and getting its own compilation?

1 Like

We lock the cache key to avoid that, so it shouldn’t be the case, unless there is a bug.

1 Like