Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 16 additions & 19 deletions jax_privacy/matrix_factorization/toeplitz.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,14 @@ def materialize_lower_triangular(
) -> jax.Array:
"""Creates a lower-triangular Toeplitz matrix.

Example: If `coef = [a, b, c]` and `n = 6`, then this method returns:
Example: If ``coef = [a, b, c]`` and ``n = 6``, then this method returns::

```
[a 0 0 0 0 0]
[b a 0 0 0 0]
[c b a 0 0 0]
[0 c b a 0 0]
[0 0 c b a 0]
[0 0 0 c b a]
```
[a 0 0 0 0 0]
[b a 0 0 0 0]
[c b a 0 0 0]
[0 c b a 0 0]
[0 0 c b a 0]
[0 0 0 c b a]

Args:
coef: The nonzero coefficients of a lower-triangular Toeplitz matrix C, that
Expand All @@ -211,16 +209,15 @@ def solve_banded(coef: jax.Array, rhs: jax.Array) -> jax.Array:
Note we want to be able to back-propagate gradients through this function,
hence we cannot use scipy.linalg.solve_toeplitz.

Example: coef = [a, b, c], rhs = [1, 1, 1, 1, 1, 1], we solve the following
system for x
```
[a 0 0 0 0 0] [x_0] [1]
[b a 0 0 0 0] [x_1] [1]
[c b a 0 0 0] [x_2] = [1]
[0 c b a 0 0] [x_3] [1]
[0 0 c b a 0] [x_4] [1]
[0 0 0 c b a] [x_5] [1]
```
Example: ``coef = [a, b, c]``, ``rhs = [1, 1, 1, 1, 1, 1]``, we solve the
following system for x::

[a 0 0 0 0 0] [x_0] [1]
[b a 0 0 0 0] [x_1] [1]
[c b a 0 0 0] [x_2] = [1]
[0 c b a 0 0] [x_3] [1]
[0 0 c b a 0] [x_4] [1]
[0 0 0 c b a] [x_5] [1]

Args:
coef: The nonzero coefficients of a lower-triangular Toeplitz matrix C, that
Expand Down
Loading