Skip to content

Commit 271a6f8

Browse files
authored
[ENH] Add gradient computation control and improve memory management (#31)
### TL;DR Added memory optimization features to improve performance when evaluating large models. ### What changed? - Added a `grads` parameter to the backend tensor configuration to control PyTorch gradient computation - Implemented proper context management for PyTorch's `no_grad()` mode - Added garbage collection calls during chunked evaluation to prevent memory buildup - Optimized memory usage in the evaluation kernel by immediately deleting tensors after use - Improved error handling in the evaluation function ### How to test? 1. Test with large models that previously caused memory issues 2. Compare memory usage before and after these changes 3. Verify that model evaluation still produces correct results 4. Test with both gradient computation enabled and disabled ### Why make this change? These optimizations address memory leaks and excessive memory usage during model evaluation, particularly for large models. By properly managing PyTorch's gradient computation and implementing strategic garbage collection, we can significantly reduce memory footprint without sacrificing performance. The immediate cleanup of tensors after use prevents memory buildup during evaluation of large datasets.
2 parents e2acfa7 + a3853c9 commit 271a6f8

File tree

2 files changed

+36
-7
lines changed

2 files changed

+36
-7
lines changed

gempy_engine/core/backend_tensor.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,14 @@ def get_backend_string(cls) -> str:
4646
return "CPU"
4747

4848
@classmethod
49-
def change_backend_gempy(cls, engine_backend: AvailableBackends, use_gpu: bool = False, dtype: Optional[str] = None):
50-
cls._change_backend(engine_backend, use_pykeops=PYKEOPS, use_gpu=use_gpu, dtype=dtype)
49+
def change_backend_gempy(cls, engine_backend: AvailableBackends, use_gpu: bool = False,
50+
dtype: Optional[str] = None, grads:bool = False):
51+
cls._change_backend(engine_backend, use_pykeops=PYKEOPS, use_gpu=use_gpu, dtype=dtype,
52+
grads=grads)
5153

5254
@classmethod
53-
def _change_backend(cls, engine_backend: AvailableBackends, use_pykeops: bool = False, use_gpu: bool = True, dtype: Optional[str] = None):
55+
def _change_backend(cls, engine_backend: AvailableBackends, use_pykeops: bool = False,
56+
use_gpu: bool = True, dtype: Optional[str] = None, grads:bool = False):
5457
cls.dtype = DEFAULT_TENSOR_DTYPE if dtype is None else dtype
5558
cls.dtype_obj = cls.dtype
5659
match engine_backend:
@@ -99,6 +102,21 @@ def _change_backend(cls, engine_backend: AvailableBackends, use_pykeops: bool =
99102
cls.dtype_obj = pytorch_copy.float32 if cls.dtype == "float32" else pytorch_copy.float64
100103
cls.tensor_types = pytorch_copy.Tensor
101104

105+
torch.set_num_threads(torch.get_num_threads()) # Use all available threads
106+
cls.COMPUTE_GRADS = grads # Store the grads setting
107+
if grads is False:
108+
cls._torch_no_grad_context = torch.no_grad()
109+
cls._torch_no_grad_context.__enter__()
110+
else:
111+
# If there was a previous context, exit it first
112+
if hasattr(cls, '_torch_no_grad_context') and cls._torch_no_grad_context is not None:
113+
try:
114+
cls._torch_no_grad_context.__exit__(None, None, None)
115+
except:
116+
pass # Context might already be exited
117+
cls._torch_no_grad_context = None
118+
torch.set_grad_enabled(True)
119+
102120
cls.use_pykeops = use_pykeops # TODO: Make this compatible with pykeops
103121
if (use_pykeops):
104122
import pykeops

gempy_engine/modules/evaluator/generic_evaluator.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import gc
23
from typing import Optional
34

45
from gempy_engine.core.backend_tensor import BackendTensor
@@ -57,6 +58,10 @@ def generic_evaluator(
5758
if gz_field is not None:
5859
gz_field[slice_array] = gz_chunk # type: ignore
5960

61+
# Force garbage collection every few chunks to prevent memory buildup
62+
if (i + 1) % 5 == 0 or i == n_chunks - 1:
63+
gc.collect()
64+
6065
if n_chunks > 5:
6166
print(f"Chunking done: {n_chunks} chunks")
6267

@@ -75,7 +80,9 @@ def _eval_on(
7580
try:
7681
scalar_field = (eval_kernel.T @ weights).reshape(-1)
7782
except ValueError:
78-
pass
83+
scalar_field = None
84+
85+
del eval_kernel
7986

8087
gx_field: Optional[np.ndarray] = None
8188
gy_field: Optional[np.ndarray] = None
@@ -85,17 +92,21 @@ def _eval_on(
8592
eval_gx = yield_evaluation_grad_kernel(
8693
solver_input, options.kernel_options, axis=0, slice_array=slice_array
8794
)
95+
gx_field = (eval_gx.T @ weights).reshape(-1) # Use BEFORE deleting
96+
del eval_gx # Clean up immediately after use
97+
8898
eval_gy = yield_evaluation_grad_kernel(
8999
solver_input, options.kernel_options, axis=1, slice_array=slice_array
90100
)
91-
gx_field = (eval_gx.T @ weights).reshape(-1)
92-
gy_field = (eval_gy.T @ weights).reshape(-1)
101+
gy_field = (eval_gy.T @ weights).reshape(-1) # Use BEFORE deleting
102+
del eval_gy # Clean up immediately after use
93103

94104
if options.number_dimensions == 3:
95105
eval_gz = yield_evaluation_grad_kernel(
96106
solver_input, options.kernel_options, axis=2, slice_array=slice_array
97107
)
98-
gz_field = (eval_gz.T @ weights).reshape(-1)
108+
gz_field = (eval_gz.T @ weights).reshape(-1) # Use BEFORE deleting
109+
del eval_gz # Clean up immediately after use
99110
elif options.number_dimensions != 2:
100111
raise ValueError("`number_dimensions` must be 2 or 3")
101112

0 commit comments

Comments
 (0)