diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 15d3261ca513..29bfd18adb92 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -49,6 +49,7 @@ chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } datafusion-expr-common = { workspace = true } +datafusion-functions-aggregate = { workspace = true } datafusion-physical-expr = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index e6b24dec87fd..31aae1c99036 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -65,6 +65,7 @@ pub mod propagate_empty_relation; pub mod push_down_filter; pub mod push_down_limit; pub mod replace_distinct_aggregate; +pub mod rewrite_aggregate_with_constant; pub mod rewrite_set_comparison; pub mod scalar_subquery_to_join; pub mod simplify_expressions; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 877a84fe4dc1..8cb7d5291e7a 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -51,6 +51,7 @@ use crate::propagate_empty_relation::PropagateEmptyRelation; use crate::push_down_filter::PushDownFilter; use crate::push_down_limit::PushDownLimit; use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; +use crate::rewrite_aggregate_with_constant::RewriteAggregateWithConstant; use crate::rewrite_set_comparison::RewriteSetComparison; use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; use crate::simplify_expressions::SimplifyExpressions; @@ -259,6 +260,7 @@ impl Optimizer { // The previous optimizations added expressions and projections, // that might benefit from the following rules Arc::new(EliminateGroupByConstant::new()), + Arc::new(RewriteAggregateWithConstant::new()), Arc::new(CommonSubexprEliminate::new()), Arc::new(OptimizeProjections::new()), ]; diff --git a/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs b/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs new file mode 100644 index 000000000000..d550d8577937 --- /dev/null +++ b/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs @@ -0,0 +1,607 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`RewriteAggregateWithConstant`] rewrites `SUM(column ± constant)` to `SUM(column) ± constant * COUNT(column)` + +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; + +use std::collections::HashMap; + +use datafusion_common::tree_node::Transformed; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::expr::AggregateFunctionParams; +use datafusion_expr::{ + Aggregate, BinaryExpr, Expr, LogicalPlan, LogicalPlanBuilder, Operator, binary_expr, + col, lit, +}; +use datafusion_functions_aggregate::expr_fn::{count, sum}; +use indexmap::IndexMap; + +/// Optimizer rule that rewrites `SUM(column ± constant)` expressions +/// into `SUM(column) ± constant * COUNT(column)` when multiple such expressions +/// exist for the same base column. +/// +/// This reduces computation by calculating SUM once and deriving other values. +/// +/// # Example +/// ```sql +/// SELECT SUM(a), SUM(a + 1), SUM(a + 2) FROM t; +/// ``` +/// is rewritten into a Projection on top of an Aggregate: +/// ```sql +/// -- New Projection Node +/// SELECT sum_a, sum_a + 1 * count_a, sum_a + 2 * count_a +/// -- New Aggregate Node +/// FROM (SELECT SUM(a) as sum_a, COUNT(a) as count_a FROM t); +/// ``` +#[derive(Default, Debug)] +pub struct RewriteAggregateWithConstant {} + +impl RewriteAggregateWithConstant { + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for RewriteAggregateWithConstant { + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + match plan { + // This rule specifically targets Aggregate nodes + LogicalPlan::Aggregate(aggregate) => { + // Step 1: Identify which expressions can be rewritten and group them by base column + let rewrite_info = analyze_aggregate(&aggregate)?; + + if rewrite_info.is_empty() { + // No groups found with 2+ matching SUM expressions, return original plan + return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate))); + } + + // Step 2: Perform the actual transformation into Aggregate + Projection + transform_aggregate(aggregate, &rewrite_info) + } + // Non-aggregate plans are passed through unchanged + _ => Ok(Transformed::no(plan)), + } + } + + fn name(&self) -> &str { + "rewrite_aggregate_with_constant" + } + + fn apply_order(&self) -> Option { + // Bottom-up ensures we optimize subqueries before the outer query + Some(ApplyOrder::BottomUp) + } +} + +/// Internal structure to track metadata for a SUM expression that qualifies for rewrite. +#[derive(Debug, Clone)] +struct SumWithConstant { + /// The inner expression being summed (e.g., the `a` in `SUM(a + 1)`) + base_expr: Expr, + /// The constant value being added/subtracted (e.g., `1` in `SUM(a + 1)`) + constant: ScalarValue, + /// The operator (`+` or `-`) + operator: Operator, + /// The index in the original Aggregate's `aggr_expr` list, used to maintain output order + original_index: usize, + // Note: ORDER BY inside SUM is irrelevant because SUM is commutative — + // the order of addition doesn't change the result. If this rule is ever + // extended to non-commutative aggregates, ORDER BY handling would need + // to be added back. +} + +/// Maps a base expression's schema name to all its SUM(base ± const) variants. +/// We use IndexMap to preserve insertion order, ensuring deterministic output +/// in the rewritten plan (important for stable EXPLAIN output in tests). +type RewriteGroups = IndexMap>; + +/// Scans the aggregate expressions to find candidates for the rewrite. +fn analyze_aggregate(aggregate: &Aggregate) -> Result { + let mut groups: RewriteGroups = IndexMap::new(); + + for (idx, expr) in aggregate.aggr_expr.iter().enumerate() { + // Try to match the pattern SUM(col ± lit) + if let Some(sum_info) = extract_sum_with_constant(expr, idx)? { + let key = sum_info.base_expr.schema_name().to_string(); + groups.entry(key).or_default().push(sum_info); + } + } + + // Optimization: Only rewrite if we have at least 2 expressions for the same column. + // If there's only one SUM(a + 1), rewriting it to SUM(a) + 1*COUNT(a) + // actually increases the work (1 agg -> 2 aggs). + groups.retain(|_, v| v.len() >= 2); + + Ok(groups) +} + +/// Extract SUM(base_expr ± constant) pattern from an expression. +/// Handles both `Expr::AggregateFunction(...)` and `Expr::Alias(Expr::AggregateFunction(...))` +/// so the rule works regardless of whether aggregate expressions carry aliases +/// (e.g., when plans are built via the LogicalPlanBuilder API). +fn extract_sum_with_constant(expr: &Expr, idx: usize) -> Result> { + // Unwrap Expr::Alias if present — the SQL planner puts aliases in a + // Projection above the Aggregate, but the builder API allows aliases + // directly inside aggr_expr. + let inner = match expr { + Expr::Alias(alias) => alias.expr.as_ref(), + other => other, + }; + + match inner { + Expr::AggregateFunction(agg_fn) => { + // Rule only applies to SUM + if agg_fn.func.name().to_lowercase() != "sum" { + return Ok(None); + } + + let AggregateFunctionParams { + args, + distinct, + filter, + order_by: _, + null_treatment: _, + } = &agg_fn.params; + + // We cannot easily rewrite SUM(DISTINCT a + 1) or SUM(a + 1) FILTER (...) + // as the math SUM(a) + k*COUNT(a) wouldn't hold correctly with these modifiers. + if *distinct || filter.is_some() { + return Ok(None); + } + + // SUM must have exactly one argument (e.g. SUM(a + 1)). + // This rejects invalid calls like SUM() or non-standard multi-argument variations. + if args.len() != 1 { + return Ok(None); + } + + let arg = &args[0]; + + // Try to match: base_expr +/- constant + // Note: If the base_expr is complex (e.g., SUM(a + b + 1)), base_expr will be "a + b". + // The rule will still work if multiple SUMs have the exact same complex base_expr, + // as they will be grouped by the string representation of that expression. + if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = arg + && matches!(op, Operator::Plus | Operator::Minus) + { + // Check if right side is a literal constant + // Check if right side is a literal constant (e.g., SUM(a + 1)) + if let Expr::Literal(constant, _) = right.as_ref() + && is_numeric_constant(constant) + { + return Ok(Some(SumWithConstant { + base_expr: (**left).clone(), + constant: constant.clone(), + operator: *op, + original_index: idx, + })); + } + + // Also check left side for commutative addition (e.g., SUM(1 + a)) + // Does NOT apply to subtraction: SUM(5 - a) ≠ SUM(a - 5) + if let Expr::Literal(constant, _) = left.as_ref() + && is_numeric_constant(constant) + && *op == Operator::Plus + { + return Ok(Some(SumWithConstant { + base_expr: (**right).clone(), + constant: constant.clone(), + operator: Operator::Plus, + original_index: idx, + })); + } + } + + Ok(None) + } + _ => Ok(None), + } +} + +/// Check if a scalar value is a numeric constant +/// (guards against non-arithmetic types like strings, booleans, dates, etc.) +fn is_numeric_constant(value: &ScalarValue) -> bool { + matches!( + value, + ScalarValue::Int8(_) + | ScalarValue::Int16(_) + | ScalarValue::Int32(_) + | ScalarValue::Int64(_) + | ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) + | ScalarValue::Float32(_) + | ScalarValue::Float64(_) + | ScalarValue::Decimal128(_, _, _) + | ScalarValue::Decimal256(_, _, _) + ) +} + +/// Check if an expression is a plain `SUM(base_expr)` whose base expression matches +/// one of our rewrite groups. Returns `Some(base_key)` if matched, `None` otherwise. +/// +/// Handles both `Expr::AggregateFunction(...)` and `Expr::Alias(Expr::AggregateFunction(...))` +/// so that aliased plain SUMs (e.g., `SUM(a) AS total`) are correctly detected and reused +/// instead of being duplicated in the new Aggregate node. +/// +/// Returning the `base_key` directly eliminates the need for the caller to re-extract it +/// from the expression (avoiding a potential panic on impossible-to-reach fallback paths). +fn check_plain_sum_in_group( + expr: &Expr, + known_base_keys: &IndexMap, +) -> Option { + // Unwrap alias if present + let inner = match expr { + Expr::Alias(alias) => alias.expr.as_ref(), + other => other, + }; + + if let Expr::AggregateFunction(agg_fn) = inner + && agg_fn.func.name().to_lowercase() == "sum" + && agg_fn.params.args.len() == 1 + && !agg_fn.params.distinct + && agg_fn.params.filter.is_none() + { + let arg = &agg_fn.params.args[0]; + let base_key = arg.schema_name().to_string(); + if known_base_keys.contains_key(&base_key) { + return Some(base_key); + } + } + None +} + +/// Alias `expr` to match the output column name of `orig_expr`. +/// +/// If `orig_expr` is `Expr::Alias(name)`, the result is aliased to that name. +/// Otherwise, the result is aliased to `orig_expr.schema_name()` to preserve +/// the auto-generated column name (e.g., `"sum(t.a + Int64(1))"`). +fn alias_like(expr: Expr, orig_expr: &Expr) -> Expr { + match orig_expr { + Expr::Alias(alias) => expr.alias(alias.name.clone()), + _ => expr.alias(orig_expr.schema_name().to_string()), + } +} + +/// Transform the aggregate plan by rewriting SUM(col ± constant) expressions. +/// +/// Replaces the original `Aggregate` node with: +/// 1. A new `Aggregate` containing one `SUM(base)` + one `COUNT(base)` per group, +/// plus any unrelated expressions passed through. +/// 2. A `Projection` on top that derives each original output column using +/// the formula: `SUM(base) ± constant * COUNT(base)`. +/// +/// Delegates to [`build_rewrite_expressions`] for the expression logic and +/// [`assemble_rewritten_plan`] for the plan construction. +fn transform_aggregate( + aggregate: Aggregate, + rewrite_groups: &RewriteGroups, +) -> Result> { + let (new_aggr_exprs, projection_exprs) = + build_rewrite_expressions(&aggregate, rewrite_groups); + assemble_rewritten_plan(aggregate, new_aggr_exprs, projection_exprs) +} + +/// Build the new aggregate expressions and corresponding projection expressions. +/// +/// For each rewrite group (e.g., all SUMs involving column `a`), adds one +/// `SUM(a)` and one `COUNT(a)` to the new aggregate. Then maps every original +/// aggregate expression to a projection expression using one of three cases: +/// +/// - **Case 1** (rewritable): `SUM(a + k)` → `SUM(a) + k * COUNT(a)` +/// - **Case 2** (plain SUM in group): `SUM(a)` → reference the already-added `SUM(a)` +/// - **Case 3** (unrelated): pass through unchanged (e.g., `AVG(b)`) +/// +/// Returns `(new_aggr_exprs, projection_exprs)`. +fn build_rewrite_expressions( + aggregate: &Aggregate, + rewrite_groups: &RewriteGroups, +) -> (Vec, Vec) { + let mut new_aggr_exprs: Vec = Vec::new(); + let mut projection_exprs: Vec = Vec::new(); + + // O(1) lookup of rewritable expressions by their original index. + let rewrite_lookup: HashMap = rewrite_groups + .values() + .flatten() + .map(|s| (s.original_index, s)) + .collect(); + + // Maps base expression key → (sum_index, count_index) in `new_aggr_exprs`. + // Column names are derived on-demand from `new_aggr_exprs[index].schema_name()` + // rather than maintaining separate name maps. + let mut base_expr_indices: IndexMap = IndexMap::new(); + + // For every rewrite group, add one SUM(base) and one COUNT(base) + // to the new Aggregate. + for (base_key, sums) in rewrite_groups { + let base_expr = &sums[0].base_expr; + + let sum_index = new_aggr_exprs.len(); + new_aggr_exprs.push(sum(base_expr.clone())); + + // COUNT(col) rather than COUNT(*): if col is NULL, SUM(col+1) should + // be NULL, and COUNT(col) correctly returns 0 for NULLs. + let count_index = new_aggr_exprs.len(); + new_aggr_exprs.push(count(base_expr.clone())); + + base_expr_indices.insert(base_key.clone(), (sum_index, count_index)); + } + + // Map each original aggregate expression to a projection expression. + for (idx, orig_expr) in aggregate.aggr_expr.iter().enumerate() { + let projection_expr = if let Some(sum_info) = rewrite_lookup.get(&idx) { + // ── Case 1: Rewritable SUM(col ± constant) ── + // Derive: SUM(col) ± (constant * COUNT(col)) + build_derived_projection( + sum_info, + &base_expr_indices, + &new_aggr_exprs, + orig_expr, + ) + } else if let Some(base_key) = + check_plain_sum_in_group(orig_expr, &base_expr_indices) + { + // ── Case 2: Plain SUM(a) matching a rewrite group ── + // Reuse the SUM(a) we already added instead of creating a duplicate. + let (sum_idx, _) = base_expr_indices[&base_key]; + let sum_ref = col(new_aggr_exprs[sum_idx].schema_name().to_string()); + alias_like(sum_ref, orig_expr) + } else { + // ── Case 3: Unrelated expression (e.g., AVG(b), MAX(c)) ── + // Pass it through to the new Aggregate node unchanged. + new_aggr_exprs.push(orig_expr.clone()); + match orig_expr { + Expr::Alias(alias) => col(alias.name.clone()), + _ => col(orig_expr.schema_name().to_string()), + } + }; + + projection_exprs.push(projection_expr); + } + + (new_aggr_exprs, projection_exprs) +} + +/// Build the projection expression for a single rewritable SUM(col ± constant): +/// +/// `SUM(col) ± (constant * COUNT(col))`, aliased to match the original expression. +fn build_derived_projection( + sum_info: &SumWithConstant, + base_expr_indices: &IndexMap, + new_aggr_exprs: &[Expr], + orig_expr: &Expr, +) -> Expr { + let base_key = sum_info.base_expr.schema_name().to_string(); + let (sum_idx, count_idx) = base_expr_indices[&base_key]; + + let sum_ref = col(new_aggr_exprs[sum_idx].schema_name().to_string()); + let count_ref = col(new_aggr_exprs[count_idx].schema_name().to_string()); + + let multiplied = binary_expr( + lit(sum_info.constant.clone()), + Operator::Multiply, + count_ref, + ); + let result = binary_expr(sum_ref, sum_info.operator, multiplied); + alias_like(result, orig_expr) +} + +/// Assemble the final rewritten plan: `Projection → Aggregate → (original input)`. +/// +/// GROUP BY columns are passed through both nodes so the output schema is preserved. +fn assemble_rewritten_plan( + aggregate: Aggregate, + new_aggr_exprs: Vec, + projection_exprs: Vec, +) -> Result> { + // GROUP BY columns: reference them in the Projection. + let group_refs: Vec = aggregate + .group_expr + .iter() + .map(|e| match e { + Expr::Alias(alias) => col(alias.name.clone()), + Expr::Column(c) => Expr::Column(c.clone()), + _ => col(e.schema_name().to_string()), + }) + .collect(); + + // Final projection: [GROUP BY columns] ++ [aggregate/rewritten columns] + let mut final_projection = group_refs; + final_projection.extend(projection_exprs); + + // Build: Projection → Aggregate → (original input) + let new_aggregate = LogicalPlan::Aggregate(Aggregate::try_new( + aggregate.input, + aggregate.group_expr, + new_aggr_exprs, + )?); + + let projection = LogicalPlanBuilder::from(new_aggregate) + .project(final_projection)? + .build()?; + + Ok(Transformed::yes(projection)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::OptimizerContext; + use crate::test::*; + use datafusion_common::Result; + use datafusion_expr::{LogicalPlanBuilder, col, lit}; + use datafusion_functions_aggregate::expr_fn::sum; + + #[test] + fn test_sum_with_constant_basic() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + Vec::::new(), + vec![ + sum(col("a")), + sum(col("a") + lit(1)), + sum(col("a") + lit(2)), + ], + )? + .build()?; + + let rule = RewriteAggregateWithConstant::new(); + let config = OptimizerContext::new(); + let result = rule.rewrite(plan, &config)?; + + // Should be transformed + assert!(result.transformed); + Ok(()) + } + + #[test] + fn test_no_transform_single_sum() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(Vec::::new(), vec![sum(col("a") + lit(1))])? + .build()?; + + let rule = RewriteAggregateWithConstant::new(); + let config = OptimizerContext::new(); + let result = rule.rewrite(plan, &config)?; + + // Should NOT be transformed (only one SUM) + assert!(!result.transformed); + Ok(()) + } + + #[test] + fn test_no_transform_no_constant() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(Vec::::new(), vec![sum(col("a")), sum(col("b"))])? + .build()?; + + let rule = RewriteAggregateWithConstant::new(); + let config = OptimizerContext::new(); + let result = rule.rewrite(plan, &config)?; + + // Should NOT be transformed (no constants) + assert!(!result.transformed); + Ok(()) + } + + /// Test that aliased SUM(col ± constant) expressions are correctly detected + /// and rewritten. This exercises the Expr::Alias unwrapping in + /// `extract_sum_with_constant`. + /// + /// Note: The SQL planner puts aliases in a Projection above the Aggregate, + /// so this case only arises when building plans via the LogicalPlanBuilder API. + #[test] + fn test_aliased_sum_with_constant() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + Vec::::new(), + vec![ + sum(col("a") + lit(1)).alias("sum_a_plus_1"), + sum(col("a") + lit(2)).alias("sum_a_plus_2"), + ], + )? + .build()?; + + let rule = RewriteAggregateWithConstant::new(); + let config = OptimizerContext::new(); + let result = rule.rewrite(plan, &config)?; + + // Should be transformed: both aliased SUMs share base column "a" + assert!(result.transformed); + + // The rewritten plan should be a Projection on top of an Aggregate + let plan_str = format!("{}", result.data.display_indent()); + assert!( + plan_str.contains("Projection:"), + "Expected Projection node in rewritten plan, got:\n{plan_str}" + ); + // The new aggregate should have SUM(a) and COUNT(a), not two separate SUMs + assert!( + plan_str.contains("sum(test.a)") && plan_str.contains("count(test.a)"), + "Expected SUM(a) and COUNT(a) in rewritten plan, got:\n{plan_str}" + ); + Ok(()) + } + + /// Test that an aliased plain SUM(a) alongside SUM(a+1) and SUM(a+2) is + /// correctly detected as a "plain SUM in group" and reused, rather than + /// being duplicated in the new Aggregate node. + /// + /// This exercises the Expr::Alias unwrapping in `check_plain_sum_in_group`. + #[test] + fn test_aliased_plain_sum_in_group() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + Vec::::new(), + vec![ + // This plain SUM(a) is aliased — it should be detected + // as matching the rewrite group for "a" and reused. + sum(col("a")).alias("total_a"), + sum(col("a") + lit(1)), + sum(col("a") + lit(2)), + ], + )? + .build()?; + + let rule = RewriteAggregateWithConstant::new(); + let config = OptimizerContext::new(); + let result = rule.rewrite(plan, &config)?; + + // Should be transformed + assert!(result.transformed); + + let plan_str = format!("{}", result.data.display_indent()); + + // The rewritten plan should contain "total_a" alias in the Projection + assert!( + plan_str.contains("total_a"), + "Expected alias 'total_a' in rewritten plan, got:\n{plan_str}" + ); + + // The Aggregate node should contain exactly one SUM(a) and one COUNT(a), + // NOT a duplicate SUM(a). Count occurrences of "sum(test.a)" in the + // Aggregate line (not the Projection line). + let aggr_line = plan_str + .lines() + .find(|l| l.contains("Aggregate:")) + .expect("should have Aggregate node"); + let sum_count = aggr_line.matches("sum(test.a)").count(); + assert_eq!( + sum_count, 1, + "Expected exactly 1 SUM(test.a) in Aggregate (no duplicate), got {sum_count} in:\n{aggr_line}" + ); + + Ok(()) + } +} diff --git a/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt b/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt new file mode 100644 index 000000000000..a9390ce537d1 --- /dev/null +++ b/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt @@ -0,0 +1,692 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Aggregate Rewrite With Constant Optimizer Tests +## Tests for the optimizer rule that rewrites SUM(col ± constant) to SUM(col) ± constant * COUNT(*) +## Rule only applies when there are 2+ SUM expressions on the SAME base column +########## + +# ==== Test 1: Basic addition with multiple sum expressions ==== + +statement ok +CREATE TABLE test_table ( + a INT, + b INT, + c INT +) AS VALUES + (1, 10, 100), + (2, 20, 200), + (3, 30, 300), + (4, 40, 400), + (5, 50, 500); + +# Test: Multiple SUM expressions with constants should be rewritten +query TT +EXPLAIN SELECT + SUM(a) as sum_a, + SUM(a + 1) as sum_a_plus_1, + SUM(a + 2) as sum_a_plus_2, + SUM(a + 3) as sum_a_plus_3 +FROM test_table; +---- +logical_plan +01)Projection: sum(test_table.a) AS sum_a, sum(test_table.a) + count(test_table.a) AS sum_a_plus_1, sum(test_table.a) + Int64(2) * count(test_table.a) AS sum_a_plus_2, sum(test_table.a) + Int64(3) * count(test_table.a) AS sum_a_plus_3 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test_table.a), count(__common_expr_1 AS test_table.a)]] +03)----Projection: CAST(test_table.a AS Int64) AS __common_expr_1 +04)------TableScan: test_table projection=[a] +physical_plan +01)ProjectionExec: expr=[sum(test_table.a)@0 as sum_a, sum(test_table.a)@0 + count(test_table.a)@1 as sum_a_plus_1, sum(test_table.a)@0 + 2 * count(test_table.a)@1 as sum_a_plus_2, sum(test_table.a)@0 + 3 * count(test_table.a)@1 as sum_a_plus_3] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a), count(test_table.a)] +03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query IIII +SELECT + SUM(a) as sum_a, + SUM(a + 1) as sum_a_plus_1, + SUM(a + 2) as sum_a_plus_2, + SUM(a + 3) as sum_a_plus_3 +FROM test_table; +---- +15 20 25 30 + +# ==== Test 2: Subtraction operations ==== + +query TT +EXPLAIN SELECT + SUM(a) as sum_a, + SUM(a - 1) as sum_a_minus_1, + SUM(a - 2) as sum_a_minus_2, + SUM(a - 3) as sum_a_minus_3 +FROM test_table; +---- +logical_plan +01)Projection: sum(test_table.a) AS sum_a, sum(test_table.a) - count(test_table.a) AS sum_a_minus_1, sum(test_table.a) - Int64(2) * count(test_table.a) AS sum_a_minus_2, sum(test_table.a) - Int64(3) * count(test_table.a) AS sum_a_minus_3 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test_table.a), count(__common_expr_1 AS test_table.a)]] +03)----Projection: CAST(test_table.a AS Int64) AS __common_expr_1 +04)------TableScan: test_table projection=[a] +physical_plan +01)ProjectionExec: expr=[sum(test_table.a)@0 as sum_a, sum(test_table.a)@0 - count(test_table.a)@1 as sum_a_minus_1, sum(test_table.a)@0 - 2 * count(test_table.a)@1 as sum_a_minus_2, sum(test_table.a)@0 - 3 * count(test_table.a)@1 as sum_a_minus_3] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a), count(test_table.a)] +03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query IIII +SELECT + SUM(a) as sum_a, + SUM(a - 1) as sum_a_minus_1, + SUM(a - 2) as sum_a_minus_2, + SUM(a - 3) as sum_a_minus_3 +FROM test_table; +---- +15 10 5 0 + +# ==== Test 3: With GROUP BY ==== + +statement ok +CREATE TABLE group_test ( + category VARCHAR, + value INT +) AS VALUES + ('A', 1), + ('A', 2), + ('A', 3), + ('B', 4), + ('B', 5), + ('B', 6); + +query TT +EXPLAIN SELECT + category, + SUM(value) as sum_val, + SUM(value + 1) as sum_val_plus_1, + SUM(value - 2) as sum_val_minus_2 +FROM group_test +GROUP BY category; +---- +logical_plan +01)Projection: group_test.category, sum(group_test.value) AS sum_val, sum(group_test.value) + count(group_test.value) AS sum_val_plus_1, sum(group_test.value) - Int64(2) * count(group_test.value) AS sum_val_minus_2 +02)--Aggregate: groupBy=[[group_test.category]], aggr=[[sum(__common_expr_1 AS group_test.value), count(__common_expr_1 AS group_test.value)]] +03)----Projection: CAST(group_test.value AS Int64) AS __common_expr_1, group_test.category +04)------TableScan: group_test projection=[category, value] +physical_plan +01)ProjectionExec: expr=[category@0 as category, sum(group_test.value)@1 as sum_val, sum(group_test.value)@1 + count(group_test.value)@2 as sum_val_plus_1, sum(group_test.value)@1 - 2 * count(group_test.value)@2 as sum_val_minus_2] +02)--AggregateExec: mode=FinalPartitioned, gby=[category@0 as category], aggr=[sum(group_test.value), count(group_test.value)] +03)----RepartitionExec: partitioning=Hash([category@0], 4), input_partitions=1 +04)------AggregateExec: mode=Partial, gby=[category@1 as category], aggr=[sum(group_test.value), count(group_test.value)] +05)--------ProjectionExec: expr=[CAST(value@1 AS Int64) as __common_expr_1, category@0 as category] +06)----------DataSourceExec: partitions=1, partition_sizes=[1] + +query TIII rowsort +SELECT + category, + SUM(value) as sum_val, + SUM(value + 1) as sum_val_plus_1, + SUM(value - 2) as sum_val_minus_2 +FROM group_test +GROUP BY category; +---- +A 6 9 0 +B 15 18 9 + +# ==== Test 4: With nullable columns (SHOULD NOT rewrite - only 1 SUM per column) ==== + +statement ok +CREATE TABLE nullable_test ( + id INT, + a INT, + b INT +) AS VALUES + (1, 10, NULL), + (2, 20, 200), + (3, NULL, 300), + (4, 40, 400), + (5, 50, NULL); + +# This should NOT be rewritten because each column has only 1 SUM with a constant +query TT +EXPLAIN SELECT + SUM(a) as sum_a, + SUM(a + 5) as sum_a_plus_5, + SUM(b) as sum_b, + SUM(b - 10) as sum_b_minus_10 +FROM nullable_test; +---- +logical_plan +01)Projection: sum(nullable_test.a) AS sum_a, sum(nullable_test.a + Int64(5)) AS sum_a_plus_5, sum(nullable_test.b) AS sum_b, sum(nullable_test.b - Int64(10)) AS sum_b_minus_10 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS nullable_test.a), sum(__common_expr_1 AS nullable_test.a + Int64(5)), sum(__common_expr_2 AS nullable_test.b), sum(__common_expr_2 AS nullable_test.b - Int64(10))]] +03)----Projection: CAST(nullable_test.a AS Int64) AS __common_expr_1, CAST(nullable_test.b AS Int64) AS __common_expr_2 +04)------TableScan: nullable_test projection=[a, b] +physical_plan +01)ProjectionExec: expr=[sum(nullable_test.a)@0 as sum_a, sum(nullable_test.a + Int64(5))@1 as sum_a_plus_5, sum(nullable_test.b)@2 as sum_b, sum(nullable_test.b - Int64(10))@3 as sum_b_minus_10] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(nullable_test.a), sum(nullable_test.a + Int64(5)), sum(nullable_test.b), sum(nullable_test.b - Int64(10))] +03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1, CAST(b@1 AS Int64) as __common_expr_2] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query IIII +SELECT + SUM(a) as sum_a, + SUM(a + 5) as sum_a_plus_5, + SUM(b) as sum_b, + SUM(b - 10) as sum_b_minus_10 +FROM nullable_test; +---- +120 140 900 870 + +# Test with multiple SUMs on nullable column +query TT +EXPLAIN SELECT + SUM(a) as sum_a, + SUM(a + 5) as sum_a_plus_5, + SUM(a + 10) as sum_a_plus_10 +FROM nullable_test; +---- +logical_plan +01)Projection: sum(nullable_test.a) AS sum_a, sum(nullable_test.a) + Int64(5) * count(nullable_test.a) AS sum_a_plus_5, sum(nullable_test.a) + Int64(10) * count(nullable_test.a) AS sum_a_plus_10 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS nullable_test.a), count(__common_expr_1 AS nullable_test.a)]] +03)----Projection: CAST(nullable_test.a AS Int64) AS __common_expr_1 +04)------TableScan: nullable_test projection=[a] +physical_plan +01)ProjectionExec: expr=[sum(nullable_test.a)@0 as sum_a, sum(nullable_test.a)@0 + 5 * count(nullable_test.a)@1 as sum_a_plus_5, sum(nullable_test.a)@0 + 10 * count(nullable_test.a)@1 as sum_a_plus_10] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(nullable_test.a), count(nullable_test.a)] +03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query III +SELECT + SUM(a) as sum_a, + SUM(a + 5) as sum_a_plus_5, + SUM(a + 10) as sum_a_plus_10 +FROM nullable_test; +---- +120 140 160 + +# ==== Test 5: Negative constants ==== + +query TT +EXPLAIN SELECT + SUM(a) as sum_a, + SUM(a + (-1)) as sum_a_minus_1, + SUM(a - (-2)) as sum_a_plus_2, + SUM(a + (-3)) as sum_a_minus_3 +FROM test_table; +---- +logical_plan +01)Projection: sum(test_table.a) AS sum_a, sum(test_table.a) + Int64(-1) * count(test_table.a) AS sum_a_minus_1, sum(test_table.a) - Int64(-2) * count(test_table.a) AS sum_a_plus_2, sum(test_table.a) + Int64(-3) * count(test_table.a) AS sum_a_minus_3 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test_table.a), count(__common_expr_1 AS test_table.a)]] +03)----Projection: CAST(test_table.a AS Int64) AS __common_expr_1 +04)------TableScan: test_table projection=[a] +physical_plan +01)ProjectionExec: expr=[sum(test_table.a)@0 as sum_a, sum(test_table.a)@0 + -1 * count(test_table.a)@1 as sum_a_minus_1, sum(test_table.a)@0 - -2 * count(test_table.a)@1 as sum_a_plus_2, sum(test_table.a)@0 + -3 * count(test_table.a)@1 as sum_a_minus_3] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a), count(test_table.a)] +03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query IIII +SELECT + SUM(a) as sum_a, + SUM(a + (-1)) as sum_a_minus_1, + SUM(a - (-2)) as sum_a_plus_2, + SUM(a + (-3)) as sum_a_minus_3 +FROM test_table; +---- +15 10 25 0 + +# ==== Test 6: No matching rewrite patterns ==== + +# Should not rewrite - only one sum with constant +query TT +EXPLAIN SELECT SUM(a + 1) FROM test_table; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[sum(CAST(test_table.a AS Int64) + Int64(1))]] +02)--TableScan: test_table projection=[a] +physical_plan +01)AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a + Int64(1))] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +# Should not rewrite - different base columns +query TT +EXPLAIN SELECT SUM(a + 1), SUM(b + 2) FROM test_table; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[sum(CAST(test_table.a AS Int64) + Int64(1)), sum(CAST(test_table.b AS Int64) + Int64(2))]] +02)--TableScan: test_table projection=[a, b] +physical_plan +01)AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a + Int64(1)), sum(test_table.b + Int64(2))] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +# ==== Test 7: Mixed sum types (rewrites a and b, not c) ==== + +query TT +EXPLAIN SELECT + SUM(a) as sum_a, + SUM(a + 1) as sum_a_plus_1, + SUM(b) as sum_b, + SUM(b + 2) as sum_b_plus_2, + SUM(c + 3) as sum_c_plus_3 +FROM test_table; +---- +logical_plan +01)Projection: sum(test_table.a) AS sum_a, sum(test_table.a + Int64(1)) AS sum_a_plus_1, sum(test_table.b) AS sum_b, sum(test_table.b + Int64(2)) AS sum_b_plus_2, sum(test_table.c + Int64(3)) AS sum_c_plus_3 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test_table.a), sum(__common_expr_1 AS test_table.a + Int64(1)), sum(__common_expr_2 AS test_table.b), sum(__common_expr_2 AS test_table.b + Int64(2)), sum(CAST(test_table.c AS Int64) + Int64(3))]] +03)----Projection: CAST(test_table.a AS Int64) AS __common_expr_1, CAST(test_table.b AS Int64) AS __common_expr_2, test_table.c +04)------TableScan: test_table projection=[a, b, c] +physical_plan +01)ProjectionExec: expr=[sum(test_table.a)@0 as sum_a, sum(test_table.a + Int64(1))@1 as sum_a_plus_1, sum(test_table.b)@2 as sum_b, sum(test_table.b + Int64(2))@3 as sum_b_plus_2, sum(test_table.c + Int64(3))@4 as sum_c_plus_3] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a), sum(test_table.a + Int64(1)), sum(test_table.b), sum(test_table.b + Int64(2)), sum(test_table.c + Int64(3))] +03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1, CAST(b@1 AS Int64) as __common_expr_2, c@2 as c] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query IIIII +SELECT + SUM(a) as sum_a, + SUM(a + 1) as sum_a_plus_1, + SUM(b) as sum_b, + SUM(b + 2) as sum_b_plus_2, + SUM(c + 3) as sum_c_plus_3 +FROM test_table; +---- +15 20 150 160 1515 + +# ==== Test 8: Aliased expressions (should NOT rewrite - only 1 SUM with constant per column) ==== + +# This has SUM(a), SUM(a+10), SUM(b), SUM(b-5) +# Each column has only 1 SUM WITH a constant, so no rewrite +query TT +EXPLAIN SELECT + SUM(a) AS total_a, + SUM(a + 10) AS total_a_plus_10, + SUM(b) AS total_b, + SUM(b - 5) AS total_b_minus_5 +FROM test_table; +---- +logical_plan +01)Projection: sum(test_table.a) AS total_a, sum(test_table.a + Int64(10)) AS total_a_plus_10, sum(test_table.b) AS total_b, sum(test_table.b - Int64(5)) AS total_b_minus_5 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test_table.a), sum(__common_expr_1 AS test_table.a + Int64(10)), sum(__common_expr_2 AS test_table.b), sum(__common_expr_2 AS test_table.b - Int64(5))]] +03)----Projection: CAST(test_table.a AS Int64) AS __common_expr_1, CAST(test_table.b AS Int64) AS __common_expr_2 +04)------TableScan: test_table projection=[a, b] +physical_plan +01)ProjectionExec: expr=[sum(test_table.a)@0 as total_a, sum(test_table.a + Int64(10))@1 as total_a_plus_10, sum(test_table.b)@2 as total_b, sum(test_table.b - Int64(5))@3 as total_b_minus_5] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a), sum(test_table.a + Int64(10)), sum(test_table.b), sum(test_table.b - Int64(5))] +03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1, CAST(b@1 AS Int64) as __common_expr_2] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query IIII +SELECT + SUM(a) AS total_a, + SUM(a + 10) AS total_a_plus_10, + SUM(b) AS total_b, + SUM(b - 5) AS total_b_minus_5 +FROM test_table; +---- +15 65 150 125 + +# Test with 2+ SUMs with constants on same columns (triggers rewrite) +# With IndexMap, the order is deterministic based on insertion order: +# - Group "a" first (from SUM(a + 5)) +# - Group "b" second (from SUM(b - 3)) +query TT +EXPLAIN SELECT + SUM(a + 5) AS sum_a_plus_5, + SUM(a + 10) AS sum_a_plus_10, + SUM(b - 3) AS sum_b_minus_3, + SUM(b - 5) AS sum_b_minus_5 +FROM test_table; +---- +logical_plan +01)Projection: sum(test_table.a) + Int64(5) * count(test_table.a) AS sum_a_plus_5, sum(test_table.a) + Int64(10) * count(test_table.a) AS sum_a_plus_10, sum(test_table.b) - Int64(3) * count(test_table.b) AS sum_b_minus_3, sum(test_table.b) - Int64(5) * count(test_table.b) AS sum_b_minus_5 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test_table.a), count(__common_expr_1 AS test_table.a), sum(__common_expr_2 AS test_table.b), count(__common_expr_2 AS test_table.b)]] +03)----Projection: CAST(test_table.a AS Int64) AS __common_expr_1, CAST(test_table.b AS Int64) AS __common_expr_2 +04)------TableScan: test_table projection=[a, b] +physical_plan +01)ProjectionExec: expr=[sum(test_table.a)@0 + 5 * count(test_table.a)@1 as sum_a_plus_5, sum(test_table.a)@0 + 10 * count(test_table.a)@1 as sum_a_plus_10, sum(test_table.b)@2 - 3 * count(test_table.b)@3 as sum_b_minus_3, sum(test_table.b)@2 - 5 * count(test_table.b)@3 as sum_b_minus_5] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a), count(test_table.a), sum(test_table.b), count(test_table.b)] +03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1, CAST(b@1 AS Int64) as __common_expr_2] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query IIII +SELECT + SUM(a + 5) AS sum_a_plus_5, + SUM(a + 10) AS sum_a_plus_10, + SUM(b - 3) AS sum_b_minus_3, + SUM(b - 5) AS sum_b_minus_5 +FROM test_table; +---- +40 65 135 125 + +# ==== Test 9: Complex base expressions (SUM(a + b + 1)) ==== + +statement ok +CREATE TABLE complex_test ( + a INT, + b INT +) AS VALUES + (1, 10), + (2, 20), + (3, 30); + +# Test: Multiple SUMs on the same complex expression (a + b) +query TT +EXPLAIN SELECT + SUM(a + b) as sum_ab, + SUM(a + b + 1) as sum_ab_plus_1, + SUM(a + b + 2) as sum_ab_plus_2 +FROM complex_test; +---- +logical_plan +01)Projection: sum(complex_test.a + complex_test.b) AS sum_ab, sum(complex_test.a + complex_test.b) + count(complex_test.a + complex_test.b) AS sum_ab_plus_1, sum(complex_test.a + complex_test.b) + Int64(2) * count(complex_test.a + complex_test.b) AS sum_ab_plus_2 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS complex_test.a + complex_test.b), count(__common_expr_1 AS complex_test.a + complex_test.b)]] +03)----Projection: CAST(complex_test.a + complex_test.b AS Int64) AS __common_expr_1 +04)------TableScan: complex_test projection=[a, b] +physical_plan +01)ProjectionExec: expr=[sum(complex_test.a + complex_test.b)@0 as sum_ab, sum(complex_test.a + complex_test.b)@0 + count(complex_test.a + complex_test.b)@1 as sum_ab_plus_1, sum(complex_test.a + complex_test.b)@0 + 2 * count(complex_test.a + complex_test.b)@1 as sum_ab_plus_2] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(complex_test.a + complex_test.b), count(complex_test.a + complex_test.b)] +03)----ProjectionExec: expr=[CAST(a@0 + b@1 AS Int64) as __common_expr_1] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query III +SELECT + SUM(a + b) as sum_ab, + SUM(a + b + 1) as sum_ab_plus_1, + SUM(a + b + 2) as sum_ab_plus_2 +FROM complex_test; +---- +66 69 72 + +# Test: Different complex expressions (a + b) vs (b + a) should NOT be grouped together +query TT +EXPLAIN SELECT + SUM(a + b + 1) as sum_ab_plus_1, + SUM(b + a + 1) as sum_ba_plus_1 +FROM complex_test; +---- +logical_plan +01)Projection: sum(complex_test.a + complex_test.b + Int64(1)) AS sum_ab_plus_1, sum(complex_test.b + complex_test.a + Int64(1)) AS sum_ba_plus_1 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS complex_test.a + complex_test.b + Int64(1)), sum(__common_expr_1 AS complex_test.b + complex_test.a + Int64(1))]] +03)----Projection: CAST(complex_test.a + complex_test.b AS Int64) + Int64(1) AS __common_expr_1 +04)------TableScan: complex_test projection=[a, b] +physical_plan +01)ProjectionExec: expr=[sum(complex_test.a + complex_test.b + Int64(1))@0 as sum_ab_plus_1, sum(complex_test.b + complex_test.a + Int64(1))@1 as sum_ba_plus_1] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(complex_test.a + complex_test.b + Int64(1)), sum(complex_test.b + complex_test.a + Int64(1))] +03)----ProjectionExec: expr=[CAST(a@0 + b@1 AS Int64) + 1 as __common_expr_1] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +# ==== Test 10: Nested constants - Limitation of SimplifyExpressions ==== +# +# This test demonstrates a limitation in constant folding that affects this rule. +# +# Expression parsing (left-to-right associativity): +# - "a + 1 + 2" is parsed as "(a + 1) + 2" +# - "a + 5 + 7 * 8" is parsed as "(a + 5) + (7 * 8)" +# +# What SimplifyExpressions does: +# - "7 * 8" → "56" (two literals in one binary expr → folded) +# - "(a + 5) + 56" is NOT simplified to "a + 61" because: +# - The simplifier sees: (Expr) + 56 +# - It doesn't "look inside" the left child to find the 5 +# - Constants spread across nested additions are not gathered +# +# Expression tree for "a + 5 + 7 * 8" after partial simplification: +# +# + +# / \ +# + 56 <-- 7*8 was folded to 56 +# / \ +# a 5 <-- 5 is NOT combined with 56 +# +# Impact on RewriteAggregateWithConstant: +# - The rule sees base_expr = "(a + 5)" with constant = 56 +# - It does NOT see base_expr = "a" with constant = 61 +# - Since SUM(a) has a different base_expr than SUM((a+5) + 56), they are NOT grouped +# - Therefore, this query does NOT trigger the rewrite optimization +# +# The plan below shows all three SUMs computed separately (no Projection + Aggregate rewrite): +query TT +EXPLAIN SELECT + SUM(a), + SUM(a + 1 + 2), + SUM(a + 5 + 7 * 8) +FROM complex_test; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS complex_test.a), sum(__common_expr_1 AS complex_test.a + Int64(1) + Int64(2)), sum(__common_expr_1 + Int64(5) + Int64(56)) AS sum(complex_test.a + Int64(5) + Int64(7) * Int64(8))]] +02)--Projection: CAST(complex_test.a AS Int64) AS __common_expr_1 +03)----TableScan: complex_test projection=[a] +physical_plan +01)AggregateExec: mode=Single, gby=[], aggr=[sum(complex_test.a), sum(complex_test.a + Int64(1) + Int64(2)), sum(complex_test.a + Int64(5) + Int64(7) * Int64(8))] +02)--ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1] +03)----DataSourceExec: partitions=1, partition_sizes=[1] + +# Verify correctness: +# - SUM(a) = 1+2+3 = 6 +# - SUM(a + 1 + 2) = SUM(a + 3) = 6 + 3*3 = 15 +# - SUM(a + 5 + 7*8) = SUM(a + 61) = 6 + 61*3 = 189 +query III +SELECT + SUM(a), + SUM(a + 1 + 2), + SUM(a + 5 + 7 * 8) +FROM complex_test; +---- +6 15 189 + +# ==== Test 11: Multiple groups with different base expressions ==== +# +# This test demonstrates how the rule handles multiple independent groups +# of SUM expressions, each with a different base expression. +# +# Expression parsing (all parsed left-to-right): +# - sum(a) → base: a (but no constant, so not matched by rule) +# - sum(a+1) → base: a, constant: 1 +# - sum(a+1+1) → parsed as (a+1)+1 → base: (a+1), constant: 1 +# - sum(a+1+2) → parsed as (a+1)+2 → base: (a+1), constant: 2 +# - sum(a+5+1) → parsed as (a+5)+1 → base: (a+5), constant: 1 +# - sum(a+5+2) → parsed as (a+5)+2 → base: (a+5), constant: 2 +# - sum(a+5+3) → parsed as (a+5)+3 → base: (a+5), constant: 3 +# +# Groupings by base expression: +# - Group "a": [sum(a+1)] → only 1 SUM with constant, NOT rewritten +# - Group "a + 1": [sum(a+1+1), sum(a+1+2)] → 2 SUMs, REWRITTEN +# - Group "a + 5": [sum(a+5+1), sum(a+5+2), sum(a+5+3)] → 3 SUMs, REWRITTEN +# +# Note: sum(a) is a plain SUM without a binary expression argument, +# so it doesn't match the SUM(base ± const) pattern at all. +# +# Expected rewrite: +# - sum(a), sum(a+1): computed as-is (no optimization) +# - sum(a+1+1), sum(a+1+2): use SUM(a+1) + k*COUNT(a+1) +# - sum(a+5+1), sum(a+5+2), sum(a+5+3): use SUM(a+5) + k*COUNT(a+5) +# +# With IndexMap, the groups are processed in insertion order: +# - Group "a + Int64(1)" first (from SUM(a+1+1)) +# - Group "a + Int64(5)" second (from SUM(a+5+1)) +# Non-matching expressions (SUM(a), SUM(a+1)) are passed through. +query TT +EXPLAIN SELECT + SUM(a), + SUM(a + 1), + SUM(a + 1 + 1), + SUM(a + 1 + 2), + SUM(a + 5 + 1), + SUM(a + 5 + 2), + SUM(a + 5 + 3) +FROM complex_test; +---- +logical_plan +01)Projection: sum(complex_test.a), sum(complex_test.a + Int64(1)), sum(complex_test.a + Int64(1)) + count(complex_test.a + Int64(1)) AS sum(complex_test.a + Int64(1) + Int64(1)), sum(complex_test.a + Int64(1)) + Int64(2) * count(complex_test.a + Int64(1)) AS sum(complex_test.a + Int64(1) + Int64(2)), sum(complex_test.a + Int64(5)) + count(complex_test.a + Int64(5)) AS sum(complex_test.a + Int64(5) + Int64(1)), sum(complex_test.a + Int64(5)) + Int64(2) * count(complex_test.a + Int64(5)) AS sum(complex_test.a + Int64(5) + Int64(2)), sum(complex_test.a + Int64(5)) + Int64(3) * count(complex_test.a + Int64(5)) AS sum(complex_test.a + Int64(5) + Int64(3)) +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS complex_test.a + Int64(1)), count(__common_expr_1 AS complex_test.a + Int64(1)), sum(__common_expr_2 AS complex_test.a + Int64(5)), count(__common_expr_2 AS complex_test.a + Int64(5)), sum(__common_expr_3 AS complex_test.a)]] +03)----Projection: __common_expr_4 + Int64(1) AS __common_expr_1, __common_expr_4 + Int64(5) AS __common_expr_2, __common_expr_4 AS __common_expr_3 +04)------Projection: CAST(complex_test.a AS Int64) AS __common_expr_4 +05)--------TableScan: complex_test projection=[a] +physical_plan +01)ProjectionExec: expr=[sum(complex_test.a)@4 as sum(complex_test.a), sum(complex_test.a + Int64(1))@0 as sum(complex_test.a + Int64(1)), sum(complex_test.a + Int64(1))@0 + count(complex_test.a + Int64(1))@1 as sum(complex_test.a + Int64(1) + Int64(1)), sum(complex_test.a + Int64(1))@0 + 2 * count(complex_test.a + Int64(1))@1 as sum(complex_test.a + Int64(1) + Int64(2)), sum(complex_test.a + Int64(5))@2 + count(complex_test.a + Int64(5))@3 as sum(complex_test.a + Int64(5) + Int64(1)), sum(complex_test.a + Int64(5))@2 + 2 * count(complex_test.a + Int64(5))@3 as sum(complex_test.a + Int64(5) + Int64(2)), sum(complex_test.a + Int64(5))@2 + 3 * count(complex_test.a + Int64(5))@3 as sum(complex_test.a + Int64(5) + Int64(3))] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(complex_test.a + Int64(1)), count(complex_test.a + Int64(1)), sum(complex_test.a + Int64(5)), count(complex_test.a + Int64(5)), sum(complex_test.a)] +03)----ProjectionExec: expr=[__common_expr_4@0 + 1 as __common_expr_1, __common_expr_4@0 + 5 as __common_expr_2, __common_expr_4@0 as __common_expr_3] +04)------ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_4] +05)--------DataSourceExec: partitions=1, partition_sizes=[1] + +# Verify correctness (using complex_test: a = 1, 2, 3): +# - sum(a) = 1+2+3 = 6 +# - sum(a+1) = 2+3+4 = 9 +# - sum(a+1+1) = sum(a+2) = 3+4+5 = 12 +# - sum(a+1+2) = sum(a+3) = 4+5+6 = 15 +# - sum(a+5+1) = sum(a+6) = 7+8+9 = 24 +# - sum(a+5+2) = sum(a+7) = 8+9+10 = 27 +# - sum(a+5+3) = sum(a+8) = 9+10+11 = 30 +query IIIIIII +SELECT + SUM(a), + SUM(a + 1), + SUM(a + 1 + 1), + SUM(a + 1 + 2), + SUM(a + 5 + 1), + SUM(a + 5 + 2), + SUM(a + 5 + 3) +FROM complex_test; +---- +6 9 12 15 24 27 30 + + +# ==== Test 12: Expressions with DISTINCT or FILTER (SHOULD NOT rewrite) ==== + +# SUM(DISTINCT a + 1) should NOT be rewritten because the math SUM(a) + COUNT(a) +# doesn't hold when counting distinct values of the modified expression. +query TT +EXPLAIN SELECT + SUM(DISTINCT a + 1), + SUM(DISTINCT a + 2) +FROM test_table; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[sum(DISTINCT __common_expr_1 AS test_table.a + Int64(1)), sum(DISTINCT __common_expr_1 AS test_table.a + Int64(2))]] +02)--Projection: CAST(test_table.a AS Int64) AS __common_expr_1 +03)----TableScan: test_table projection=[a] +physical_plan +01)AggregateExec: mode=Single, gby=[], aggr=[sum(DISTINCT test_table.a + Int64(1)), sum(DISTINCT test_table.a + Int64(2))] +02)--ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1] +03)----DataSourceExec: partitions=1, partition_sizes=[1] + +# SUM(a + 1) FILTER (WHERE a > 1) should NOT be rewritten because the filter +# applies to the entire expression, and the rule doesn't handle filters. +query TT +EXPLAIN SELECT + SUM(a + 1) FILTER (WHERE a > 1), + SUM(a + 2) FILTER (WHERE a > 1) +FROM test_table; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 + Int64(1)) FILTER (WHERE __common_expr_2) AS sum(test_table.a + Int64(1)) FILTER (WHERE test_table.a > Int64(1)), sum(__common_expr_1 + Int64(2)) FILTER (WHERE __common_expr_2) AS sum(test_table.a + Int64(2)) FILTER (WHERE test_table.a > Int64(1))]] +02)--Projection: CAST(test_table.a AS Int64) AS __common_expr_1, test_table.a > Int32(1) AS __common_expr_2 +03)----TableScan: test_table projection=[a] +physical_plan +01)AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a + Int64(1)) FILTER (WHERE test_table.a > Int64(1)), sum(test_table.a + Int64(2)) FILTER (WHERE test_table.a > Int64(1))] +02)--ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1, a@0 > 1 as __common_expr_2] +03)----DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +EXPLAIN SELECT + SUM(a + 1), + SUM(a + 2) +FROM test_table WHERE a > 3; +---- +logical_plan +01)Projection: sum(test_table.a) + count(test_table.a) AS sum(test_table.a + Int64(1)), sum(test_table.a) + Int64(2) * count(test_table.a) AS sum(test_table.a + Int64(2)) +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test_table.a), count(__common_expr_1 AS test_table.a)]] +03)----Projection: CAST(test_table.a AS Int64) AS __common_expr_1 +04)------Filter: test_table.a > Int32(3) +05)--------TableScan: test_table projection=[a] +physical_plan +01)ProjectionExec: expr=[sum(test_table.a)@0 + count(test_table.a)@1 as sum(test_table.a + Int64(1)), sum(test_table.a)@0 + 2 * count(test_table.a)@1 as sum(test_table.a + Int64(2))] +02)--AggregateExec: mode=Final, gby=[], aggr=[sum(test_table.a), count(test_table.a)] +03)----CoalescePartitionsExec +04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(test_table.a), count(test_table.a)] +05)--------ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------FilterExec: a@0 > 3 +08)--------------DataSourceExec: partitions=1, partition_sizes=[1] + +# ==== Test 13: Constant on the left side (Commutative property) ==== + +# Test: SUM(constant + a) should be rewritten to SUM(a) + constant * COUNT(a) +query TT +EXPLAIN SELECT + SUM(5 + a) as sum_5_plus_a, + SUM(10 + a) as sum_10_plus_a +FROM test_table; +---- +logical_plan +01)Projection: sum(test_table.a) + Int64(5) * count(test_table.a) AS sum_5_plus_a, sum(test_table.a) + Int64(10) * count(test_table.a) AS sum_10_plus_a +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test_table.a), count(__common_expr_1 AS test_table.a)]] +03)----Projection: CAST(test_table.a AS Int64) AS __common_expr_1 +04)------TableScan: test_table projection=[a] +physical_plan +01)ProjectionExec: expr=[sum(test_table.a)@0 + 5 * count(test_table.a)@1 as sum_5_plus_a, sum(test_table.a)@0 + 10 * count(test_table.a)@1 as sum_10_plus_a] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a), count(test_table.a)] +03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query II +SELECT + SUM(5 + a) as sum_5_plus_a, + SUM(10 + a) as sum_10_plus_a +FROM test_table; +---- +40 65 + +# ==== Test 14: Constant on the left side with subtraction (SHOULD NOT rewrite) ==== + +# Test: SUM(constant - a) should NOT be rewritten because 5 - a is NOT a - 5. +# The rule intentionally only handles constant + col and col - constant. +query TT +EXPLAIN SELECT + SUM(5 - a), + SUM(10 - a) +FROM test_table; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[sum(Int64(5) - __common_expr_1 AS test_table.a), sum(Int64(10) - __common_expr_1 AS test_table.a)]] +02)--Projection: CAST(test_table.a AS Int64) AS __common_expr_1 +03)----TableScan: test_table projection=[a] +physical_plan +01)AggregateExec: mode=Single, gby=[], aggr=[sum(Int64(5) - test_table.a), sum(Int64(10) - test_table.a)] +02)--ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1] +03)----DataSourceExec: partitions=1, partition_sizes=[1] + +query II +SELECT + SUM(5 - a), + SUM(10 - a) +FROM test_table; +---- +10 35 + +# ==== Cleanup ==== + + +statement ok +DROP TABLE test_table; + +statement ok +DROP TABLE group_test; + +statement ok +DROP TABLE nullable_test; + +statement ok +DROP TABLE complex_test; diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 6f615ec391c9..617259d2e917 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -196,6 +196,7 @@ logical_plan after push_down_limit SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE +logical_plan after rewrite_aggregate_with_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE logical_plan after optimize_projections TableScan: simple_explain_test projection=[a, b, c] logical_plan after rewrite_set_comparison SAME TEXT AS ABOVE @@ -218,6 +219,7 @@ logical_plan after push_down_limit SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE +logical_plan after rewrite_aggregate_with_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE logical_plan after optimize_projections SAME TEXT AS ABOVE logical_plan TableScan: simple_explain_test projection=[a, b, c] @@ -557,6 +559,7 @@ logical_plan after push_down_limit SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE +logical_plan after rewrite_aggregate_with_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE logical_plan after optimize_projections TableScan: simple_explain_test projection=[a, b, c] logical_plan after rewrite_set_comparison SAME TEXT AS ABOVE @@ -579,6 +582,7 @@ logical_plan after push_down_limit SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE +logical_plan after rewrite_aggregate_with_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE logical_plan after optimize_projections SAME TEXT AS ABOVE logical_plan TableScan: simple_explain_test projection=[a, b, c]