Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,11 @@ jobs:
- {os: macos-latest, r: 'release'}

- {os: windows-latest, r: 'release'}
# use 4.0 or 4.1 to check with rtools40's older compiler
- {os: windows-latest, r: 'oldrel-4'}

- {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'}
- {os: ubuntu-latest, r: 'release'}
- {os: ubuntu-latest, r: 'oldrel-1'}
- {os: ubuntu-latest, r: 'oldrel-2'}
- {os: ubuntu-latest, r: 'oldrel-3'}
- {os: ubuntu-latest, r: 'oldrel-4'}
Comment on lines -29 to -37
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if these are related to the GHAs failing on older versions, I would use this trick instead: https://github.com/tidymodels/parsnip/pull/1317/files#diff-9c940e8ad2b7bc4c26ec3da57b94bc00e73e2166cfed689da51a4c59bcc0a310L59-L61


env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
Expand Down
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ Config/usethis/last-upkeep: 2025-04-24
Encoding: UTF-8
Language: en-US
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.2
RoxygenNote: 7.3.3
3 changes: 2 additions & 1 deletion R/rule_fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ xrf_fit <-
counts = TRUE,
event_level = c("first", "second"),
lambda = 0.1,
objective = NULL,
...
) {
converted <-
Expand All @@ -40,7 +41,7 @@ xrf_fit <-
subsample = subsample,
validation = validation,
early_stop = early_stop,
objective = NULL,
objective = objective,
counts = counts,
event_level = event_level
)
Expand Down
1 change: 1 addition & 0 deletions inst/WORDLIST
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ PBC
PSOCK
Popescu
Quinlan
ROR
RStudio
RuleFit
doi
Expand Down
1 change: 1 addition & 0 deletions man/rules-internal.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/rules-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

53 changes: 28 additions & 25 deletions man/tidy.cubist.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 2 additions & 5 deletions tests/testthat/_snaps/rule-fit-regression.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
# early stopping works in xrf_fit

Code
suppressMessages(rf_fit_3 <- fit(rf_mod_3, mpg ~ ., data = mtcars))
Condition
Warning:
`early_stop` was reduced to 4.
suppressMessages(rf_fit_3 <- fit(rf_mod_3, outcome ~ ., data = reg_data))

# xrf_fit guards xgb_control

Expand All @@ -16,7 +13,7 @@
Output
parsnip model object

An eXtreme RuleFit model of 7 rules.
An eXtreme RuleFit model of 17 rules.

Original Formula:

