Hi, I am trying to evaluate trained model by calling Axon.build
and then calling resulting predict_fn
function multiple times (once for each evaluation step). Each step involves generating predictions for some number of objects and according to the documentation looks like this:
prediction = predict_fn.(state, tensor)
After the predictions are generated I update the metrics and generate predictions again for other objects. The problem is that after ~ 1800 out of ~ 10 000 evaluation steps the program crashes because xla runs out of memory. Apparently the problem is that my data is witten to the GPU, but afterwards the memory required for predictions computations is not released. I’ve already tried to set preallocate
to false
in config, but that doesn’t help though it is more clearn now that there is a memory leak. What should I do in this case, is there any solution?
My setup:
Erlang/OTP 24 [erts-12.2.1] [source] [64-bit] [smp:12:12] [ds:12:12:10] [async-threads:1] [jit]
Elixir 1.14.1 (compiled with Erlang/OTP 24)
:axon, "~> 0.2"
:nx, "~> 0.2"
:exla, "~> 0.2"
The error message:
** (RuntimeError) Out of memory while trying to allocate 31129600 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 15.61MiB
constant allocation: 0B
maybe_live_out allocation: 14.84MiB
preallocated temp allocation: 29.69MiB
preallocated temp fragmentation: 0B (0.00%)
total allocation: 60.15MiB
total fragmentation: 8B (0.00%)
Peak buffers:
Buffer 1:
Size: 29.69MiB
XLA Label: fusion
Shape: f32[38,1024,2,100]
==========================
Buffer 2:
Size: 14.84MiB
XLA Label: fusion
Shape: f32[38,1024,1,100]
==========================
Buffer 3:
Size: 14.72MiB
Entry Parameter Subshape: f32[38588,100]
==========================
Buffer 4:
Size: 608.0KiB
Entry Parameter Subshape: s64[38,1024,2]
==========================
Buffer 5:
Size: 304.0KiB
Entry Parameter Subshape: s64[38,1024,1]
==========================
Buffer 6:
Size: 4.3KiB
Entry Parameter Subshape: f32[11,100]
==========================
Buffer 7:
Size: 8B
XLA Label: tuple
Shape: (f32[38,1024,1,100])
==========================
(exla 0.4.0) lib/exla/executable.ex:56: EXLA.Executable.unwrap!/1
(exla 0.4.0) lib/exla/executable.ex:19: EXLA.Executable.run/3
(exla 0.4.0) lib/exla/defn.ex:308: EXLA.Defn.maybe_outfeed/7
(nx 0.4.0) lib/nx/defn.ex:442: Nx.Defn.do_jit_apply/3
(grapex 0.1.0) lib/grapex/models/testers/entity_based.ex:26: Grapex.Models.Testers.EntityBased.generate_predictions_for_testing/5
(grapex 0.1.0) lib/grapex/models/testers/entity_based.ex:84: Grapex.Models.Testers.EntityBased.test_one_triple/6
(grapex 0.1.0) lib/grapex/models/testers/entity_based.ex:114: Grapex.Models.Testers.EntityBased.evaluate/3