I’m probably making a rookie mistake, but I’m writing some Nx code and noticed that I hit an ArithmeticError when I change def
to defn
. The context is I’m generating random unit vectors by norming a vector whose entries are all picked from a Gaussian distribution: Generating random unit vectors in n-dimensional space
This works:
import Nx.Defn
defmodule Foo do
def bar(key, d) do
Nx.Random.split(key, parts: d)
end
end
key = Nx.random.key(42)
Foo.bar(key, 1)
But if I change the def
to defn
, I get:
** (ArithmeticError) bad argument in arithmetic expression: #Nx.Tensor<
s64
Nx.Defn.Expr
parameter a:1 s64
> * 1
(erts 15.0) :erlang.*(#Nx.Tensor<
s64
Nx.Defn.Expr
parameter a:1 s64
>, 1)
(elixir 1.17.1) lib/tuple.ex:167: Tuple.product/2
(elixir 1.17.1) lib/tuple.ex:167: Tuple.product/2
(nx 0.7.2) lib/nx/random.ex:189: Nx.Random."__defn:threefry2x32__"/2
(nx 0.7.2) lib/nx/random.ex:128: Nx.Random."__defn:split__"/2
(nx 0.7.2) lib/nx/defn/compiler.ex:218: Nx.Defn.Compiler.__remote__/4
#cell:452eidny2v67fxva:7: Foo."__defn:bar"/2
...
I noticed that replacing parts: d
with parts: 1
also works. This made me think maybe d
has to be an Nx.Tensor
, but calling Foo.bar(key, Nx.tensor(1))
still causes the same error.
Does Nx.Random.split
just not work inside of defn
?