Skip to content

Adapt EigenDecomposition implementation #2038

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math3.exception.MaxCountExceededException;
import org.apache.commons.math3.exception.util.LocalizedFormats;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.CholeskyDecomposition;
Expand All @@ -40,11 +41,13 @@
import org.apache.commons.math3.linear.QRDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularValueDecomposition;
import org.apache.commons.math3.util.FastMath;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.codegen.CodegenUtils;
import org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;
import org.apache.sysds.runtime.compress.utils.IntArrayList;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.DenseBlockFactory;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Divide;
import org.apache.sysds.runtime.functionobjects.MinusMultiply;
Expand Down Expand Up @@ -74,6 +77,9 @@
private static final Log LOG = LogFactory.getLog(LibCommonsMath.class.getName());
private static final double RELATIVE_SYMMETRY_THRESHOLD = 1e-6;
private static final double EIGEN_LAMBDA = 1e-8;
// Machine epsilon 2^-53
private static final double EIGEN_MACHINE_EPS = 0x1.0p-53;
private static final int EIGEN_MAX_ITER = 30;

private LibCommonsMath() {
//prevent instantiation via private constructor
Expand Down Expand Up @@ -141,11 +147,13 @@
else if (opcode.equals("lu"))
return computeLU(in);
else if (opcode.equals("eigen"))
return computeEigen(in);
return computeEigen(in);
else if (opcode.equals("eigen_lanczos"))
return computeEigenLanczos(in, threads, seed);
else if (opcode.equals("eigen_qr"))
return computeEigenQR(in, threads);
else if (opcode.equals("eigen_symm"))
return computeEigenDecompositionSymm(in, threads);
else if (opcode.equals("svd"))
return computeSvd(in);
else if (opcode.equals("fft"))
Expand Down Expand Up @@ -209,6 +217,350 @@

return DataConverter.convertToMatrixBlock(solutionMatrix);
}

/**
* Computes the eigen decomposition of a symmetric matrix using the Implicit QL Algorithm (Dubrulle et al., 1971).
*
* @param in The input matrix to compute the eigen decomposition on.
* @param threads The number of threads to use for computation.
* @return An array of MatrixBlock objects containing the real eigen values and eigen vectors.
*/
public static MatrixBlock[] computeEigenDecompositionSymm(MatrixBlock in, int threads) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would like to know where the time is spend inside this function call.

is it in transformToTridiagonal, getQ or findEigenVectors?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

findEigenVectors seems to take ~60-70% of the time.
That performEigenDecomposition is a wrapper for the do-while loop and most of its execution is taken by the matrix updates at the end of each iteration

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great this is something that is easy to fix, simply do not use the get and set method for the dense block, but directly on the underlying linearized array assign the cells.

This should immediately reduce the execution time by at least ~50% of that part.

if ( in.getNumRows() != in.getNumColumns() ) {
throw new DMLRuntimeException("Eigen Decomposition can only be done on a square and symmetric matrix. "
+ "Input matrix is rectangular (rows=" + in.getNumRows() + ", cols="+ in.getNumColumns() +")");

Check warning on line 231 in src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java#L230-L231

Added lines #L230 - L231 were not covered by tests
}

final double[] mainDiag = new double[in.rlen];
final double[] secDiag = new double[in.rlen - 1];

final MatrixBlock householderVectors = transformToTridiagonal(in, mainDiag, secDiag, threads);

// TODO: Consider using sparse blocks
final double[] hv = householderVectors.getDenseBlockValues();
MatrixBlock houseHolderMatrix = getQ(hv, mainDiag, secDiag);

MatrixBlock[] evResult = findEigenVectors(mainDiag, secDiag, houseHolderMatrix, EIGEN_MAX_ITER, threads);

MatrixBlock realEigenValues = evResult[0];
MatrixBlock eigenVectors = evResult[1];

eigenVectors = LibMatrixReorg.transposeInPlace(eigenVectors, threads);

return new MatrixBlock[] {realEigenValues, eigenVectors};
}

