Context
I am experimenting with text embedding with the hope of implementing semantic similarity search inside a Phoenix application.
My target use case involves a user writing a short sentence (typically 5 to 30 words). In less than a few seconds, I want to present the user with similar sentences out of a collection of equally short sentences previously written by other users.
The test that puzzles me
As a first quick test of feasibility, I am playing with the example posted by @jonatanklosko at Add text embedding serving · Issue #206 · elixir-nx/bumblebee · GitHub
{:ok, model_info} = Bumblebee.load_model({:hf, "bert-base-uncased"}, architecture: :base)
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "bert-base-uncased"})
text = "Hello, world!"
inputs = Bumblebee.apply_tokenizer(tokenizer, text)
Axon.predict(model_info.model, model_info.params, inputs).hidden_state[0]
The code executes without error but when I run it locally on my machine (MacBookAir <4yo), the last line Axon.predict(model_info.model, model_info.params, inputs).hidden_state[0]
takes more than 1 minute to complete.
In contrast, the Python equivalent presented at the top of the same GH thread (Add text embedding serving · Issue #206 · elixir-nx/bumblebee · GitHub) completes almost instantaneously (fractions of a second) on the same machine:
from transformers import AutoTokenizer, AutoModel
import torch
# Load pre-trained model tokenizer and model weights
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")
# Tokenize input text
text = "Hello, world!"
tokens = tokenizer.encode(text, add_special_tokens=True, return_tensors="pt")
# Generate model embeddings
with torch.no_grad():
embeddings = model(tokens)[0].squeeze(0) # Remove batch dimension
# Print the embeddings for the first token
print(embeddings[0])
I am guessing this is not normal and I am doing something wrong. Any idea what that could be? Is there a better way to retrieve text embedding vectors than the script from Add text embedding serving · Issue #206 · elixir-nx/bumblebee · GitHub I am playing with?
Notes
- I am using:
{:bumblebee, "~> 0.5.3"},
{:nx, "~> 0.7.0"},
- The 1 minute runtime I am reporting above is for the last
Axon.predict(model_info.model, model_info.params, inputs).hidden_state[0]
step alone (does not include the other mode/tokenizer loading steps). - I did read through Nx vs. Python performance for sentence-transformer encoding. I am guessing my issue is different than what’s discussed there since that post is “only” discussing 2x slower running time compared to equivalent Python code (much lower than the delta I am experiencing => for my application, I’d be more than happy with 2x slower than the equivalent Python runtime I am experiencing).
- Given my use case, since Python is fast enough, I realize I could let Python handle the embedding part and pick things up inside Phoenix after Python completes the embedding. But I’d prefer keeping it all in Elixir if possible.