Skip to content

Commit

Permalink
Fixed not-equal comparison for complex matrices.
Browse files Browse the repository at this point in the history
  • Loading branch information
mikiobraun committed Nov 19, 2020
1 parent 7637950 commit 18ae6e3
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 13 deletions.
6 changes: 3 additions & 3 deletions src/main/java/org/jblas/ComplexDoubleMatrix.java
Original file line number Diff line number Diff line change
Expand Up @@ -1659,7 +1659,7 @@ public void load(String filename) throws IOException {
gen_overloads('mul', 'rows', 'columns'),
gen_overloads('mmul', 'rows', 'other.columns'),
gen_compare('eq', 'eq'),
gen_compare('ne', 'eq'),
gen_compare('ne', 'ne'),
gen_logical('and', '&'),
gen_logical('or', '|'),
gen_logical('xor', '^'))
Expand Down Expand Up @@ -1903,7 +1903,7 @@ public ComplexDoubleMatrix nei(ComplexDoubleMatrix other, ComplexDoubleMatrix re
ComplexDouble c2 = new ComplexDouble(0.0);

for (int i = 0; i < length; i++)
result.put(i, get(i, c1).eq(other.get(i, c2)) ? 1.0 : 0.0);
result.put(i, get(i, c1).ne(other.get(i, c2)) ? 1.0 : 0.0);
return result;
}

Expand All @@ -1919,7 +1919,7 @@ public ComplexDoubleMatrix nei(ComplexDouble value, ComplexDoubleMatrix result)
ensureResultLength(null, result);
ComplexDouble c = new ComplexDouble(0.0);
for (int i = 0; i < length; i++)
result.put(i, get(i, c).eq(value) ? 1.0 : 0.0);
result.put(i, get(i, c).ne(value) ? 1.0 : 0.0);
return result;
}

Expand Down
6 changes: 3 additions & 3 deletions src/main/java/org/jblas/ComplexFloatMatrix.java
Original file line number Diff line number Diff line change
Expand Up @@ -1659,7 +1659,7 @@ public void load(String filename) throws IOException {
gen_overloads('mul', 'rows', 'columns'),
gen_overloads('mmul', 'rows', 'other.columns'),
gen_compare('eq', 'eq'),
gen_compare('ne', 'eq'),
gen_compare('ne', 'ne'),
gen_logical('and', '&'),
gen_logical('or', '|'),
gen_logical('xor', '^'))
Expand Down Expand Up @@ -1903,7 +1903,7 @@ public ComplexFloatMatrix nei(ComplexFloatMatrix other, ComplexFloatMatrix resul
ComplexFloat c2 = new ComplexFloat(0.0f);

for (int i = 0; i < length; i++)
result.put(i, get(i, c1).eq(other.get(i, c2)) ? 1.0f : 0.0f);
result.put(i, get(i, c1).ne(other.get(i, c2)) ? 1.0f : 0.0f);
return result;
}

Expand All @@ -1919,7 +1919,7 @@ public ComplexFloatMatrix nei(ComplexFloat value, ComplexFloatMatrix result) {
ensureResultLength(null, result);
ComplexFloat c = new ComplexFloat(0.0f);
for (int i = 0; i < length; i++)
result.put(i, get(i, c).eq(value) ? 1.0f : 0.0f);
result.put(i, get(i, c).ne(value) ? 1.0f : 0.0f);
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
* @author mikio
*/

public class ComplexDoubleMatrixTest {
public class TestComplexDoubleMatrix {

@Test
public void testConstruction() {
Expand All @@ -58,15 +58,25 @@ public void testConstruction() {
for (int i = 0; i < A.rows; i++)
for (int j = 0; j < A.columns; j++)
A.put(i, j, new ComplexDouble(i, j));
//System.out.printf("A = %s\n", A.toString());

//System.out.println(A.mmul(A));


DoubleMatrix R = new DoubleMatrix(3, 3, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0);
A = new ComplexDoubleMatrix(R, R.transpose());
//System.out.println(A);


assertEquals(A.real(), R);
assertEquals(A.imag(), R.transpose());
}

@Test
public void testComparison() {
ComplexDoubleMatrix A = new ComplexDoubleMatrix(2, 1, 1.0, 2.0, 3.0, 4.0);
ComplexDoubleMatrix B = new ComplexDoubleMatrix(2, 1, 1.0, 2.0, 2.0, 4.0);

assertEquals(A.eq(B), new ComplexDoubleMatrix(2, 1, 1.0, 0.0, 0.0, 0.0));
assertEquals(A.ne(B), new ComplexDoubleMatrix(2, 1, 0.0, 0.0, 1.0, 0.0));

assertEquals(A.eq(new ComplexDouble(1.0, 2.0)),
new ComplexDoubleMatrix(2, 1, 1.0, 0.0, 0.0, 0.0));
assertEquals(A.ne(new ComplexDouble(1.0, 2.0)),
new ComplexDoubleMatrix(2, 1, 0.0, 0.0, 1.0, 0.0));
}
}

0 comments on commit 18ae6e3

Please sign in to comment.