Nx vs. Python performance for sentence-transformer encoding

I’ve been thinking about the following: couldn’t we also allow a dynamic sequence length and just in time compile when we first get an input with a specific sequence length? Further requests should then be compiled. As the sequence length is finite, this would mean that one could either pre-compile every sequence length or “warmup” the serving.

We can do that for sure but it means you may compile the program several times. But it is something I will consider while exploring these ideas. :slight_smile:

1 Like

Having multiple variants sounds great! Both 1. and 2. make certain tradeoffs and which is better depends on the length distribution. If longer inputs are rare, then using 2. it will hit batch timeout and we will pad with empty batch items, while we may as well put some shorter inputs there. But then note that we pad on the client as part of tokenization and it impacts all of the input tensors (input ids, attention mask), but padding to higher length means we need to pad on the server. With 2. we always know what length to pad to.

1 Like

Hey,

I found this thread very helpful, but I would like to ask some additional questions / clarifications!

I am trying to speed up some text_embedding creations.
I have a data migration where we want to back-fill embeddings for existing entries. I tried various batch sizes for this and settled on 500.
I tried to create the changesets with the embedding with Enum.map and Task.async_stream. But they yield the same time. I realized that all the time is spend on generating the embedding.
I also tried to create multiple servings, both manually and with nimble_pool, but the results were the same. This leads me to believe that even though I had multiple servings the embedding creation is still sequential.
Then I found this thread and went back to a single serving and tried to tweak it.
All these finish the embedding changeset creation step in 55-57 seconds.

My vector size is 384.
I am caching the serving, is that a bad idea?

I tried the following setups:

1, In my config I have: config :nx, default_backend: EXLA.Backend
and create the serving by just calling Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer) as is. - I used this in all the scenarios above.

2, I tried the example form here:

    Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer,
      compile: [batch_size: 32, sequence_length: 8],
      defn_options: [compiler: EXLA]
    )

This is slower, it takes 76 seconds.

How do I determine the batch size and the sequence_length?

Thank you

This leads me to believe that even though I had multiple servings the embedding creation is still sequential.

Only one computation can be running at a time on the given device, so spawning multiple servings is not going to help (unless you have a cluster of multiple nodes, each with its own serving).

Just to be sure, you are using Nx.Serving.batched_run/2 and not Nx.Serving.run/2 right?

This is slower, it takes 76 seconds.

Just the computation or everything including application boot? Note that with compiler: EXLA the serving is going to compile everything into an efficient computation, which may take a bit, but then the computations themselves are faster.

How do I determine the batch size and the sequence_length?

sequence_length is used to pad/trim the input, usually you want to set it to the longest sequence the model supports, unless you know your inputs are always short and it can be reduced. batch_size depends on several factors, when running on a CPU increasing the batch_size at some point makes the computation take linearly as long since there is no more room for for parallelisation, when running on a GPU it depends on how much computation you can fit into the GPU memory; and in both cases depends on the expected number of requests (if it’s small then it doesn’t make sense to use large batch_size).

Hey,

Thanks for the reply!

I am using Nx.Serving.run/2 now. I will try out Nx.Serving.batched_run/2.
These are the docs: Nx.Serving — Nx v0.5.3
I imagine I can just use Bumblebee.Text.TextEmbedding.text_embedding/3 in the place of Nx.Serving.new(Nx.Defn.jit(&print_and_multiply/1))?

So sequence_length is the length of my entry? What is the default?
What is the upside of having a high batch_size on a CPU? What range would you recommend trying?

If you are just running it once for back-filling then Nx.Serving.run/2 may be fine, but it runs right away so you want chunk the list and pass a list of inputs to Nx.Serving.run/2 instead of Task.async_stream (in which case all runs are sequential anyway). Starting a serving under you app supervision tree and batched_run is mainly for long-running and handling concurrent requests.

So sequence_length is the length of my entry?

The input text is tokenized into a sequence of numbers (roughly one per word, sometimes many per word), the model needs a fixed-length sequence, so if the text is short we usually pad the sequence with zeros, if it’s too long it is truncated. The reason for fixed-length is basically so that we can compile the model upfront once with known input shape and run inference quickly. If you don’t set :compile then :sequence_length will effectively be the length of the input each time, but this means that the model is compiled multiple times (per each different input length).

What is the upside of having a high batch_size on a CPU?

