1+ import torch
12from typing import Tuple
23
34from gempy_engine .config import AvailableBackends
45from ...core .backend_tensor import BackendTensor
56from ...core .data .dual_contouring_data import DualContouringData
6- import numpy as np
77
88
9- def find_intersection_on_edge (_xyz_corners : np . ndarray , scalar_field_on_corners : np . ndarray ,
10- scalar_at_sp : np . ndarray , masking = None ) -> Tuple [ np . ndarray , np . ndarray ] :
9+ def find_intersection_on_edge (_xyz_corners , scalar_field_on_corners ,
10+ scalar_at_sp , masking = None ) -> Tuple :
1111 """This function finds all the intersections for multiple layers per series
1212
1313 - The shape of valid edges is n_surfaces * xyz_corners. Where xyz_corners is 8 * the octree leaf
@@ -27,13 +27,13 @@ def find_intersection_on_edge(_xyz_corners: np.ndarray, scalar_field_on_corners:
2727 scalar_at_sp = scalar_at_sp .reshape ((- 1 , 1 , 1 ))
2828
2929 n_isosurface = scalar_at_sp .shape [0 ]
30- xyz_8 = BackendTensor .t .tile (xyz_8 , (n_isosurface , 1 , 1 )) # TODO: Generalize
30+ xyz_8 = BackendTensor .tfnp .tile (xyz_8 , (n_isosurface , 1 , 1 )) # TODO: Generalize
3131
3232 # Compute distance of scalar field on the corners
33- scalar_dx = scalar_8 [:, :, :4 ] - scalar_8 [:, :, 4 :]
33+ scalar_dx = scalar_8 [:, :, :4 ] - scalar_8 [:, :, 4 :]
3434 scalar_d_y = scalar_8 [:, :, [0 , 1 , 4 , 5 ]] - scalar_8 [:, :, [2 , 3 , 6 , 7 ]]
3535 scalar_d_z = scalar_8 [:, :, ::2 ] - scalar_8 [:, :, 1 ::2 ]
36-
36+
3737 # Add a tiny value to avoid division by zero
3838 scalar_dx += 1e-10
3939 scalar_d_y += 1e-10
@@ -53,16 +53,16 @@ def find_intersection_on_edge(_xyz_corners: np.ndarray, scalar_field_on_corners:
5353 intersect_dx = d_x [:, :, :] * weight_x [:, :, :]
5454 intersect_dy = d_y [:, :, :] * weight_y [:, :, :]
5555 intersect_dz = d_z [:, :, :] * weight_z [:, :, :]
56-
56+
5757 # Mask invalid edges
58- valid_edge_x = np .logical_and (weight_x > - 0.01 , weight_x < 1.01 )
59- valid_edge_y = np .logical_and (weight_y > - 0.01 , weight_y < 1.01 )
60- valid_edge_z = np .logical_and (weight_z > - 0.01 , weight_z < 1.01 )
58+ valid_edge_x = BackendTensor . tfnp .logical_and (weight_x > - 0.01 , weight_x < 1.01 )
59+ valid_edge_y = BackendTensor . tfnp .logical_and (weight_y > - 0.01 , weight_y < 1.01 )
60+ valid_edge_z = BackendTensor . tfnp .logical_and (weight_z > - 0.01 , weight_z < 1.01 )
6161
6262 # * Note(miguel) From this point on the arrays become sparse
63- xyz_8_edges = BackendTensor .t .hstack ([xyz_8 [:, 4 :], xyz_8 [:, [2 , 3 , 6 , 7 ]], xyz_8 [:, 1 ::2 ]])
64- intersect_segment = BackendTensor .t .hstack ([intersect_dx , intersect_dy , intersect_dz ])
65- valid_edges = BackendTensor .t .hstack ([valid_edge_x , valid_edge_y , valid_edge_z ])[:, :, 0 ]
63+ xyz_8_edges = BackendTensor .tfnp .hstack ([xyz_8 [:, 4 :], xyz_8 [:, [2 , 3 , 6 , 7 ]], xyz_8 [:, 1 ::2 ]])
64+ intersect_segment = BackendTensor .tfnp .hstack ([intersect_dx , intersect_dy , intersect_dz ])
65+ valid_edges = BackendTensor .tfnp .hstack ([valid_edge_x , valid_edge_y , valid_edge_z ])[:, :, 0 ]
6666 valid_edges = valid_edges > 0
6767
6868 intersection_xyz = xyz_8_edges [valid_edges ] + intersect_segment [valid_edges ]
@@ -88,13 +88,13 @@ def triangulate_dual_contouring(dc_data_per_surface: DualContouringData):
8888 x_2 = centers_xyz [valid_voxels ][None , :, :]
8989
9090 manhattan = x_1 - x_2
91- zeros = np .isclose (manhattan [:, :, :], 0 , .00001 )
92- x_direction_neighbour = np .isclose (manhattan [:, :, 0 ], dx , .00001 )
93- nx_direction_neighbour = np .isclose (manhattan [:, :, 0 ], - dx , .00001 )
94- y_direction_neighbour = np .isclose (manhattan [:, :, 1 ], dy , .00001 )
95- ny_direction_neighbour = np .isclose (manhattan [:, :, 1 ], - dy , .00001 )
96- z_direction_neighbour = np .isclose (manhattan [:, :, 2 ], dz , .00001 )
97- nz_direction_neighbour = np .isclose (manhattan [:, :, 2 ], - dz , .00001 )
91+ zeros = BackendTensor . tfnp .isclose (manhattan [:, :, :], 0 , .00001 )
92+ x_direction_neighbour = BackendTensor . tfnp .isclose (manhattan [:, :, 0 ], dx , .00001 )
93+ nx_direction_neighbour = BackendTensor . tfnp .isclose (manhattan [:, :, 0 ], - dx , .00001 )
94+ y_direction_neighbour = BackendTensor . tfnp .isclose (manhattan [:, :, 1 ], dy , .00001 )
95+ ny_direction_neighbour = BackendTensor . tfnp .isclose (manhattan [:, :, 1 ], - dy , .00001 )
96+ z_direction_neighbour = BackendTensor . tfnp .isclose (manhattan [:, :, 2 ], dz , .00001 )
97+ nz_direction_neighbour = BackendTensor . tfnp .isclose (manhattan [:, :, 2 ], - dz , .00001 )
9898
9999 x_direction = x_direction_neighbour * zeros [:, :, 1 ] * zeros [:, :, 2 ]
100100 nx_direction = nx_direction_neighbour * zeros [:, :, 1 ] * zeros [:, :, 2 ]
@@ -103,12 +103,12 @@ def triangulate_dual_contouring(dc_data_per_surface: DualContouringData):
103103 z_direction = z_direction_neighbour * zeros [:, :, 0 ] * zeros [:, :, 1 ]
104104 nz_direction = nz_direction_neighbour * zeros [:, :, 0 ] * zeros [:, :, 1 ]
105105
106- np .fill_diagonal (x_direction , True )
107- np .fill_diagonal (nx_direction , True )
108- np .fill_diagonal (y_direction , True )
109- np . fill_diagonal (nx_direction , True )
110- np .fill_diagonal (z_direction , True )
111- np .fill_diagonal (nz_direction , True )
106+ BackendTensor . tfnp .fill_diagonal (x_direction , True )
107+ BackendTensor . tfnp .fill_diagonal (nx_direction , True )
108+ BackendTensor . tfnp .fill_diagonal (y_direction , True )
109+ BackendTensor . tfnp . fill_diagonal (ny_direction , True )
110+ BackendTensor . tfnp .fill_diagonal (z_direction , True )
111+ BackendTensor . tfnp .fill_diagonal (nz_direction , True )
112112
113113 # X edges
114114 nynz_direction = ny_direction + nz_direction
@@ -129,72 +129,81 @@ def triangulate_dual_contouring(dc_data_per_surface: DualContouringData):
129129 xy_direction = x_direction + y_direction
130130
131131 # Stack all 12 directions
132- directions = np .dstack ([nynz_direction , nyz_direction , ynz_direction , yz_direction ,
133- nxnz_direction , xnz_direction , nxz_direction , xz_direction ,
134- nxny_direction , nxy_direction , xny_direction , xy_direction ])
132+ directions = BackendTensor . tfnp .dstack ([nynz_direction , nyz_direction , ynz_direction , yz_direction ,
133+ nxnz_direction , xnz_direction , nxz_direction , xz_direction ,
134+ nxny_direction , nxy_direction , xny_direction , xy_direction ])
135135
136136 # endregion
137137
138138 valid_edg = valid_edges [valid_voxels ][:, :]
139- valid_edg = BackendTensor .t .to_numpy (valid_edg )
140-
139+ valid_edg = BackendTensor .tfnp .to_numpy (valid_edg )
140+
141141 direction_each_edge = (directions * valid_edg )
142142
143143 # Pick only edges with more than 2 voxels nearby
144144 three_neighbours = (directions * valid_edg ).sum (axis = 0 ) == 3
145- matrix_to_right_C_order = np .transpose ((direction_each_edge * three_neighbours ), (1 , 2 , 0 ))
146- indices = np .where (matrix_to_right_C_order )[2 ].reshape (- 1 , 3 )
145+ matrix_to_right_C_order = BackendTensor . tfnp .transpose ((direction_each_edge * three_neighbours ), (1 , 2 , 0 ))
146+ indices = BackendTensor . tfnp .where (matrix_to_right_C_order )[2 ].reshape (- 1 , 3 )
147147
148148 indices_shift = indices
149149 indices_arrays .append (indices_shift )
150- indices_arrays_f = np .vstack (indices_arrays )
150+ indices_arrays_f = BackendTensor . tfnp .vstack (indices_arrays )
151151
152152 return indices_arrays_f
153153
154154
155- def generate_dual_contouring_vertices (dc_data_per_stack : DualContouringData , slice_surface : slice , debug : bool = False ) -> np . ndarray :
155+ def generate_dual_contouring_vertices (dc_data_per_stack : DualContouringData , slice_surface : slice , debug : bool = False ):
156156 # @off
157- n_edges = dc_data_per_stack .n_edges
158- valid_edges = dc_data_per_stack .valid_edges
157+ n_edges = dc_data_per_stack .n_edges
158+ valid_edges = dc_data_per_stack .valid_edges
159159 valid_voxels = dc_data_per_stack .valid_voxels
160- xyz_on_edge = dc_data_per_stack .xyz_on_edge [slice_surface ]
161- gradients = dc_data_per_stack .gradients [slice_surface ]
160+ xyz_on_edge = dc_data_per_stack .xyz_on_edge [slice_surface ]
161+ gradients = dc_data_per_stack .gradients [slice_surface ]
162162 # @on
163163
164164 # * Coordinates for all posible edges (12) and 3 dummy edges_normals in the center
165- edges_xyz = BackendTensor .t .zeros ((n_edges , 15 , 3 ), dtype = BackendTensor .dtype_obj )
165+ edges_xyz = BackendTensor .tfnp .zeros ((n_edges , 15 , 3 ), dtype = BackendTensor .dtype_obj )
166166 valid_edges = valid_edges > 0
167167 edges_xyz [:, :12 ][valid_edges ] = xyz_on_edge
168168
169169 # Normals
170- edges_normals = BackendTensor .t .zeros ((n_edges , 15 , 3 ), dtype = BackendTensor .dtype_obj )
170+ edges_normals = BackendTensor .tfnp .zeros ((n_edges , 15 , 3 ), dtype = BackendTensor .dtype_obj )
171171 edges_normals [:, :12 ][valid_edges ] = gradients
172172
173- if OLD_METHOD := False :
173+ if OLD_METHOD := False :
174174 # ! Moureze model does not seems to work with the new method
175175 # ! This branch is all nans at least with ch1_1 model
176- bias_xyz = np .copy (edges_xyz [:, :12 ])
177- isclose = np .isclose (bias_xyz , 0 )
178- bias_xyz [isclose ] = np . nan # np zero values to nans
179- mass_points = np .nanmean (bias_xyz , axis = 1 ) # Mean ignoring nans
176+ bias_xyz = BackendTensor . tfnp .copy (edges_xyz [:, :12 ])
177+ isclose = BackendTensor . tfnp .isclose (bias_xyz , 0 )
178+ bias_xyz [isclose ] = BackendTensor . tfnp . nan # zero values to nans
179+ mass_points = BackendTensor . tfnp .nanmean (bias_xyz , axis = 1 ) # Mean ignoring nans
180180 else : # ? This is actually doing something
181- bias_xyz = BackendTensor .t .copy (edges_xyz [:, :12 ])
182- bias_xyz = BackendTensor .t .to_numpy (bias_xyz )
183- mask = bias_xyz == 0
184- masked_arr = np .ma .masked_array (bias_xyz , mask )
185- mass_points = masked_arr .mean (axis = 1 )
186- mass_points = BackendTensor .t .array (mass_points )
181+ bias_xyz = BackendTensor .tfnp .copy (edges_xyz [:, :12 ])
182+ if BackendTensor .engine_backend == AvailableBackends .PYTORCH :
183+ # PyTorch doesn't have masked arrays, so we'll use a different approach
184+ mask = bias_xyz == 0
185+ # Replace zeros with NaN for mean calculation
186+ bias_xyz_masked = BackendTensor .tfnp .where (mask , float ('nan' ), bias_xyz )
187+ mass_points = BackendTensor .tfnp .nanmean (bias_xyz_masked , axis = 1 )
188+ else :
189+ # NumPy approach with masked arrays
190+ bias_xyz = BackendTensor .tfnp .to_numpy (bias_xyz )
191+ import numpy as np
192+ mask = bias_xyz == 0
193+ masked_arr = np .ma .masked_array (bias_xyz , mask )
194+ mass_points = masked_arr .mean (axis = 1 )
195+ mass_points = BackendTensor .tfnp .array (mass_points )
187196
188197 edges_xyz [:, 12 ] = mass_points
189198 edges_xyz [:, 13 ] = mass_points
190199 edges_xyz [:, 14 ] = mass_points
191200
192201 BIAS_STRENGTH = 1
193-
194- bias_x = BackendTensor .t .array ([BIAS_STRENGTH , 0 , 0 ], dtype = BackendTensor .dtype_obj )
195- bias_y = BackendTensor .t .array ([0 , BIAS_STRENGTH , 0 ], dtype = BackendTensor .dtype_obj )
196- bias_z = BackendTensor .t .array ([0 , 0 , BIAS_STRENGTH ], dtype = BackendTensor .dtype_obj )
197-
202+
203+ bias_x = BackendTensor .tfnp .array ([BIAS_STRENGTH , 0 , 0 ], dtype = BackendTensor .dtype_obj )
204+ bias_y = BackendTensor .tfnp .array ([0 , BIAS_STRENGTH , 0 ], dtype = BackendTensor .dtype_obj )
205+ bias_z = BackendTensor .tfnp .array ([0 , 0 , BIAS_STRENGTH ], dtype = BackendTensor .dtype_obj )
206+
198207 edges_normals [:, 12 ] = bias_x
199208 edges_normals [:, 13 ] = bias_y
200209 edges_normals [:, 14 ] = bias_z
@@ -208,14 +217,15 @@ def generate_dual_contouring_vertices(dc_data_per_stack: DualContouringData, sli
208217 b = (A * edges_xyz ).sum (axis = 2 )
209218
210219 if BackendTensor .engine_backend == AvailableBackends .PYTORCH :
211- transpose_shape = (2 , 1 )
220+ transpose_shape = (2 , 1 , 0 ) # For PyTorch: (batch, dim2, dim1 )
212221 else :
213- transpose_shape = (0 , 2 ,1 )
214-
215- term1 = BackendTensor .t .einsum ("ijk, ilj->ikl" , A , BackendTensor .t .transpose (A , transpose_shape ))
216- term2 = BackendTensor .t .linalg .inv (term1 )
217- term3 = BackendTensor .t .einsum ("ijk,ik->ij" , BackendTensor .t .transpose (A , transpose_shape ), b )
218- vertices = BackendTensor .t .einsum ("ijk, ij->ik" , term2 , term3 )
222+ transpose_shape = (0 , 2 , 1 ) # For NumPy: (batch, dim2, dim1)
223+
224+ import torch
225+ term1 = BackendTensor .tfnp .einsum ("ijk, ilj->ikl" , A , BackendTensor .tfnp .transpose (A , transpose_shape ))
226+ term2 = BackendTensor .tfnp .linalg .inv (term1 )
227+ term3 = BackendTensor .tfnp .einsum ("ijk,ik->ij" , BackendTensor .tfnp .transpose (A , transpose_shape ), b )
228+ vertices = BackendTensor .tfnp .einsum ("ijk, ij->ik" , term2 , term3 )
219229
220230 if debug :
221231 dc_data_per_stack .bias_center_mass = edges_xyz [:, 12 :].reshape (- 1 , 3 )
@@ -237,8 +247,8 @@ def evaluate(self, x):
237247 """Evaluates the function at a given point.
238248 This is what the solve method is trying to minimize.
239249 NB: Doesn't work with fixed axes."""
240- x = np .array (x )
241- return np . linalg .norm (np .matmul (self .A , x ) - self .b )
250+ x = BackendTensor . tfnp .array (x )
251+ return BackendTensor . tfnp . linalg .norm (BackendTensor . tfnp .matmul (self .A , x ) - self .b )
242252
243253 def eval_with_pos (self , x ):
244254 """Evaluates the QEF at a position, returning the same format solve does."""
@@ -248,15 +258,15 @@ def eval_with_pos(self, x):
248258 def make_3d (positions , normals ):
249259 """Returns a QEF that measures the the error from a bunch of normals, each emanating
250260 from given positions"""
251- A = np .array (normals )
261+ A = BackendTensor . tfnp .array (normals )
252262 b = [v [0 ] * n [0 ] + v [1 ] * n [1 ] + v [2 ] * n [2 ] for v , n in zip (positions , normals )]
253263 fixed_values = [None ] * A .shape [1 ]
254264 return QEF (A , b , fixed_values )
255265
256266 def solve (self ):
257267 """Finds the point that minimizes the error of this QEF,
258268 and returns a tuple of the error squared and the point itself"""
259- result , residual , rank , s = np .linalg .lstsq (self .A , self .b )
269+ result , residual , rank , s = BackendTensor . tfnp .linalg .lstsq (self .A , self .b )
260270 if len (residual ) == 0 :
261271 residual = self .evaluate (result )
262272 else :
0 commit comments