Porting Sakana AI's TRINITY Qwen-based model to Elixir/Bumblebee/Nx/Axon

Aloha gang,

I’m working on a port of Sakana AI’s TRINITY, an evolved LLM coordinator:

TRINITY Paper
OpenReview
Downloadable Assets

I started by attempting to reconstruct the work itself, but that isn’t realistic for me, given skill/resource constraints. So I’ve instead pivoted to porting their Python mechanism that uses a base Qwen model to build their coordinator: Trinity Coordinator. (Right now it’s been deconstructed to use a local path-dep in mix.exs related to an inference library I’m building to generalize abstracting LLM providers, thus not “clone friendly” yet.)

Just seeing if this is of interest to anyone. Certainly open to input/feedback/ideas/critiques on approach. Please respond here or open an issue with your candid feedback. There must be someone out there with more knowledge/experience on such matters who can provide guidance?

I’ve created the safetensors file from the original python scripts, so nx can talk numbers properly. I’ve been working on a staged verification process so that the resulting coordinator based on Qwen will behave the same as the generated .pt file from their Python system. There are some nuances related to numpy -> nx and others that might prevent perfect alignment but I’m aiming for behavioral and functional parity.

One thing I see often these days is people creating amazing work and ideas in Elixir, but often hard coupled to providers and the like. One goal for trinity_coordinator is to have a working standalone system with built in routing to LLM’s, but also pluggable/modular for integration into any other codebase/framework/system.

ps: I wasn’t sure if this is the right forum category, but there was a note that said to use the nx forum if it’s nx related.

I put in some elbow grease yesterday. Seems to be pretty solid so far. Using a modified, tiny 0.6B Qwen SLM for robust routing might turn out to be quite useful!

Sakana is using their version of TRINITY (obviously, theirs is not served from Elixir/Bumblebee/etc) in their forthcoming commercial AI framework. So, this is pretty cutting edge. Glad to stand on the shoulders of giants (both in the research community, as well as the ML community in Elixir)

I’m running this on a 5060 Ti 16GB, but it should work on most any GPU since it’s an SLM.

Right now it’s just saying what model it would call, so it’s not complete framework yet, just a proving ground.

$ XLA_TARGET=cuda12 mix run examples/qwen_router_prompt_eval.exs

=== QWEN ROUTER PROMPT EVAL ===

What this does
  Loads the local adapted Qwen router once, sends fixed transcripts through it,
  and checks whether the selected agent slot and role match expectations.

  No external LLM/provider calls are made.

Agent labels
  The agent names are labels from the original Sakana checkpoint.
  Example: agent 4 is labeled "google/gemma-3-27b-it".
  This eval does not call Gemma; it only reports that the router selected
  checkpoint slot 4.

Artifact
  priv/sakana_trinity/adapted_qwen3_0_6b_layer26

Model
  base router model: Qwen/Qwen3-0.6B
  router head shape: [10, 1024]
  assertion mode: strict

Native logs
  hidden in normal mode: tmp/examples/qwen_router_prompt_eval.native.log
  use --debug-native-logs to print XLA/CUDA compiler logs inline


[1/12] math_direct - PASS

Prompt sent to router:
  user: What is 17 + 25? Answer briefly.

Expected route:
  agent 4: google/gemma-3-27b-it
  role  2: Verifier

Router returned:
  agent 4: google/gemma-3-27b-it
  role  2: Verifier

Router input tokens: 15


[2/12] math_proof - PASS

Prompt sent to router:
  user: Prove that the sum of the first n odd positive integers is n squared. Route this to the best next role.

Expected route:
  agent 0: gpt-5
  role  0: Worker

Router returned:
  agent 0: gpt-5
  role  0: Worker

Router input tokens: 26


[3/12] code_debug - PASS

Prompt sent to router:
  user: A Python function mutates its default list argument across calls. Identify the bug and propose the smallest fix.

Expected route:
  agent 0: gpt-5
  role  0: Worker

Router returned:
  agent 0: gpt-5
  role  0: Worker

Router input tokens: 23


[4/12] security_review - PASS

Prompt sent to router:
  user: Review this login flow for security risks: passwords are hashed, sessions are cookies, and reset tokens never expire.

Expected route:
  agent 4: google/gemma-3-27b-it
  role  1: Thinker

Router returned:
  agent 4: google/gemma-3-27b-it
  role  1: Thinker

Router input tokens: 24


[5/12] planning - PASS

Prompt sent to router:
  user: Create a concise implementation plan for migrating a small Elixir service from in-memory state to Postgres.

Expected route:
  agent 4: google/gemma-3-27b-it
  role  0: Worker

Router returned:
  agent 4: google/gemma-3-27b-it
  role  0: Worker

Router input tokens: 22


[6/12] verification_after_worker - PASS

Prompt sent to router:
  user: Calculate 6 * 7 and verify the answer.
  assistant: Worker answer: 6 * 7 = 42.

Expected route:
  agent 4: google/gemma-3-27b-it
  role  2: Verifier

Router returned:
  agent 4: google/gemma-3-27b-it
  role  2: Verifier

Router input tokens: 28


[7/12] needs_revision - PASS

Prompt sent to router:
  user: Check whether the answer is correct: 19 + 24 = 41.
  assistant: Worker answer: 19 + 24 = 41.

Expected route:
  agent 4: google/gemma-3-27b-it
  role  2: Verifier

Router returned:
  agent 4: google/gemma-3-27b-it
  role  2: Verifier

Router input tokens: 38


[8/12] ambiguous_decomposition - PASS

Prompt sent to router:
  user: This problem has unclear requirements. We may need to split it into assumptions, risks, and a concrete next action.

Expected route:
  agent 0: gpt-5
  role  2: Verifier

