2323def _cross_squared_distance_matrix (x : TensorLike , y : TensorLike ) -> tf .Tensor :
2424 """Pairwise squared distance between two (batch) matrices' rows (2nd dim).
2525
26- Computes the pairwise distances between rows of x and rows of y
26+ Computes the pairwise distances between rows of x and rows of y.
27+
2728 Args:
28- x: [batch_size, n, d] float `Tensor`
29- y: [batch_size, m, d] float `Tensor`
29+ x: ` [batch_size, n, d]` float `Tensor`.
30+ y: ` [batch_size, m, d]` float `Tensor`.
3031
3132 Returns:
32- squared_dists: [batch_size, n, m] float `Tensor`, where
33- squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2
33+ squared_dists: ` [batch_size, n, m]` float `Tensor`, where
34+ ` squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2`.
3435 """
3536 x_norm_squared = tf .reduce_sum (tf .square (x ), 2 )
3637 y_norm_squared = tf .reduce_sum (tf .square (y ), 2 )
@@ -52,14 +53,14 @@ def _pairwise_squared_distance_matrix(x: TensorLike) -> tf.Tensor:
5253 """Pairwise squared distance among a (batch) matrix's rows (2nd dim).
5354
5455 This saves a bit of computation vs. using
55- _cross_squared_distance_matrix(x,x)
56+ ` _cross_squared_distance_matrix(x, x)`
5657
5758 Args:
58- x: `[batch_size, n, d]` float `Tensor`
59+ x: `[batch_size, n, d]` float `Tensor`.
5960
6061 Returns:
6162 squared_dists: `[batch_size, n, n]` float `Tensor`, where
62- squared_dists[b,i,j] = ||x[b,i,:] - x[b,j,:]||^2
63+ ` squared_dists[b,i,j] = ||x[b,i,:] - x[b,j,:]||^2`.
6364 """
6465
6566 x_x_transpose = tf .matmul (x , x , adjoint_b = True )
@@ -83,17 +84,17 @@ def _solve_interpolation(
8384 order : int ,
8485 regularization_weight : FloatTensorLike ,
8586) -> TensorLike :
86- """Solve for interpolation coefficients.
87+ r """Solve for interpolation coefficients.
8788
8889 Computes the coefficients of the polyharmonic interpolant for the
89- 'training' data defined by (train_points, train_values) using the kernel
90- phi.
90+ 'training' data defined by ` (train_points, train_values)` using the kernel
91+ $\ phi$ .
9192
9293 Args:
93- train_points: `[b, n, d]` interpolation centers
94- train_values: `[b, n, k]` function values
95- order: order of the interpolation
96- regularization_weight: weight to place on smoothness regularization term
94+ train_points: `[b, n, d]` interpolation centers.
95+ train_values: `[b, n, k]` function values.
96+ order: order of the interpolation.
97+ regularization_weight: weight to place on smoothness regularization term.
9798
9899 Returns:
99100 w: `[b, n, k]` weights on each interpolation center
@@ -173,15 +174,15 @@ def _apply_interpolation(
173174 interpolated function values at query_points.
174175
175176 Args:
176- query_points: `[b, m, d]` x values to evaluate the interpolation at
177+ query_points: `[b, m, d]` x values to evaluate the interpolation at.
177178 train_points: `[b, n, d]` x values that act as the interpolation centers
178- ( the c variables in the wikipedia article)
179- w: `[b, n, k]` weights on each interpolation center
180- v: `[b, d, k]` weights on each input dimension
181- order: order of the interpolation
179+ ( the c variables in the wikipedia article).
180+ w: `[b, n, k]` weights on each interpolation center.
181+ v: `[b, d, k]` weights on each input dimension.
182+ order: order of the interpolation.
182183
183184 Returns:
184- Polyharmonic interpolation evaluated at points defined in query_points.
185+ Polyharmonic interpolation evaluated at points defined in ` query_points` .
185186 """
186187
187188 # First, compute the contribution from the rbf term.
@@ -207,11 +208,11 @@ def _phi(r: FloatTensorLike, order: int) -> FloatTensorLike:
207208 See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition.
208209
209210 Args:
210- r: input op
211- order: interpolation order
211+ r: input op.
212+ order: interpolation order.
212213
213214 Returns:
214- phi_k evaluated coordinate-wise on r , for k = r
215+ ` phi_k` evaluated coordinate-wise on `r` , for ` k = r`.
215216 """
216217
217218 # using EPSILON prevents log(0), sqrt0), etc.
0 commit comments