Nx default backend does not use CUDA even when it's set as default client

:wave:

Hi, I’m trying to get my head around Nx and it’s configuration to run on CUDA.

# Configures Nx default backend and default defn options
config :nx, :default_backend, EXLA.Backend
config :nx, :default_defn_options, compiler: EXLA

I also set the XLA_TARGET environment variable to cuda118 both when building and running my app.

From the following output, everything looks fine (at least to me :sweat_smile:).

iex> System.get_env "XLA_TARGET"
"cuda118"
iex> Application.get_all_env :nx
[default_backend: EXLA.Backend, default_defn_options: [compiler: EXLA]]
iex> EXLA.Client.get_supported_platforms
%{host: 4, cuda: 1, interpreter: 1}
iex> EXLA.Client.default_name
:cuda

From the EXLA.Client outputs, it seems the default client should be :cuda, it’s also listed in the supported platforms.

When running a serving with this config, I would assume that it picks my default client. But it’s not the case:

{:ok, model_info} =
  Bumblebee.load_model({:hf, "patrickjohncyh/fashion-clip"},
    module: Bumblebee.Text.ClipText,
    architecture: :for_embedding
  )

{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "patrickjohncyh/fashion-clip"})

serving =  Bumblebee.Text.text_embedding(model_info, tokenizer, output_attribute: :embedding)
%Nx.Serving{
  module: Nx.Serving.Default,
  arg: #Function<1.26226188/2 in Bumblebee.Text.TextEmbedding.text_embedding/3>,
  client_preprocessing: #Function<2.26226188/1 in Bumblebee.Text.TextEmbedding.text_embedding/3>,
  client_postprocessing: #Function<3.26226188/2 in Bumblebee.Text.TextEmbedding.text_embedding/3>,
  streaming: nil,
  batch_size: nil,
  distributed_postprocessing: &Function.identity/1,
  process_options: [batch_keys: [:default]],
  defn_options: [compiler: EXLA]
}

The following code does run the serving on the GPU but on the CPU.

Nx.run serving, "a pair of white Nike Air Max shoes in white"

It seems I have to explicitly set client: :cuda for my serving to run on the GPU. Either by configuring the :defn_options for my serving or setting it globally.

%{serving | defn_options: [compiler: EXLA, client: :cuda]}
# or
Nx.Defn.global_default_options [compiler: EXLA, client: :cuda]

From the EXLA.Backend docs I was assuming that servings would use the default XLA client.