Nx.Defn.compile - errors when using Nx.gather and a tensor as a compile-time option.

According to the Nx documentation, it seems possible to store a “constant” tensor in a compiled anonymous function by passing them as options at compile time. I realize that in many cases this is a bad idea due to the caching mechanisms but in this case, the tensor would be used as a configurable lookup table. In some cases, “storing” Tensors in the function works fine but Nx.gather generates a compile-time error. Adding two tensors of which one is stored as a “option” is fine;

  defn anon_fn1(arg, opts) do
    t2 = opts[:t2]
    arg + t2
  end

  def create_anon_fn1() do
    t2 = Nx.tensor([1, 3, 5], type: :f64) |> Nx.backend_copy()

    Nx.Defn.compile(
      &anon_fn1/2,
      [
        Nx.template({1}, :f64),
        [t2: t2]
      ],
      compiler: EXLA
    )
  end

iex(12)> fn1 = create_anon_fn1()
#Function<135.35555145/1 in Nx.Defn.Compiler.fun/2>
iex(13)> fn1.(Nx.tensor([2], type: :f64))
#Nx.Tensor<
  f64[3]
  EXLA.Backend<host:0, 0.1447861285.840040468.11335>
  [3.0, 5.0, 7.0]
>

However, using Nx.gather creates a compile error….

  defn anon_fn2(arg, opts) do
    t2 = opts[:t2]
    Nx.gather(t2, arg)
  end

def create_anon_fn2() do
    t2 = Nx.iota({10}, type: :f64) |> Nx.backend_copy()

    Nx.Defn.compile(
      &anon_fn2/2,
      [
        Nx.template({3, 1}, :s32),
        [t2: t2]
      ],
      compiler: EXLA
    )
  end

iex(16)> fn2 = create_anon_fn2()
** (FunctionClauseError) no function clause matching in Nx.BinaryBackend.to_binary/1

    The following arguments were given to Nx.BinaryBackend.to_binary/1:

        # 1
        #Nx.Tensor<
          s32[3][1]

          Nx.Defn.Expr
          parameter a:0   s32[3][1]
        >

    Attempted function clauses (showing 1 out of 1):

        defp to_binary(%Nx.Tensor{data: %{state: data}})

    (nx 0.6.4) Nx.BinaryBackend.to_binary/1
    (nx 0.6.4) lib/nx/binary_backend.ex:2120: Nx.BinaryBackend.gather/3
    (nx 0.6.4) lib/nx.ex:14575: Nx.gather/3
    (nx 0.6.4) lib/nx/defn/compiler.ex:158: Nx.Defn.Compiler.runtime_fun/3
    (exla 0.6.4) lib/exla/defn.ex:387: anonymous fn/4 in EXLA.Defn.compile/8
    (exla 0.6.4) lib/exla/defn/locked_cache.ex:36: EXLA.Defn.LockedCache.run/2
    (stdlib 5.0.2) timer.erl:270: :timer.tc/2
    iex:16: (file)
iex(16)>

Using iEx to verify that the tensors are correctly applied to Nx.gather;

iex(18)> t2 = Nx.iota({10}, type: :f64)
#Nx.Tensor<
  f64[10]
  EXLA.Backend<host:0, 0.1447861285.840040468.11339>
  [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]
>
iex(19)> arg = Nx.tensor([[3], [2], [7]], type: :s32)
#Nx.Tensor<
  s32[3][1]
  EXLA.Backend<host:0, 0.1447861285.840040472.11967>
  [
    [3],
    [2],
    [7]
  ]
>
iex(20)> Nx.gather(t2, arg)
#Nx.Tensor<
  f64[3]
  EXLA.Backend<host:0, 0.1447861285.840040468.11340>
  [3.0, 2.0, 7.0]
>

Any thoughts on what is going on would be appreciated!

Thanks!!!

Environment:
Nx 0.6.4, EXLA 0.6.4
Erlang/OTP 26 [erts-14.0.2] [source] [64-bit] [smp:10:10] [ds:10:10:10] [async-threads:1] [jit] [dtrace]
Elixir (1.15.7)

The docs for Nx.Defn explain this with more detail but you cannot pass tensor as options. Options are not abstracted away into expressions and we need to convert all tensors into expressions in order to JIT them.

1 Like