Configuration for Text Generation in Bumblebee

Howdy!

I’m following along Generative AI with Large Language Models and trying to implement the hands-on assignment using Nx.

However I’m not getting results with same quality to the python implementations, I understand that the answers could be different but the text generation I get is odd and deranged :smiley:

This is the python code:

from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer
model_name = 'google/flan-t5-base'
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

dialogue="#Person1#: What time is it, Tom?\n#Person2#: Just a minute. It's ten to nine by my watch.\n#Person1#: Is it? I had no idea it was so late. I must be off now.\n#Person2#: What's the hurry?\n#Person1#: I must catch the nine-thirty train.\n#Person2#: You've plenty of time yet. The railway station is very close. It won't take more than twenty minutes to get there."
inputs = tokenizer(dialogue, return_tensors='pt')
output = tokenizer.decode(
    model.generate(inputs["input_ids"],max_new_tokens=50)[0],
    skip_speical_tokens=True
)

print(f'MODEL GENERATION :\n{output}')

which prints:

MODEL GENERATION :
Person1: It’s ten to nine.

And the Elixir code is:

{:ok, model} = Bumblebee.load_model({:hf, "google/flan-t5-base"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google/flan-t5-base"})
dialogue = """
#Person1#: What time is it, Tom?
#Person2#: Just a minute. It's ten to nine by my watch.
#Person1#: Is it? I had no idea it was so late. I must be off now.
#Person2#: What's the hurry?
#Person1#: I must catch the nine-thirty train.
#Person2#: You've plenty of time yet. The railway station is very close. It won't take more than twenty minutes to get there
"""
config =
  config = %Bumblebee.Text.GenerationConfig{
    max_new_tokens: 50,
    forced_token_ids: [],
    no_repeat_ngram_length: 3,
    bos_token_id: 0, # This is mandatory, I copied from a smart cell!
    pad_token_id: 1 # This is mandatory, I copied from a smart cell!
  }
serving = Bumblebee.Text.generation(model,tokenizer, config)
Nx.Serving.run(serving, dialogue) |> IO.inspect

Which I get this back:

%{
  results: [
    %{
      text: "Person1#: It's ten to nine. It'll be ten minutes before the train leaves.nt is ten.nd is nine.m is nine o'clock.f is nine thirty"
    }
  ]
}

Consistently in other dialogues in the dataset I get extra gibberish sentences. e.g on another input

python code output:

Person1: I’m worried about my future.

Elixir code output:

#Person1#: I’m worried about my future.tatta i apologise for my paleness.temo i’m very young.mo a

What kind of configuration here is needed to get the same result?

Hey @slashmili :slight_smile: There are two differences between these versions:

  1. You are missing eos_token_id: 1 in the generation config. This is exactly the token that the model uses to indicate end of generation, so it makes sense that without it, you get extra continuation.
  2. The Python dialogue has a trailing dot, on the other hand the Elixir dialogue has a trailing newline. This isn’t particularly significant, but changes the output slightly.

Regarding 1), note that in general it’s best to load the predefined config from the model repository and then only override specific options (like :max_new_tokens), otherwise you need to know what token ids to set.

Here’s a version that matches the Python output (note the trailing backslash in the heredoc that prevents the final newline):

{:ok, model} = Bumblebee.load_model({:hf, "google/flan-t5-base"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "google/flan-t5-base"})
{:ok, config} = Bumblebee.load_generation_config({:hf, "google/flan-t5-base"})
config = Bumblebee.configure(config, max_new_tokens: 50)

dialogue = """
#Person1#: What time is it, Tom?
#Person2#: Just a minute. It's ten to nine by my watch.
#Person1#: Is it? I had no idea it was so late. I must be off now.
#Person2#: What's the hurry?
#Person1#: I must catch the nine-thirty train.
#Person2#: You've plenty of time yet. The railway station is very close. It won't take more than twenty minutes to get there.\
"""

serving = Bumblebee.Text.generation(model, tokenizer, config)
Nx.Serving.run(serving, dialogue) |> IO.inspect()
2 Likes

It works now! Thanks a lot for your help.

Now I’m back on track with the course :v: