Hey!
In the model card, they give the following Python example:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-large')
model = AutoModelForSequenceClassification.from_pretrained('BAAI/bge-reranker-large')
model.eval()
pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]
with torch.no_grad():
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
print(scores)
It translates to this Bumblebee code:
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "BAAI/bge-reranker-large"})
{:ok, model_info} = Bumblebee.load_model({:hf, "BAAI/bge-reranker-large"})
pairs = [
{"what is panda?", "hi"},
{"what is panda?",
"The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China."}
]
inputs = Bumblebee.apply_tokenizer(tokenizer, pairs)
outputs = Axon.predict(model_info.model, model_info.params, inputs)
scores = Nx.to_flat_list(outputs.logits)
#=> [-5.608547687530518, 5.762268543243408]
Note that Axon.predict
is going to compile the model for each input. For production use case we would want to Axon.compile
it, and ideally use Nx.Serving
. This actually fits into the Bumblebee.Text.text_classification
serving, except that currently inputs are require to be strings, and not string pairs. I’ve just pushed a change to main to allow string pairs, so this works:
serving = Bumblebee.Text.text_classification(model_info, tokenizer, scores_function: :none)
Nx.Serving.run(serving, pairs)
#=> [
#=> %{predictions: [%{label: "LABEL_0", score: -5.608547687530518}]},
#=> %{predictions: [%{label: "LABEL_0", score: 5.762268543243408}]}
#=> ]
Or in a more production setup:
serving =
Bumblebee.Text.text_classification(model_info, tokenizer,
scores_function: :none,
compile: [batch_size: 1, sequence_length: 128],
defn_options: [compiler: EXLA]
)
# Start under supervision tree
Kino.start_child({Nx.Serving, serving: serving, name: RerankerServing})
Nx.Serving.batched_run(RerankerServing, pairs)
#=> [
#=> %{predictions: [%{label: "LABEL_0", score: -5.608547687530518}]},
#=> %{predictions: [%{label: "LABEL_0", score: 5.762268543243408}]}
#=> ]