Router returned:
  agent 0: gpt-5
  role  2: Verifier

Router input tokens: 25


[9/12] creative_but_constrained - PASS

Prompt sent to router:
  user: Draft a friendly but precise support reply explaining a billing correction. Keep it under 120 words.

Expected route:
  agent 4: google/gemma-3-27b-it
  role  2: Verifier

Router returned:
  agent 4: google/gemma-3-27b-it
  role  2: Verifier

Router input tokens: 23


[10/12] longer_context - PASS

Prompt sent to router:
  system: You are routing work inside a three-role TRINITY loop. Worker solves, Thinker plans or redirects, Verifier checks.
  user: Given a release checklist, identify the next best role. The feature compiles, unit tests pass, docs changed, but no smoke test has been run yet.

Expected route:
  agent 4: google/gemma-3-27b-it
  role  2: Verifier

Router returned:
  agent 4: google/gemma-3-27b-it
  role  2: Verifier

Router input tokens: 62


[11/12] provider_failure_triage - PASS

Prompt sent to router:
  user: The last provider call timed out after 30 seconds. Decide whether to retry, ask a thinker for a smaller plan, or verify the partial answer.

Expected route:
  agent 4: google/gemma-3-27b-it
  role  2: Verifier

Router returned:
  agent 4: google/gemma-3-27b-it
  role  2: Verifier

Router input tokens: 33


[12/12] final_answer_check - PASS

Prompt sent to router:
  user: Solve and then verify: the capital of France is Paris.
  assistant: Worker answer: Paris is the capital of France.
  assistant: Thinker note: this is a factual lookup and likely ready.

Expected route:
  agent 4: google/gemma-3-27b-it
  role  0: Worker

Router returned:
  agent 4: google/gemma-3-27b-it
  role  0: Worker

Router input tokens: 41


Summary
  passed: 12
  failed: 0
  roles selected: Worker=4, Thinker=1, Verifier=7
  agent slots selected: 0=3, 4=9

PASS qwen_router_prompt_eval

Just for fun, I ported this to work on my 24GB M4 MBP using the Emily backend. In doing so, during the export, I ran into a limitation of the current native mlx libraries - their svd functions have no ‘thin’ mode and always materialise the full matrix. For the Qwen embedder that turned out to be ~92GB of memory.

I updated Emily to support this mode (in specific cases) directly via the Gram matrix, could do the one-time export in ~2s and was able to run the qwen router example.

If you are interested, the changes I made are here

2 Likes

Nx provides a default implementation that doesn’t have any explicit evals, so that might be a way to avoid those materializations.

1 Like

Ah, that makes sense. Thanks both.

@ausimian, the Emily result is extremely useful to know. The fact that native MLX SVD materialized the full matrix, but the Gram-matrix path made the one-time export work on a 24GB M4, is exactly the kind of backend detail I would not have known to check.

@polvalente, your point about Nx’s default implementation is really helpful. If the default implementation avoids explicit evals, then it may avoid the immediate materialization issue seen in the native MLX SVD path.

This helped me separate two concerns: what TRINITY needs from Nx today, and what backend-neutral code can safely assume. In TRINITY, the export path only needs the reduced SVD outputs for reconstruction. As an application author, though, it is tempting to read:


Nx.LinAlg.svd(tensor, full_matrices?: false)

as a “thin SVD” guarantee in both output shape and execution/memory behavior. The safer cross-backend interpretation seems to be:

  1. full_matrices?: false is an output-shape contract.

  2. Execution strategy and memory behavior are backend-dependent.

If that is the intended boundary, I have two practical questions:

  1. If a backend’s native SVD is known to materialize full matrices anyway, should the preferred backend behavior be to notice full_matrices?: false and fall back to Nx’s default implementation when that is likely to behave better?

  2. Would the Nx team be open to a small docs PR clarifying that full_matrices? controls SVD output shapes but does not guarantee a reduced-memory execution profile?

For the current TRINITY path, this does not force an immediate change: it is still fine to call Nx.LinAlg.svd/2 and let Nx/the selected backend choose native SVD, default Nx lowering, or a backend-specific workaround. But the next rework is to make backend/runtime selection a profile concern instead of something CUDA-shaped in the application code, so I want to make sure I am relying on the right Nx contract before generalizing it. I do not want application code branching on Emily vs EXLA vs Torchx unless there is no better option.

Side note: I did go down a rabbit hole thinking about how production code could ask a backend about these paths before starting a large SVD. I wrote that up here: Concrete proposal: SVD execution capabilities in Nx. It is probably too much architecture for this immediate issue, but I am leaving it as background in case backend-specific memory behavior becomes a larger Nx API discussion later.

@gtcode any backend-specific issues should be addressed in the specific backends. For instance, EMLX, which is the original MLX-based backend for Nx, does not fall back to the binary backend at all, and relies on these default implementations from Nx for the most part when things aren’t supported by the native backend.

I think it doesn’t really make sense to document this specific issue in Nx because this can happen with literally any option, as different backends are limited by different constraints. Maybe a general comment on this would be welcome.

Also, feel free to test this out with EMLX in the github main version, as I’m actively working on improvements. v0.3 should be out soon.

1 Like

I gave a very quick pass through the proposal, and I think there’s something being overlooked there. It focuses on Nx.block, but the same concerns are pertinent to all Nx callbacks.

Also keep in mind that blocks and block structs can be defined by Nx users, not just Nx itself.

Additionally, even Nx.dot might have different execution profiles depending on the backend.

Finally, lowered vs eval’d is also something that’s not a concern for the backend, but for the compiler.

Do feel welcome to open an issue for us to discuss there and hopefully achieve a great feature, but I suggest rethinking the approach to something more general that includes all Nx callbacks.

1 Like