Nx vs. Python performance for sentence-transformer encoding

Hey there,

I’ve got a project where I need to encode sentences using a sentence-transformer model. Currently, I’m using Python and the sentence-transformer package, but as the rest of the project is in Elixir I’d like to switch to Nx instead.

Using Bumblebee and Axon, I already built a small proof of concept and with the recent addition of a text embedding serving to Bumblebee, I wanted to do some quick benchmark to see how many encodes I can achieve on my CPU.

tl;dr: with a simple Python script I can achieve ~115 encodes per second with ~350% CPU load (-> ~4 Cores) on my MacBook Pro (M1 Max) and ~190 encodes per second when starting two separate Python processes (nearly full CPU utilization). Using Elixir and Nx I can only achieve ~55 encodes per second, while the average latency is more than double. Elixir also only achieves ~300% CPU usage. Starting multiple BEAM instances I can get to ~95 encodes per second with full CPU utilization.

The last point is the main one I’m interested in: there seems to be some kind of bottleneck that prevents me from achieving a similar performance to Python using only a single BEAM process. Has someone an idea why that’s the case? (It’s very possible that I’m just doing something wrong!). I expected the BEAM to be able to use all cores for encoding.

Apart from that, it seems like with full CPU utilization, I can only achieve half of the encode performance of Python using Nx, so there seem to be other factors in play too.

I’ve documented this and the code snippets here: GitHub - SteffenDE/nx-sentence-transformer-bench

6 Likes

One thing I noticed is that it does not look like you are setting the compiler for your Nx serving, so you are losing a lot of optimizations there. Try setting defn_options: [compiler: EXLA] when creating the serving

I’m also not sure what the batch size you’re setting is. You can fiddle with higher and lower batch sizes to see if it improves latency.

Servings also have some built in latency, im not familiar with how the benchmark works but you can fiddle with batch timeout settings to achieve better latency as well.

Finally, if the server sends sequences of different lengths, you eat a compilation cost with every request. You should set a static sequence length

3 Likes

Thank you for the suggestion!

I think Nx.global_default_backend(EXLA.Backend) might already do this? At least I don’t measure any real difference when setting this on my serving. In general I think that the serving is not the limiting factor. I added a script that does not use the Nx.Serving at all, basically just calling Axon.predict and the performance is very similar (nx_axon.exs).

I also tried with different batch sizes and batch timeouts, but again without any measurable differences.

Concerning sequence lengths: that shouldn’t be an issue here as the benchmark is always encoding the same sentence, but good to know!

The main question I have is if there is some bottleneck with EXLA and the dirty NIF schedulers maybe?

Here you can see the scheduler usage while running the benchmark 3 times for 10 seconds. Looks like only one dirty cpu scheduler is used at a time, although which one changes. I’m no export on NIFs at all, so maybe that’s some common knowledge, but if Nx can only use one dirty scheduler at a time, this might become a bottleneck in other cases as well? Only speculations on my side though.

1 Like

Hey @steffend there is a bit of a difference between backend and compiler. You can read about it some here (it may be somewhat outdated): Nx Tip of the Week #6 – Compiler or Backend? – Sean Moriarity

First I ran your benchmarks and got:

Running 1m test @ http://127.0.0.1:5001
  8 threads and 32 connections
  Thread Stats   Avg      Stdev     Max   +/- Stdev
    Latency   550.10ms   30.31ms 851.38ms   99.08%
    Req/Sec     6.76      2.49    30.00     96.83%
  3488 requests in 1.00m, 439.41KB read
Requests/sec:     58.03
Transfer/sec:      7.31KB

Then changing the serving to look like:

serving = Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer,
  compile: [batch_size: 32, sequence_length: 8],
  defn_options: [compiler: EXLA]
)

We get:

Running 1m test @ http://127.0.0.1:5001
  8 threads and 32 connections
  Thread Stats   Avg      Stdev     Max   +/- Stdev
    Latency    25.69ms    2.21ms  91.61ms   95.06%
    Req/Sec   156.34     12.48   202.00     86.03%
  74832 requests in 1.00m, 9.21MB read
Requests/sec:   1245.80
Transfer/sec:    156.94KB

So pretty significant speed up just compiling the serving. There are some other config options you can mess with, but you probably won’t get much more of a speed up than that

12 Likes

Oh wow, that’s indeed a very significant difference. I now also now what I did wrong: I tried to set the defn_options in the child specification instead of the serving function…

I’ll update the repo with the updated results later. Thank you! :smiley:

4 Likes

I changed Nx main so we raise if the wrong option is given when starting the serving. :slight_smile:

13 Likes

To be fair to Python, I realized that this is probably just because the sequence length was limited to 8. I am pretty sure that the sequence length the python library uses is 128 (see sentence-transformers/all-MiniLM-L6-v2 · Hugging Face):

The sequence length was limited to 128 tokens.

