Piolho
January 6, 2024, 1:23pm
1
Hey there. Is there a way to use is_list
inside a defn
function definition?
This is code for a normal distribution density:
@spec pdf([number], number, number) :: number
defn pdf(x, mu \\ 0, sigma \\ 1) do
x = Nx.tensor(x)
pi = pi()
expon = - Nx.pow((x - mu) / sigma, 2) / 2
denom = 1 / (sigma * (2 * pi) ** 0.5)
Nx.exp(expon) * denom
end
I want to check if x is a list so that I know when to use Nx.tensor and stop receiving the following warning:
warning: Nx.tensor/2 inside defn expects the first argument to be a literal (such as a list)
defn pdf(x, mu \\ 0, sigma \\ 1) when is_list(x) do
x = Nx.tensor(x)
pi = pi()
expon = - Nx.pow((x - mu) / sigma, 2) / 2
denom = 1 / (sigma * (2 * pi) ** 0.5)
Nx.exp(expon) * denom
end
defn pdf(x, mu \\ 0, sigma \\ 1) when is_map(x) do
tensor =
x
|> Map.to_list
|> Nx.tensor(x)
....
For some backends, defn
will recompile its code for every input shape. Calling Nx.tensor
inside is going to circumvent that.
What about a regular function that handles the argument-juggling?
def pdf(numbers, mu \\ 0, sigma \\ 1)
def pdf(numbers, mu, sigma) when is_list(numbers) do
pdf_nx(Nx.tensor(numbers), mu, sigma)
end
def pdf(numbers, mu, sigma), do: pdf_nx(numbers, mu, sigma)
defn pdf_nx(x, mu \\ 0, sigma \\ 1) do
pi = pi()
expon = - Nx.pow((x - mu) / sigma, 2) / 2
denom = 1 / (sigma * (2 * pi) ** 0.5)
Nx.exp(expon) * denom
end
1 Like
The short answer is no. There are no lists inside defn. You must build your tensor outside of the defn
and then pass it as argument to the defn, as @al2o3cr outlined above.
2 Likes
Piolho
January 6, 2024, 8:34pm
5
Thanks, it helped a lot! @josevalim @al2o3cr @BradS2S