Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Imports:
foreach,
doParallel,
optimr,
CVXR,
sgd,
nnls,
ggplot2,
Expand Down Expand Up @@ -86,8 +87,9 @@ Collate:
'OnlineSuperLearner.S3.R'
'OnlineSuperLearner.SampleIteratively.R'
'WCC.SGD.Simplex.R'
'WCC.NMBFGS.R'
'WeightedCombinationComputer.R'
'WCC.CVXR.R'
'WCC.NMBFGS.R'
'SMG.Mock.R'
'SummaryMeasureGenerator.R'
'zzz.R'
Expand Down
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ import(optimr)
import(parallel)
import(reshape2)
import(sgd)
importFrom(CVXR,Minimize)
importFrom(CVXR,Problem)
importFrom(CVXR,Variable)
importFrom(CVXR,sum_squares)
importFrom(R.methodsS3,throw)
importFrom(R.oo,equals)
importFrom(R.utils,Arguments)
Expand Down
3 changes: 2 additions & 1 deletion R/OnlineSuperLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#' @include WeightedCombinationComputer.R
#' @include DataCache.R
#' @include WCC.NMBFGS.R
#' @include WCC.CVXR.R
#' @include WCC.SGD.Simplex.R
#' @include CrossValidationRiskCalculator.R
#' @include InterventionParser.R
Expand Down Expand Up @@ -734,7 +735,7 @@ OnlineSuperLearner <- R6Class ("OnlineSuperLearner",
## Variables
## =========
## The R.cv score of the current fit
default_wcc = WCC.NMBFGS,
default_wcc = WCC.CVXR,

# The random_variables to use throughout the osl object
random_variables = NULL,
Expand Down
58 changes: 58 additions & 0 deletions R/WCC.CVXR.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#' WCC.CVXR
#'
#' @docType class
#' @importFrom CVXR Variable Problem Minimize sum_squares
#' @include WeightedCombinationComputer.R
#'
#' @section Methods:
#' \describe{
#' \item{\code{new()}}{
#' }
#' \item{\code{getWeigths()}}{
#' }
#' \item{\code{compute(Z, Y, libraryNames, weights.initial)}}{
#' }
#' }
WCC.CVXR <- R6Class("WCC.CVXR",
inherit = WeightedCombinationComputer,
private =
list(),
active =
list(),
public =
list(
initialize = function(weights.initial) {
super$initialize(weights.initial)
},

compute = function(Z, Y, libraryNames ) {
k = length(self$get_weights)
alpha <- Variable(k)
objective <- Minimize(sum_squares(Y - Z %*% alpha))

## Constraing according to our sigma
## \Sigma_k = \{ \alpha \in \RR_+^k : \sum_{k=1}^K \alpha = 1 \}
constraints <- list(alpha > 0, sum(alpha) == 1)

optimization_problem <- Problem(objective, constraints)
optimization_solution <- solve(optimization_problem)

if (optimization_solution$status != 'optimal') {
warning('The optimization problem was not solved optimally!')
}

## Retrieve the actual alphas from the solved system
weights <- optimization_solution$getValue(alpha)

## Online update of the weights. Note that this is different from
## Benkeser2017, but I believe this is better. (We also multiply the
## original weights).
## We might want to move this to the super class.
weights <- self$get_current_nobs / (self$get_current_nobs + 1) * self$get_weights +
1 / (self$get_current_nobs + 1) * weights %>% as.vector

self$set_weights(weights)
self$get_weights
}
)
)
33 changes: 27 additions & 6 deletions R/WeightedCombinationComputer.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ WeightedCombinationComputer <- R6Class("WeightedCombinationComputer",
list(
initialize = function(weights.initial) {
private$weights <- Arguments$getNumerics(weights.initial, c(0, 1))
private$nobs <- 0
sum_of_weights <- sum(private$weights)
if (sum_of_weights != 1) {
throw("The sum of the initial weights, ", sum_of_weights, ", does not equal 1")
Expand All @@ -46,24 +47,44 @@ WeightedCombinationComputer <- R6Class("WeightedCombinationComputer",
}

if (length(private$weights) != length(libraryNames)) {
throw('Not each estimator has exactly one weight: estimators: ', length(libraryNames) ,' weights: ',length(private$weights),'.')
throw(
'Not each estimator has exactly one weight: estimators: ', length(libraryNames) ,
' weights: ',length(private$weights),'.'
)
}

libraryNames <- Arguments$getCharacters(libraryNames)

# Call the subclass
self$compute(Z, Y, libraryNames, ...)
self$increment_nobs
return(self$get_weights)
},

set_weights = function(weights) {
private$weights <- weights
}
),
private =
list(
weights = NULL

),
active =
list(
get_weights = function() {
return(private$weights)
},

get_current_nobs = function() {
private$nobs
},

increment_nobs = function() {
# TODO: should this be +1? or + tau (the size of the testset?)
private$nobs <- private$nobs + 1
}
)

),
private =
list(
nobs = NULL,
weights = NULL
),
)
1 change: 1 addition & 0 deletions inst/bash/install-package-dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ packages <- c(
"magrittr",
"assertthat",
"optimr",
"CVXR",
"nloptr",
"purrr",
"doParallel",
Expand Down
3 changes: 2 additions & 1 deletion man/Data.Static.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 26 additions & 0 deletions man/WCC.CVXR.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions test.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ RUN apt-get update && apt-get -f install -t unstable --no-install-recommends -y
openssl \
libcurl4-openssl-dev \
curl \
libmpfr-dev \
libgmp-dev \
git \
libxml2-dev \
libssl-dev \
Expand Down
137 changes: 137 additions & 0 deletions tests/testthat/test-WCC.CVXR.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
## Test
##install.packages('CVXR')
##library('CVXR')

### Simple minimal constrained optimization function, for the CVXR package
### Variables minimized over
##k = 10
##n = 100

##true_beta <- rep(1/k, k)
##X <- matrix(rnorm(k * n), n, k)
##y <- X %*% true_beta + rnorm(n, 0, 0.1)

##beta <- Variable(k)
##objective <- Minimize(sum_squares(y - X %*% beta))
##constraints <- list(beta > 0, sum(beta) == 1)
##prob3.1 <- Problem(objective, constraints)
##CVXR_solution3.1 <- solve(prob3.1)
##CVXR_solution3.1$status
##value <- CVXR_solution3.1$getValue(beta)
##value
##sum(value)

context("WCC.CVXR")
described.class <- WCC.CVXR
set.seed(12345)
n <- 1e2
Z <- rnorm(n)

pred <- cbind(Z, Z^2, Z^3, Z^4, Z^5)
true_params <- c(.2, .8, 0, 0, 0)

Y <- pred %*% true_params + rnorm(n, 0, 0.1)
num_params <- length(true_params)
initial_weights <- rep(1/num_params, num_params)
libraryNames <- c('a', 'b', 'c', 'd', 'e')

context(" initialize")
#=====================================================================
test_that("it should initialize and set the correct default initial weights", {
subject <- described.class$new(weights.initial = initial_weights)
expect_equal(subject$get_weights, initial_weights)
})

context(" process")
#=====================================================================
test_that("it should compute the correct convex combination", {
subject <- described.class$new(weights.initial = initial_weights)
subject$process(pred, Y, libraryNames)
expect_length(subject$get_weights, length(true_params))

# Check that we approximate the true parameters
difference <- abs(subject$get_weights - true_params)
expect_true(all(difference < 1e-2))
})

test_that("it should result in a list of weights summing to 1", {
subject <- described.class$new(weights.initial = initial_weights)
subject$process(pred, Y, libraryNames)
expect_length(subject$get_weights, length(true_params))

res <- subject$get_weights
expect_equal(sum(res),1)
})

test_that("it should result in a list of weights all greater or equal to zero", {
subject <- described.class$new(weights.initial = initial_weights)
subject$process(pred, Y, libraryNames)
expect_length(subject$get_weights, length(true_params))

res <- subject$get_weights
expect_true(all(res >= 0 && res < 1))
})

test_that("it should compute the correct convex combination with random initial weights", {
set.seed(12345)
weights <- runif(num_params, 0, 1)
weights <- weights / sum(weights)
subject <- described.class$new(weights.initial = weights)
subject$process(pred, Y, libraryNames)

expect_length(subject$get_weights, length(true_params))

# Check that we approximate the true parameters
difference <- abs(subject$get_weights - true_params)
expect_true(all(difference < 1e-2))
})

test_that("it should give a warning whenever the optimization_solution is not optimal", {
subject <- described.class$new(weights.initial = initial_weights)
mock_solve <- function(...) {
list(
status = 'infeasible',
getValue = function(...) { true_params }
)
}

with_mock(solve = mock_solve,
expect_warning(
subject$process(pred, Y, libraryNames),
"The optimization problem was not solved optimally!",
fixed=TRUE
)
)

})

test_that("it should take into account the earlier predicted weights", {
set.seed(12345)
n <- 10
Z <- rnorm(n)
pred <- cbind(Z, Z^2, Z^3, Z^4, Z^5)
true_params <- c(.2, .8, 0, 0, 0)
Y <- pred %*% true_params
num_params <- length(true_params)
initial_weights <- rep(1/num_params, num_params)
libraryNames <- c('a', 'b', 'c', 'd', 'e')

subject <- described.class$new(weights.initial = initial_weights)
## Increase the number of observations
for(i in seq(10000)) {
subject$increment_nobs
}
fake_but_very_important_weights <- c(0,0,0.3,.3,.4)
subject$set_weights(fake_but_very_important_weights)

## This is a very simple test, and we should probably elaborate it. What
## we're currently doing is simulating that after 10000 iterations the alphas
## are fake_but_very_important_weights. Then we actually run the
## optimization, but in essence it should not matter at all (as the fake
## weights are 10000 times more important). This is a silly test, but does at
## least some checking.
subject$process(pred, Y, libraryNames)
difference <- abs(subject$get_weights - fake_but_very_important_weights)
expect_true(all(difference < 1e-4))
})