When using a sequence length of 128, Bumblebee+EXLA achieves ~120 encodes per second, which is basically the same performance as the Python server.

This brings us back to what I was wondering in Nx vs. Python performance for sentence-transformer encoding - #3 by steffend the BEAM with Bumblebee+EXLA does not seem to be able to fully utilize all CPU cores. If I start two instances of the nx_serving script on different ports and then execute 2 instances of the benchmark, I can achieve ~190 encodes per second (with a reverse proxy it strangely gets slower).

Optimizing the Python code by using a dedicated WSGI server to run on 8 processes gunicorn -w 8 -b 0.0.0.0:5001 simple:app instead of the flask development server:

$ wrk http://127.0.0.1:5001 -t 8 -c 32 -d 60
Running 1m test @ http://127.0.0.1:5001
  8 threads and 32 connections
  Thread Stats   Avg      Stdev     Max   +/- Stdev
    Latency    72.36ms   23.81ms 314.19ms   93.41%
    Req/Sec    56.71     12.13    90.00     61.91%
  26917 requests in 1.00m, 3.95MB read
Requests/sec:    447.96
Transfer/sec:     67.37KB

Two additional considerations:

  1. batch_timeout and batch_size is going to impact on the latency and memory usage, so I recommend playing with those numbers if you haven’t yet. Does the Python version have anything along those lines?

  2. EXLA assumes a computation will use all cores and it puts a lock around it. You can set XLA_FLAGS=--xla_force_host_platform_device_count=8 and it will start several CPU devices. You can then pass partitions: true to your Nx.Serving (in the child spec/sup tree). I am hoping this will at least allow you to use all cores within a single BEAM instance.

2 Likes

Yes, I already played with the batch settings and 32 seems to be a good batch size for the sequence length of 128. I did not play with the batch timeout yet, but latency is not my focus currently.

The Python version is very barebones and I don’t think that it performs any kind of batching at all. To batch I’d probably need to use something like Ray Serve: Scalable and Programmable Serving — Ray 2.9.0.

Ah that’s interesting and could very well explain what I’m seeing!

I tried this here (nx-sentence-transformer-bench/nx_serving_partitions.exs at main · SteffenDE/nx-sentence-transformer-bench · GitHub) and seeing lots of errors when benchmarking. Seems like EXLA does not like this:

17:49:30.868 [error] GenServer #PID<0.6906.0> terminating
** (stop) exited in: Nx.Serving.local_batched_run(MyServing, ["this is a test"])
    ** (EXIT) an exception was raised:
        ** (RuntimeError) Expected buffer to be placed on device 7
            (exla 0.5.3) lib/exla/executable.ex:56: EXLA.Executable.unwrap!/1
            (exla 0.5.3) lib/exla/executable.ex:19: EXLA.Executable.run/3
            (exla 0.5.3) lib/exla/defn.ex:346: EXLA.Defn.maybe_outfeed/7
            (stdlib 4.3.1.1) timer.erl:235: :timer.tc/1
            (exla 0.5.3) lib/exla/defn.ex:283: anonymous fn/7 in EXLA.Defn.__compile__/4
            (nx 0.5.3) lib/nx/defn.ex:313: anonymous fn/4 in Nx.Defn.compile/3
            (nx 0.5.3) lib/nx/serving.ex:1107: anonymous fn/2 in Nx.Serving.Default.handle_batch/3
            (nx 0.5.3) lib/nx/serving.ex:957: anonymous fn/3 in Nx.Serving.server_task_or_enqueue/3
    (nx 0.5.3) lib/nx/serving.ex:620: Nx.Serving.local_batched_run!/3
    nx_serving_partitions.exs:33: MyPlug.call/2
    (bandit 1.0.0-pre.9) lib/bandit/pipeline.ex:110: Bandit.Pipeline.call_plug/2
    (bandit 1.0.0-pre.9) lib/bandit/pipeline.ex:25: Bandit.Pipeline.run/6
    (bandit 1.0.0-pre.9) lib/bandit/http1/handler.ex:27: Bandit.HTTP1.Handler.handle_data/3
    (bandit 1.0.0-pre.9) lib/bandit/delegating_handler.ex:18: Bandit.DelegatingHandler.handle_data/3
    (bandit 1.0.0-pre.9) /Users/steffen/Library/Caches/mix/installs/elixir-1.14.5-erts-13.2.2.1/9825c67976b12b773e3fa7c81710c74f/deps/thousand_island/lib/thousand_island/handler.ex:399: Bandit.DelegatingHandler.handle_continue/2
    (stdlib 4.3.1.1) gen_server.erl:1123: :gen_server.try_dispatch/4

To reproduce:

$ XLA_FLAGS=--xla_force_host_platform_device_count=8 elixir nx_serving_partitions.exs
# then
$ wrk http://127.0.0.1:5001 -t 8 -c 32
2 Likes

Nice, I will investigate. Also, I recommend playing a bit with the timeout just in case (try 10ms and 1000ms as a double check).

