Skip to content

Commit 918b135

Browse files
Move proximal operators to templates
- Move proximal operators to template functions to reduce code duplication - Write Rcpp glue functions to dispatch on type of X - Update tests
1 parent 98b4671 commit 918b135

File tree

4 files changed

+132
-144
lines changed

4 files changed

+132
-144
lines changed

src/clustRviz.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,36 @@ Rcpp::List TroutClusteringCPP(const Eigen::MatrixXcd& X,
254254

255255
return solver.build_return_object();
256256
}
257+
258+
// [[Rcpp::export(rng = false)]]
259+
SEXP matrix_row_prox(SEXP Xsexp,
260+
double lambda,
261+
const Eigen::VectorXd& weights,
262+
bool l1 = true){
263+
264+
switch(TYPEOF(Xsexp)){
265+
case REALSXP: return Rcpp::wrap(MatrixRowProx<double>(Rcpp::as<Eigen::MatrixXd>(Xsexp), lambda, weights, l1));
266+
case CPLXSXP: return Rcpp::wrap(MatrixRowProx<std::complex<double> >(Rcpp::as<Eigen::MatrixXcd>(Xsexp), lambda, weights, l1));
267+
default: Rcpp::stop("Unsupported type of X.");
268+
}
269+
270+
// Should not trigger but appease compiler...
271+
return R_NilValue;
272+
};
273+
274+
// [[Rcpp::export(rng = false)]]
275+
SEXP matrix_col_prox(SEXP Xsexp,
276+
double lambda,
277+
const Eigen::VectorXd& weights,
278+
bool l1 = true){
279+
280+
switch(TYPEOF(Xsexp)){
281+
case REALSXP: return Rcpp::wrap(MatrixColProx<double>(Rcpp::as<Eigen::MatrixXd>(Xsexp), lambda, weights, l1));
282+
case CPLXSXP: return Rcpp::wrap(MatrixColProx<std::complex<double> >(Rcpp::as<Eigen::MatrixXcd>(Xsexp), lambda, weights, l1));
283+
default: Rcpp::stop("Unsupported type of X.");
284+
}
285+
286+
// Should not trigger but appease compiler...
287+
return R_NilValue;
288+
};
289+

src/clustRviz_base.h

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
#define CLUSTRVIZ_STATUS_WIDTH_CHECK 20 // Every 20 status updates * 0.1s => every 2s
1010
#define CLUSTRVIZ_DEFAULT_STOP_PRECISION 1e-10 //Stop when cellwise diff between iters < val
1111

12+
// Prototypes - utils.cpp
13+
double soft_thresh(double, double);
14+
std::complex<double> soft_thresh(const std::complex<double>, double);
15+
1216
// Helper to determine if STL set contains an element
1317
//
1418
// In general, this is not efficient because one wants to do something
@@ -24,20 +28,68 @@ double scaled_squared_norm(const Eigen::MatrixBase<DataType>& X){
2428
return X.squaredNorm() / X.size();
2529
}
2630

31+
template <typename DataType>
32+
Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic> MatrixRowProx(const Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic>& X,
33+
double lambda,
34+
const Eigen::VectorXd& weights,
35+
bool l1 = true){
36+
Eigen::Index n = X.rows();
37+
Eigen::Index p = X.cols();
38+
39+
Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic> V(n, p);
40+
41+
if(l1){
42+
for(Eigen::Index i = 0; i < n; i++){
43+
for(Eigen::Index j = 0; j < p; j++){
44+
V(i, j) = soft_thresh(X(i, j), lambda * weights(i));
45+
}
46+
}
47+
} else {
48+
for(Eigen::Index i = 0; i < n; i++){
49+
Eigen::Matrix<DataType, Eigen::Dynamic, 1> X_i = X.row(i);
50+
double scale_factor = 1 - lambda * weights(i) / X_i.norm();
51+
52+
if(scale_factor > 0){
53+
V.row(i) = X_i * scale_factor;
54+
} else {
55+
V.row(i).setZero();
56+
}
57+
}
58+
}
59+
60+
return V;
61+
}
62+
63+
template <typename DataType>
64+
Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic> MatrixColProx(const Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic>& X,
65+
double lambda,
66+
const Eigen::VectorXd& weights,
67+
bool l1 = true){
68+
Eigen::Index n = X.rows();
69+
Eigen::Index p = X.cols();
70+
71+
Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic> V(n, p);
72+
73+
if(l1){
74+
for(Eigen::Index i = 0; i < n; i++){
75+
for(Eigen::Index j = 0; j < p; j++){
76+
V(i, j) = soft_thresh(X(i, j), lambda * weights(j));
77+
}
78+
}
79+
} else {
80+
for(Eigen::Index j = 0; j < p; j++){
81+
Eigen::Matrix<DataType, Eigen::Dynamic, 1> X_j = X.col(j);
82+
double scale_factor = 1 - lambda * weights(j) / X_j.norm();
83+
84+
if(scale_factor > 0){
85+
V.col(j) = X_j * scale_factor;
86+
} else {
87+
V.col(j).setZero();
88+
}
89+
}
90+
}
91+
92+
return V;
93+
}
2794

