test "conditional generation" do
{:ok, model_info} = Bumblebee.load_model({:hf, "facebook/bart-large-cnn"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "facebook/bart-large-cnn"})
assert %Bumblebee.Text.Bart{architecture: :for_conditional_generation} = model_info.spec
article = """
PG&E stated it scheduled the blackouts in response to forecasts for high \
winds amid dry conditions. The aim is to reduce the risk of wildfires. \
Nearly 800 thousand customers were scheduled to be affected by the shutoffs \
which were expected to last through at least midday tomorrow.
"""
inputs = Bumblebee.apply_tokenizer(tokenizer, article)
generate =
Bumblebee.Text.Generation.build_generate(model_info.model, model_info.spec,
min_length: 0,
max_length: 8
)
token_ids = generate.(model_info.params, inputs)
assert_equal(token_ids, Nx.tensor([[2, 0, 8332, 947, 717, 1768, 5, 2]]))
assert Bumblebee.Tokenizer.decode(tokenizer, token_ids) == ["PG&E scheduled the"]
end
Edit:
We are changing from text classification to text generation (which question generation falls under). Your model should be supported as it is a Bart finetuned model.
You will need to change the model name and the inputs.