2 Likes

@steffend it has been fixed in main here: 4e21e0467ccd5ff6a54a0115f0fe79420e089f5a

You may need to have both nx and exla pointing at that, if you have any questions, please let me know. :slight_smile:

6 Likes

Also, please double check that both operations return the final data, as frameworks (both Elixir and Python) can return the output tensors without the computation fully concluding.

Finally, please double check if the SentenceTransformer is indeed padding. IIRC padding is not applied on PyTorch if you are not batching.

1 Like

Yes, indeed that fixes the particular error. Thank you for looking into this!
Interestingly, the performance is still the same with 8 local devices (~117 encodes/second), though the scheduler usage in the observer looks much messier:

I’ll try. I still have much to learn in the ML space. I guess what you’re trying to say is that if the Python version does not pad the input, my short test sentence would lead to wrong results? Looking through the code I think it might pad the input (sentence-transformers/sentence_transformers/models/Transformer.py at 179b659621c680371394d507683b25ba7faa0dd8 · UKPLab/sentence-transformers · GitHub), but I’m not sure if that’s really the correct piece of code.

When I find the time I will also try to compare the results of the Python and Elixir code. I have a Livebook that compute the same cosine similarities as Python using Bumblebee+Axon (no serving, as the mean pooling of the serving has some issues - Bumblebee.Text.TextEmbedding output_pool crashes · Issue #216 · elixir-nx/bumblebee · GitHub). When I have more results, I’ll update the repo and this thread.

1 Like

Looks like the correct piece of code to me. So it pads to the longest input sequence (so without batching that’s no padding altogether):

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
tokenizer(["hey", "hello world"], padding=True, truncation='longest_first', return_tensors="pt", max_length=100)
#=> {'input_ids': tensor([[ 101, 4931,  102,    0], [ 101, 7592, 2088,  102]]), ...}

On the contrary we always pad do the maximum sequence length, so that we only compile once.

1 Like

Keep in mind you may not want to run eight instances. When I tried this, XLA took all cores and we could not push traffic enough to the serving. :smiley:

@steffend @jonatanklosko @seanmor5 I have been thinking about this and it is clear that we are more performant but forcing a certain sequence length is going to be an issue because we are always working with the worst case.

I can think of two solutions to the problem. Both are based on allowing multiple sequence lengths. For example, instead of 128, we could say 16, 32, 64, 96, and 128. If we do so, we have two options:

  1. Allow multiple sequence lengths in the same batch and then pad to the highest. For example, if we get 18, 23, 42, 55, and 90 on a batch, we will pad to 96.

  2. Allow multiple batch keys. In the example above, 18 and 23 go to the “32-padding batch”. 42 and 55 go to the “64-padding batch” and 90 goes to the “96-padding batch”. Each batch have their own size and individual timeouts. This means better performance but you will need to balance the batch size and batch timeout accordingly (if the timeout is high, it is more likely you will always hit the timeout).

I am thinking the batch keys approach makes the most sense but I would love to hear your thoughts. :slight_smile:

3 Likes

I’ve been running some tests comparing the results more thoroughly this week and will probably post an update tomorrow. I can confirm that EXLA performs better than Python when using the full sequence length. I also started playing with CUDA on AWS, but there I still need to run some more tests.

To measure the impact of the sequence length, I adapted my serving to always tokenize twice. One time with the full sequence length and then again limited to the actual sequence length of the input. The encode/second graph looks like this for EXLA (x-axis sequence length, y-axis encodes/sec):

This is the graph for Python (not quite fair as it goes through an extra HTTP request):

And finally I’m attaching the Livebook I used to generate these graphs.

All in all, Elixir and EXLA perform well. The only thing remaining is that I could not get the CPU to be fully loaded with EXLA (the same for CUDA).

The first one seems similar to what Python does, always using the longest input sequence length, if I understood that right.
I’ve been thinking about the following: couldn’t we also allow a dynamic sequence length and just in time compile when we first get an input with a specific sequence length? Further requests should then be compiled. As the sequence length is finite, this would mean that one could either pre-compile every sequence length or “warmup” the serving.

3 Likes

I’ve been thinking about the following: couldn’t we also allow a dynamic sequence length and just in time compile when we first get an input with a specific sequence length? Further requests should then be compiled. As the sequence length is finite, this would mean that one could either pre-compile every sequence length or “warmup” the serving.

We can do that for sure but it means you may compile the program several times. But it is something I will consider while exploring these ideas. :slight_smile:

1 Like

Having multiple variants sounds great! Both 1. and 2. make certain tradeoffs and which is better depends on the length distribution. If longer inputs are rare, then using 2. it will hit batch timeout and we will pad with empty batch items, while we may as well put some shorter inputs there. But then note that we pad on the client as part of tokenization and it impacts all of the input tensors (input ids, attention mask), but padding to higher length means we need to pad on the server. With 2. we always know what length to pad to.

1 Like