I’ve recently been interested in trying out Nx with Bumblebee so have been following along with the fantastic dockyard article on creating a basic RAG.
However, i’m seemingly falling at the first hurdle with just getting the predict which uses Nx.Serving
.
I have the following:
defmodule Rag.Embedding do
def serving() do
{:ok, model} = Bumblebee.load_model({:hf, "intfloat/e5-large-v2"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "intfloat/e5-large-v2"})
Bumblebee.Text.text_embedding(model, tokenizer,
embedding_processor: :l2_norm,
defn_options: [compiler: EMLX]
)
end
def predict(text) do
Nx.Serving.batched_run(__MODULE__, text)
end
end
and starting the serving with:
defmodule Rag.Application do
use Application
@impl true
def start(_type, _args) do
Nx.Defn.default_options(compiler: EMLX)
Nx.default_backend(EMLX.Backend)
children = [
{Nx.Serving, serving: Rag.Embedding.serving(), name: Rag.Embedding}
]
opts = [strategy: :one_for_one, name: Rag.Supervisor]
Supervisor.start_link(children, opts)
end
end
The model successfully downloads however it seems to hang on the predict.
rag on main [!?] is 📦 v0.1.0 via 💧 v1.18.3 (OTP 27) via ❄️ impure (nix-shell-env) took 47s
λ iex -S mix
Erlang/OTP 27 [erts-15.2.6] [source] [64-bit] [smp:12:12] [ds:12:12:10] [async-threads:1] [jit]
Interactive Elixir (1.18.3) - press Ctrl+C to exit (type h() ENTER for help)
iex(1)> Process.whereis(Rag.Embedding)
#PID<0.220.0>
iex(2)> Rag.Embedding.predict("Hello, world") # stalls here.
I’m coming from a python machine learning background so it may be something completely obvious.
Any help here would be appreciated!