Can I use `is_list/1` with Nx.Defn?

Hey there. Is there a way to use is_list inside a defn function definition?

This is code for a normal distribution density:

 @spec pdf([number], number, number) :: number
  defn pdf(x, mu \\ 0, sigma \\ 1) do
    x = Nx.tensor(x)
    pi = pi()
    expon = - Nx.pow((x - mu) / sigma, 2) / 2
    denom = 1 / (sigma * (2 * pi) ** 0.5)
    Nx.exp(expon) * denom
  end

I want to check if x is a list so that I know when to use Nx.tensor and stop receiving the following warning:

warning: Nx.tensor/2 inside defn expects the first argument to be a literal (such as a list)
  defn pdf(x, mu \\ 0, sigma \\ 1) when is_list(x) do
    x = Nx.tensor(x)
    pi = pi()
    expon = - Nx.pow((x - mu) / sigma, 2) / 2
    denom = 1 / (sigma * (2 * pi) ** 0.5)
    Nx.exp(expon) * denom
  end

  defn pdf(x, mu \\ 0, sigma \\ 1)  when is_map(x) do
     tensor = 
          x 
          |> Map.to_list 
          |> Nx.tensor(x)
....

For some backends, defn will recompile its code for every input shape. Calling Nx.tensor inside is going to circumvent that.

What about a regular function that handles the argument-juggling?

  def pdf(numbers, mu \\ 0, sigma \\ 1)

  def pdf(numbers, mu, sigma) when is_list(numbers) do
    pdf_nx(Nx.tensor(numbers), mu, sigma)
  end

  def pdf(numbers, mu, sigma), do: pdf_nx(numbers, mu, sigma)

  defn pdf_nx(x, mu \\ 0, sigma \\ 1) do
    pi = pi()
    expon = - Nx.pow((x - mu) / sigma, 2) / 2
    denom = 1 / (sigma * (2 * pi) ** 0.5)
    Nx.exp(expon) * denom
  end
1 Like

The short answer is no. There are no lists inside defn. You must build your tensor outside of the defn and then pass it as argument to the defn, as @al2o3cr outlined above.

2 Likes

Thanks, it helped a lot! @josevalim @al2o3cr @BradS2S