Hey, I know compiling with EMLX is unstable at the moment.
Is there already a plan how to make the compiler work?
I still tried to compile smollm2 with EMLX and got left with some questions.
Here a script for reproduction. I’m running this on an M2 mac.
Mix.install([
{:bumblebee, "~> 0.6.0"},
{:nx, "~> 0.9.0"},
{:emlx, github: "elixir-nx/emlx"}
])
backend = {EMLX.Backend, device: :gpu}
compiler = Nx.Defn.Evaluator
Nx.global_default_backend(backend)
repo = {:hf, "HuggingFaceTB/SmolLM2-135M-Instruct"}
{:ok, model_info} = Bumblebee.load_model(repo)
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)
{:ok, generation_config} = Bumblebee.load_generation_config(repo)
generation_config =
Bumblebee.configure(generation_config,
max_new_tokens: 256,
strategy: %{type: :multinomial_sampling, top_p: 0.6}
)
serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
compile: [batch_size: 1, sequence_length: 512],
stream: false,
defn_options: [compiler: compiler]
)
{:ok, _pid} =
Supervisor.start_link([{Nx.Serving, name: Serving, serving: serving}], strategy: :one_for_one)
prompt = """
<|im_start|>user
Tell me a joke<|im_end|>
<|im_start|>assistant
"""
Nx.Serving.run(serving, prompt) |> dbg
Nx.Serving.batched_run(Serving, prompt) |> dbg
1. Not compiling
When I don’t set the compile option, everything just works as compilation is disabled (just an observation, no question here).
serving =
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
- compile: [batch_size: 1, sequence_length: 512],
stream: false,
defn_options: [compiler: compiler]
)
2. Compiling with Nx.Defn.Evaluator
When I set the compile option, the compilation with Nx.Defn.Evaluator gets stuck, similar to what’s described in Basic Bumblee text embedding run seemingly stalling and Bumblebee: Slow load_model in GenServer, slow Nx.Serving.run in exs file .
Nx.Serving.run finishes, but Nx.Serving.batched_run doesn’t.
I digged deeper and found that this is the place where it gets stuck in Nx.Serving:
wrapped_function = fn ->
:telemetry.span([:nx, :serving, :execute], %{module: module}, fn ->
if hooks_table do
:ets.insert(hooks_table, {partition, ref_sizes})
end
{output, metadata} = function.()
for {[ref | pids], start, size} <- ref_sizes do
send(ref, {ref, {:batch, {start, size, output, metadata}}})
for pid <- pids do
send(pid, {ref, size})
end
end
{:done, %{metadata: metadata, module: module}}
end)
end
+ :persistent_term.put(:fun, wrapped_function)
+
+ task = Task.Supervisor.async_nolink(state.task_supervisor, fn -> fun = :persistent_term.get(:fun); fun.() end)
- task = Task.Supervisor.async_nolink(state.task_supervisor, wrapped_function)
Looks to me like the call to Task.Supervisor.async_nolink/2 never returns.
I actually ran into a similar issue when using the EMLX compiler in nif_call, I will talk about that below.
There, I found a solution which I applied here again: putting wrapped_function into persistent_term and reading it back from there inside the task.
So, this actually works.
I just have no clue why it’s not working in the original version.
I first noticed that the size of wrapped_function is large when using Nx.Defn.Evaluator compared to what EXLA is producing (size retrieved using :erts_debug.size/1: 431971 vs. 31294).
That’s why I tried storing and reading from persistent_term in the first place.
However, I can run closures of similar size as task in iex, and after all it doesn’t seem huge to me.
According to my calculations, that should be around 3.5 MB.
So, if the size is not the problem, why is this not working in the original version, but working when using persistent_term?
3. Compiling with EMLX
When using the EMLX compiler I ran into a similar issue: in nif_call, it gets stuck at passing the callback function when trying to register the function, here a GenServer.call never returns.
Here, I found the solution with persistent_term.
in lib/emlx.ex
callback = fn args ->
args = Enum.map(args, fn ref -> fn -> EMLX.Backend.to_nx({device, ref}) end end)
eval_fun.([args])
|> Nx.Defn.Composite.flatten_list()
|> Enum.map(fn %Nx.Tensor{data: %{ref: {_device, ref}}} -> ref end)
end
- fun = NifCall.run(EMLX.Runner, callback, &nif_compile(nif_args, &1))
+ read_key = {__MODULE__, :read, key}
+ :persistent_term.put(read_key, callback)
+
+ receive_fun = fn -> :persistent_term.get(read_key) end
+ fun = NifCall.run(EMLX.Runner, receive_fun, &nif_compile(nif_args, &1))
To make that work, I added the option to register functions by passing an arity 0 receive function in nif_call.
lib/runner.ex
def handle_call({:register, owner, receive_function}, _from, state)
when is_function(receive_function, 0) do
function = receive_function.()
ref = Process.monitor(owner)
{:reply, {self(), ref}, %{state | refs: Map.put(state.refs, ref, function)}}
end
Same question here: what’s the actual problem?
4. eval in mlx and compiliation
With this out of the way, the compilation still fails because of eval calls on C++ side which are not allowed when compiling with mlx.
I tried to avoid functions that call eval but it looks to me like it’s necessary for the way Nx backends work, at some point we must convert a tensor to a number, transfer to another backend, etc., so we must have the actual tensor data available at that point.
Is there actually a way around that? Is there already a plan how to move on?






