28-
// Prototypes - utils.cpp
29-
Eigen::MatrixXd MatrixRowProx(const Eigen::MatrixXd&,
30-
double,
31-
const Eigen::VectorXd&,
32-
bool);
33-
34-
Eigen::MatrixXcd MatrixRowProx(const Eigen::MatrixXcd&,
35-
double,
36-
const Eigen::VectorXd&,
37-
bool);
38-
39-
Eigen::MatrixXd MatrixColProx(const Eigen::MatrixXd&,
40-
double,
41-
const Eigen::VectorXd&,
42-
bool);
4395
#endif

src/utils.cpp

Lines changed: 5 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -33,106 +33,6 @@ std::complex<double> soft_thresh(const std::complex<double> x, double lambda){
3333
}
3434
}
3535

36-
// Apply a row-wise prox operator (with weights) to a matrix
37-
// [[Rcpp::export(rng = false)]]
38-
Eigen::MatrixXd MatrixRowProx(const Eigen::MatrixXd& X,
39-
double lambda,
40-
const Eigen::VectorXd& weights,
41-
bool l1 = true){
42-
Eigen::Index n = X.rows();
43-
Eigen::Index p = X.cols();
44-
45-
Eigen::MatrixXd V(n, p);
46-
47-
if(l1){
48-
for(Eigen::Index i = 0; i < n; i++){
49-
for(Eigen::Index j = 0; j < p; j++){
50-
V(i, j) = soft_thresh(X(i, j), lambda * weights(i));
51-
}
52-
}
53-
} else {
54-
for(Eigen::Index i = 0; i < n; i++){
55-
Eigen::VectorXd X_i = X.row(i);
56-
double scale_factor = 1 - lambda * weights(i) / X_i.norm();
57-
58-
if(scale_factor > 0){
59-
V.row(i) = X_i * scale_factor;
60-
} else {
61-
V.row(i).setZero();
62-
}
63-
}
64-
}
65-
66-
return V;
67-
}
68-
69-
// This is the same as the real case - is there a way to do this with overloading / templates
70-
// while also allowing Rcpp::export?
71-
Eigen::MatrixXcd MatrixRowProx(const Eigen::MatrixXcd& X,
72-
double lambda,
73-
const Eigen::VectorXd& weights,
74-
bool l1 = true){
75-
76-
Eigen::Index n = X.rows();
77-
Eigen::Index p = X.cols();
78-
79-
Eigen::MatrixXcd V(n, p);
80-
81-
if(l1){
82-
for(Eigen::Index i = 0; i < n; i++){
83-
for(Eigen::Index j = 0; j < p; j++){
84-
V(i, j) = soft_thresh(X(i, j), lambda * weights(i));
85-
}
86-
}
87-
} else {
88-
for(Eigen::Index i = 0; i < n; i++){
89-
Eigen::VectorXcd X_i = X.row(i);
90-
double scale_factor = 1 - lambda * weights(i) / X_i.norm();
91-
92-
if(scale_factor > 0){
93-
V.row(i) = X_i * scale_factor;
94-
} else {
95-
V.row(i).setZero();
96-
}
97-
}
98-
}
99-
100-
return V;
101-
}
102-
103-
// Apply a col-wise prox operator (with weights) to a matrix
104-
// [[Rcpp::export(rng = false)]]
105-
Eigen::MatrixXd MatrixColProx(const Eigen::MatrixXd& X,
106-
double lambda,
107-
const Eigen::VectorXd& weights,
108-
bool l1 = true){
109-
Eigen::Index n = X.rows();
110-
Eigen::Index p = X.cols();
111-
112-
Eigen::MatrixXd V(n, p);
113-
114-
if(l1){
115-
for(Eigen::Index i = 0; i < n; i++){
116-
for(Eigen::Index j = 0; j < p; j++){
117-
V(i, j) = soft_thresh(X(i, j), lambda * weights(j));
118-
}
119-
}
120-
} else {
121-
for(Eigen::Index j = 0; j < p; j++){
122-
Eigen::VectorXd X_j = X.col(j);
123-
double scale_factor = 1 - lambda * weights(j) / X_j.norm();
124-
125-
if(scale_factor > 0){
126-
V.col(j) = X_j * scale_factor;
127-
} else {
128-
V.col(j).setZero();
129-
}
130-
}
131-
}
132-
133-
return V;
134-
}
135-
13636
// Some basic cheap checks that a weight
13737
// matrix can lead to a connected graph
13838
//
@@ -308,8 +208,10 @@ Rcpp::NumericVector trout_dist(const Eigen::MatrixXcd& X){
308208
Eigen::Index ix = 0;
309209
for(Eigen::Index i = 0; i < n; i++){
310210
for(Eigen::Index j = 0; j < i; j++){
311-
312-
distances(ix) = std::sqrt((X.row(i) - align_phase_v(X.row(j), X.row(i))).squaredNorm());
211+
// IMPORTANT: We have to transpose the result of align_phase_v to a _row vector_
212+
// or else it doesn't align with X.row() and silently discards all but the first
213+
// element
214+
distances(ix) = std::sqrt((X.row(i) - align_phase_v(X.row(j), X.row(i)).transpose()).squaredNorm());
313215
ix++; // FIXME - Would be better to calculate this as a function of (i, j) explicitly
314216
}
315217
}
@@ -322,3 +224,4 @@ Rcpp::NumericVector trout_dist(const Eigen::MatrixXcd& X){
322224
distances.attr("class") = "dist";
323225

324226
return distances;
227+
}

tests/testthat/test_matrix_prox.R

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,66 +7,66 @@ test_that("L1 matrix prox works", {
77

88
X <- matrix(rnorm(n * p), nrow = n, ncol = p)
99

10-
MatrixRowProx <- clustRviz:::MatrixRowProx
10+
matrix_row_prox <- clustRviz:::matrix_row_prox
1111
weights <- rep(1, n)
1212

13-
expect_equal(X, MatrixRowProx(X, lambda = 0, weights = weights, l1 = TRUE))
14-
expect_equal(abs(X) + 4, MatrixRowProx(abs(X) + 5, lambda = 1, weights = weights, l1 = TRUE))
15-
expect_equal(-abs(X) - 4, MatrixRowProx(-abs(X) - 5, lambda = 1, weights = weights, l1 = TRUE))
13+
expect_equal(X, matrix_row_prox(X, lambda = 0, weights = weights, l1 = TRUE))
14+
expect_equal(abs(X) + 4, matrix_row_prox(abs(X) + 5, lambda = 1, weights = weights, l1 = TRUE))
15+
expect_equal(-abs(X) - 4, matrix_row_prox(-abs(X) - 5, lambda = 1, weights = weights, l1 = TRUE))
1616

1717
## Now we check that weights work
1818
X <- matrix(1:25, nrow = 25, ncol = 1)
1919
weights <- 1:25
2020
expect_equal(matrix(0, nrow = 25, ncol = 1),
21-
MatrixRowProx(X, lambda = 1, weights = weights, l1 = TRUE))
21+
matrix_row_prox(X, lambda = 1, weights = weights, l1 = TRUE))
2222

2323
X <- matrix(5, nrow = 6, ncol = 1)
2424
weights <- seq(0, 5)
2525
expect_equal(matrix(5 - weights, nrow = 6, ncol = 1),
26-
MatrixRowProx(X, lambda = 1, weights = weights, l1 = TRUE))
27-
26+
matrix_row_prox(X, lambda = 1, weights = weights, l1 = TRUE))
27+
2828
#Now check matrix_col_prox against row prox
29-
MatrixColProx <- clustRviz:::MatrixColProx
30-
expect_equal(t(MatrixColProx(t(X), lambda = 1, weights = weights, l1 = TRUE)),
31-
MatrixRowProx(X, lambda = 1, weights = weights, l1 = TRUE))
29+
matrix_col_prox <- clustRviz:::matrix_col_prox
30+
expect_equal(t(matrix_col_prox(t(X), lambda = 1, weights = weights, l1 = TRUE)),
31+
matrix_row_prox(X, lambda = 1, weights = weights, l1 = TRUE))
3232
})
3333

