Skip to content

Commit

Permalink
Merge pull request #273 from tidymodels/use-cli
Browse files Browse the repository at this point in the history
use cli functions
  • Loading branch information
EmilHvitfeldt authored Oct 30, 2024
2 parents 8997aef + 7430b22 commit 8c88b3f
Show file tree
Hide file tree
Showing 24 changed files with 85 additions and 104 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Depends:
R (>= 3.6),
recipes (>= 1.1.0.9000)
Imports:
cli,
lifecycle,
dplyr,
generics (>= 0.1.0),
Expand Down Expand Up @@ -64,5 +65,5 @@ Config/testthat/edition: 3
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
SystemRequirements: "GNU make"
11 changes: 1 addition & 10 deletions R/aaa.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,7 @@ factor_to_text <- function(data, names) {

check_possible_tokenizers <- function(x, dict, call = caller_env(2)) {
if (!(x %in% dict)) {
possible_tokenizers <- glue::glue_collapse(
dict,
sep = ", ", last = ", or "
)
rlang::abort(
glue(
"token should be one of the supported: {possible_tokenizers}"
),
call = call
)
cli::cli_abort("Token should be one of {dict}.", call = call)
}
}

Expand Down
9 changes: 4 additions & 5 deletions R/lda.R
Original file line number Diff line number Diff line change
Expand Up @@ -267,12 +267,11 @@ check_lda_character <- function(dat) {
all_good <- character_ind | factor_ind

if (any(all_good)) {
rlang::abort(
glue(
cli::cli_abort(
c(
"All columns selected for this step should be tokenlists.",
"\n",
"See https://github.com/tidymodels/textrecipes#breaking-changes",
" for more information."
"i" = "See {.url https://github.com/tidymodels/textrecipes#breaking-changes}
for more information."
)
)
}
Expand Down
10 changes: 4 additions & 6 deletions R/lemma.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,10 @@ bake.step_lemma <- function(object, new_data, ...) {
variable <- new_data[[col_name]]

if (is.null(maybe_get_lemma(variable))) {
rlang::abort(
glue(
"`{col_name}` doesn't have a lemma attribute. ",
"Make sure the tokenization step includes lemmatization."
)
)
cli::cli_abort(c(
"{.code {col_name}} doesn't have a lemma attribute.",
"i" = "Make sure the tokenization step includes lemmatization."
))
} else {
lemma_variable <- tokenlist_lemma(variable)
}
Expand Down
9 changes: 4 additions & 5 deletions R/pos_filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,10 @@ bake.step_pos_filter <- function(object, new_data, ...) {
variable <- new_data[[col_name]]

if (is.null(maybe_get_pos(variable))) {
rlang::abort(
glue(
"`{col_name}` doesn't have a pos attribute. ",
"Make sure the tokenization step includes ",
"part of speech tagging."
cli::cli_abort(
c(
"{.arg {col_name}} doesn't have a pos attribute.",
"i" = "Make sure the tokenization step includes part of speech tagging."
)
)
}
Expand Down
4 changes: 2 additions & 2 deletions R/sequence_onehot.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ step_sequence_onehot <-
skip = FALSE,
id = rand_id("sequence_onehot")) {
if (length(padding) != 1 || !(padding %in% c("pre", "post"))) {
rlang::abort("`padding` should be one of: 'pre', 'post'")
cli::cli_abort("{.arg padding} should be one of: {.val pre}, {.val post}")
}

if (length(truncating) != 1 || !(truncating %in% c("pre", "post"))) {
rlang::abort("`truncating` should be one of: 'pre', 'post'")
cli::cli_abort("{.code truncating} should be {.val pre} or {.val post}.")
}

add_step(
Expand Down
12 changes: 5 additions & 7 deletions R/text_normalization.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,13 @@ bake.step_text_normalization <- function(object, new_data, ...) {
nfkd = stringi::stri_trans_nfkd,
nfkc = stringi::stri_trans_nfkc,
nfkc_casefold = stringi::stri_trans_nfkc_casefold,
rlang::abort(
glue(
"'normalization_form' must be one of",
"'nfc', 'nfd', 'nfkd', 'nfkc', or 'nfkc_casefold'",
"but was {object$normalization_form}."
)
cli::cli_abort(
"{.arg normalization_form} must be one of {.val nfc}, {.val nfd},
{.val nfkd}, {.val nfkc}, or {.val nfkc_casefold} but was
{.val {object$normalization_form}}."
)
)

for (col_name in col_names) {
new_data[[col_name]] <- normalization_fun(new_data[[col_name]])
new_data[[col_name]] <- factor(new_data[[col_name]])
Expand Down
10 changes: 5 additions & 5 deletions R/textfeature.R
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,14 @@ validate_string2num <- function(fun) {

out <- fun(string)
if (!(is.numeric(out) | is.logical(out))) {
rlang::abort(paste0(deparse(substitute(fun)), " must return a numeric."))
cli::cli_abort("Function {.fn {fun}} must return a numeric.")
}

if (length(string) != length(out)) {
rlang::abort(paste0(
deparse(substitute(fun)),
" must return the same length output as its input."
))
cli::cli_abort(
"{.fn {deparse(substitute(fun))}} must return the same length output as
its input."
)
}
}

Expand Down
4 changes: 2 additions & 2 deletions R/tfidf.R
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,10 @@ dtm_to_tfidf <- function(dtm, idf_weights, smooth_idf, norm, sublinear_tf) {
dtm@x <- 1 + log(dtm@x)
}
if (is.character(idf_weights)) {
rlang::warn(
cli::cli_warn(
c(
"Please retrain this recipe with version 0.5.1 or higher.",
"A data leakage bug has been fixed for `step_tfidf()`."
"i" = "A data leakage bug has been fixed for {.fn step_tfidf}."
)
)
idf_weights <- log(smooth_idf + nrow(dtm) / Matrix::colSums(dtm > 0))
Expand Down
14 changes: 6 additions & 8 deletions R/tokenfilter.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ step_tokenfilter <-
id = rand_id("tokenfilter")) {
if (percentage && (max_times > 1 | max_times < 0 |
min_times > 1 | min_times < 0)) {
rlang::abort(
"`max_times` and `min_times` should be in the interval [0, 1]."
cli::cli_abort(
"{.arg max_times} and {.arg min_times} should be in the interval [0, 1]."
)
}

add_step(
recipe,
step_tokenfilter_new(
Expand Down Expand Up @@ -258,11 +258,9 @@ tokenfilter_fun <- function(data, max_times, min_times, max_tokens,
names(sort(tf[ids], decreasing = TRUE))
} else {
if (max_tokens > sum(ids)) {
rlang::warn(
glue(
"max_tokens was set to '{max_tokens}', ",
"but only {sum(ids)} was available and selected."
)
cli::cli_warn(
"max_tokens was set to {.val {max_tokens}}, but only {sum(ids)} was
available and selected."
)
max_tokens <- sum(ids)
}
Expand Down
2 changes: 1 addition & 1 deletion R/tokenize.R
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ tokenizer_switch <- function(name, object, data, call = caller_env()) {
return(res)
}

rlang::abort("`engine` argument is not valid.", call = call)
cli::cli_abort("The {.arg engine} argument is not valid.", call = call)
}

#' @rdname required_pkgs.step
Expand Down
13 changes: 6 additions & 7 deletions R/tokenize_bpe.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@ prep.step_tokenize_bpe <- function(x, training, info = NULL, ...) {

bpe_options <- x$options
if (!is.null(bpe_options$vocab_size)) {
rlang::abort(
"Please supply the vocabulary size using the `vocabulary_size` argument."
cli::cli_abort(
"Please supply the vocabulary size using the {.arg vocabulary_size}
argument."
)
}
bpe_options$vocab_size <- x$vocabulary_size
Expand Down Expand Up @@ -158,11 +159,9 @@ check_bpe_vocab_size <- function(text,
text_count <- length(text_count)

if (vocabulary_size < text_count) {
rlang::abort(
glue(
"`vocabulary_size` of {vocabulary_size} is too small for column ",
"`{column}` which has a unique character count of {text_count}"
),
cli::cli_abort(
"{.arg vocabulary_size} of {vocabulary_size} is too small for column
{.arg {column}} which has a unique character count of {text_count}",
call = call
)
}
Expand Down
15 changes: 7 additions & 8 deletions R/tokenize_sentencepiece.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ prep.step_tokenize_sentencepiece <- function(x, training, info = NULL, ...) {

sentencepiece_options <- x$options
if (!is.null(sentencepiece_options$vocab_size)) {
rlang::abort(
"Please supply the vocabulary size using the `vocabulary_size` argument."
cli::cli_abort(
"Please supply the vocabulary size using the {.arg vocabulary_size}
argument."
)
}
sentencepiece_options$vocab_size <- x$vocabulary_size
Expand Down Expand Up @@ -160,12 +161,10 @@ check_sentencepiece_vocab_size <- function(text,
text_count <- length(text_count)

if (vocabulary_size < text_count) {
rlang::abort(
glue(
"`vocabulary_size` of {vocabulary_size} is too small for column ",
"`{column}` which has a unique character count of {text_count}."
),
call = call
cli::cli_abort(
"The {.arg vocabulary_size} of {vocabulary_size} is too small for column {.arg {column}}
which has a unique character count of {text_count}.",
call = call
)
}
}
Expand Down
22 changes: 11 additions & 11 deletions R/tokenlist.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ new_tokenlist <- function(tokens = list(), lemma = NULL, pos = NULL,
unique_tokens = character()) {
vec_assert(tokens, list())
if (!(is.null(lemma) | is.list(lemma))) {
rlang::abort("`lemma` must be NULL or a list.")
cli::cli_abort("{.arg lemma} must be NULL or a list.")
}
if (!(is.null(pos) | is.list(pos))) {
rlang::abort("`pos` must be NULL or a list.")
cli::cli_abort("{.arg pos} must be NULL or a list.")
}
vec_assert(unique_tokens, character())

Expand Down Expand Up @@ -141,7 +141,7 @@ obj_print_footer.textrecipes_tokenlist <- function(x, ...) {
# or removes (for keep = FALSE) the words
tokenlist_filter <- function(x, dict, keep = FALSE) {
if (!is_tokenlist(x)) {
rlang::abort("Input must be a tokenlist.")
cli::cli_abort("Input must be a tokenlist.")
}

if (!keep) {
Expand Down Expand Up @@ -180,7 +180,7 @@ tokenlist_filter <- function(x, dict, keep = FALSE) {

tokenlist_filter_function <- function(x, fn) {
if (!is_tokenlist(x)) {
rlang::abort("Input must be a tokenlist.")
cli::cli_abort("Input must be a {.cls tokenlist}.")
}

tokens <- get_tokens(x)
Expand Down Expand Up @@ -210,7 +210,7 @@ tokenlist_filter_function <- function(x, fn) {

tokenlist_apply <- function(x, fun, arguments = NULL) {
if (!is_tokenlist(x)) {
rlang::abort("Input must be a tokenlist.")
cli::cli_abort("Input must be {.cls tokenlist} object.")
}

tokens <- get_tokens(x)
Expand All @@ -226,7 +226,7 @@ tokenlist_apply <- function(x, fun, arguments = NULL) {
# Takes a [tokenlist] and calculate the token count matrix
tokenlist_to_dtm <- function(x, dict) {
if (!is_tokenlist(x)) {
rlang::abort("Input must be a tokenlist.")
cli::cli_abort("Input must be a tokenlist.")
}

tokens <- get_tokens(x)
Expand All @@ -246,23 +246,23 @@ tokenlist_to_dtm <- function(x, dict) {

tokenlist_lemma <- function(x) {
if (!is_tokenlist(x)) {
rlang::abort("Input must be a tokenlist.")
cli::cli_abort("Input must be a tokenlist.")
}

if (is.null(maybe_get_lemma(x))) {
rlang::abort("`lemma` attribute not avaliable.")
cli::cli_abort("The {.code lemma} attribute is not available.")
}

tokenlist(maybe_get_lemma(x), pos = maybe_get_pos(x))
}

tokenlist_pos_filter <- function(x, pos_tags) {
if (!is_tokenlist(x)) {
rlang::abort("Input must be a tokenlist.")
cli::cli_abort("Input must be a tokenlist.")
}

if (is.null(maybe_get_pos(x))) {
rlang::abort("pos attribute not avaliable.")
cli::cli_abort("{.arg pos} attribute not available.")
}

tokens <- get_tokens(x)
Expand Down Expand Up @@ -292,7 +292,7 @@ tokenlist_pos_filter <- function(x, pos_tags) {

tokenlist_ngram <- function(x, n, n_min, delim) {
if (!is_tokenlist(x)) {
rlang::abort("Input must be a tokenlist.")
cli::cli_abort("Input must be a tokenlist.")
}

tokenlist(cpp11_ngram(get_tokens(x), n, n_min, delim))
Expand Down
11 changes: 4 additions & 7 deletions R/word_embeddings.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,13 @@ step_word_embeddings <- function(recipe,
ncol(embeddings) == 1 ||
!all(map_lgl(embeddings[, 2:ncol(embeddings)], is.numeric))
) {
embeddings_message <- glue(
"embeddings should be a tibble with 1 character or factor column and ",
"additional numeric columns."
)
rlang::abort(
embeddings_message,
cli::cli_abort(
"embeddings should be a tibble with {.code 1} character or factor column
and additional numeric columns.",
class = "bad_embeddings"
)
}

aggregation <- match.arg(aggregation)

add_step(
Expand Down
2 changes: 1 addition & 1 deletion src/ngram.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,4 @@ cpp11_ngram(cpp11::list_of<cpp11::strings> x,
}

return(out);
}
}
3 changes: 2 additions & 1 deletion tests/testthat/_snaps/lemma.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
Condition
Error in `step_lemma()`:
Caused by error in `bake()`:
! `text` doesn't have a lemma attribute. Make sure the tokenization step includes lemmatization.
! `text` doesn't have a lemma attribute.
i Make sure the tokenization step includes lemmatization.

# bake method errors when needed non-standard role columns are missing

Expand Down
3 changes: 2 additions & 1 deletion tests/testthat/_snaps/pos_filter.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
Condition
Error in `step_pos_filter()`:
Caused by error in `bake()`:
! `text` doesn't have a pos attribute. Make sure the tokenization step includes part of speech tagging.
! `text` doesn't have a pos attribute.
i Make sure the tokenization step includes part of speech tagging.

# bake method errors when needed non-standard role columns are missing

Expand Down
Loading

0 comments on commit 8c88b3f

Please sign in to comment.