Even though it doesn’t have as much parallel qualities as GPU, the XLA compiler can still do some parallelisation, so it could be the case that computation with batch_size: 4 is just a bit slower than batch_size: 1 (and not 4x slower), but it depends on model. It doesn’t hurt to have a big larger :batch_size. It depends on the model, but I would try like 8, 16.

1 Like

I’m using bumblebee to compute the embedding of user messages (I’m using HF sentence-transformers/paraphrase-multilingual-mpnet-base-v2), with the goal of semantically cache the response of a LLM and save CPU time (and money). Bumblebee needs on average 50% more time than the equivalent Python code, wrapped by a FastAPI server.

I went through this thread, but it’s not clear to me what I should to to improve Bumblebee’s performance. I’m compiling the model (Axon.compile(model_info.model, template, %{}, compiler: EXLA)), which increases the performance relative to the JIT compilation, but it’s still slow compared to Python.

Thanks!

It is hard to say without more information on how you are running the model. Per above, the results will vary depending if you are batching or not and if you are padding or not. Can you provide a snippet with more information on how you are starting the serving and calling it?

For the sequence_length, to arrive at a number, how could I get exact sequence_lengths for string entries?

Sure! Here are some more info:

I init the transformer in this way:

template = %{
  "attention_mask" => Nx.template({1, 128}, :u32),
  "input_ids" => Nx.template({1, 128}, :u32),
  "token_type_ids" => Nx.template({1, 128}, :u32)
}

model_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
{:ok, model_info} = Bumblebee.load_model({:hf, model_name})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, model_name})
{_init_fn, predict_fn} = Axon.compile(model_info.model, template, %{}, compiler: EXLA)

predict_fn and the tokenizer are then stored in a GenServer’s state to be re-used.

To compute the embeddings, I do:

inputs = Bumblebee.apply_tokenizer(tokenizer, text, length: 128)
result = predict_fn.(model_info.params, inputs)
input_mask_expanded = Nx.new_axis(inputs["attention_mask"], -1)
result.hidden_state
  |> Nx.multiply(input_mask_expanded)
  |> Nx.sum(axes: [1])
  |> Nx.divide(Nx.sum(input_mask_expanded, axes: [1]))

I hope that helps, thank you!

Hi @aus, that looks good to me. Just make sure that input_mask_expanded and result.hidden_state are allocated on the EXLA Backend and not the binary backend. You can print them to the terminal to confirm.

At the end, this may still be slower than the Python version for two reasons:

  1. EXLA for CPU is not as fast as it should be

  2. We don’t support dynamic shapes, which means you need to precompute/pad to 128. You could try passing larger inputs to both and ensure they both perform same at 128 entries or not

I believe you can call the tokenizer without a length and the size of the tensor it returns.

2 Likes

Btw, folks can also try running ONNX versions of the models if the lack of dynamic shapes is being a hindrance.

2 Likes

I have implemented batch_keys in Nx (soon to be v0.6): Support batch keys in Nx.Serving by josevalim · Pull Request #1268 · elixir-nx/nx · GitHub

We will later ship a new bumblebee version that supports multiple sequence lengths out of the box. :slight_smile:

5 Likes

I experimented with running a model through ortex today, but I don’t see how this improves the situation with different sequence lengths. As soon as I send inputs with different sizes to my serving, I get an error that I cannot merge batches due to incompatible templates, which makes sense. Sending the full tokenizer output to the Ortex model does not offer a performance gain compared to running EXLA with the full sequence length.

