Nx: different behavior between backends

I wrote a small project to get familiar with tensor operations by calculating fractal noise at each coordinate of a grid. However, I noticed it was too slow; for example, computing the noise for a 100x100 grid took over a minute. After investigating, I realized I was using Nx.BinaryBackend, which is pure Elixir.

The problem is that when I switch to the EXLA back-end, the same code stops working. Specifically, Nx.gather no longer functions as expected. My question is: Should Nx behave the same regardless of the back-end?

P.S.: I managed to install ROCm on openSUSE, but I couldn’t compile XLA for Nx, so I went back to using the precompiled binaries for EXLA.

Different backends may return different values, but in the general sense all of them should work. Can you provide more information? What exactly is different?

1 Like

Hey jose, thank you very much for your time!

Forgive me, but I’m pretty sure I didn’t express myself correctly in my initial post. After setting the backend to EXLA, when I use Nx.gather in my code, it results in an error.

Basically, I have a tensor of gradients:

#Nx.Tensor<
  f32[4][2]
  [
    [1.0, 1.0],
    [-1.0, 1.0],
    [1.0, -1.0],
    [-1.0, -1.0]
  ]
>

And a tensor (i) with “random” indices for each octave:

#Nx.Tensor<
  vectorized[x: 1][octaves: 8]
  u16
  [
    [0, 2, 3, 2, 2, 2, 2, 1]
  ]
>

What I need is to generate a tensor replacing each index with the corresponding gradient:

Nx.gather(gradients, Nx.reshape(i, {1}))

With Nx.BinaryBackend, I get what I need:

#Nx.Tensor<
  vectorized[x: 1][octaves: 8]
  f32[2]
  [
    [
      [1.0, 1.0],
      [1.0, -1.0],
      [-1.0, -1.0],
      [1.0, -1.0],
      [1.0, -1.0],
      [1.0, -1.0],
      [1.0, -1.0],
      [-1.0, 1.0]
    ]
  ]
>

But when I set the default_backend to EXLA.Backend, the same code won’t work:

#Nx.Tensor<
  f32[4][2]
  EXLA.Backend<host:0, 0.3185043492.1851392032.250556>
  [
    [1.0, 1.0],
    [-1.0, 1.0],
    [1.0, -1.0],
    [-1.0, -1.0]
  ]
>

#Nx.Tensor<
  vectorized[x: 1][octaves: 8]
  u16
  EXLA.Backend<host:0, 0.3185043492.1851392024.250510>
  [
    [0, 2, 3, 2, 2, 2, 2, 1]
  ]
>

Then executing:

Nx.gather(gradients, Nx.reshape(i, {1}))

Results in the following error:

"builtin.module"() ({
    "func.func"() <{function_type = (tensor<1x8x4x2xf32>, tensor<1x8x3xi32>) -> tensor<1x8x2xf32>, sym_name = "main", sym_visibility = "public"}> ({
        ^bb0(%arg0: tensor<1x8x4x2xf32>, %arg1: tensor<1x8x3xi32>):
        %0 = "mhlo.gather"(%arg0, %arg1) <{dimension_numbers = #mhlo.gather<offset_dims = [0], collapsed_slice_dims = [0, 1, 2], start_index_map = [0, 1, 2], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = dense<[1, 1, 1, 2]> : tensor<4xi64>}> : (tensor<1x8x4x2xf32>, tensor<1x8x3xi32>) -> tensor<1x8x2xf32>
        "func.return"(%0) : (tensor<1x8x2xf32>) -> ()
    }) : () -> ()
}) : () -> ()

** (RuntimeError) <unknown>:0: error: 'mhlo.gather' op inferred type(s) 'tensor<2x1x8xf32>' are incompatible with return type(s) of operation 'tensor<1x8x2xf32>'
<unknown>:0: error: 'mhlo.gather' op failed to infer returned types
<unknown>:0: note: see current operation: %0 = "mhlo.gather"(%arg0, %arg1) <{dimension_numbers = #mhlo.gather<offset_dims = [0], collapsed_slice_dims = [0, 1, 2], start_index_map = [0, 1, 2], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = dense<[1, 1, 1, 2]> : tensor<4xi64>}> : (tensor<1x8x4x2xf32>, tensor<1x8x3xi32>) -> tensor<1x8x2xf32>

    (exla 0.9.2) lib/exla/mlir/module.ex:147: EXLA.MLIR.Module.unwrap!/1
    (exla 0.9.2) lib/exla/mlir/module.ex:124: EXLA.MLIR.Module.compile/5
    (stdlib 6.2) timer.erl:595: :timer.tc/2
    (exla 0.9.2) lib/exla/defn.ex:432: anonymous fn/14 in EXLA.Defn.compile/8
    (exla 0.9.2) lib/exla/mlir/context_pool.ex:10: anonymous fn/3 in EXLA.MLIR.ContextPool.checkout/1
    (nimble_pool 1.1.0) lib/nimble_pool.ex:462: NimblePool.checkout!/4
    (exla 0.9.2) lib/exla/defn/locked_cache.ex:36: EXLA.Defn.LockedCache.run/2
    (stdlib 6.2) timer.erl:595: :timer.tc/2
    (exla 0.9.2) lib/exla/defn.ex:383: anonymous fn/15 in EXLA.Defn.compile/8
    (exla 0.9.2) lib/exla/defn.ex:229: EXLA.Defn.__compile__/4
    (exla 0.9.2) lib/exla/defn.ex:219: EXLA.Defn.__jit__/5
    (nx 0.9.2) lib/nx/defn.ex:452: Nx.Defn.do_jit_apply/3
    (nx 0.9.2) lib/nx/defn/evaluator.ex:441: Nx.Defn.Evaluator.eval_apply/4
    (nx 0.9.2) lib/nx/defn/evaluator.ex:256: Nx.Defn.Evaluator.eval/3
    (nx 0.9.2) lib/nx/defn/evaluator.ex:70: anonymous fn/5 in Nx.Defn.Evaluator.__compile__/4
    (nx 0.9.2) lib/nx/defn.ex:452: Nx.Defn.do_jit_apply/3
    iex:5: (file)

This is 100% bug. Either the operation is invalid and we should generate a better exception message or we should make it work. Please open up a bug report!

Sure thing! I’ve just opened the issue here:
https://github.com/elixir-nx/nx/issues/1593