Batched_run for bumblebee Model

want to do sentiment analysis and language predictions on text with a livebook. Both are available with bumblebee.
I am hitting the problem that the number of text is to large for the ram. In addition I am wondering why Nx.Serving.run only uses a single batch. Therefore I looked into the Nx.Serving documentation and thought that Nx.Serving.batched_run was the way to go. I tried to follow the documentation an d came up with:

{:ok, model_info} = Bumblebee.load_model({:hf, "finiteautomata/bertweet-base-sentiment-analysis"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "vinai/bertweet-base"})

english_sentiment_serving =
  Bumblebee.Text.text_classification(model_info, tokenizer,
    defn_options: [compiler: EXLA]
  )

  children = [
  {Nx.Serving,
   serving: english_sentiment_serving,
   name: EngSentimentServering,
   batch_size: 16}
]

Supervisor.start_child(children, strategy: :one_for_one)

So what I thing I do is, that I start a Genserver with the model,tokenizer on each of the 16 cores. What I do get is an Error that Genserver.whereis got the wrong type of arguments.

Hey @sehHeiden, you need start_link, rather than start_child. Also, to specify the batch size, pass :compile to the serving configuration, so the computation is compiled on serving startup.

{:ok, model_info} = Bumblebee.load_model({:hf, "finiteautomata/bertweet-base-sentiment-analysis"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "vinai/bertweet-base"})

english_sentiment_serving =
  Bumblebee.Text.text_classification(model_info, tokenizer,
    compile: [batch_size: 16, sequence_length: 130],
    defn_options: [compiler: EXLA]
  )

children = [
  {Nx.Serving, serving: english_sentiment_serving, name: EngSentimentServering}
]

Supervisor.start_link(children, strategy: :one_for_one)

When in Livebook, the best way to start the serving is Kino.start_child({Nx.Serving, ...}).

that I start a Genserver with the model,tokenizer on each of the 16 cores

As single serving process is started. Tokenization happens in the calling process to get raw input, the raw inputs are sent to the server and batched, then the model runs for that batch and utilizes multiple cores (how exactly is up to the XLA compiler).

Also note that this particular model accepts max input of 130 tokens (model_info.spec.max_positions), so longer texts need to be truncated (which we ensure by compiling for sequence_length: 130 above).

4 Likes

Yeah,
in a livebook it makes more sense to use Kino.start_child to start a single child.The code works.

Next question is, whether it makes sence to start a GenServer for it on livebooks at all. As I expect livebooks to a model only a few times.

What works different than expected is, than I started the model with Nx.Serving.batched_run(EngSentimentServer, english_toots). With english_toots being a Series of 60 texts.

Which throws the error that 60 larger than the batch size of 16. Which was unexpected, because I read the documentation and the function names, as that I would execute 16 of the 60 texts at a time.

I assume reshaping to (x, 16) texts whould not change that? and that I have to reshape and than execute each batch manually?

It should behave like this in Nx v0.6.1. :slight_smile:

1 Like

@josevalim

a) Great to know.
b) Works, after removing some decencies not needed anymore. kino_bumblebee requires Nx 0.5

The method @jonatanklosko proposed worked very well for the Language Classification model.

But it did not work for the sentiment analysis model above:

{:ok, model_info} = Bumblebee.load_model({:hf, "finiteautomata/bertweet-base-sentiment-analysis"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "vinai/bertweet-base"})

english_sentiment_serving =
  Bumblebee.Text.text_classification(model_info, tokenizer,
    compile: [batch_size: 16, sequence_length: 130],
    defn_options: [compiler: EXLA]

Kino.start_child({
  Nx.Serving,
  serving: english_sentiment_serving, name: EngSentimentServer
})

predictions = Nx.Serving.batched_run(EngSentimentServer, english_toots)

This also did not work with `Nx.Serving.run()

The error is:
** (ArgumentError) top_k input last axis size must be greater than or equal to k, got size=3 and k=5
(nx 0.6.1) lib/nx/shape.ex:2151: Nx.Shape.top_k/3
(nx 0.6.1) lib/nx.ex:14975: anonymous fn/2 in Nx.top_k/2
(nx 0.6.1) lib/nx.ex:5431: Nx.apply_vectorized/2
(bumblebee 0.4.0) lib/bumblebee/text/text_classification.ex:40: anonymous fn/5 in Bumblebee.Text.TextClassification.text_classification/3
(nx 0.6.1) lib/nx/defn/compiler.ex:158: Nx.Defn.Compiler.runtime_fun/3
(exla 0.6.1) lib/exla/defn.ex:387: anonymous fn/4 in EXLA.Defn.compile/8
(exla 0.6.1) lib/exla/defn/locked_cache.ex:36: EXLA.Defn.LockedCache.run/2
/home/path.livemd#cell:l6ujx36qnbng7fsjrpn4jlifxaaaz77j:1: (file)

@josevalim looks likes directly has to do with the EXLA/Nx Update.

Nx.Servering.run works with:
{:nx, “~>0.5.1”},
{:bumblebee, “~> 0.3”},
{:exla, “~> 0.5.1”},
… etc.

But it does not work when I update to:
Nx 0.6.1
EXLA 0.6.1
Bumblebee 0.4

I just checked that by downgrading to Nx 0.5.1, EXLA 0.5.1 and Bumblebee 0.3.
But that brakes Nx.Serving.batched_run^^

@sehHeiden ah yeah, I fixed this on main, please try {:bumblebee, github: "elixir-nx/bumblebee"}.

2 Likes

Works!