Skip to content

Commit 4aafa01

Browse files
committed
[BUG/CLN] Refactor tensor handling and cleanup lazy weight logic
Updated tensor operations in `symbolic_evaluator.py` to use `lazy_weights` consistently, replacing redundant `LazyTensor` calls. Adjusted dtype handling in `backend_tensor.py` for better compatibility with PyTorch.
1 parent 1a2c04b commit 4aafa01

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

gempy_engine/core/backend_tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ def _sum(tensor, axis=None, dtype=None, keepdims=False):
209209
case pykeops.numpy.LazyTensor() | pykeops.torch.LazyTensor():
210210
return tensor.sum(axis)
211211
case torch.Tensor() if torch_available:
212+
if isinstance(dtype, str):
213+
dtype = getattr(torch, dtype)
212214
return tensor.sum(axis, keepdims=keepdims, dtype=dtype)
213215
case _:
214216
raise TypeError("Unsupported tensor type")

gempy_engine/modules/evaluator/symbolic_evaluator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ def symbolic_evaluator(solver_input: SolverInput, weights: np.ndarray, options:
3434
eval_gx_kernel = yield_evaluation_grad_kernel(solver_input, options.kernel_options, axis=0)
3535
eval_gy_kernel = yield_evaluation_grad_kernel(solver_input, options.kernel_options, axis=1)
3636

37-
gx_field = (eval_gx_kernel.T * LazyTensor(weights, axis=1)).sum(axis=1, backend=backend_string).reshape(-1)
38-
gy_field = (eval_gy_kernel.T * LazyTensor(weights, axis=1)).sum(axis=1, backend=backend_string).reshape(-1)
37+
gx_field = (eval_gx_kernel.T * lazy_weights).sum(axis=1, backend=backend_string).reshape(-1)
38+
gy_field = (eval_gy_kernel.T * lazy_weights).sum(axis=1, backend=backend_string).reshape(-1)
3939

4040
if options.number_dimensions == 3:
4141
eval_gz_kernel = yield_evaluation_grad_kernel(solver_input, options.kernel_options, axis=2)
42-
gz_field = (eval_gz_kernel.T * LazyTensor(weights, axis=1)).sum(axis=1, backend=backend_string).reshape(-1)
42+
gz_field = (eval_gz_kernel.T * lazy_weights).sum(axis=1, backend=backend_string).reshape(-1)
4343
elif options.number_dimensions == 2:
4444
gz_field = None
4545
else:

0 commit comments

Comments
 (0)