Expand Down
17 changes: 17 additions & 0 deletions tests/testthat/helpers.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
penalties <- 10^(-5:-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears not to be used anywhere


did_stop_early <- function(x) {
if (inherits(x, "model_fit")) {
x <- x$fit$xgb
} else if (inherits(x, "model_fit")) {
x <- x$xgb
}
attr <- attributes(x)
if (any(names(attr) == "early_stop")) {
res <- attr$early_stop$stopped_by_max_rounds
} else {
res <- FALSE
}
res
}

make_chi_data <- function() {
Chicago <- modeldata::Chicago

Expand Down
52 changes: 24 additions & 28 deletions tests/testthat/test-rule-fit-binomial.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ test_that("formula method", {
type = "response"
)[, 1]

expect_no_error(
expect_no_error({
set.seed(4526)
rf_mod <-
rule_fit(trees = 3, min_n = 3, penalty = 1) |>
set_engine("xrf") |>
set_mode("classification")
)
})

set.seed(4526)
expect_no_error(
Expand All @@ -41,11 +42,6 @@ test_that("formula method", {
rf_pred <- predict(rf_fit, ad_data$ad_pred)
rf_prob <- predict(rf_fit, ad_data$ad_pred, type = "prob")

expect_equal(
unname(rf_fit_exp$xgb$evaluation_log),
unname(rf_fit$fit$xgb$evaluation_log)
)

expect_equal(names(rf_pred), ".pred_class")
expect_true(tibble::is_tibble(rf_pred))
expect_equal(rf_pred$.pred_class, unname(rf_pred_exp))
Expand All @@ -68,9 +64,9 @@ test_that("formula method", {

rf_m_pred <-
rf_m_pred |>
mutate(.row_number = 1:nrow(rf_m_pred)) |>
dplyr::mutate(.row_number = 1:nrow(rf_m_pred)) |>
tidyr::unnest(cols = c(.pred)) |>
arrange(penalty, .row_number)
dplyr::arrange(penalty, .row_number)

for (i in ad_data$vals) {
exp_pred <- predict(rf_fit_exp, ad_data$ad_pred, lambda = i)[, 1]
Expand All @@ -79,15 +75,17 @@ test_that("formula method", {
levels = ad_data$lvls
)
exp_pred <- unname(exp_pred)
obs_pred <- rf_m_pred |> dplyr::filter(penalty == i) |> pull(.pred_class)
obs_pred <- rf_m_pred |>
dplyr::filter(penalty == i) |>
dplyr::pull(.pred_class)
expect_equal(unname(exp_pred), obs_pred)
}

rf_m_prob <-
rf_m_prob |>
mutate(.row_number = 1:nrow(rf_m_prob)) |>
dplyr::mutate(.row_number = 1:nrow(rf_m_prob)) |>
tidyr::unnest(cols = c(.pred)) |>
arrange(penalty, .row_number)
dplyr::arrange(penalty, .row_number)

for (i in ad_data$vals) {
exp_pred <- predict(
Expand All @@ -98,8 +96,8 @@ test_that("formula method", {
)[, 1]
obs_pred <- rf_m_prob |>
dplyr::filter(penalty == i) |>
pull(.pred_Control)
expect_equal(unname(exp_pred), obs_pred)
dplyr::pull(.pred_Control)
expect_equal(unname(exp_pred), obs_pred, tolerance = 0.1)
}
})

Expand Down Expand Up @@ -134,12 +132,13 @@ test_that("non-formula method", {
type = "response"
)[, 1]

expect_no_error(
expect_no_error({
set.seed(4526)
rf_mod <-
rule_fit(trees = 3, min_n = 3, penalty = 1) |>
set_engine("xrf") |>
set_mode("classification")
)
})

expect_no_error(
rf_fit <- fit_xy(
Expand All @@ -151,11 +150,6 @@ test_that("non-formula method", {
rf_pred <- predict(rf_fit, ad_data$ad_pred)
rf_prob <- predict(rf_fit, ad_data$ad_pred, type = "prob")

expect_equal(
unname(rf_fit_exp$xgb$evaluation_log),
unname(rf_fit$fit$xgb$evaluation_log)
)

expect_equal(names(rf_pred), ".pred_class")
expect_true(tibble::is_tibble(rf_pred))
expect_equal(rf_pred$.pred_class, unname(rf_pred_exp))
Expand All @@ -178,9 +172,9 @@ test_that("non-formula method", {

rf_m_pred <-
rf_m_pred |>
mutate(.row_number = 1:nrow(rf_m_pred)) |>
dplyr::mutate(.row_number = 1:nrow(rf_m_pred)) |>
tidyr::unnest(cols = c(.pred)) |>
arrange(penalty, .row_number)
dplyr::arrange(penalty, .row_number)

for (i in ad_data$vals) {
exp_pred <- predict(rf_fit_exp, ad_data$ad_pred, lambda = i)[, 1]
Expand All @@ -189,15 +183,17 @@ test_that("non-formula method", {
levels = ad_data$lvls
)
exp_pred <- unname(exp_pred)
obs_pred <- rf_m_pred |> dplyr::filter(penalty == i) |> pull(.pred_class)
obs_pred <- rf_m_pred |>
dplyr::filter(penalty == i) |>
dplyr::pull(.pred_class)
expect_equal(unname(exp_pred), obs_pred)
}

rf_m_prob <-
rf_m_prob |>
mutate(.row_number = 1:nrow(rf_m_prob)) |>
dplyr::mutate(.row_number = 1:nrow(rf_m_prob)) |>
tidyr::unnest(cols = c(.pred)) |>
arrange(penalty, .row_number)
dplyr::arrange(penalty, .row_number)

for (i in ad_data$vals) {
exp_pred <- predict(
Expand All @@ -208,8 +204,8 @@ test_that("non-formula method", {
)[, 1]
obs_pred <- rf_m_prob |>
dplyr::filter(penalty == i) |>
pull(.pred_Control)
expect_equal(unname(exp_pred), obs_pred)
dplyr::pull(.pred_Control)
expect_equal(unname(exp_pred), obs_pred, tolerance = 0.1)
}
})

Expand Down
Loading