Skip to content

Commit 0d53085

Browse files
authored
[ENH] Improve CG solver and add weights caching for better performance (#17)
# Enhanced Conjugate Gradient Solver with Improved Stability and GPU Support This PR significantly improves the robustness and performance of our solver system with several key enhancements: ## Solver Improvements - Completely refactored the `ConjugateGradientSolver` with enhanced stability features for ill-conditioned matrices - Added adaptive tolerance, regularization support, and robust convergence criteria - Implemented residual monitoring and restart capability for better convergence - Added detailed diagnostics and warnings for numerical issues ## GPU Acceleration - Added GPU support for PyKeOps solver in the torch backend - Properly configured backend tensor handling for GPU operations - Fixed device selection logic in solver interface ## Caching and Weights Management - Added ability to pass initial weights through the interpolation input - Implemented `weights_x0` in solver input for warm starts - Added `clear_cache` method to `WeightCache` for better memory management ## Numerical Stability - Fixed condition number computation and made plotting optional - Properly store condition number in kernel options during operations - Added safeguards against numerical breakdown in denominator calculations - Implemented stagnation detection for ill-conditioned systems ## Experimental Features - Added Nyström preconditioner implementation (currently disabled by default) - Included configurable pivoting strategies for matrix approximation These changes significantly improve the solver's ability to handle challenging numerical problems while maintaining performance.
2 parents 99c9d7e + 1a2c04b commit 0d53085

File tree

10 files changed

+774
-84
lines changed

10 files changed

+774
-84
lines changed

gempy_engine/API/interp_single/_interp_scalar_field.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def interpolate_scalar_field(solver_input: SolverInput, options: InterpolationOp
4747

4848
match weights_cached:
4949
case None:
50+
foo = solver_input.weights_x0
5051
weights = _solve_and_store_weights(
5152
solver_input=solver_input,
5253
kernel_options=options.kernel_options,
@@ -87,7 +88,6 @@ def _solve_interpolation(interp_input: SolverInput, kernel_options: KernelOption
8788
if kernel_options.optimizing_condition_number:
8889
_optimize_nuggets_against_condition_number(A_matrix, interp_input, kernel_options)
8990

90-
# TODO: Smooth should be taken from options
9191
weights = solver_interface.kernel_reduction(
9292
cov=A_matrix,
9393
b=b_vector,

gempy_engine/API/interp_single/_interp_single_feature.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def input_preprocess(data_shape: TensorsStructure, interpolation_input: Interpol
108108
xyz_to_interpolate=grid_internal,
109109
fault_internal=fault_values
110110
)
111+
solver_input.weights_x0 = interpolation_input.weights
111112

112113
return solver_input
113114

gempy_engine/core/backend_tensor.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
if is_pytorch_installed:
1414
import torch
15+
16+
PYKEOPS= DEFAULT_PYKEOPS
1517

1618
# * Import a copy of numpy as tfnp
1719
from importlib.util import find_spec, module_from_spec
@@ -44,7 +46,7 @@ def get_backend_string(cls) -> str:
4446

4547
@classmethod
4648
def change_backend_gempy(cls, engine_backend: AvailableBackends, use_gpu: bool = True, dtype: Optional[str] = None):
47-
cls._change_backend(engine_backend, pykeops_enabled=DEFAULT_PYKEOPS, use_gpu=use_gpu, dtype=dtype)
49+
cls._change_backend(engine_backend, pykeops_enabled=PYKEOPS, use_gpu=use_gpu, dtype=dtype)
4850

4951
@classmethod
5052
def _change_backend(cls, engine_backend: AvailableBackends, pykeops_enabled: bool = False, use_gpu: bool = True, dtype: Optional[str] = None):
@@ -100,6 +102,12 @@ def _change_backend(cls, engine_backend: AvailableBackends, pykeops_enabled: boo
100102
if (pykeops_enabled):
101103
import pykeops
102104
cls._wrap_pykeops_functions()
105+
106+
if (use_gpu):
107+
cls.use_gpu = True
108+
# cls.tensor_backend_pointer['active_backend'].set_default_device("cuda")
109+
else:
110+
cls.use_gpu = False
103111

104112
case (_):
105113
raise AttributeError(

gempy_engine/core/data/interpolation_input.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pprint
2-
from dataclasses import dataclass
3-
from typing import Optional
2+
from dataclasses import dataclass, field
3+
from typing import Optional, Union
44

55
import numpy as np
66

@@ -21,23 +21,9 @@ class InterpolationInput:
2121
surface_points: SurfacePoints
2222
orientations: Orientations
2323
_original_grid: EngineGrid
24-
25-
@property
26-
def original_grid(self):
27-
return self._original_grid
28-
29-
def set_grid_to_original(self):
30-
self._grid = self._original_grid
31-
32-
3324
_grid: EngineGrid
34-
@property
35-
def grid(self):
36-
return self._grid
37-
38-
def set_temp_grid(self, value):
39-
self._grid = value
4025

26+
weights: Union[list[np.ndarray] | np.ndarray] = field(default_factory=lambda: [])
4127
_unit_values: Optional[np.ndarray] = None
4228
segmentation_function: Optional[callable] = None # * From scalar field to values
4329

@@ -52,14 +38,18 @@ def set_temp_grid(self, value):
5238

5339
def __init__(self, surface_points: SurfacePoints, orientations: Orientations, grid: EngineGrid,
5440
unit_values: Optional[np.ndarray] = None, segmentation_function: Optional[callable] = None,
55-
stack_relation: StackRelationType = StackRelationType.ERODE):
41+
stack_relation: StackRelationType = StackRelationType.ERODE, weights: list[np.ndarray] = None):
42+
if weights is None:
43+
weights = []
44+
5645
self.surface_points = surface_points
5746
self._original_grid = grid
5847
self._grid = grid
5948
self.orientations = orientations
6049
self.unit_values = unit_values
6150
self.segmentation_function = segmentation_function
6251
self.stack_relation = stack_relation
52+
self.weights = weights
6353

6454
# @ on
6555

@@ -92,6 +82,7 @@ def from_interpolation_input_subset(cls, all_interpolation_input: "Interpolation
9282
grid=grid,
9383
unit_values=unit_values,
9484
stack_relation=stack_structure.active_masking_descriptor,
85+
weights=(all_interpolation_input.weights[stack_number] if stack_number < len(all_interpolation_input.weights) else None)
9586
)
9687

9788
# ! Setting this on the constructor does not work with data classes.
@@ -116,6 +107,19 @@ def from_schema(cls, schema: InterpolationInputSchema) -> "InterpolationInput":
116107
grid=grid
117108
)
118109

110+
@property
111+
def original_grid(self):
112+
return self._original_grid
113+
114+
def set_grid_to_original(self):
115+
self._grid = self._original_grid
116+
117+
@property
118+
def grid(self):
119+
return self._grid
120+
121+
def set_temp_grid(self, value):
122+
self._grid = value
119123

120124
@property
121125
def slice_feature(self):

0 commit comments

Comments
 (0)