Skip to content

Commit a3853c9

Browse files
committed
[BUG] Improve memory management by adding garbage collection and cleanup tasks in evaluators
1 parent e2acfa7 commit a3853c9

File tree

2 files changed

+36
-7
lines changed

2 files changed

+36
-7
lines changed

gempy_engine/core/backend_tensor.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,14 @@ def get_backend_string(cls) -> str:
4646
return "CPU"
4747

4848
@classmethod
49-
def change_backend_gempy(cls, engine_backend: AvailableBackends, use_gpu: bool = False, dtype: Optional[str] = None):
50-
cls._change_backend(engine_backend, use_pykeops=PYKEOPS, use_gpu=use_gpu, dtype=dtype)
49+
def change_backend_gempy(cls, engine_backend: AvailableBackends, use_gpu: bool = False,
50+
dtype: Optional[str] = None, grads:bool = False):
51+
cls._change_backend(engine_backend, use_pykeops=PYKEOPS, use_gpu=use_gpu, dtype=dtype,
52+
grads=grads)
5153

5254
@classmethod
53-
def _change_backend(cls, engine_backend: AvailableBackends, use_pykeops: bool = False, use_gpu: bool = True, dtype: Optional[str] = None):
55+
def _change_backend(cls, engine_backend: AvailableBackends, use_pykeops: bool = False,
56+
use_gpu: bool = True, dtype: Optional[str] = None, grads:bool = False):
5457
cls.dtype = DEFAULT_TENSOR_DTYPE if dtype is None else dtype
5558
cls.dtype_obj = cls.dtype
5659
match engine_backend:
@@ -99,6 +102,21 @@ def _change_backend(cls, engine_backend: AvailableBackends, use_pykeops: bool =
99102
cls.dtype_obj = pytorch_copy.float32 if cls.dtype == "float32" else pytorch_copy.float64
100103
cls.tensor_types = pytorch_copy.Tensor
101104

105+
torch.set_num_threads(torch.get_num_threads()) # Use all available threads
106+
cls.COMPUTE_GRADS = grads # Store the grads setting
107+
if grads is False:
108+
cls._torch_no_grad_context = torch.no_grad()
109+
cls._torch_no_grad_context.__enter__()
110+
else:
111+
# If there was a previous context, exit it first
112+
if hasattr(cls, '_torch_no_grad_context') and cls._torch_no_grad_context is not None:
113+
try:
114+
cls._torch_no_grad_context.__exit__(None, None, None)
115+
except:
116+
pass # Context might already be exited
117+
cls._torch_no_grad_context = None
118+
torch.set_grad_enabled(True)
119+
102120
cls.use_pykeops = use_pykeops # TODO: Make this compatible with pykeops
103121
if (use_pykeops):
104122
import pykeops

gempy_engine/modules/evaluator/generic_evaluator.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import gc
23
from typing import Optional
34

45
from gempy_engine.core.backend_tensor import BackendTensor
@@ -57,6 +58,10 @@ def generic_evaluator(
5758
if gz_field is not None:
5859
gz_field[slice_array] = gz_chunk # type: ignore
5960

61+
# Force garbage collection every few chunks to prevent memory buildup
62+
if (i + 1) % 5 == 0 or i == n_chunks - 1:
63+
gc.collect()
64+
6065
if n_chunks > 5:
6166
print(f"Chunking done: {n_chunks} chunks")
6267

@@ -75,7 +80,9 @@ def _eval_on(
7580
try:
7681
scalar_field = (eval_kernel.T @ weights).reshape(-1)
7782
except ValueError:
78-
pass
83+
scalar_field = None
84+
85+
del eval_kernel
7986

8087
gx_field: Optional[np.ndarray] = None
8188
gy_field: Optional[np.ndarray] = None
@@ -85,17 +92,21 @@ def _eval_on(
8592
eval_gx = yield_evaluation_grad_kernel(
8693
solver_input, options.kernel_options, axis=0, slice_array=slice_array
8794
)
95+
gx_field = (eval_gx.T @ weights).reshape(-1) # Use BEFORE deleting
96+
del eval_gx # Clean up immediately after use
97+
8898
eval_gy = yield_evaluation_grad_kernel(
8999
solver_input, options.kernel_options, axis=1, slice_array=slice_array
90100
)
91-
gx_field = (eval_gx.T @ weights).reshape(-1)
92-
gy_field = (eval_gy.T @ weights).reshape(-1)
101+
gy_field = (eval_gy.T @ weights).reshape(-1) # Use BEFORE deleting
102+
del eval_gy # Clean up immediately after use
93103

94104
if options.number_dimensions == 3:
95105
eval_gz = yield_evaluation_grad_kernel(
96106
solver_input, options.kernel_options, axis=2, slice_array=slice_array
97107
)
98-
gz_field = (eval_gz.T @ weights).reshape(-1)
108+
gz_field = (eval_gz.T @ weights).reshape(-1) # Use BEFORE deleting
109+
del eval_gz # Clean up immediately after use
99110
elif options.number_dimensions != 2:
100111
raise ValueError("`number_dimensions` must be 2 or 3")
101112

0 commit comments

Comments
 (0)