Question on defn function definition resulting in 'invalid dimension in axis 0 in shape'


I have some troubles getting started with Nx. What is the problem with the following code?

defmodule MyNxModule do
    import Nx.Defn

    defn vector_with_all_zeros(length) do
        Nx.broadcast(0.0, {length})

When I run it from IEx

iex> MyNxModule.vector_with_all_zeros(10)

the following error returns

** (ArgumentError) invalid dimension in axis 0 in shape. Each dimension must be a positive integer, got #Nx.Tensor<
  parameter a:0   s64
> in shape {#Nx.Tensor<
   parameter a:0   s64
    (nx 0.1.0) lib/nx/shape.ex:35: Nx.Shape.validate!/3
    (nx 0.1.0) lib/nx.ex:2285: Nx.broadcast/3

It works when i replace the ‘length’ function param with a fixed integer value. Is there a problem with the function definition?

Thank you for your help.

From the Nx.Defn docs:

defn expects all inputs to be tensors, with the exception of a default argument (declared with \\ ) which will be treated as options.

Thank you!

Defining dimensions of tensors in function arguments is then best done using opts?

    defn vector_all_zeros(opts \\ [length: 5]) do
        l = transform(opts, &Keyword.fetch!(&1, :length))
        Nx.broadcast(0.0, {l})