Hey!
In the model card, they give the following Python example:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-large')
model = AutoModelForSequenceClassification.from_pretrained('BAAI/bge-reranker-large')
model.eval()
pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]
with torch.no_grad():
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
print(scores)
It translates to this Bumblebee code:
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "BAAI/bge-reranker-large"})
{:ok, model_info} = Bumblebee.load_model({:hf, "BAAI/bge-reranker-large"})
pairs = [
{"what is panda?", "hi"},
{"what is panda?",
"The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China."}
]
inputs = Bumblebee.apply_tokenizer(tokenizer, pairs)
outputs = Axon.predict(model_info.model, model_info.params, inputs)
scores = Nx.to_flat_list(outputs.logits)
#=> [-5.608547687530518, 5.762268543243408]
Note that Axon.predict is going to compile the model for each input. For production use case we would want to Axon.compile it, and ideally use Nx.Serving. This actually fits into the Bumblebee.Text.text_classification serving, except that currently inputs are require to be strings, and not string pairs. I’ve just pushed a change to main to allow string pairs, so this works:
serving = Bumblebee.Text.text_classification(model_info, tokenizer, scores_function: :none)
Nx.Serving.run(serving, pairs)
#=> [
#=> %{predictions: [%{label: "LABEL_0", score: -5.608547687530518}]},
#=> %{predictions: [%{label: "LABEL_0", score: 5.762268543243408}]}
#=> ]
Or in a more production setup:
serving =
Bumblebee.Text.text_classification(model_info, tokenizer,
scores_function: :none,
compile: [batch_size: 1, sequence_length: 128],
defn_options: [compiler: EXLA]
)
# Start under supervision tree
Kino.start_child({Nx.Serving, serving: serving, name: RerankerServing})
Nx.Serving.batched_run(RerankerServing, pairs)
#=> [
#=> %{predictions: [%{label: "LABEL_0", score: -5.608547687530518}]},
#=> %{predictions: [%{label: "LABEL_0", score: 5.762268543243408}]}
#=> ]