Hi, I am trying to evaluate trained model by calling
Axon.build and then calling resulting
predict_fn function multiple times (once for each evaluation step). Each step involves generating predictions for some number of objects and according to the documentation looks like this:
prediction = predict_fn.(state, tensor)
After the predictions are generated I update the metrics and generate predictions again for other objects. The problem is that after ~ 1800 out of ~ 10 000 evaluation steps the program crashes because xla runs out of memory. Apparently the problem is that my data is witten to the GPU, but afterwards the memory required for predictions computations is not released. I’ve already tried to set
false in config, but that doesn’t help though it is more clearn now that there is a memory leak. What should I do in this case, is there any solution?
Erlang/OTP 24 [erts-12.2.1] [source] [64-bit] [smp:12:12] [ds:12:12:10] [async-threads:1] [jit] Elixir 1.14.1 (compiled with Erlang/OTP 24) :axon, "~> 0.2" :nx, "~> 0.2" :exla, "~> 0.2"
The error message:
** (RuntimeError) Out of memory while trying to allocate 31129600 bytes. BufferAssignment OOM Debugging. BufferAssignment stats: parameter allocation: 15.61MiB constant allocation: 0B maybe_live_out allocation: 14.84MiB preallocated temp allocation: 29.69MiB preallocated temp fragmentation: 0B (0.00%) total allocation: 60.15MiB total fragmentation: 8B (0.00%) Peak buffers: Buffer 1: Size: 29.69MiB XLA Label: fusion Shape: f32[38,1024,2,100] ========================== Buffer 2: Size: 14.84MiB XLA Label: fusion Shape: f32[38,1024,1,100] ========================== Buffer 3: Size: 14.72MiB Entry Parameter Subshape: f32[38588,100] ========================== Buffer 4: Size: 608.0KiB Entry Parameter Subshape: s64[38,1024,2] ========================== Buffer 5: Size: 304.0KiB Entry Parameter Subshape: s64[38,1024,1] ========================== Buffer 6: Size: 4.3KiB Entry Parameter Subshape: f32[11,100] ========================== Buffer 7: Size: 8B XLA Label: tuple Shape: (f32[38,1024,1,100]) ========================== (exla 0.4.0) lib/exla/executable.ex:56: EXLA.Executable.unwrap!/1 (exla 0.4.0) lib/exla/executable.ex:19: EXLA.Executable.run/3 (exla 0.4.0) lib/exla/defn.ex:308: EXLA.Defn.maybe_outfeed/7 (nx 0.4.0) lib/nx/defn.ex:442: Nx.Defn.do_jit_apply/3 (grapex 0.1.0) lib/grapex/models/testers/entity_based.ex:26: Grapex.Models.Testers.EntityBased.generate_predictions_for_testing/5 (grapex 0.1.0) lib/grapex/models/testers/entity_based.ex:84: Grapex.Models.Testers.EntityBased.test_one_triple/6 (grapex 0.1.0) lib/grapex/models/testers/entity_based.ex:114: Grapex.Models.Testers.EntityBased.evaluate/3