Help needed fine-tuning a sentence transformer model

We are attempting to fine-tune a sentence transformer model and we’re getting an error when we run the Axon training loop. Our attempt can be found in this notebook: AudioTagger.Trainer · GitHub.

The error is:

(ArgumentError) cannot convert %{cache: #Axon.None<...>, hidden_state: #Nx.Tensor<
    f32[5][64][384]
    ...
to tensor because it represents a collection of tensors, use Nx.stack/2 or Nx.concatenate/2 instead
    (nx 0.6.4) lib/nx.ex:2168: Nx.to_tensor/1
    (nx 0.6.4) lib/nx.ex:5470: Nx.apply_vectorized/2
    (axon 0.6.0) lib/axon/losses.ex:895: Axon.Losses."__defn:cosine_similarity__"/3
    (axon 0.6.0) lib/axon/loop.ex:376: anonymous fn/8 in Axon.Loop.train_step/4
    (nx 0.6.4) lib/nx/defn/grad.ex:23: Nx.Defn.Grad.transform/3
    (axon 0.6.0) lib/axon/loop.ex:402: anonymous fn/6 in Axon.Loop.train_step/4
    (axon 0.6.0) lib/axon/loop.ex:1925: anonymous fn/6 in Axon.Loop.build_batch_fn/2
    (nx 0.6.4) lib/nx/defn/compiler.ex:158: Nx.Defn.Compiler.runtime_fun/3
    (exla 0.6.4) lib/exla/defn.ex:387: anonymous fn/4 in EXLA.Defn.compile/8
    (exla 0.6.4) lib/exla/defn/locked_cache.ex:36: EXLA.Defn.LockedCache.run/2
    (stdlib 5.1.1) timer.erl:270: :timer.tc/2
    (exla 0.6.4) lib/exla/defn.ex:385: EXLA.Defn.compile/8
    (exla 0.6.4) lib/exla/defn.ex:270: EXLA.Defn.__compile__/4
    (exla 0.6.4) lib/exla/defn.ex:256: EXLA.Defn.__jit__/5
    (nx 0.6.4) lib/nx/defn.ex:443: Nx.Defn.do_jit_apply/3
    (stdlib 5.1.1) timer.erl:270: :timer.tc/2
    (axon 0.6.0) lib/axon/loop.ex:1805: anonymous fn/4 in Axon.Loop.run_epoch/5
    (elixir 1.15.7) lib/enum.ex:4830: Enumerable.List.reduce/3
    (elixir 1.15.7) lib/enum.ex:2564: Enum.reduce_while/3
    (stdlib 5.1.1) timer.erl:295: :timer.tc/3
    (axon 0.6.0) lib/axon/loop.ex:1685: anonymous fn/6 in Axon.Loop.run/4
    (elixir 1.15.7) lib/range.ex:526: Enumerable.Range.reduce/5
    (elixir 1.15.7) lib/enum.ex:2564: Enum.reduce_while/3
    (axon 0.6.0) lib/axon/loop.ex:1669: Axon.Loop.run/4

We’ve been using this HF guide as inspiration: Train and Fine-Tune Sentence Transformers Models along with Bumblebee’s fine-tuning guide: Fine-tuning — Bumblebee v0.4.2

If anyone has some suggestions as to what we are missing, please let me know.

Thank you!

HTH - a working script that uses the same Bumblebee model as your code. You might need to remove the GPU config to run on a CPU…

#!/usr/bin/env elixir 

Mix.install(
  [
    {:nx, "~> 0.6"},
    {:exla, "~> 0.6"},
    {:bumblebee, "~> 0.4"},
    {:jason, "~> 1.4"}
  ],
  config: [
    nx: [
      default_backend: EXLA.Backend,
      default_defn_options: [compiler: EXLA]
    ],
    exla: [
      default_client: :cuda,
      clients: [
        host: [platform: :host],
        cuda: [platform: :cuda]
      ]
    ]
  ],
  system_env: [
    TF_CPP_MIN_LOG_LEVEL: 3, 
    XLA_TARGET: "cuda120"
  ]
)

defmodule CLI do
  def exec(args) do
    case args do
      [] ->
        {:error, "No argument was provided."}

      [arg] ->
        {:ok, Generator.runner(arg)}

      _ ->
        {:error, "More than one argument provided."}
    end
  end
end

defmodule Generator do
  def serving do
    {:ok, model_info} = Bumblebee.load_model({:hf, "sentence-transformers/all-MiniLM-L6-v2"})
    {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "sentence-transformers/all-MiniLM-L6-v2"})
    Bumblebee.Text.TextEmbedding.text_embedding(model_info, tokenizer,
      output_pool: :mean_pooling,
      output_attribute: :hidden_state,
      embedding_processor: :l2_norm
    )
  end

  def runner(text) do
    serving()
    |> Nx.Serving.run(text)
  end

  def to_json(tensor) do 
    tensor 
    |> Nx.to_list() 
    |> Jason.encode!()
  end
end

case CLI.exec(System.argv()) do
  {:ok, result} ->
    result.embedding
    |> IO.inspect()
    # |> Generator.to_json() 
    # |> IO.puts()

  {:error, msg} ->
    IO.puts("ERROR: #{msg})")
end