Piolho
January 6, 2024, 1:23pm
1
Hey there. Is there a way to use `is_list`

inside a `defn`

function definition?

This is code for a normal distribution density:

```
@spec pdf([number], number, number) :: number
defn pdf(x, mu \\ 0, sigma \\ 1) do
x = Nx.tensor(x)
pi = pi()
expon = - Nx.pow((x - mu) / sigma, 2) / 2
denom = 1 / (sigma * (2 * pi) ** 0.5)
Nx.exp(expon) * denom
end
```

I want to check if x is a list so that I know when to use Nx.tensor and stop receiving the following warning:

```
warning: Nx.tensor/2 inside defn expects the first argument to be a literal (such as a list)
```

```
defn pdf(x, mu \\ 0, sigma \\ 1) when is_list(x) do
x = Nx.tensor(x)
pi = pi()
expon = - Nx.pow((x - mu) / sigma, 2) / 2
denom = 1 / (sigma * (2 * pi) ** 0.5)
Nx.exp(expon) * denom
end
defn pdf(x, mu \\ 0, sigma \\ 1) when is_map(x) do
tensor =
x
|> Map.to_list
|> Nx.tensor(x)
....
```

For some backends, `defn`

will recompile its code for every input shape. Calling `Nx.tensor`

inside is going to circumvent that.

What about a regular function that handles the argument-juggling?

```
def pdf(numbers, mu \\ 0, sigma \\ 1)
def pdf(numbers, mu, sigma) when is_list(numbers) do
pdf_nx(Nx.tensor(numbers), mu, sigma)
end
def pdf(numbers, mu, sigma), do: pdf_nx(numbers, mu, sigma)
defn pdf_nx(x, mu \\ 0, sigma \\ 1) do
pi = pi()
expon = - Nx.pow((x - mu) / sigma, 2) / 2
denom = 1 / (sigma * (2 * pi) ** 0.5)
Nx.exp(expon) * denom
end
```

1 Like

The short answer is no. There are no lists inside defn. You must build your tensor outside of the `defn`

and then pass it as argument to the defn, as @al2o3cr outlined above.

2 Likes

Piolho
January 6, 2024, 8:34pm
5
Thanks, it helped a lot! @josevalim @al2o3cr @BradS2S