Bumblebee - New Text Generation Model

I plan to use the following model

{:ok, bertqca} = Bumblebee.load_model({:hf, "voidful/bart-eqg-question-generator"})

with tokenizer

{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/bart-base"})

But on running

serving = Bumblebee.Text.text_classification(bertqca, tokenizer)

I get

** (ArgumentError) expected a model with architecture :for_sequence_classification, got :for_conditional_generation
    (bumblebee 0.1.0) lib/bumblebee/shared.ex:209: Bumblebee.Shared.validate_architecture!/2
    (bumblebee 0.1.0) lib/bumblebee/text/text_classification.ex:9: Bumblebee.Text.TextClassification.text_classification/3
    /Documents/qca.livemd#cell:tzn72neqi6ygxd42wtiunchytys3yjz2:1: (file)

So, following module does not support :for_conditional_generation .

I am new to machine learning. I would like to extend bumblebee to support above model.
Please show the way. @josevalim

Willing to raise a PR.

Thank you.

1 Like

One thing i noticed is that I am using wrong model.

it should not be text_classification in first place.

Try this from the Bumblebee tests: bumblebee/bart_test.exs at 9f38d63175e392eb2e7f3f74f635a57021f1e595 · elixir-nx/bumblebee · GitHub

  test "conditional generation" do
    {:ok, model_info} = Bumblebee.load_model({:hf, "facebook/bart-large-cnn"})
    {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/bart-large-cnn"})

    assert %Bumblebee.Text.Bart{architecture: :for_conditional_generation} = model_info.spec

    article = """
    PG&E stated it scheduled the blackouts in response to forecasts for high \
    winds amid dry conditions. The aim is to reduce the risk of wildfires. \
    Nearly 800 thousand customers were scheduled to be affected by the shutoffs \
    which were expected to last through at least midday tomorrow.

    inputs = Bumblebee.apply_tokenizer(tokenizer, article)

    generate =
      Bumblebee.Text.Generation.build_generate(model_info.model, model_info.spec,
        min_length: 0,
        max_length: 8

    token_ids = generate.(model_info.params, inputs)

    assert_equal(token_ids, Nx.tensor([[2, 0, 8332, 947, 717, 1768, 5, 2]]))

    assert Bumblebee.Tokenizer.decode(tokenizer, token_ids) == ["PG&E scheduled the"]

We are changing from text classification to text generation (which question generation falls under). Your model should be supported as it is a Bart finetuned model.

You will need to change the model name and the inputs.


This helps! Thank you !