Skip to content

Commit be6a032

Browse files
authored
Adds functions to rewrite cholesky decomposition of identity and diagonal matrices (pymc-devs#925)
* fixed merge conflicts * fixed failing tests and added rewrite for pt.diag * minor changes; added test to not apply rewrite * added test for batched case and more cases of not applying rewrite * minor changes
1 parent 3e98b9f commit be6a032

File tree

2 files changed

+182
-0
lines changed

2 files changed

+182
-0
lines changed

pytensor/tensor/rewriting/linalg.py

+79
Original file line numberDiff line numberDiff line change
@@ -887,3 +887,82 @@ def rewrite_slogdet_kronecker(fgraph, node):
887887
logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]
888888

889889
return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]
890+
891+
892+
@register_canonicalize
893+
@register_stabilize
894+
@node_rewriter([Blockwise])
895+
def rewrite_remove_useless_cholesky(fgraph, node):
896+
"""
897+
This rewrite takes advantage of the fact that the cholesky decomposition of an identity matrix is the matrix itself
898+
899+
The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside Cholesky.
900+
901+
Parameters
902+
----------
903+
fgraph: FunctionGraph
904+
Function graph being optimized
905+
node: Apply
906+
Node of the function graph to be optimized
907+
908+
Returns
909+
-------
910+
list of Variable, optional
911+
List of optimized variables, or None if no optimization was performed
912+
"""
913+
# Find whether cholesky op is being applied
914+
if not isinstance(node.op.core_op, Cholesky):
915+
return None
916+
917+
# Check whether input to Cholesky is Eye and the 1's are on main diagonal
918+
potential_eye = node.inputs[0]
919+
if not (
920+
potential_eye.owner
921+
and isinstance(potential_eye.owner.op, Eye)
922+
and hasattr(potential_eye.owner.inputs[-1], "data")
923+
and potential_eye.owner.inputs[-1].data.item() == 0
924+
):
925+
return None
926+
return [potential_eye]
927+
928+
929+
@register_canonicalize
930+
@register_stabilize
931+
@node_rewriter([Blockwise])
932+
def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
933+
# Find whether cholesky op is being applied
934+
if not isinstance(node.op.core_op, Cholesky):
935+
return None
936+
937+
[input] = node.inputs
938+
# Check for use of pt.diag first
939+
if (
940+
input.owner
941+
and isinstance(input.owner.op, AllocDiag)
942+
and AllocDiag.is_offset_zero(input.owner)
943+
):
944+
diag_input = input.owner.inputs[0]
945+
cholesky_val = pt.diag(diag_input**0.5)
946+
return [cholesky_val]
947+
948+
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
949+
inputs_or_none = _find_diag_from_eye_mul(input)
950+
if inputs_or_none is None:
951+
return None
952+
953+
eye_input, non_eye_inputs = inputs_or_none
954+
955+
# Dealing with only one other input
956+
if len(non_eye_inputs) != 1:
957+
return None
958+
959+
[non_eye_input] = non_eye_inputs
960+
961+
# Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements
962+
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
963+
if non_eye_input.type.broadcastable[-2:] == (False, False):
964+
non_eye_input = non_eye_input.diagonal(axis1=-1, axis2=-2)
965+
if eye_input.type.ndim > 2:
966+
non_eye_input = pt.shape_padaxis(non_eye_input, -2)
967+
968+
return [eye_input * (non_eye_input**0.5)]

tests/tensor/rewriting/test_linalg.py