private static MatrixBlock transformToTridiagonal(MatrixBlock matrix, double[] main, double[] secondary,
int threads) {
final int m = matrix.rlen;

MatrixBlock householderVectors = matrix.extractTriangular(new MatrixBlock(m, m, false), false, true, true);
if(householderVectors.isInSparseFormat()) {
householderVectors.sparseToDense(threads);

Check warning on line 259 in src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java#L259

Added line #L259 was not covered by tests
}

final double[] hv = householderVectors.getDenseBlockValues();

for(int k = 0; k < m - 1; k++) {
final int rowKp1 = k * m + k + 1;
final int rowK = k * m;

// zero-out a row and a column simultaneously
main[k] = hv[rowK + k];

double xNormSqr = 0;
for(int j = k + 1; j < m; ++j) {
final double c = hv[k * m + j];
xNormSqr += c * c;
}

final double a = (hv[rowKp1] > 0) ? -FastMath.sqrt(xNormSqr) : FastMath.sqrt(xNormSqr);

secondary[k] = a;

if(a != 0.0) {
// apply Householder transform from left and right simultaneously
hv[rowKp1] -= a;

final double beta = -1 / (a * hv[rowKp1]);
double gamma = 0;

// compute a = beta A v, where v is the Householder vector
// this loop is written in such a way
// 1) only the upper triangular part of the matrix is accessed
// 2) access is cache-friendly for a matrix stored in rows
final double[] z = new double[m];
for(int i = k + 1; i < m; ++i) {
final double hKI = hv[rowK + i];
double zI = hv[i * m + i] * hKI;
for(int j = i + 1; j < m; ++j) {
final double hIJ = hv[i * m + j];
zI += hIJ * hv[k * m + j];
z[j] += hIJ * hKI;
}
z[i] = beta * (z[i] + zI);

gamma += z[i] * hv[rowK + i];
}

gamma *= beta / 2;

// compute z = z - gamma v
for(int i = k + 1; i < m; ++i) {
z[i] -= gamma * hv[rowK + i];
}

// update matrix: A = A - v zT - z vT
// only the upper triangular part of the matrix is updated
for(int i = k + 1; i < m; ++i) {
final double hki = hv[rowK + i];
for(int j = i; j < m; ++j) {
final double hkj = hv[rowK + j];
hv[i * m + j] -= (hki * z[j] + z[i] * hkj);
}
}
}
}

main[m - 1] = hv[(m - 1) * m + m - 1];

return householderVectors;
}

/**
* Computes the orthogonal matrix Q using Householder transforms. The matrix Q is built by applying Householder
* transforms to the input vectors.
*
* @param hv The input vector containing the Householder vectors.
* @param main The main diagonal of the matrix.
* @param secondary The secondary diagonal of the matrix.
* @return The orthogonal matrix Q.
*/
public static MatrixBlock getQ(final double[] hv, double[] main, double[] secondary) {
final int m = main.length;

DenseBlock qaB = DenseBlockFactory.createDenseBlock(m, m);

double[] qaV = qaB.valuesAt(0);

// build up first part of the matrix by applying Householder transforms
for(int k = m - 1; k >= 1; --k) {
final int rowK = k * m;
final int rowKm1 = (k - 1) * m;

qaV[rowK + k] = 1.0;
if(hv[rowKm1 + k] != 0.0) {
final double inv = 1.0 / (secondary[k - 1] * hv[rowKm1 + k]);

double beta = 1.0 / secondary[k - 1];

qaV[rowK + k] = 1 + beta * hv[rowKm1 + k];

for(int i = k + 1; i < m; ++i) {
qaV[rowK + i] = beta * hv[rowKm1 + i];
}

for(int j = k + 1; j < m; ++j) {

beta = 0;
for(int i = k + 1; i < m; ++i) {
beta += qaV[m * j + i] * hv[rowKm1 + i];
}
beta *= inv;

qaV[m * j + k] = hv[rowKm1 + k] * beta;

for(int i = k + 1; i < m; ++i) {
qaV[m * j + i] += beta * hv[rowKm1 + i];
}
}
}
}

qaV[0] = 1.0;
MatrixBlock res = new MatrixBlock(m, m, qaB);
// Arbitrarily set non zero count, will set actual count at the end of the algorithm
res.setNonZeros(m * m);

return res;
}

