Skip to content

Commit 98b4671

Browse files
Add trout distance
- Add trout distance function - Add trout distance as option to weight schemes
1 parent 549d91d commit 98b4671

File tree

3 files changed

+62
-3
lines changed

3 files changed

+62
-3
lines changed

R/weights.R

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ dense_rbf_kernel_weights <- function(phi = "auto",
181181

182182
tryCatch(dist.method <- match.arg(dist.method),
183183
error = function(e){
184-
crv_error("Unsupported choice of ", sQuote("weight.dist;"),
184+
crv_error("Unsupported choice of ", sQuote("dist.method;"),
185185
" see the ", sQuote("method"), " argument of ",
186186
sQuote("stats::dist"), " for supported distances.")
187187
})
@@ -192,6 +192,12 @@ dense_rbf_kernel_weights <- function(phi = "auto",
192192
" argument of ", sQuote("stats::dist"), " for details.")
193193
}
194194

195+
if(dist.method == "trout"){
196+
dist_f <- trout_dist
197+
} else {
198+
dist_f <- function(X) dist(X, method = dist.method, p = p)
199+
}
200+
195201
function(X){
196202
user_phi <- (phi != "auto")
197203

@@ -200,7 +206,7 @@ dense_rbf_kernel_weights <- function(phi = "auto",
200206
## necessary...
201207
phi_range <- 10^(seq(-10, 10, length.out = 21))
202208
weight_vars <- vapply(phi_range,
203-
function(phi) var(exp((-1) * phi * (dist(X, method = dist.method, p = p)[TRUE])^2)),
209+
function(phi) var(exp((-1) * phi * (dist_f(X)[TRUE])^2)),
204210
numeric(1))
205211

206212
phi <- phi_range[which.max(weight_vars)]
@@ -214,7 +220,7 @@ dense_rbf_kernel_weights <- function(phi = "auto",
214220
crv_error(sQuote("phi"), " must be positive.")
215221
}
216222

217-
dist_mat <- as.matrix(dist(X, method = dist.method, p = p))
223+
dist_mat <- as.matrix(dist_f(X))
218224
dist_mat <- exp(-1 * phi * dist_mat^2)
219225

220226
check_weight_matrix(dist_mat)

src/utils.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,3 +299,26 @@ Eigen::MatrixXcd align_phase(const Eigen::MatrixXcd& U,
299299

300300
return V;
301301
}
302+
303+
// [[Rcpp::export(rng = false)]]
304+
Rcpp::NumericVector trout_dist(const Eigen::MatrixXcd& X){
305+
Eigen::Index n = X.rows();
306+
Rcpp::NumericVector distances(n * (n - 1) / 2);
307+
308+
Eigen::Index ix = 0;
309+
for(Eigen::Index i = 0; i < n; i++){
310+
for(Eigen::Index j = 0; j < i; j++){
311+
312+
distances(ix) = std::sqrt((X.row(i) - align_phase_v(X.row(j), X.row(i))).squaredNorm());
313+
ix++; // FIXME - Would be better to calculate this as a function of (i, j) explicitly
314+
}
315+
}
316+
317+
distances.attr("Size") = n;
318+
distances.attr("Diag") = false;
319+
distances.attr("Upper") = false;
320+
distances.attr("method") = "trout";
321+
distances.attr("call") = R_NilValue;
322+
distances.attr("class") = "dist";
323+
324+
return distances;
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
context("Test trout_distance")
2+
3+
test_that("Trout distance doesn't depend on signs", {
4+
trout_dist <- clustRviz:::trout_dist
5+
6+
X <- matrix(c(1, 1, -1, -1), ncol = 2)
7+
expect_equal(as.vector(trout_dist(X)), 0)
8+
9+
X <- matrix(c(1, -1, 1, -1), ncol = 2)
10+
expect_equal(as.vector(trout_dist(X)), 0)
11+
12+
X <- matrix(c(1 + 1i, -1 -1i, 1 + 1i, -1 -1i), ncol = 2)
13+
expect_equal(as.vector(trout_dist(X)), 0)
14+
})
15+
16+
test_that("Trout distance minimizes distance", {
17+
trout_dist <- clustRviz:::trout_dist
18+
set.seed(125)
19+
20+
X1 <- rnorm(25) + (0 + 1i) * rnorm(25)
21+
X2 <- rnorm(25) + (0 + 1i) * rnorm(25)
22+
X <- rbind(X1, X2)
23+
24+
theta_grid <- seq(0, 2 * pi, length.out = 501)
25+
d <- Vectorize(function(theta) sum(Mod(X1 - exp((0 + 1i) * theta) * X2)^2))
26+
min_d <- sqrt(min(d(theta_grid)))
27+
28+
td <- as.vector(trout_dist(X))
29+
expect_lte(td, min_d)
30+
})

0 commit comments

Comments
 (0)