Skip to content

Commit 52f5f70

Browse files
committed
Add caching improvements and enhance solver stability
Introduced a `clear_cache` method in the `WeightCache` class for better memory management. Enhanced the numerical stability of the conjugate gradient solver with improved initialization and adaptive tolerances. Refactored interpolation logic to handle weights more efficiently, and adjusted benchmarking to test expanded solver configurations.
1 parent e9ae973 commit 52f5f70

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

gempy/modules/data_manipulation/_engine_factory.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,18 @@ def interpolation_input_from_structural_frame(geo_model: "gempy.data.GeoModel")
4242
extent_transformed=geo_model.extent_transformed_transformed_by_input
4343
)
4444

45+
weights = []
46+
if geo_model.solutions is not None:
47+
for stack_sol in geo_model.solutions.root_output.outputs_centers:
48+
weights.append(stack_sol.weights)
49+
50+
4551
interpolation_input: InterpolationInput = InterpolationInput(
4652
surface_points=surface_points,
4753
orientations=orientations,
4854
grid=grid,
49-
unit_values=structural_frame.elements_ids # TODO: Here we will need to pass densities etc.
55+
unit_values=structural_frame.elements_ids, # TODO: Here we will need to pass densities etc.
56+
weights=weights
5057
)
5158

5259
return interpolation_input

test/test_modules/test_cg/test_cg_solver.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
import gempy as gp
22
from gempy.core.data.enumerators import ExampleModel
33
from gempy.optional_dependencies import require_gempy_viewer
4+
import gempy_engine.core.backend_tensor as BackendTensor
5+
6+
from gempy_engine.modules.weights_cache.weights_cache_interface import WeightCache
47

58
PLOT = True
69

710

8-
def test_generate_greenstone_model():
11+
def test_solve_with_cg():
912
model = gp.generate_example_model(ExampleModel.GREENSTONE, compute_model=False)
1013
print(model.structural_frame)
1114

15+
WeightCache.clear_cache()
16+
BackendTensor.PYKEOPS = True
17+
1218
sol = gp.compute_model(
1319
gempy_model=model,
1420
engine_config=gp.data.GemPyEngineConfig(
@@ -21,3 +27,36 @@ def test_generate_greenstone_model():
2127
if PLOT:
2228
gpv = require_gempy_viewer()
2329
gpv.plot_3d(model, image=True)
30+
31+
32+
def test_save_weights():
33+
model = gp.generate_example_model(ExampleModel.GREENSTONE, compute_model=False)
34+
print(model.structural_frame)
35+
36+
sol = gp.compute_model(
37+
gempy_model=model,
38+
engine_config=gp.data.GemPyEngineConfig(
39+
backend=gp.data.AvailableBackends.PYTORCH,
40+
use_gpu=False,
41+
dtype='float32'
42+
)
43+
)
44+
weights1 = sol.octrees_output[0].outputs_centers[0].weights
45+
weights2 = sol.octrees_output[0].outputs_centers[1].weights
46+
weights3 = sol.octrees_output[0].outputs_centers[2].weights
47+
48+
WeightCache.clear_cache()
49+
BackendTensor.PYKEOPS = True
50+
51+
sol = gp.compute_model(
52+
gempy_model=model,
53+
engine_config=gp.data.GemPyEngineConfig(
54+
backend=gp.data.AvailableBackends.PYTORCH,
55+
use_gpu=False,
56+
dtype='float32'
57+
)
58+
)
59+
60+
if PLOT:
61+
gpv = require_gempy_viewer()
62+
gpv.plot_3d(model, image=True)

0 commit comments

Comments
 (0)