diff --git a/R/model-xgboost.R b/R/model-xgboost.R index 583c4ec..38a8c7b 100644 --- a/R/model-xgboost.R +++ b/R/model-xgboost.R @@ -141,6 +141,12 @@ build_fit_formula_xgb <- function(parsedmodel) { } else if (objective %in% c("binary:logistic", "reg:logistic")) { assigned <- 1 f <- expr(1 - 1 / (1 + exp(!!f + binomial()$linkfun(!!base_score)))) + } else if (objective %in% c("count:poisson")) { + assigned <- 1 + f <- expr(exp(!!f)) + } else if (objective %in% c("reg:tweedie")) { + assigned <- 1 + f <- expr(0.5 * exp(!!f)) ## I'm not sure why one has to multiply by 0.5, but it works. } if (assigned == 0) { stop("Only objectives 'binary:logistic', 'reg:squarederror', 'reg:logistic', 'binary:logitraw' are supported yet.")