+103
Original file line numberDiff line numberDiff line change
@@ -803,3 +803,106 @@ def test_slogdet_kronecker_rewrite():
803803
atol=1e-3 if config.floatX == "float32" else 1e-8,
804804
rtol=1e-3 if config.floatX == "float32" else 1e-8,
805805
)
806+
807+
808+
def test_cholesky_eye_rewrite():
809+
x = pt.eye(10)
810+
L = pt.linalg.cholesky(x)
811+
f_rewritten = function([], L, mode="FAST_RUN")
812+
nodes = f_rewritten.maker.fgraph.apply_nodes
813+
814+
# Rewrite Test
815+
assert not any(isinstance(node.op, Cholesky) for node in nodes)
816+
817+
# Value Test
818+
x_test = np.eye(10)
819+
L = np.linalg.cholesky(x_test)
820+
rewritten_val = f_rewritten()
821+
822+
assert_allclose(
823+
L,
824+
rewritten_val,
825+
atol=1e-3 if config.floatX == "float32" else 1e-8,
826+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
827+
)
828+
829+
830+
@pytest.mark.parametrize(
831+
"shape",
832+
[(), (7,), (7, 7), (5, 7, 7)],
833+
ids=["scalar", "vector", "matrix", "batched"],
834+
)
835+
def test_cholesky_diag_from_eye_mul(shape):
836+
# Initializing x based on scalar/vector/matrix
837+
x = pt.tensor("x", shape=shape)
838+
y = pt.eye(7) * x
839+
# Performing cholesky decomposition using pt.linalg.cholesky
840+
z_cholesky = pt.linalg.cholesky(y)
841+
842+
# REWRITE TEST
843+
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
844+
nodes = f_rewritten.maker.fgraph.apply_nodes
845+
assert not any(isinstance(node.op, Cholesky) for node in nodes)
846+
847+
# NUMERIC VALUE TEST
848+
if len(shape) == 0:
849+
x_test = np.array(np.random.rand()).astype(config.floatX)
850+
elif len(shape) == 1:
851+
x_test = np.random.rand(*shape).astype(config.floatX)
852+
else:
853+
x_test = np.random.rand(*shape).astype(config.floatX)
854+
x_test_matrix = np.eye(7) * x_test
855+
cholesky_val = np.linalg.cholesky(x_test_matrix)
856+
rewritten_val = f_rewritten(x_test)
857+
858+
assert_allclose(
859+
cholesky_val,
860+
rewritten_val,
861+
atol=1e-3 if config.floatX == "float32" else 1e-8,
862+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
863+
)
864+
865+
866+
def test_cholesky_diag_from_diag():
867+
x = pt.dvector("x")
868+
x_diag = pt.diag(x)
869+
x_cholesky = pt.linalg.cholesky(x_diag)
870+
871+
# REWRITE TEST
872+
f_rewritten = function([x], x_cholesky, mode="FAST_RUN")
873+
nodes = f_rewritten.maker.fgraph.apply_nodes
874+
875+
assert not any(isinstance(node.op, Cholesky) for node in nodes)
876+
877+
# NUMERIC VALUE TEST
878+
x_test = np.random.rand(10)
879+
x_test_matrix = np.eye(10) * x_test
880+
cholesky_val = np.linalg.cholesky(x_test_matrix)
881+
rewritten_cholesky = f_rewritten(x_test)
882+
883+
assert_allclose(
884+
cholesky_val,
885+
rewritten_cholesky,
886+
atol=1e-3 if config.floatX == "float32" else 1e-8,
887+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
888+
)
889+
890+
891+
def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied():
892+
# Case 1 : y is not a diagonal matrix because of k = -1
893+
x = pt.tensor("x", shape=(7, 7))
894+
y = pt.eye(7, k=-1) * x
895+
z_cholesky = pt.linalg.cholesky(y)
896+
897+
# REWRITE TEST (should not be applied)
898+
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
899+
nodes = f_rewritten.maker.fgraph.apply_nodes
900+
assert any(isinstance(node.op, Cholesky) for node in nodes)
901+
902+
# Case 2 : eye is degenerate
903+
x = pt.scalar("x")
904+
y = pt.eye(1) * x
905+
z_cholesky = pt.linalg.cholesky(y)
906+
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
907+
nodes = f_rewritten.maker.fgraph.apply_nodes
908+
assert any(isinstance(node.op, Cholesky) for node in nodes)

0 commit comments

Comments
 (0)