Nx.Random.split ArithmeticError when I change def to defn

I’m probably making a rookie mistake, but I’m writing some Nx code and noticed that I hit an ArithmeticError when I change def to defn. The context is I’m generating random unit vectors by norming a vector whose entries are all picked from a Gaussian distribution: Generating random unit vectors in n-dimensional space

This works:

import Nx.Defn
defmodule Foo do
  def bar(key, d) do
    Nx.Random.split(key, parts: d)
  end
end

key = Nx.random.key(42)
Foo.bar(key, 1)

But if I change the def to defn, I get:

** (ArithmeticError) bad argument in arithmetic expression: #Nx.Tensor<
  s64
  
  Nx.Defn.Expr
  parameter a:1   s64
> * 1
    (erts 15.0) :erlang.*(#Nx.Tensor<
  s64
  
  Nx.Defn.Expr
  parameter a:1   s64
>, 1)
    (elixir 1.17.1) lib/tuple.ex:167: Tuple.product/2
    (elixir 1.17.1) lib/tuple.ex:167: Tuple.product/2
    (nx 0.7.2) lib/nx/random.ex:189: Nx.Random."__defn:threefry2x32__"/2
    (nx 0.7.2) lib/nx/random.ex:128: Nx.Random."__defn:split__"/2
    (nx 0.7.2) lib/nx/defn/compiler.ex:218: Nx.Defn.Compiler.__remote__/4
    #cell:452eidny2v67fxva:7: Foo."__defn:bar"/2
...

I noticed that replacing parts: d with parts: 1 also works. This made me think maybe d has to be an Nx.Tensor, but calling Foo.bar(key, Nx.tensor(1)) still causes the same error.

Does Nx.Random.split just not work inside of defn?

I think I found the answer by looking at the Nx code.

It seems like maybe it’s because Nx.Random.split expects the parts option to be a BEAM number:

… but defn converts arguments as tensors?

When numbers are given as arguments, they are always immediately converted to tensors on invocation. If you want to keep numbers as is or if you want to pass any other value to numerical definitions, they must be given as keyword lists.

From: Nx.Defn — Nx v0.7.3

So I’m guessing I’m supposed to pass d in a keyword list, similar to how split takes parts in the keyword list. Is this right?

Yes, exactly. Generally speaking, any options you pass as a keyword to Nx functions or defns must be passed as options to your defn as well. One way to get around that is to use deftransform, but internally you’re still gonna need to pass them as an option ultimately.

1 Like