diff --git a/jax_privacy/matrix_factorization/toeplitz.py b/jax_privacy/matrix_factorization/toeplitz.py index 61cb0839..eb530030 100644 --- a/jax_privacy/matrix_factorization/toeplitz.py +++ b/jax_privacy/matrix_factorization/toeplitz.py @@ -177,16 +177,14 @@ def materialize_lower_triangular( ) -> jax.Array: """Creates a lower-triangular Toeplitz matrix. - Example: If `coef = [a, b, c]` and `n = 6`, then this method returns: + Example: If ``coef = [a, b, c]`` and ``n = 6``, then this method returns:: - ``` - [a 0 0 0 0 0] - [b a 0 0 0 0] - [c b a 0 0 0] - [0 c b a 0 0] - [0 0 c b a 0] - [0 0 0 c b a] - ``` + [a 0 0 0 0 0] + [b a 0 0 0 0] + [c b a 0 0 0] + [0 c b a 0 0] + [0 0 c b a 0] + [0 0 0 c b a] Args: coef: The nonzero coefficients of a lower-triangular Toeplitz matrix C, that @@ -211,16 +209,15 @@ def solve_banded(coef: jax.Array, rhs: jax.Array) -> jax.Array: Note we want to be able to back-propagate gradients through this function, hence we cannot use scipy.linalg.solve_toeplitz. - Example: coef = [a, b, c], rhs = [1, 1, 1, 1, 1, 1], we solve the following - system for x - ``` - [a 0 0 0 0 0] [x_0] [1] - [b a 0 0 0 0] [x_1] [1] - [c b a 0 0 0] [x_2] = [1] - [0 c b a 0 0] [x_3] [1] - [0 0 c b a 0] [x_4] [1] - [0 0 0 c b a] [x_5] [1] - ``` + Example: ``coef = [a, b, c]``, ``rhs = [1, 1, 1, 1, 1, 1]``, we solve the + following system for x:: + + [a 0 0 0 0 0] [x_0] [1] + [b a 0 0 0 0] [x_1] [1] + [c b a 0 0 0] [x_2] = [1] + [0 c b a 0 0] [x_3] [1] + [0 0 c b a 0] [x_4] [1] + [0 0 0 c b a] [x_5] [1] Args: coef: The nonzero coefficients of a lower-triangular Toeplitz matrix C, that