/**
* Finds the eigen vectors corresponding to the given eigen values using the Householder transformation.
*
* @param main The main diagonal of the tridiagonal matrix.
* @param secondary The secondary diagonal of the tridiagonal matrix.
* @param hhMatrix The Householder matrix (Q).
* @param maxIter The maximum number of iterations for convergence.
* @param threads The number of threads to use for computation.
* @return An array of two MatrixBlock objects: eigen values and eigen vectors.
* @throws MaxCountExceededException If the maximum number of iterations is exceeded and convergence fails.
*/
private static MatrixBlock[] findEigenVectors(double[] main, double[] secondary, MatrixBlock hhMatrix,
final int maxIter, int threads) {

double[] hhvalues = hhMatrix.getDenseBlock().valuesAt(0);

final int n = hhMatrix.rlen;

MatrixBlock eigenValues = new MatrixBlock(n, 1, main);

double[] ev = eigenValues.denseBlock.valuesAt(0);
final double[] e = new double[n];

System.arraycopy(secondary, 0, e, 0, n - 1);
e[n - 1] = 0;

// Determine the largest main and secondary value in absolute term.
double maxAbsoluteValue = 0;
for(int i = 0; i < n; i++) {
maxAbsoluteValue = FastMath.max(maxAbsoluteValue, FastMath.abs(ev[i]));
maxAbsoluteValue = FastMath.max(maxAbsoluteValue, FastMath.abs(e[i]));
}

// Make null any main and secondary value too small to be significant
if(maxAbsoluteValue != 0) {
for(int i = 0; i < n; i++) {
if(FastMath.abs(ev[i]) <= EIGEN_MACHINE_EPS * maxAbsoluteValue && ev[i] != 0.0) {
ev[i] = 0;

Check warning on line 425 in src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java#L425

Added line #L425 was not covered by tests
}
if(FastMath.abs(e[i]) <= EIGEN_MACHINE_EPS * maxAbsoluteValue && e[i] != 0.0) {
e[i] = 0;

Check warning on line 428 in src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java#L428

Added line #L428 was not covered by tests
}
}
}

for(int j = 0; j < n; j++) {
int its = 0;
int m = -1;
while(m != j) {
for(m = j; m < n - 1; m++) {
final double delta = FastMath.abs(ev[m]) + FastMath.abs(ev[m + 1]);
if(FastMath.abs(e[m]) + delta == delta) {
break;
}
}
if(m == j)
break;
if(its == maxIter)
throw new MaxCountExceededException(LocalizedFormats.CONVERGENCE_FAILED, maxIter);

Check warning on line 446 in src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java#L446

Added line #L446 was not covered by tests

its++;
formMatrixShift(e, ev, hhvalues, j, m);
}

}

// Sort eigen values and vectors in decreasing order
for(int i = 0; i < n; i++) {
int k = i;
double p = ev[i];
for(int j = i + 1; j < n; j++) {
if(ev[j] < p) {
k = j;
p = ev[j];
}
}
if(k != i) {
ev[k] = ev[i];
ev[i] = p;
for(int j = 0; j < n; j++) {
p = hhvalues[i * n + j];
hhvalues[i * n + j] = hhvalues[k * n + j];
hhvalues[k * n + j] = p;
}
}
}

// Determine the largest eigen value in absolute term.
maxAbsoluteValue = 0;
for(int i = 0; i < n; i++) {
maxAbsoluteValue = FastMath.max(maxAbsoluteValue, FastMath.abs(ev[i]));
}
// Make null any eigen value too small to be significant
int zeros = 0;
if(maxAbsoluteValue != 0.0) {
for(int i = 0; i < n; i++) {
if(FastMath.abs(ev[i]) < EIGEN_MACHINE_EPS * maxAbsoluteValue) {
ev[i] = 0;
zeros++;

Check warning on line 486 in src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java#L485-L486

Added lines #L485 - L486 were not covered by tests
}
}
}

eigenValues.setNonZeros(n - zeros);
hhMatrix.setNonZeros(hhMatrix.denseBlock.countNonZeros());

return new MatrixBlock[] {eigenValues, hhMatrix};
}

/**
* Performs a matrix shift operation on the given arrays. Implements 'imtqll' procedure (Dubrulle et al., 1971).
*
* @param e The array to store eigenvalues.
* @param ev The array (dense block) to store eigenvectors.
* @param hhValues The array of Householder vectors.
* @param j The starting index for the matrix shift operation.
* @param m The ending index for the matrix shift operation.
*/
private static void formMatrixShift(double[] e, double[] ev, double[] hhValues, int j, int m) {
final int n = e.length;

double q = (ev[j + 1] - ev[j]) / (2 * e[j]);
double t = FastMath.sqrt(1 + q * q);
if(q < 0.0) {
q = ev[m] - ev[j] + e[j] / (q - t);
}
else {
q = ev[m] - ev[j] + e[j] / (q + t);
}

double u = 0.0;
double s = 1.0;
double c = 1.0;
int i;
for(i = m - 1; i >= j; i--) {
double p = s * e[i];
double h = c * e[i];
if(FastMath.abs(p) >= FastMath.abs(q)) {
c = q / p;
t = FastMath.sqrt(c * c + 1.0);
e[i + 1] = p * t;
s = 1.0 / t;
c *= s;
}
else {
s = p / q;
t = FastMath.sqrt(s * s + 1.0);
e[i + 1] = q * t;
c = 1.0 / t;
s *= c;
}
if(e[i + 1] == 0.0) {
ev[i + 1] -= u;
e[m] = 0.0;
break;

Check warning on line 542 in src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java#L540-L542

Added lines #L540 - L542 were not covered by tests
}
q = ev[i + 1] - u;
t = (ev[i] - q) * s + 2.0 * c * h;
u = s * t;
ev[i + 1] = q + u;
q = c * t - h;

for(int ia = 0; ia < n; ++ia) {
p = hhValues[(i + 1) * n + ia];
hhValues[(i + 1) * n + ia] = s * hhValues[i * n + ia] + c * p;
hhValues[i * n + ia] = c * hhValues[i * n + ia] - s * p;
}
}

if(t == 0.0 && i >= j)
return;

Check warning on line 558 in src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java#L558

Added line #L558 was not covered by tests

ev[j] -= u;
e[j] = q;
e[m] = 0.0;
}

/**
* Function to perform QR decomposition on a given matrix.
Expand Down
Loading
Loading