Attempting to follow example on classifying text but running into trouble

Hello all, I am trying to fix an issue with my code. I was following Machine Learning in Elixir’s example on using Bumblebee to load a model from Hugging Face and then retraining it to some specific material. I was running into issues, though. Here’s my code so far:

model_name = {:hf, "distilbert-base-cased"}
model_module = Bumblebee.Text.Bert
{:ok, spec} = Bumblebee.load_spec(model_name, module: model_module, architecture: :for_sequence_classification)

spec = Bumblebee.configure(spec, num_labels: 3)

{:ok, %{model: model, params: params}} = Bumblebee.load_model(model_name, spec: spec)
{:ok, tokenizer} = Bumblebee.load_tokenizer(model_name)

ignore_phrases = [
  # This is a bunch of phrases I want to strip out of the input because I know it doesn't matter
]

{:ok, pid} = Postgrex.start_link(...)

query =
  """
  ...
  """

%{rows: results} = Postgrex.query!(pid, query, [], timeout: 60_000)

results =
  Enum.map(results, fn [_id, reason, body] ->
    body =
      ignore_phrases
      |> Enum.reduce(body, fn phrase, acc -> String.replace(acc, phrase, "") end)
      |> String.downcase()

    [reason, body]
  end)

batch_size = 32
max_length = 512

train_data =
  results
  |> Stream.chunk_every(batch_size)
  |> Stream.map(fn inputs ->
    {reasons, bodies} =
      inputs
      |> Enum.map(fn [reason, body] -> {reason, body} end)
      |> Enum.unzip()

    # Convert reason to a numerical representation
    reasons =
      reasons
      |> Enum.map(fn
        "n" -> 0
        "x" -> 1
        "y" -> 2
      end)
      |> Nx.tensor()

    tokens = Bumblebee.apply_tokenizer(tokenizer, bodies, length: max_length)

    {reasons, tokens}
  end)

model = Axon.nx(model, fn %{logits: logits} -> logits end)
optimizer = Axon.Optimizers.adamw(5.0e-5)

loss = &Axon.Losses.categorical_cross_entropy(&1, &2,
  from_logits: true,
  sparse: true,
  reduction: :mean
)

trained_model_state =
  model
  |> Axon.Loop.trainer(loss, optimizer, log: 1)
  |> Axon.Loop.metric(:accuracy)
  |> Axon.Loop.run(train_data, params, epochs: 3, compiler: EXLA)

When I run this, I get the following exception:

** (Axon.CompileError) exception found when compiling layer anonymous fn/2 in Axon.nx/3 named nx_0:

    ** (ArgumentError) given axis (1) invalid for shape with rank 1
        (nx 0.7.3) lib/nx/shape.ex:1121: Nx.Shape.normalize_axis/4
        (nx 0.7.3) lib/nx.ex:4199: Nx.axis_size/2
        (bumblebee 0.5.3) lib/bumblebee/layers.ex:951: anonymous fn/2 in Bumblebee.Layers.default_position_ids/2
    

(pass debug: true to build/compile see where the layer was defined)


Compiling of the model was initiated at:

    (axon 0.6.1) lib/axon/loop.ex:345: anonymous fn/6 in Axon.Loop.train_step/4
    (nx 0.7.3) lib/nx/defn/compiler.ex:173: Nx.Defn.Compiler.runtime_fun/3
    (exla 0.7.3) lib/exla/defn.ex:534: anonymous fn/4 in EXLA.Defn.compile/8
    (exla 0.7.3) lib/exla/defn/locked_cache.ex:36: EXLA.Defn.LockedCache.run/2
    (stdlib 6.0) timer.erl:590: :timer.tc/2
    #cell:gcye3lz3uwx7sksq:14: (file)

I’m not super experienced in machine learning stuff, so I wasn’t quite sure how to search it. When I tried, I didn’t find much.

Some info about my training dataset:

There are three classifications (“n”, “x”, and “y”). The bodies are text that are somewhere between 200 to 2000 words long. Any assistance would be super appreciated.

Currently I’m using a model made with fasttext and it’s working fairly well, but I was curious about making an equivalent in Elixir. The slight issue is that training a model with fasttext doesn’t really teach anything about machine learning so I still don’t quite feel like I know what I’m doing…

Thanks!

Just curious if anyone had any ideas. Thanks!