According to the Nx documentation, it seems possible to store a “constant” tensor in a compiled anonymous function by passing them as options at compile time. I realize that in many cases this is a bad idea due to the caching mechanisms but in this case, the tensor would be used as a configurable lookup table. In some cases, “storing” Tensors in the function works fine but Nx.gather generates a compile-time error. Adding two tensors of which one is stored as a “option” is fine;
defn anon_fn1(arg, opts) do
t2 = opts[:t2]
arg + t2
end
def create_anon_fn1() do
t2 = Nx.tensor([1, 3, 5], type: :f64) |> Nx.backend_copy()
Nx.Defn.compile(
&anon_fn1/2,
[
Nx.template({1}, :f64),
[t2: t2]
],
compiler: EXLA
)
end
iex(12)> fn1 = create_anon_fn1()
#Function<135.35555145/1 in Nx.Defn.Compiler.fun/2>
iex(13)> fn1.(Nx.tensor([2], type: :f64))
#Nx.Tensor<
f64[3]
EXLA.Backend<host:0, 0.1447861285.840040468.11335>
[3.0, 5.0, 7.0]
>
However, using Nx.gather creates a compile error….
defn anon_fn2(arg, opts) do
t2 = opts[:t2]
Nx.gather(t2, arg)
end
def create_anon_fn2() do
t2 = Nx.iota({10}, type: :f64) |> Nx.backend_copy()
Nx.Defn.compile(
&anon_fn2/2,
[
Nx.template({3, 1}, :s32),
[t2: t2]
],
compiler: EXLA
)
end
iex(16)> fn2 = create_anon_fn2()
** (FunctionClauseError) no function clause matching in Nx.BinaryBackend.to_binary/1
The following arguments were given to Nx.BinaryBackend.to_binary/1:
# 1
#Nx.Tensor<
s32[3][1]
Nx.Defn.Expr
parameter a:0 s32[3][1]
>
Attempted function clauses (showing 1 out of 1):
defp to_binary(%Nx.Tensor{data: %{state: data}})
(nx 0.6.4) Nx.BinaryBackend.to_binary/1
(nx 0.6.4) lib/nx/binary_backend.ex:2120: Nx.BinaryBackend.gather/3
(nx 0.6.4) lib/nx.ex:14575: Nx.gather/3
(nx 0.6.4) lib/nx/defn/compiler.ex:158: Nx.Defn.Compiler.runtime_fun/3
(exla 0.6.4) lib/exla/defn.ex:387: anonymous fn/4 in EXLA.Defn.compile/8
(exla 0.6.4) lib/exla/defn/locked_cache.ex:36: EXLA.Defn.LockedCache.run/2
(stdlib 5.0.2) timer.erl:270: :timer.tc/2
iex:16: (file)
iex(16)>
Using iEx to verify that the tensors are correctly applied to Nx.gather;
iex(18)> t2 = Nx.iota({10}, type: :f64)
#Nx.Tensor<
f64[10]
EXLA.Backend<host:0, 0.1447861285.840040468.11339>
[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]
>
iex(19)> arg = Nx.tensor([[3], [2], [7]], type: :s32)
#Nx.Tensor<
s32[3][1]
EXLA.Backend<host:0, 0.1447861285.840040472.11967>
[
[3],
[2],
[7]
]
>
iex(20)> Nx.gather(t2, arg)
#Nx.Tensor<
f64[3]
EXLA.Backend<host:0, 0.1447861285.840040468.11340>
[3.0, 2.0, 7.0]
>
Any thoughts on what is going on would be appreciated!
Thanks!!!
Environment:
Nx 0.6.4, EXLA 0.6.4
Erlang/OTP 26 [erts-14.0.2] [source] [64-bit] [smp:10:10] [ds:10:10:10] [async-threads:1] [jit] [dtrace]
Elixir (1.15.7)