diff --git a/tests/testthat/test_inputchecks_newcustomer.R b/tests/testthat/test_inputchecks_newcustomer.R index d36159db..f3a733b4 100644 --- a/tests/testthat/test_inputchecks_newcustomer.R +++ b/tests/testthat/test_inputchecks_newcustomer.R @@ -183,16 +183,21 @@ test_that("newcustomer fits the type of fitted model", { test_that("predict(): Error if other parameters are passed (spending & transactions)", { - # transactions - expect_error(predict(p.cdnow, newdata=newcustomer(12), prediction.end=12), regexp = "No other parameters") - expect_error(predict(p.cdnow, newdata=newcustomer(12), continuous.discount.factor=0.1), regexp = "No other parameters") - expect_error(predict(p.cdnow, newdata=newcustomer(12), predict.spending=TRUE), regexp = "No other parameters") - - # spending - expect_error(predict(gg.cdnow, newdata=newcustomer.spending(), uncertainty="none"), regexp = "No other parameters") - expect_error(predict(gg.cdnow, newdata=newcustomer.spending(), num.boots=12), regexp = "No other parameters") - expect_error(predict(gg.cdnow, newdata=newcustomer.spending(), level=0.8), regexp = "No other parameters") - + for(m in list(p.cdnow, gg.cdnow)){ + if(is(m, "clv.pnbd")){ + nc <- newcustomer(12) + + expect_error(predict(m, newdata=nc, prediction.end=12), regexp = "No other parameters") + expect_error(predict(m, newdata=nc, continuous.discount.factor=0.1), regexp = "No other parameters") + expect_error(predict(m, newdata=nc, predict.spending=TRUE), regexp = "No other parameters") + }else{ + nc <- newcustomer.spending() + } + + expect_error(predict(m, newdata=nc, uncertainty="boots"), regexp = "No other parameters") + expect_error(predict(m, newdata=nc, num.boots=12), regexp = "No other parameters") + expect_error(predict(m, newdata=nc, level=0.8), regexp = "No other parameters") + } }) test_that("predict vs newcustomer: dyn/static cov data names are not the same as parameters", { diff --git a/tests/testthat/test_runability_bootstrapping.R b/tests/testthat/test_runability_bootstrapping.R index 481fa049..0644498c 100644 --- a/tests/testthat/test_runability_bootstrapping.R +++ b/tests/testthat/test_runability_bootstrapping.R @@ -262,14 +262,14 @@ for(clv.fitted in list( } -# predict(boots) works on all model specifications ----------------------------- -# This also includes testing clv.bootstrapped.apply because it is used under the hood +# predict(uncertainty=boots) works on all model specifications ----------------------------- +# This also includes testing `clv.bootstrapped.apply` because it is used under the hood # - fit with correlation # - constrained params # - regularization # - combinations -test_that("predict(boots) works on all model specifications", { +test_that("predict(uncertainty=boots) works on all model specifications", { fn.predict.boots <- function(clv.fitted){ expect_warning(predict(clv.fitted, uncertainty='boots', num.boots=2, predict.spending=TRUE, verbose=FALSE), regexp = 'recommended to run') } @@ -307,3 +307,53 @@ test_that("predict(boots) works on all model specifications", { }) + + +# predict(uncertainty=boots) works with various inputs ------------------------------------ + +test_that("predict(uncertainty=boots) works with predict.spending, newdata, prediction.end", { + + p.cdnow <- fit.cdnow(optimx.args = optimx.args.NM) + + fn.predict.boots <- function(predict.spending=TRUE, newdata=NULL, prediction.end=NULL){ + expect_warning(dt.pred <- predict( + p.cdnow, + verbose=FALSE, + uncertainty='boots', + num.boots=2, + newdata=newdata, + prediction.end=prediction.end, + predict.spending=predict.spending + ), regexp = "recommended to run") + return(dt.pred) + } + + # predict.spending + fn.predict.boots(predict.spending = TRUE) + fn.predict.boots(predict.spending = FALSE) + fn.predict.boots(predict.spending = gg) + fn.predict.boots(predict.spending = fit.cdnow(model = gg)) + + # newdata + clv.apparel.nocov <- fct.helper.create.clvdata.apparel.nocov() + dt.pred <- fn.predict.boots(newdata=clv.apparel.nocov) + # really did predict for the apparel dataset and not the cdnow + expect_true(dt.pred[, .N] == nobs(clv.apparel.nocov)) + + # prediction.end + clv.cdnow.noholdout <- fct.helper.create.clvdata.cdnow(estimation.split = NULL) + + # with holdout, no prediction.end is required + fn.predict.boots(prediction.end=NULL) + # with holdout, can also with prediction.end + fn.predict.boots(prediction.end=10) + + # without holdout, prediction.end is required + expect_error( + predict(p.cdnow, uncertainty='boots', newdata=clv.cdnow.noholdout), + regexp = "Cannot predict without prediction.end" + ) + # without holdout, works if prediction.end is given + fn.predict.boots(newdata=clv.cdnow.noholdout, prediction.end=10) + +})