3434
test_that("L2 prox works", {
3535
set.seed(125)
36-
MatrixRowProx <- clustRviz:::MatrixRowProx
36+
matrix_row_prox <- clustRviz:::matrix_row_prox
3737
num_unique_cols <- clustRviz:::num_unique_cols
3838
n <- 25
3939

4040
## If X has a single column, same as L1 prox
4141
X <- matrix(rnorm(n, sd = 3), ncol = 1)
4242
weights <- rexp(n)
4343

44-
expect_equal(MatrixRowProx(X, lambda = 1, weights = weights, l1 = TRUE),
45-
MatrixRowProx(X, lambda = 1, weights = weights, l1 = FALSE))
44+
expect_equal(matrix_row_prox(X, lambda = 1, weights = weights, l1 = TRUE),
45+
matrix_row_prox(X, lambda = 1, weights = weights, l1 = FALSE))
4646

4747
p <- 5
4848
X <- matrix(1, nrow = n, ncol = p)
4949
weights <- seq(0, 5, length.out = 25)
5050

51-
expect_equal(1, num_unique_cols(MatrixRowProx(X, lambda = 1, weights = weights, l1 = FALSE)))
51+
expect_equal(1, num_unique_cols(matrix_row_prox(X, lambda = 1, weights = weights, l1 = FALSE)))
5252

5353
y <- matrix(c(3, 4), nrow = 1)
5454

55-
expect_equal(MatrixRowProx(y, 1, 1, l1 = FALSE), y * (1 - 1/5))
56-
expect_equal(MatrixRowProx(y, 1, 3, l1 = FALSE), y * (1 - 3/5))
57-
expect_equal(MatrixRowProx(y, 2, 1, l1 = FALSE), y * (1 - 2/5))
58-
expect_equal(MatrixRowProx(y, 2, 3, l1 = FALSE), y * 0)
55+
expect_equal(matrix_row_prox(y, 1, 1, l1 = FALSE), y * (1 - 1/5))
56+
expect_equal(matrix_row_prox(y, 1, 3, l1 = FALSE), y * (1 - 3/5))
57+
expect_equal(matrix_row_prox(y, 2, 1, l1 = FALSE), y * (1 - 2/5))
58+
expect_equal(matrix_row_prox(y, 2, 3, l1 = FALSE), y * 0)
5959

6060
y <- -1 * y
61-
expect_equal(MatrixRowProx(y, 1, 1, l1 = FALSE), y * (1 - 1/5))
62-
expect_equal(MatrixRowProx(y, 1, 3, l1 = FALSE), y * (1 - 3/5))
63-
expect_equal(MatrixRowProx(y, 2, 1, l1 = FALSE), y * (1 - 2/5))
64-
expect_equal(MatrixRowProx(y, 2, 3, l1 = FALSE), y * 0)
61+
expect_equal(matrix_row_prox(y, 1, 1, l1 = FALSE), y * (1 - 1/5))
62+
expect_equal(matrix_row_prox(y, 1, 3, l1 = FALSE), y * (1 - 3/5))
63+
expect_equal(matrix_row_prox(y, 2, 1, l1 = FALSE), y * (1 - 2/5))
64+
expect_equal(matrix_row_prox(y, 2, 3, l1 = FALSE), y * 0)
6565

6666
#Now check matrix_col_prox against row prox
67-
MatrixColProx <- clustRviz:::MatrixColProx
68-
expect_equal(t(MatrixColProx(t(X), lambda = 1, weights = weights, l1 = FALSE)),
69-
MatrixRowProx(X, lambda = 1, weights = weights, l1 = FALSE))
70-
expect_equal(t(MatrixColProx(t(y), lambda = 1, weights = weights, l1 = TRUE)),
71-
MatrixRowProx(y, lambda = 1, weights = weights, l1 = TRUE))
67+
matrix_col_prox <- clustRviz:::matrix_col_prox
68+
expect_equal(t(matrix_col_prox(t(X), lambda = 1, weights = weights, l1 = FALSE)),
69+
matrix_row_prox(X, lambda = 1, weights = weights, l1 = FALSE))
70+
expect_equal(t(matrix_col_prox(t(y), lambda = 1, weights = weights, l1 = TRUE)),
71+
matrix_row_prox(y, lambda = 1, weights = weights, l1 = TRUE))
7272
})

0 commit comments

Comments
 (0)