Hi, I’m trying to get my head around Nx and it’s configuration to run on CUDA.
# Configures Nx default backend and default defn options
config :nx, :default_backend, EXLA.Backend
config :nx, :default_defn_options, compiler: EXLA
I also set the XLA_TARGET
environment variable to cuda118
both when building and running my app.
From the following output, everything looks fine (at least to me ).
iex> System.get_env "XLA_TARGET"
"cuda118"
iex> Application.get_all_env :nx
[default_backend: EXLA.Backend, default_defn_options: [compiler: EXLA]]
iex> EXLA.Client.get_supported_platforms
%{host: 4, cuda: 1, interpreter: 1}
iex> EXLA.Client.default_name
:cuda
From the EXLA.Client
outputs, it seems the default client should be :cuda
, it’s also listed in the supported platforms.
When running a serving with this config, I would assume that it picks my default client. But it’s not the case:
{:ok, model_info} =
Bumblebee.load_model({:hf, "patrickjohncyh/fashion-clip"},
module: Bumblebee.Text.ClipText,
architecture: :for_embedding
)
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "patrickjohncyh/fashion-clip"})
serving = Bumblebee.Text.text_embedding(model_info, tokenizer, output_attribute: :embedding)
%Nx.Serving{
module: Nx.Serving.Default,
arg: #Function<1.26226188/2 in Bumblebee.Text.TextEmbedding.text_embedding/3>,
client_preprocessing: #Function<2.26226188/1 in Bumblebee.Text.TextEmbedding.text_embedding/3>,
client_postprocessing: #Function<3.26226188/2 in Bumblebee.Text.TextEmbedding.text_embedding/3>,
streaming: nil,
batch_size: nil,
distributed_postprocessing: &Function.identity/1,
process_options: [batch_keys: [:default]],
defn_options: [compiler: EXLA]
}
The following code does run the serving on the GPU but on the CPU.
Nx.run serving, "a pair of white Nike Air Max shoes in white"
It seems I have to explicitly set client: :cuda
for my serving to run on the GPU. Either by configuring the :defn_options
for my serving or setting it globally.
%{serving | defn_options: [compiler: EXLA, client: :cuda]}
# or
Nx.Defn.global_default_options [compiler: EXLA, client: :cuda]
From the EXLA.Backend
docs I was assuming that servings would use the default XLA client.