Running different servings for each sequence length I see similar performance running the all-MiniLM-L6-v2 model on EXLA (see Nx vs. Python performance for sentence-transformer encoding - #18 by steffend) and Ortex:

For small inputs, Ortex seems to perform a little better, for larger inputs a little worse.

Here is the Livebook I used:

Notice the part that uses the smallest input sequence length (by looking at the attention mask) in the client_preprocessing:

  def serving(model, tokenizer) do
    Nx.Serving.new(Ortex.Serving, model)
    |> Nx.Serving.client_preprocessing(fn inputs ->
      {:ok, encodings} = Tokenizers.Tokenizer.encode_batch(tokenizer, inputs)

      # get the maximum sequence length from the input by looking at the attention mask
      max_length =
        encodings
        |> Enum.map(&Tokenizers.Encoding.get_attention_mask/1)
        |> Enum.map(fn tensor -> Enum.sum(tensor) end)
        |> Enum.max(fn -> nil end)

      encodings =
        if max_length do
          for e <- encodings, do: Tokenizers.Encoding.truncate(e, max_length)
        else
          encodings
        end

      input_ids = for i <- encodings, do: Tokenizers.Encoding.get_ids(i)
      input_mask = for i <- encodings, do: Tokenizers.Encoding.get_attention_mask(i)
      token_type_ids = for i <- encodings, do: Tokenizers.Encoding.get_type_ids(i)

      inputs =
        Enum.zip_with([input_ids, input_mask, token_type_ids], fn [a, b, c] ->
          {Nx.tensor(a), Nx.tensor(b), Nx.tensor(c)}
        end)
        |> Nx.Batch.stack()

      {inputs, %{attention_mask: Nx.tensor(input_mask)}}
    end)
    |> Nx.Serving.client_postprocessing(fn {{output}, _meta}, client_info ->
      mean_pooling(output, client_info.attention_mask)
    end)
  end

I have to say that besides that Ortex works flawlessly and I’m able to run models that Bumblebee does not support yet (e.g. all-mpnet-base-v2): Running the all-mpnet-base-v2 sentence transformer in Elixir using Ortex · GitHub

1 Like

Could you benchmark Ortex without the serving? Basically calling it as the input arrives (which is what PyTorch would do)? Because otherwise, you are right, serving will still be the limitation (unless you want to give the main branch a try and provide multiple functions for different sequence lengths sizes).

And thanks for sharing the Ortex notebook!

2 Likes

Ah that makes sense. Ortex embeds/second by sequence length:

This also achieves full CPU utilization on my MacBook in comparison to EXLA :smiley:

So it seems like Ortex does not benefit that much from using Nx.Serving, as the model does not need to be precompiled to certain input shapes, in contrast to when using EXLA, is that right?

1 Like

I am not sure if we should generalize that to Ortex but we can conclude that’s true for Ortex+sberts running on CPU. I assume GPUs will be happier with batching than CPUs, even if not compiled. :slight_smile:

1 Like

@steffend on Bumblebee main you can specify multiple sequence lengths, in which case we compile multiple versions of the computation and inputs are batched depending on the length. This way short sequences don’t have overly long padding. Here’s an example:

# Text embedding with multiple lengths

```elixir
Mix.install([
  {:bumblebee, github: "elixir-nx/bumblebee"},
  {:rustler, ">= 0.0.0", optional: true},
  {:nx, github: "elixir-nx/nx", sparse: "nx", override: true},
  {:exla, github: "elixir-nx/nx", sparse: "exla", override: true},
  {:kino, "~> 0.10.0"}
])

Nx.global_default_backend(EXLA.Backend)
```

## 🐈‍⬛

```elixir
repo = "sentence-transformers/all-MiniLM-L6-v2"
{:ok, model_info} = Bumblebee.load_model({:hf, repo})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, repo})

serving =
  Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer,
    compile: [batch_size: 32, sequence_length: [16, 32, 64, 128, 512]],
    defn_options: [compiler: EXLA]
  )

Kino.start_child({Nx.Serving, serving: serving, name: MyServing})
```

```elixir
short_text = "this is a test"
Nx.Serving.batched_run(MyServing, short_text)
```

```elixir
long_text = String.duplicate("this is a test with a much longer text", 50)
Nx.Serving.batched_run(MyServing, long_text)
```

The first input falls under a shorter sequence length, meaning we use less padding and the computation is faster. The second input falls under the largest length, so we pad to 512 and the computation takes longer.
2 Likes

Thank you very much, this is awesome. Indeed a quick benchmark shows a staircase like pattern with the configured batch sizes:

So to summarize this topic: Elixir + Nx perform quite well for generating sentence embeddings. It is very important to make sure that the model is compiled for the right sequence lengths though, as the input is padded.
For varying input lengths, the newest Bumblebee code on GitHub now supports specifying multiple sequence lengths, as seen in post 36.

Another option is using Ortex and an ONNX model. Here is an example livebook that uses this approach: Running the all-mpnet-base-v2 sentence transformer in Elixir using Ortex · GitHub. In that case, using a serving does not improve the performance that much, at least when running on CPU. Therefore one can also just call Ortex.run directly.

Thanks for all the replies and insights. I’m happy to see that this lead to some improvements in Bumblebee and Nx!

I’m also happy to report that we’re currently working on moving our production setup from Python to Elixir+Nx at my job :smiley:

9 Likes