Hey jose, thank you very much for your time!
Forgive me, but I’m pretty sure I didn’t express myself correctly in my initial post. After setting the backend to EXLA
, when I use Nx.gather
in my code, it results in an error.
Basically, I have a tensor of gradients:
#Nx.Tensor<
f32[4][2]
[
[1.0, 1.0],
[-1.0, 1.0],
[1.0, -1.0],
[-1.0, -1.0]
]
>
And a tensor (i
) with “random” indices for each octave:
#Nx.Tensor<
vectorized[x: 1][octaves: 8]
u16
[
[0, 2, 3, 2, 2, 2, 2, 1]
]
>
What I need is to generate a tensor replacing each index with the corresponding gradient:
Nx.gather(gradients, Nx.reshape(i, {1}))
With Nx.BinaryBackend
, I get what I need:
#Nx.Tensor<
vectorized[x: 1][octaves: 8]
f32[2]
[
[
[1.0, 1.0],
[1.0, -1.0],
[-1.0, -1.0],
[1.0, -1.0],
[1.0, -1.0],
[1.0, -1.0],
[1.0, -1.0],
[-1.0, 1.0]
]
]
>
But when I set the default_backend
to EXLA.Backend
, the same code won’t work:
#Nx.Tensor<
f32[4][2]
EXLA.Backend<host:0, 0.3185043492.1851392032.250556>
[
[1.0, 1.0],
[-1.0, 1.0],
[1.0, -1.0],
[-1.0, -1.0]
]
>
#Nx.Tensor<
vectorized[x: 1][octaves: 8]
u16
EXLA.Backend<host:0, 0.3185043492.1851392024.250510>
[
[0, 2, 3, 2, 2, 2, 2, 1]
]
>
Then executing:
Nx.gather(gradients, Nx.reshape(i, {1}))
Results in the following error:
"builtin.module"() ({
"func.func"() <{function_type = (tensor<1x8x4x2xf32>, tensor<1x8x3xi32>) -> tensor<1x8x2xf32>, sym_name = "main", sym_visibility = "public"}> ({
^bb0(%arg0: tensor<1x8x4x2xf32>, %arg1: tensor<1x8x3xi32>):
%0 = "mhlo.gather"(%arg0, %arg1) <{dimension_numbers = #mhlo.gather<offset_dims = [0], collapsed_slice_dims = [0, 1, 2], start_index_map = [0, 1, 2], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = dense<[1, 1, 1, 2]> : tensor<4xi64>}> : (tensor<1x8x4x2xf32>, tensor<1x8x3xi32>) -> tensor<1x8x2xf32>
"func.return"(%0) : (tensor<1x8x2xf32>) -> ()
}) : () -> ()
}) : () -> ()
** (RuntimeError) <unknown>:0: error: 'mhlo.gather' op inferred type(s) 'tensor<2x1x8xf32>' are incompatible with return type(s) of operation 'tensor<1x8x2xf32>'
<unknown>:0: error: 'mhlo.gather' op failed to infer returned types
<unknown>:0: note: see current operation: %0 = "mhlo.gather"(%arg0, %arg1) <{dimension_numbers = #mhlo.gather<offset_dims = [0], collapsed_slice_dims = [0, 1, 2], start_index_map = [0, 1, 2], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = dense<[1, 1, 1, 2]> : tensor<4xi64>}> : (tensor<1x8x4x2xf32>, tensor<1x8x3xi32>) -> tensor<1x8x2xf32>
(exla 0.9.2) lib/exla/mlir/module.ex:147: EXLA.MLIR.Module.unwrap!/1
(exla 0.9.2) lib/exla/mlir/module.ex:124: EXLA.MLIR.Module.compile/5
(stdlib 6.2) timer.erl:595: :timer.tc/2
(exla 0.9.2) lib/exla/defn.ex:432: anonymous fn/14 in EXLA.Defn.compile/8
(exla 0.9.2) lib/exla/mlir/context_pool.ex:10: anonymous fn/3 in EXLA.MLIR.ContextPool.checkout/1
(nimble_pool 1.1.0) lib/nimble_pool.ex:462: NimblePool.checkout!/4
(exla 0.9.2) lib/exla/defn/locked_cache.ex:36: EXLA.Defn.LockedCache.run/2
(stdlib 6.2) timer.erl:595: :timer.tc/2
(exla 0.9.2) lib/exla/defn.ex:383: anonymous fn/15 in EXLA.Defn.compile/8
(exla 0.9.2) lib/exla/defn.ex:229: EXLA.Defn.__compile__/4
(exla 0.9.2) lib/exla/defn.ex:219: EXLA.Defn.__jit__/5
(nx 0.9.2) lib/nx/defn.ex:452: Nx.Defn.do_jit_apply/3
(nx 0.9.2) lib/nx/defn/evaluator.ex:441: Nx.Defn.Evaluator.eval_apply/4
(nx 0.9.2) lib/nx/defn/evaluator.ex:256: Nx.Defn.Evaluator.eval/3
(nx 0.9.2) lib/nx/defn/evaluator.ex:70: anonymous fn/5 in Nx.Defn.Evaluator.__compile__/4
(nx 0.9.2) lib/nx/defn.ex:452: Nx.Defn.do_jit_apply/3
iex:5: (file)