Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions r/R/arrow-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
supported_dplyr_methods <- list(
select = NULL,
filter = NULL,
filter_out = NULL,
collect = NULL,
summarise = c(
"window functions not currently supported;",
Expand Down
113 changes: 89 additions & 24 deletions r/R/dplyr-filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,49 +17,108 @@

# The following S3 methods are registered on load if dplyr is present

filter.arrow_dplyr_query <- function(.data, ..., .by = NULL, .preserve = FALSE) {
try_arrow_dplyr({
# TODO something with the .preserve argument
out <- as_adq(.data)
apply_filter_impl <- function(
.data,
...,
.by = NULL,
.preserve = FALSE,
exclude = FALSE,
verb = c("filter", "filter_out")
) {
verb <- match.arg(verb)

by <- compute_by({{ .by }}, out, by_arg = ".by", data_arg = ".data")
# TODO something with the .preserve argument
out <- as_adq(.data)

if (by$from_by) {
out$group_by_vars <- by$names
}
by <- compute_by({{ .by }}, out, by_arg = ".by", data_arg = ".data")

expanded_filters <- expand_across(out, quos(...))
if (length(expanded_filters) == 0) {
# Nothing to do
return(as_adq(.data))
}
if (by$from_by) {
out$group_by_vars <- by$names
}

expanded_filters <- expand_across(out, quos(...))
if (length(expanded_filters) == 0) {
# Nothing to do
return(as_adq(.data))
}

# tidy-eval the filter expressions inside an Arrow data_mask
mask <- arrow_mask(out)

if (isTRUE(exclude)) {
# filter_out(): combine all predicates with &, then exclude
combined <- NULL

# tidy-eval the filter expressions inside an Arrow data_mask
mask <- arrow_mask(out)
for (expr in expanded_filters) {
filt <- arrow_eval(expr, mask)

if (length(mask$.aggregations)) {
# dplyr lets you filter on e.g. x < mean(x), but we haven't implemented it.
# But we could, the same way it works in mutate() via join, if someone asks.
# Until then, just error.
arrow_not_supported(
.actual_msg = "Expression not supported in filter() in Arrow",
.actual_msg = sprintf("Expression not supported in %s() in Arrow", verb),
call = expr
)
}
out <- set_filters(out, filt)

if (is_list_of(filt, "Expression")) {
filt <- Reduce("&", filt)
}

combined <- if (is.null(combined)) filt else (combined & filt)
}

if (by$from_by) {
out$group_by_vars <- character()
out <- set_filters(out, combined, exclude = TRUE)
} else {
# filter(): apply each predicate sequentially
for (expr in expanded_filters) {
filt <- arrow_eval(expr, mask)

if (length(mask$.aggregations)) {
arrow_not_supported(
.actual_msg = sprintf("Expression not supported in %s() in Arrow", verb),
call = expr
)
}

out <- set_filters(out, filt, exclude = FALSE)
}
}

if (by$from_by) {
out$group_by_vars <- character()
}

out
}

out
filter.arrow_dplyr_query <- function(.data, ..., .by = NULL, .preserve = FALSE) {
try_arrow_dplyr({
apply_filter_impl(
.data,
...,
.by = {{ .by }},
.preserve = .preserve,
exclude = FALSE,
verb = "filter"
)
})
}
filter.Dataset <- filter.ArrowTabular <- filter.RecordBatchReader <- filter.arrow_dplyr_query

set_filters <- function(.data, expressions) {
filter_out.arrow_dplyr_query <- function(.data, ..., .by = NULL, .preserve = FALSE) {
try_arrow_dplyr({
apply_filter_impl(
.data,
...,
.by = {{ .by }},
.preserve = .preserve,
exclude = TRUE,
verb = "filter_out"
)
})
}
filter_out.Dataset <- filter_out.ArrowTabular <- filter_out.RecordBatchReader <- filter_out.arrow_dplyr_query

set_filters <- function(.data, expressions, exclude = FALSE) {
if (length(expressions)) {
if (is_list_of(expressions, "Expression")) {
# expressions is a list of Expressions. AND them together and set them on .data
Expand All @@ -70,6 +129,12 @@ set_filters <- function(.data, expressions) {
stop("filter expressions must be either an expression or a list of expressions", call. = FALSE)
}

if (isTRUE(exclude)) {
# dplyr::filter_out() semantics: drop rows where predicate is TRUE;
# keep rows where predicate is FALSE or NA.
new_filter <- (!new_filter) | is.na(new_filter)
}

if (isTRUE(.data$filtered_rows)) {
# TRUE is default (i.e. no filter yet), so we don't need to & with it
.data$filtered_rows <- new_filter
Expand Down
41 changes: 41 additions & 0 deletions r/tests/testthat/test-dplyr-filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -498,3 +498,44 @@ test_that("filter() with aggregation expressions errors", {
"not supported in filter"
)
})

test_that("filter_out() basic", {
compare_dplyr_binding(
.input |>
filter_out(chr == "b") |>
select(chr, int, lgl) |>
collect(),
tbl
)
})

test_that("filter_out() keeps NA values in predicate result", {
compare_dplyr_binding(
.input |>
filter_out(lgl) |>
select(chr, int, lgl) |>
collect(),
tbl
)
})

test_that("filter_out() with multiple conditions", {
compare_dplyr_binding(
.input |>
filter_out(dbl > 2, chr %in% c("d", "f")) |>
collect(),
tbl
)
})

test_that("More complex select/filter_out", {
compare_dplyr_binding(
.input |>
filter_out(dbl > 2, chr == "d" | chr == "f") |>
select(chr, int, lgl) |>
filter(int < 5) |>
select(int, chr) |>
collect(),
tbl
)
})