XLA runs out of memory during inference

Hi, I am trying to evaluate trained model by calling Axon.build and then calling resulting predict_fn function multiple times (once for each evaluation step). Each step involves generating predictions for some number of objects and according to the documentation looks like this:

prediction = predict_fn.(state, tensor)

After the predictions are generated I update the metrics and generate predictions again for other objects. The problem is that after ~ 1800 out of ~ 10 000 evaluation steps the program crashes because xla runs out of memory. Apparently the problem is that my data is witten to the GPU, but afterwards the memory required for predictions computations is not released. I’ve already tried to set preallocate to false in config, but that doesn’t help though it is more clearn now that there is a memory leak. What should I do in this case, is there any solution?

My setup:

Erlang/OTP 24 [erts-12.2.1] [source] [64-bit] [smp:12:12] [ds:12:12:10] [async-threads:1] [jit]
Elixir 1.14.1 (compiled with Erlang/OTP 24)

:axon, "~> 0.2"
:nx, "~> 0.2"
:exla, "~> 0.2"

The error message:

** (RuntimeError) Out of memory while trying to allocate 31129600 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   15.61MiB
              constant allocation:         0B
        maybe_live_out allocation:   14.84MiB
     preallocated temp allocation:   29.69MiB
  preallocated temp fragmentation:         0B (0.00%)
                 total allocation:   60.15MiB
              total fragmentation:         8B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 29.69MiB
		XLA Label: fusion
		Shape: f32[38,1024,2,100]
		==========================

	Buffer 2:
		Size: 14.84MiB
		XLA Label: fusion
		Shape: f32[38,1024,1,100]
		==========================

	Buffer 3:
		Size: 14.72MiB
		Entry Parameter Subshape: f32[38588,100]
		==========================

	Buffer 4:
		Size: 608.0KiB
		Entry Parameter Subshape: s64[38,1024,2]
		==========================

	Buffer 5:
		Size: 304.0KiB
		Entry Parameter Subshape: s64[38,1024,1]
		==========================

	Buffer 6:
		Size: 4.3KiB
		Entry Parameter Subshape: f32[11,100]
		==========================

	Buffer 7:
		Size: 8B
		XLA Label: tuple
		Shape: (f32[38,1024,1,100])
		==========================


    (exla 0.4.0) lib/exla/executable.ex:56: EXLA.Executable.unwrap!/1
    (exla 0.4.0) lib/exla/executable.ex:19: EXLA.Executable.run/3
    (exla 0.4.0) lib/exla/defn.ex:308: EXLA.Defn.maybe_outfeed/7
    (nx 0.4.0) lib/nx/defn.ex:442: Nx.Defn.do_jit_apply/3
    (grapex 0.1.0) lib/grapex/models/testers/entity_based.ex:26: Grapex.Models.Testers.EntityBased.generate_predictions_for_testing/5
    (grapex 0.1.0) lib/grapex/models/testers/entity_based.ex:84: Grapex.Models.Testers.EntityBased.test_one_triple/6
    (grapex 0.1.0) lib/grapex/models/testers/entity_based.ex:114: Grapex.Models.Testers.EntityBased.evaluate/3

can you try running exla from main? an exla OOM issue was recently fixed Refactor EXLA execution and fix a memory leak (#993) · elixir-nx/nx@45fa92b · GitHub

1 Like

Unfortunately version of exla from main also crashes with Segmentation fault at the very beginning, so the model fails to even train:

18:27:42.611 [info] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero

18:27:42.617 [info] XLA service 0x7f82a81674a0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:

18:27:42.617 [info]   StreamExecutor device (0): NVIDIA GeForce GTX 1650, Compute Capability 7.5

18:27:42.617 [info] Using BFC allocator.

18:27:42.617 [info] XLA backend will use up to 399468134 bytes on device 0 for BFCAllocator.
Segmentation fault (core dumped)

Ubuntu prints the following stacktrace in the crash report:

 #0  0x00007fa045ae0111 in xla::PjRtStreamExecutorDevice::GetLocalDeviceState() const () from /home/zeio/grapex/_build/dev/lib/exla/priv/lib/libxla_extension.so
 No symbol table info available.
 #1  0x00007fa045afdbf7 in xla::PjRtStreamExecutorClient::BufferFromHostBuffer(void const*, xla::PrimitiveType, absl::lts_20220623::Span<long const>, std::optional<absl::lts_20220623::Span<long const> >, xla::PjRtClient::HostBufferSemantics, std::function<void ()>, xla::PjRtDevice*) () from /home/zeio/grapex/_build/dev/lib/exla/priv/lib/libxla_extension.so
 No symbol table info available.
 #2  0x00007fa0ec06dc73 in exla::PjRtBufferFromBinary(xla::PjRtClient*, enif_environment_t*, unsigned long, xla::Shape const&, int) () from /home/zeio/grapex/_build/dev/lib/exla/priv/libexla.so
 No symbol table info available.
 #3  0x00007fa0ec06f28e in exla::UnpackReplicaArguments(enif_environment_t*, unsigned long, exla::ExlaClient*, int) () from /home/zeio/grapex/_build/dev/lib/exla/priv/libexla.so
 No symbol table info available.
 #4  0x00007fa0ec06fa25 in exla::UnpackRunArguments(enif_environment_t*, unsigned long, exla::ExlaClient*, xla::DeviceAssignment, int) () from /home/zeio/grapex/_build/dev/lib/exla/priv/libexla.so
 No symbol table info available.
 #5  0x00007fa0ec0709ff in exla::ExlaExecutable::Run(enif_environment_t*, unsigned long, int) () from /home/zeio/grapex/_build/dev/lib/exla/priv/libexla.so
 No symbol table info available.
 #6  0x00007fa0ec05c8ad in run(enif_environment_t*, int, unsigned long const*) () from /home/zeio/grapex/_build/dev/lib/exla/priv/libexla.so
 No symbol table info available.
 #7  0x000055dd9645e15d in erts_call_dirty_nif ()
 No symbol table info available.
 #8  0x000055dd9631da16 in erts_dirty_process_main ()
 No symbol table info available.
 #9  0x000055dd962a1944 in ?? ()
 No symbol table info available.
 #10 0x000055dd96513840 in ?? ()
 No symbol table info available.
 #11 0x00007fa138217b43 in start_thread (arg=<optimized out>) at ./nptl/pthread_create.c:442
         ret = <optimized out>
         pd = <optimized out>
         out = <optimized out>
         unwind_buf = {cancel_jmp_buf = {{jmp_buf = {140734819911200, -6401769154171084895, 140329176364608, 0, 140330408179792, 140734819911552, 6370770162799632289, 6370949280816324513}, mask_was_saved = 0}}, priv = {pad = {0x0, 0x0, 0x0, 0x0}, data = {prev = 0x0, cleanup = 0x0, canceltype = 0}}}
         not_first_call = <optimized out>
 #12 0x00007fa1382a9a00 in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:81
 No locals.

Created an issue on github

1 Like