diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index a1167433c93..5a596dffe3c 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -38,6 +38,7 @@ supported_dplyr_methods <- list( select = NULL, filter = NULL, + filter_out = NULL, collect = NULL, summarise = c( "window functions not currently supported;", diff --git a/r/R/dplyr-filter.R b/r/R/dplyr-filter.R index 18f5c929aff..0ccb5fb8944 100644 --- a/r/R/dplyr-filter.R +++ b/r/R/dplyr-filter.R @@ -17,49 +17,108 @@ # The following S3 methods are registered on load if dplyr is present -filter.arrow_dplyr_query <- function(.data, ..., .by = NULL, .preserve = FALSE) { - try_arrow_dplyr({ - # TODO something with the .preserve argument - out <- as_adq(.data) +apply_filter_impl <- function( + .data, + ..., + .by = NULL, + .preserve = FALSE, + exclude = FALSE, + verb = c("filter", "filter_out") +) { + verb <- match.arg(verb) - by <- compute_by({{ .by }}, out, by_arg = ".by", data_arg = ".data") + # TODO something with the .preserve argument + out <- as_adq(.data) - if (by$from_by) { - out$group_by_vars <- by$names - } + by <- compute_by({{ .by }}, out, by_arg = ".by", data_arg = ".data") - expanded_filters <- expand_across(out, quos(...)) - if (length(expanded_filters) == 0) { - # Nothing to do - return(as_adq(.data)) - } + if (by$from_by) { + out$group_by_vars <- by$names + } + + expanded_filters <- expand_across(out, quos(...)) + if (length(expanded_filters) == 0) { + # Nothing to do + return(as_adq(.data)) + } + + # tidy-eval the filter expressions inside an Arrow data_mask + mask <- arrow_mask(out) + + if (isTRUE(exclude)) { + # filter_out(): combine all predicates with &, then exclude + combined <- NULL - # tidy-eval the filter expressions inside an Arrow data_mask - mask <- arrow_mask(out) for (expr in expanded_filters) { filt <- arrow_eval(expr, mask) + if (length(mask$.aggregations)) { - # dplyr lets you filter on e.g. x < mean(x), but we haven't implemented it. - # But we could, the same way it works in mutate() via join, if someone asks. - # Until then, just error. arrow_not_supported( - .actual_msg = "Expression not supported in filter() in Arrow", + .actual_msg = sprintf("Expression not supported in %s() in Arrow", verb), call = expr ) } - out <- set_filters(out, filt) + + if (is_list_of(filt, "Expression")) { + filt <- Reduce("&", filt) + } + + combined <- if (is.null(combined)) filt else (combined & filt) } - if (by$from_by) { - out$group_by_vars <- character() + out <- set_filters(out, combined, exclude = TRUE) + } else { + # filter(): apply each predicate sequentially + for (expr in expanded_filters) { + filt <- arrow_eval(expr, mask) + + if (length(mask$.aggregations)) { + arrow_not_supported( + .actual_msg = sprintf("Expression not supported in %s() in Arrow", verb), + call = expr + ) + } + + out <- set_filters(out, filt, exclude = FALSE) } + } + + if (by$from_by) { + out$group_by_vars <- character() + } + + out +} - out +filter.arrow_dplyr_query <- function(.data, ..., .by = NULL, .preserve = FALSE) { + try_arrow_dplyr({ + apply_filter_impl( + .data, + ..., + .by = {{ .by }}, + .preserve = .preserve, + exclude = FALSE, + verb = "filter" + ) }) } filter.Dataset <- filter.ArrowTabular <- filter.RecordBatchReader <- filter.arrow_dplyr_query -set_filters <- function(.data, expressions) { +filter_out.arrow_dplyr_query <- function(.data, ..., .by = NULL, .preserve = FALSE) { + try_arrow_dplyr({ + apply_filter_impl( + .data, + ..., + .by = {{ .by }}, + .preserve = .preserve, + exclude = TRUE, + verb = "filter_out" + ) + }) +} +filter_out.Dataset <- filter_out.ArrowTabular <- filter_out.RecordBatchReader <- filter_out.arrow_dplyr_query + +set_filters <- function(.data, expressions, exclude = FALSE) { if (length(expressions)) { if (is_list_of(expressions, "Expression")) { # expressions is a list of Expressions. AND them together and set them on .data @@ -70,6 +129,12 @@ set_filters <- function(.data, expressions) { stop("filter expressions must be either an expression or a list of expressions", call. = FALSE) } + if (isTRUE(exclude)) { + # dplyr::filter_out() semantics: drop rows where predicate is TRUE; + # keep rows where predicate is FALSE or NA. + new_filter <- (!new_filter) | is.na(new_filter) + } + if (isTRUE(.data$filtered_rows)) { # TRUE is default (i.e. no filter yet), so we don't need to & with it .data$filtered_rows <- new_filter diff --git a/r/tests/testthat/test-dplyr-filter.R b/r/tests/testthat/test-dplyr-filter.R index d56e25fca32..9bf81b9a4f0 100644 --- a/r/tests/testthat/test-dplyr-filter.R +++ b/r/tests/testthat/test-dplyr-filter.R @@ -498,3 +498,44 @@ test_that("filter() with aggregation expressions errors", { "not supported in filter" ) }) + +test_that("filter_out() basic", { + compare_dplyr_binding( + .input |> + filter_out(chr == "b") |> + select(chr, int, lgl) |> + collect(), + tbl + ) +}) + +test_that("filter_out() keeps NA values in predicate result", { + compare_dplyr_binding( + .input |> + filter_out(lgl) |> + select(chr, int, lgl) |> + collect(), + tbl + ) +}) + +test_that("filter_out() with multiple conditions", { + compare_dplyr_binding( + .input |> + filter_out(dbl > 2, chr %in% c("d", "f")) |> + collect(), + tbl + ) +}) + +test_that("More complex select/filter_out", { + compare_dplyr_binding( + .input |> + filter_out(dbl > 2, chr == "d" | chr == "f") |> + select(chr, int, lgl) |> + filter(int < 5) |> + select(int, chr) |> + collect(), + tbl + ) +})