From e62ca58bfa1dacd11d4f76229a1795040f4d6b83 Mon Sep 17 00:00:00 2001 From: Devanshu Date: Tue, 3 Feb 2026 23:01:14 +0700 Subject: [PATCH 1/9] Initial draft --- datafusion/optimizer/Cargo.toml | 1 + datafusion/optimizer/src/lib.rs | 1 + datafusion/optimizer/src/optimizer.rs | 2 + .../src/rewrite_aggregate_with_constant.rs | 462 ++++++++++++++++++ .../aggregate_rewrite_with_constant.slt | 71 +++ 5 files changed, 537 insertions(+) create mode 100644 datafusion/optimizer/src/rewrite_aggregate_with_constant.rs create mode 100644 datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 15d3261ca5132..29bfd18adb92a 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -49,6 +49,7 @@ chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } datafusion-expr-common = { workspace = true } +datafusion-functions-aggregate = { workspace = true } datafusion-physical-expr = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index e6b24dec87fd8..31aae1c990361 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -65,6 +65,7 @@ pub mod propagate_empty_relation; pub mod push_down_filter; pub mod push_down_limit; pub mod replace_distinct_aggregate; +pub mod rewrite_aggregate_with_constant; pub mod rewrite_set_comparison; pub mod scalar_subquery_to_join; pub mod simplify_expressions; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 877a84fe4dc14..8cb7d5291e7af 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -51,6 +51,7 @@ use crate::propagate_empty_relation::PropagateEmptyRelation; use crate::push_down_filter::PushDownFilter; use crate::push_down_limit::PushDownLimit; use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; +use crate::rewrite_aggregate_with_constant::RewriteAggregateWithConstant; use crate::rewrite_set_comparison::RewriteSetComparison; use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; use crate::simplify_expressions::SimplifyExpressions; @@ -259,6 +260,7 @@ impl Optimizer { // The previous optimizations added expressions and projections, // that might benefit from the following rules Arc::new(EliminateGroupByConstant::new()), + Arc::new(RewriteAggregateWithConstant::new()), Arc::new(CommonSubexprEliminate::new()), Arc::new(OptimizeProjections::new()), ]; diff --git a/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs b/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs new file mode 100644 index 0000000000000..1614d748bffdc --- /dev/null +++ b/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs @@ -0,0 +1,462 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`RewriteAggregateWithConstant`] rewrites `SUM(column ± constant)` to `SUM(column) ± constant * COUNT(*)` + +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; + +use datafusion_common::tree_node::Transformed; +use datafusion_common::{ExprSchema, Result, ScalarValue}; +use datafusion_expr::expr::{AggregateFunctionParams, NullTreatment, Sort}; +use datafusion_expr::{ + Aggregate, BinaryExpr, Expr, LogicalPlan, LogicalPlanBuilder, Operator, binary_expr, + col, lit, +}; +use datafusion_functions_aggregate::expr_fn::{count, sum}; +use std::collections::HashMap; + +/// Optimizer rule that rewrites `SUM(column ± constant)` expressions +/// into `SUM(column) ± constant * COUNT(*)` when multiple such expressions +/// exist for the same base column. +/// +/// This reduces computation by calculating SUM once and deriving other values. +/// +/// # Example +/// ```sql +/// SELECT SUM(a), SUM(a + 1), SUM(a + 2) FROM t; +/// ``` +/// is rewritten to: +/// ```sql +/// SELECT sum_a, sum_a + 1 * count_a, sum_a + 2 * count_a +/// FROM (SELECT SUM(a) as sum_a, COUNT(a) as count_a FROM t); +/// ``` +#[derive(Default, Debug)] +pub struct RewriteAggregateWithConstant {} + +impl RewriteAggregateWithConstant { + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for RewriteAggregateWithConstant { + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + match plan { + LogicalPlan::Aggregate(aggregate) => { + // Check if we can apply the transformation + let rewrite_info = analyze_aggregate(&aggregate)?; + + if rewrite_info.is_empty() { + // No transformation possible + return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate))); + } + + // Build the transformed plan + transform_aggregate(aggregate, &rewrite_info) + } + _ => Ok(Transformed::no(plan)), + } + } + + fn name(&self) -> &str { + "rewrite_aggregate_with_constant" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } +} + +/// Information about a SUM expression with a constant offset +#[derive(Debug, Clone)] +struct SumWithConstant { + /// The base expression (e.g., column 'a' in SUM(a + 1)) + base_expr: Expr, + /// The constant value being added/subtracted + constant: ScalarValue, + /// The operator (+ or -) + operator: Operator, + /// Original index in the aggregate expressions + original_index: usize, + /// ORDER BY clause if present + order_by: Vec, + /// NULL treatment + _null_treatment: Option, + /// Whether the base column is nullable + _is_nullable: bool, +} + +/// Information about groups of SUMs that can be rewritten +type RewriteGroups = HashMap>; + +/// Analyze the aggregate to find groups of SUM(col ± constant) that can be rewritten +fn analyze_aggregate(aggregate: &Aggregate) -> Result { + let mut groups: RewriteGroups = HashMap::new(); + + for (idx, expr) in aggregate.aggr_expr.iter().enumerate() { + if let Some(sum_info) = extract_sum_with_constant(expr, idx, aggregate)? { + let key = sum_info.base_expr.schema_name().to_string(); + groups.entry(key).or_default().push(sum_info); + } + } + + // Only keep groups with 2 or more SUMs on the same base column + groups.retain(|_, v| v.len() >= 2); + + Ok(groups) +} + +/// Extract SUM(base_expr ± constant) pattern from an expression +fn extract_sum_with_constant( + expr: &Expr, + idx: usize, + aggregate: &Aggregate, +) -> Result> { + match expr { + Expr::AggregateFunction(agg_fn) => { + // Must be SUM function + if agg_fn.func.name().to_lowercase() != "sum" { + return Ok(None); + } + + let AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + } = &agg_fn.params; + + // Skip if DISTINCT or FILTER present + if *distinct || filter.is_some() { + return Ok(None); + } + + // Must have exactly one argument + if args.len() != 1 { + return Ok(None); + } + + let arg = &args[0]; + + // Try to match: base_expr +/- constant + if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = arg { + if matches!(op, Operator::Plus | Operator::Minus) { + // Check if right side is a literal constant + if let Expr::Literal(constant, _) = right.as_ref() { + // Check if it's a numeric constant + if is_numeric_constant(constant) { + let is_nullable = is_expr_nullable(left, aggregate)?; + + return Ok(Some(SumWithConstant { + base_expr: (**left).clone(), + constant: constant.clone(), + operator: *op, + original_index: idx, + order_by: order_by.clone(), + _null_treatment: *null_treatment, + _is_nullable: is_nullable, + })); + } + } + + // Also check left side (for patterns like: constant + base_expr) + if let Expr::Literal(constant, _) = left.as_ref() { + if is_numeric_constant(constant) && *op == Operator::Plus { + let is_nullable = is_expr_nullable(right, aggregate)?; + + return Ok(Some(SumWithConstant { + base_expr: (**right).clone(), + constant: constant.clone(), + operator: Operator::Plus, + original_index: idx, + order_by: order_by.clone(), + _null_treatment: *null_treatment, + _is_nullable: is_nullable, + })); + } + } + } + } + + Ok(None) + } + _ => Ok(None), + } +} + +/// Check if a scalar value is a numeric constant +fn is_numeric_constant(value: &ScalarValue) -> bool { + matches!( + value, + ScalarValue::Int8(_) + | ScalarValue::Int16(_) + | ScalarValue::Int32(_) + | ScalarValue::Int64(_) + | ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) + | ScalarValue::Float32(_) + | ScalarValue::Float64(_) + | ScalarValue::Decimal128(_, _, _) + | ScalarValue::Decimal256(_, _, _) + ) +} + +/// Check if an expression references a nullable column +fn is_expr_nullable(expr: &Expr, aggregate: &Aggregate) -> Result { + // For simple column references, check the input schema + if let Expr::Column(col_ref) = expr { + let input_schema = aggregate.input.schema(); + if let Ok(field) = input_schema.field_from_column(col_ref) { + return Ok(field.is_nullable()); + } + } + + // For more complex expressions, assume nullable to be safe + Ok(true) +} + +/// Check if an expression is a plain SUM(base_expr) that matches one of our rewrite groups +fn check_plain_sum_in_group( + expr: &Expr, + base_expr_indices: &HashMap, +) -> Option<(usize, usize)> { + if let Expr::AggregateFunction(agg_fn) = expr + && agg_fn.func.name().to_lowercase() == "sum" + && agg_fn.params.args.len() == 1 + && !agg_fn.params.distinct + && agg_fn.params.filter.is_none() + { + let arg = &agg_fn.params.args[0]; + let base_key = arg.schema_name().to_string(); + return base_expr_indices.get(&base_key).copied(); + } + None +} + +/// Transform the aggregate plan by rewriting SUM(col ± constant) expressions +fn transform_aggregate( + aggregate: Aggregate, + rewrite_groups: &RewriteGroups, +) -> Result> { + let mut new_aggr_exprs = Vec::new(); + let mut projection_exprs = Vec::new(); + + // Build a flat list of all SUMs to rewrite, sorted by original index + let mut all_sums: Vec = rewrite_groups + .values() + .flat_map(|v| v.iter().cloned()) + .collect(); + all_sums.sort_by_key(|s| s.original_index); + + // Track which base expressions we've already added SUM/COUNT for + let mut base_expr_indices: HashMap = HashMap::new(); + + // Process each group to determine what to add to the aggregate + let mut sum_names: HashMap = HashMap::new(); + let mut count_names: HashMap = HashMap::new(); + + #[allow(clippy::needless_borrows_for_generic_args)] + for (base_key, sums) in rewrite_groups.iter() { + // Find a representative SUM (prefer one with ORDER BY if any) + let representative = sums + .iter() + .find(|s| !s.order_by.is_empty()) + .unwrap_or(&sums[0]); + + // Add SUM(base_expr) with ORDER BY preserved + let sum_expr = sum(representative.base_expr.clone()); + // Note: ORDER BY is not needed for SUM as it's commutative + let sum_name = sum_expr.schema_name().to_string(); + + let sum_index = new_aggr_exprs.len(); + new_aggr_exprs.push(sum_expr); + sum_names.insert(base_key.clone(), sum_name); + + // Add COUNT - use COUNT(col) for nullable columns + // For nullable columns, COUNT(col) correctly excludes NULLs + let count_expr = count(representative.base_expr.clone()); + let count_name = count_expr.schema_name().to_string(); + + let count_index = new_aggr_exprs.len(); + new_aggr_exprs.push(count_expr); + count_names.insert(base_key.clone(), count_name); + + base_expr_indices.insert(base_key.clone(), (sum_index, count_index)); + } + + // Now build projection expressions for all original aggregate expressions + for (idx, orig_expr) in aggregate.aggr_expr.iter().enumerate() { + // Check if this expression should be rewritten + let rewritten = all_sums.iter().find(|s| s.original_index == idx); + + let projection_expr = if let Some(sum_info) = rewritten { + let base_key = sum_info.base_expr.schema_name().to_string(); + + // Build: SUM(col) ± constant * COUNT(...) + let sum_ref = col(&sum_names[&base_key]); + let count_ref = col(&count_names[&base_key]); + + let multiplied = binary_expr( + lit(sum_info.constant.clone()), + Operator::Multiply, + count_ref, + ); + + let result = binary_expr(sum_ref, sum_info.operator, multiplied); + + // Preserve original alias if present + match orig_expr { + Expr::Alias(alias) => result.alias(alias.name.clone()), + _ => result.alias(orig_expr.schema_name().to_string()), + } + } else { + // Check if this is a plain SUM(base_expr) that we're already computing + let is_plain_sum_in_group = + check_plain_sum_in_group(orig_expr, &base_expr_indices); + + if is_plain_sum_in_group.is_some() { + // Use the already-computed SUM + let base_key = if let Expr::AggregateFunction(agg_fn) = orig_expr { + agg_fn.params.args[0].schema_name().to_string() + } else { + String::new() + }; + let sum_ref = col(&sum_names[&base_key]); + match orig_expr { + Expr::Alias(alias) => sum_ref.alias(alias.name.clone()), + _ => sum_ref.alias(orig_expr.schema_name().to_string()), + } + } else { + // Keep non-rewritten expressions as-is + new_aggr_exprs.push(orig_expr.clone()); + + match orig_expr { + Expr::Alias(alias) => col(alias.name.clone()), + _ => col(orig_expr.schema_name().to_string()), + } + } + }; + + projection_exprs.push(projection_expr); + } + + // Also add group by expressions to projection + let group_exprs: Vec = aggregate + .group_expr + .iter() + .map(|e| match e { + Expr::Alias(alias) => col(alias.name.clone()), + Expr::Column(c) => Expr::Column(c.clone()), + _ => col(e.schema_name().to_string()), + }) + .collect(); + + // Prepend group expressions to projection + let mut final_projection = group_exprs; + final_projection.extend(projection_exprs); + + // Create new aggregate with rewritten expressions + let new_aggregate = LogicalPlan::Aggregate(Aggregate::try_new( + aggregate.input, + aggregate.group_expr, + new_aggr_exprs, + )?); + + // Wrap with projection + let projection = LogicalPlanBuilder::from(new_aggregate) + .project(final_projection)? + .build()?; + + Ok(Transformed::yes(projection)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::OptimizerContext; + use crate::test::*; + use datafusion_common::Result; + use datafusion_expr::{LogicalPlanBuilder, col, lit}; + use datafusion_functions_aggregate::expr_fn::sum; + + #[test] + fn test_sum_with_constant_basic() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + Vec::::new(), + vec![ + sum(col("a")), + sum(col("a") + lit(1)), + sum(col("a") + lit(2)), + ], + )? + .build()?; + + let rule = RewriteAggregateWithConstant::new(); + let config = OptimizerContext::new(); + let result = rule.rewrite(plan, &config)?; + + // Should be transformed + assert!(result.transformed); + Ok(()) + } + + #[test] + fn test_no_transform_single_sum() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(Vec::::new(), vec![sum(col("a") + lit(1))])? + .build()?; + + let rule = RewriteAggregateWithConstant::new(); + let config = OptimizerContext::new(); + let result = rule.rewrite(plan, &config)?; + + // Should NOT be transformed (only one SUM) + assert!(!result.transformed); + Ok(()) + } + + #[test] + fn test_no_transform_no_constant() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(Vec::::new(), vec![sum(col("a")), sum(col("b"))])? + .build()?; + + let rule = RewriteAggregateWithConstant::new(); + let config = OptimizerContext::new(); + let result = rule.rewrite(plan, &config)?; + + // Should NOT be transformed (no constants) + assert!(!result.transformed); + Ok(()) + } +} diff --git a/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt b/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt new file mode 100644 index 0000000000000..b2c0679ef0220 --- /dev/null +++ b/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## Aggregate Rewrite With Constant Optimizer Tests +## Tests for the optimizer rule that rewrites SUM(col ± constant) to SUM(col) ± constant * COUNT(*) +########## + +# Setup test table +statement ok +CREATE TABLE test_table ( + a INT, + b INT, + c INT +) AS VALUES + (1, 10, 100), + (2, 20, 200), + (3, 30, 300), + (4, 40, 400), + (5, 50, 500); + +# Test: Multiple SUM expressions with constants should be rewritten +# This query should be optimized to compute SUM(a) and COUNT(a) once, +# then derive SUM(a+1), SUM(a+2), SUM(a+3) from those base aggregates +query TT +EXPLAIN SELECT + SUM(a) as sum_a, + SUM(a + 1) as sum_a_plus_1, + SUM(a + 2) as sum_a_plus_2, + SUM(a + 3) as sum_a_plus_3 +FROM test_table; +---- +logical_plan +01)Projection: sum(test_table.a) AS sum_a, sum(test_table.a) + count(test_table.a) AS sum_a_plus_1, sum(test_table.a) + Int64(2) * count(test_table.a) AS sum_a_plus_2, sum(test_table.a) + Int64(3) * count(test_table.a) AS sum_a_plus_3 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test_table.a), count(__common_expr_1 AS test_table.a)]] +03)----Projection: CAST(test_table.a AS Int64) AS __common_expr_1 +04)------TableScan: test_table projection=[a] +physical_plan +01)ProjectionExec: expr=[sum(test_table.a)@0 as sum_a, sum(test_table.a)@0 + count(test_table.a)@1 as sum_a_plus_1, sum(test_table.a)@0 + 2 * count(test_table.a)@1 as sum_a_plus_2, sum(test_table.a)@0 + 3 * count(test_table.a)@1 as sum_a_plus_3] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a), count(test_table.a)] +03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +# Verify the query produces correct results +query IIII +SELECT + SUM(a) as sum_a, + SUM(a + 1) as sum_a_plus_1, + SUM(a + 2) as sum_a_plus_2, + SUM(a + 3) as sum_a_plus_3 +FROM test_table; +---- +15 20 25 30 + +# Cleanup +statement ok +DROP TABLE test_table; From dc252b4f44919bf8351097fd858cdfcffd68fcca Mon Sep 17 00:00:00 2001 From: Devanshu Date: Tue, 3 Feb 2026 23:09:02 +0700 Subject: [PATCH 2/9] code review by gpt 5.2 --- .../src/rewrite_aggregate_with_constant.rs | 97 +++++++------------ 1 file changed, 33 insertions(+), 64 deletions(-) diff --git a/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs b/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs index 1614d748bffdc..a29f16904cd05 100644 --- a/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs +++ b/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs @@ -22,7 +22,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::{ExprSchema, Result, ScalarValue}; -use datafusion_expr::expr::{AggregateFunctionParams, NullTreatment, Sort}; +use datafusion_expr::expr::{AggregateFunctionParams, Sort}; use datafusion_expr::{ Aggregate, BinaryExpr, Expr, LogicalPlan, LogicalPlanBuilder, Operator, binary_expr, col, lit, @@ -103,10 +103,6 @@ struct SumWithConstant { original_index: usize, /// ORDER BY clause if present order_by: Vec, - /// NULL treatment - _null_treatment: Option, - /// Whether the base column is nullable - _is_nullable: bool, } /// Information about groups of SUMs that can be rewritten @@ -117,7 +113,7 @@ fn analyze_aggregate(aggregate: &Aggregate) -> Result { let mut groups: RewriteGroups = HashMap::new(); for (idx, expr) in aggregate.aggr_expr.iter().enumerate() { - if let Some(sum_info) = extract_sum_with_constant(expr, idx, aggregate)? { + if let Some(sum_info) = extract_sum_with_constant(expr, idx)? { let key = sum_info.base_expr.schema_name().to_string(); groups.entry(key).or_default().push(sum_info); } @@ -130,11 +126,7 @@ fn analyze_aggregate(aggregate: &Aggregate) -> Result { } /// Extract SUM(base_expr ± constant) pattern from an expression -fn extract_sum_with_constant( - expr: &Expr, - idx: usize, - aggregate: &Aggregate, -) -> Result> { +fn extract_sum_with_constant(expr: &Expr, idx: usize) -> Result> { match expr { Expr::AggregateFunction(agg_fn) => { // Must be SUM function @@ -147,7 +139,7 @@ fn extract_sum_with_constant( distinct, filter, order_by, - null_treatment, + null_treatment: _, } = &agg_fn.params; // Skip if DISTINCT or FILTER present @@ -163,42 +155,34 @@ fn extract_sum_with_constant( let arg = &args[0]; // Try to match: base_expr +/- constant - if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = arg { - if matches!(op, Operator::Plus | Operator::Minus) { - // Check if right side is a literal constant - if let Expr::Literal(constant, _) = right.as_ref() { - // Check if it's a numeric constant - if is_numeric_constant(constant) { - let is_nullable = is_expr_nullable(left, aggregate)?; - - return Ok(Some(SumWithConstant { - base_expr: (**left).clone(), - constant: constant.clone(), - operator: *op, - original_index: idx, - order_by: order_by.clone(), - _null_treatment: *null_treatment, - _is_nullable: is_nullable, - })); - } - } - - // Also check left side (for patterns like: constant + base_expr) - if let Expr::Literal(constant, _) = left.as_ref() { - if is_numeric_constant(constant) && *op == Operator::Plus { - let is_nullable = is_expr_nullable(right, aggregate)?; - - return Ok(Some(SumWithConstant { - base_expr: (**right).clone(), - constant: constant.clone(), - operator: Operator::Plus, - original_index: idx, - order_by: order_by.clone(), - _null_treatment: *null_treatment, - _is_nullable: is_nullable, - })); - } - } + if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = arg + && matches!(op, Operator::Plus | Operator::Minus) + { + // Check if right side is a literal constant + if let Expr::Literal(constant, _) = right.as_ref() + && is_numeric_constant(constant) + { + return Ok(Some(SumWithConstant { + base_expr: (**left).clone(), + constant: constant.clone(), + operator: *op, + original_index: idx, + order_by: order_by.clone(), + })); + } + + // Also check left side (for patterns like: constant + base_expr) + if let Expr::Literal(constant, _) = left.as_ref() + && is_numeric_constant(constant) + && *op == Operator::Plus + { + return Ok(Some(SumWithConstant { + base_expr: (**right).clone(), + constant: constant.clone(), + operator: Operator::Plus, + original_index: idx, + order_by: order_by.clone(), + })); } } @@ -227,20 +211,6 @@ fn is_numeric_constant(value: &ScalarValue) -> bool { ) } -/// Check if an expression references a nullable column -fn is_expr_nullable(expr: &Expr, aggregate: &Aggregate) -> Result { - // For simple column references, check the input schema - if let Expr::Column(col_ref) = expr { - let input_schema = aggregate.input.schema(); - if let Ok(field) = input_schema.field_from_column(col_ref) { - return Ok(field.is_nullable()); - } - } - - // For more complex expressions, assume nullable to be safe - Ok(true) -} - /// Check if an expression is a plain SUM(base_expr) that matches one of our rewrite groups fn check_plain_sum_in_group( expr: &Expr, @@ -281,8 +251,7 @@ fn transform_aggregate( let mut sum_names: HashMap = HashMap::new(); let mut count_names: HashMap = HashMap::new(); - #[allow(clippy::needless_borrows_for_generic_args)] - for (base_key, sums) in rewrite_groups.iter() { + for (base_key, sums) in rewrite_groups { // Find a representative SUM (prefer one with ORDER BY if any) let representative = sums .iter() From 281b05a2b7cc4d72843f1e4657c40d51fe9486a8 Mon Sep 17 00:00:00 2001 From: Devanshu Date: Tue, 3 Feb 2026 23:30:07 +0700 Subject: [PATCH 3/9] SLT Tests --- .../src/rewrite_aggregate_with_constant.rs | 2 +- .../aggregate_rewrite_with_constant.slt | 318 +++++++++++++++++- 2 files changed, 314 insertions(+), 6 deletions(-) diff --git a/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs b/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs index a29f16904cd05..a9472d585269f 100644 --- a/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs +++ b/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs @@ -21,7 +21,7 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; -use datafusion_common::{ExprSchema, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::expr::{AggregateFunctionParams, Sort}; use datafusion_expr::{ Aggregate, BinaryExpr, Expr, LogicalPlan, LogicalPlanBuilder, Operator, binary_expr, diff --git a/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt b/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt index b2c0679ef0220..583ff89d2a9ca 100644 --- a/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt +++ b/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt @@ -18,9 +18,11 @@ ########## ## Aggregate Rewrite With Constant Optimizer Tests ## Tests for the optimizer rule that rewrites SUM(col ± constant) to SUM(col) ± constant * COUNT(*) +## Rule only applies when there are 2+ SUM expressions on the SAME base column ########## -# Setup test table +# ==== Test 1: Basic addition with multiple sum expressions ==== + statement ok CREATE TABLE test_table ( a INT, @@ -34,8 +36,6 @@ CREATE TABLE test_table ( (5, 50, 500); # Test: Multiple SUM expressions with constants should be rewritten -# This query should be optimized to compute SUM(a) and COUNT(a) once, -# then derive SUM(a+1), SUM(a+2), SUM(a+3) from those base aggregates query TT EXPLAIN SELECT SUM(a) as sum_a, @@ -55,7 +55,6 @@ physical_plan 03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1] 04)------DataSourceExec: partitions=1, partition_sizes=[1] -# Verify the query produces correct results query IIII SELECT SUM(a) as sum_a, @@ -66,6 +65,315 @@ FROM test_table; ---- 15 20 25 30 -# Cleanup +# ==== Test 2: Subtraction operations ==== + +query TT +EXPLAIN SELECT + SUM(a) as sum_a, + SUM(a - 1) as sum_a_minus_1, + SUM(a - 2) as sum_a_minus_2, + SUM(a - 3) as sum_a_minus_3 +FROM test_table; +---- +logical_plan +01)Projection: sum(test_table.a) AS sum_a, sum(test_table.a) - count(test_table.a) AS sum_a_minus_1, sum(test_table.a) - Int64(2) * count(test_table.a) AS sum_a_minus_2, sum(test_table.a) - Int64(3) * count(test_table.a) AS sum_a_minus_3 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test_table.a), count(__common_expr_1 AS test_table.a)]] +03)----Projection: CAST(test_table.a AS Int64) AS __common_expr_1 +04)------TableScan: test_table projection=[a] +physical_plan +01)ProjectionExec: expr=[sum(test_table.a)@0 as sum_a, sum(test_table.a)@0 - count(test_table.a)@1 as sum_a_minus_1, sum(test_table.a)@0 - 2 * count(test_table.a)@1 as sum_a_minus_2, sum(test_table.a)@0 - 3 * count(test_table.a)@1 as sum_a_minus_3] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a), count(test_table.a)] +03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query IIII +SELECT + SUM(a) as sum_a, + SUM(a - 1) as sum_a_minus_1, + SUM(a - 2) as sum_a_minus_2, + SUM(a - 3) as sum_a_minus_3 +FROM test_table; +---- +15 10 5 0 + +# ==== Test 3: With GROUP BY ==== + +statement ok +CREATE TABLE group_test ( + category VARCHAR, + value INT +) AS VALUES + ('A', 1), + ('A', 2), + ('A', 3), + ('B', 4), + ('B', 5), + ('B', 6); + +query TT +EXPLAIN SELECT + category, + SUM(value) as sum_val, + SUM(value + 1) as sum_val_plus_1, + SUM(value - 2) as sum_val_minus_2 +FROM group_test +GROUP BY category; +---- +logical_plan +01)Projection: group_test.category, sum(group_test.value) AS sum_val, sum(group_test.value) + count(group_test.value) AS sum_val_plus_1, sum(group_test.value) - Int64(2) * count(group_test.value) AS sum_val_minus_2 +02)--Aggregate: groupBy=[[group_test.category]], aggr=[[sum(__common_expr_1 AS group_test.value), count(__common_expr_1 AS group_test.value)]] +03)----Projection: CAST(group_test.value AS Int64) AS __common_expr_1, group_test.category +04)------TableScan: group_test projection=[category, value] +physical_plan +01)ProjectionExec: expr=[category@0 as category, sum(group_test.value)@1 as sum_val, sum(group_test.value)@1 + count(group_test.value)@2 as sum_val_plus_1, sum(group_test.value)@1 - 2 * count(group_test.value)@2 as sum_val_minus_2] +02)--AggregateExec: mode=FinalPartitioned, gby=[category@0 as category], aggr=[sum(group_test.value), count(group_test.value)] +03)----RepartitionExec: partitioning=Hash([category@0], 4), input_partitions=1 +04)------AggregateExec: mode=Partial, gby=[category@1 as category], aggr=[sum(group_test.value), count(group_test.value)] +05)--------ProjectionExec: expr=[CAST(value@1 AS Int64) as __common_expr_1, category@0 as category] +06)----------DataSourceExec: partitions=1, partition_sizes=[1] + +query TIII rowsort +SELECT + category, + SUM(value) as sum_val, + SUM(value + 1) as sum_val_plus_1, + SUM(value - 2) as sum_val_minus_2 +FROM group_test +GROUP BY category; +---- +A 6 9 0 +B 15 18 9 + +# ==== Test 4: With nullable columns (SHOULD NOT rewrite - only 1 SUM per column) ==== + +statement ok +CREATE TABLE nullable_test ( + id INT, + a INT, + b INT +) AS VALUES + (1, 10, NULL), + (2, 20, 200), + (3, NULL, 300), + (4, 40, 400), + (5, 50, NULL); + +# This should NOT be rewritten because each column has only 1 SUM with a constant +query TT +EXPLAIN SELECT + SUM(a) as sum_a, + SUM(a + 5) as sum_a_plus_5, + SUM(b) as sum_b, + SUM(b - 10) as sum_b_minus_10 +FROM nullable_test; +---- +logical_plan +01)Projection: sum(nullable_test.a) AS sum_a, sum(nullable_test.a + Int64(5)) AS sum_a_plus_5, sum(nullable_test.b) AS sum_b, sum(nullable_test.b - Int64(10)) AS sum_b_minus_10 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS nullable_test.a), sum(__common_expr_1 AS nullable_test.a + Int64(5)), sum(__common_expr_2 AS nullable_test.b), sum(__common_expr_2 AS nullable_test.b - Int64(10))]] +03)----Projection: CAST(nullable_test.a AS Int64) AS __common_expr_1, CAST(nullable_test.b AS Int64) AS __common_expr_2 +04)------TableScan: nullable_test projection=[a, b] +physical_plan +01)ProjectionExec: expr=[sum(nullable_test.a)@0 as sum_a, sum(nullable_test.a + Int64(5))@1 as sum_a_plus_5, sum(nullable_test.b)@2 as sum_b, sum(nullable_test.b - Int64(10))@3 as sum_b_minus_10] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(nullable_test.a), sum(nullable_test.a + Int64(5)), sum(nullable_test.b), sum(nullable_test.b - Int64(10))] +03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1, CAST(b@1 AS Int64) as __common_expr_2] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query IIII +SELECT + SUM(a) as sum_a, + SUM(a + 5) as sum_a_plus_5, + SUM(b) as sum_b, + SUM(b - 10) as sum_b_minus_10 +FROM nullable_test; +---- +120 140 900 870 + +# Test with multiple SUMs on nullable column +query TT +EXPLAIN SELECT + SUM(a) as sum_a, + SUM(a + 5) as sum_a_plus_5, + SUM(a + 10) as sum_a_plus_10 +FROM nullable_test; +---- +logical_plan +01)Projection: sum(nullable_test.a) AS sum_a, sum(nullable_test.a) + Int64(5) * count(nullable_test.a) AS sum_a_plus_5, sum(nullable_test.a) + Int64(10) * count(nullable_test.a) AS sum_a_plus_10 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS nullable_test.a), count(__common_expr_1 AS nullable_test.a)]] +03)----Projection: CAST(nullable_test.a AS Int64) AS __common_expr_1 +04)------TableScan: nullable_test projection=[a] +physical_plan +01)ProjectionExec: expr=[sum(nullable_test.a)@0 as sum_a, sum(nullable_test.a)@0 + 5 * count(nullable_test.a)@1 as sum_a_plus_5, sum(nullable_test.a)@0 + 10 * count(nullable_test.a)@1 as sum_a_plus_10] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(nullable_test.a), count(nullable_test.a)] +03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query III +SELECT + SUM(a) as sum_a, + SUM(a + 5) as sum_a_plus_5, + SUM(a + 10) as sum_a_plus_10 +FROM nullable_test; +---- +120 140 160 + +# ==== Test 5: Negative constants ==== + +query TT +EXPLAIN SELECT + SUM(a) as sum_a, + SUM(a + (-1)) as sum_a_minus_1, + SUM(a - (-2)) as sum_a_plus_2, + SUM(a + (-3)) as sum_a_minus_3 +FROM test_table; +---- +logical_plan +01)Projection: sum(test_table.a) AS sum_a, sum(test_table.a) + Int64(-1) * count(test_table.a) AS sum_a_minus_1, sum(test_table.a) - Int64(-2) * count(test_table.a) AS sum_a_plus_2, sum(test_table.a) + Int64(-3) * count(test_table.a) AS sum_a_minus_3 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test_table.a), count(__common_expr_1 AS test_table.a)]] +03)----Projection: CAST(test_table.a AS Int64) AS __common_expr_1 +04)------TableScan: test_table projection=[a] +physical_plan +01)ProjectionExec: expr=[sum(test_table.a)@0 as sum_a, sum(test_table.a)@0 + -1 * count(test_table.a)@1 as sum_a_minus_1, sum(test_table.a)@0 - -2 * count(test_table.a)@1 as sum_a_plus_2, sum(test_table.a)@0 + -3 * count(test_table.a)@1 as sum_a_minus_3] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a), count(test_table.a)] +03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query IIII +SELECT + SUM(a) as sum_a, + SUM(a + (-1)) as sum_a_minus_1, + SUM(a - (-2)) as sum_a_plus_2, + SUM(a + (-3)) as sum_a_minus_3 +FROM test_table; +---- +15 10 25 0 + +# ==== Test 6: No matching rewrite patterns ==== + +# Should not rewrite - only one sum with constant +query TT +EXPLAIN SELECT SUM(a + 1) FROM test_table; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[sum(CAST(test_table.a AS Int64) + Int64(1))]] +02)--TableScan: test_table projection=[a] +physical_plan +01)AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a + Int64(1))] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +# Should not rewrite - different base columns +query TT +EXPLAIN SELECT SUM(a + 1), SUM(b + 2) FROM test_table; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[sum(CAST(test_table.a AS Int64) + Int64(1)), sum(CAST(test_table.b AS Int64) + Int64(2))]] +02)--TableScan: test_table projection=[a, b] +physical_plan +01)AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a + Int64(1)), sum(test_table.b + Int64(2))] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +# ==== Test 7: Mixed sum types (rewrites a and b, not c) ==== + +query TT +EXPLAIN SELECT + SUM(a) as sum_a, + SUM(a + 1) as sum_a_plus_1, + SUM(b) as sum_b, + SUM(b + 2) as sum_b_plus_2, + SUM(c + 3) as sum_c_plus_3 +FROM test_table; +---- +logical_plan +01)Projection: sum(test_table.a) AS sum_a, sum(test_table.a + Int64(1)) AS sum_a_plus_1, sum(test_table.b) AS sum_b, sum(test_table.b + Int64(2)) AS sum_b_plus_2, sum(test_table.c + Int64(3)) AS sum_c_plus_3 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test_table.a), sum(__common_expr_1 AS test_table.a + Int64(1)), sum(__common_expr_2 AS test_table.b), sum(__common_expr_2 AS test_table.b + Int64(2)), sum(CAST(test_table.c AS Int64) + Int64(3))]] +03)----Projection: CAST(test_table.a AS Int64) AS __common_expr_1, CAST(test_table.b AS Int64) AS __common_expr_2, test_table.c +04)------TableScan: test_table projection=[a, b, c] +physical_plan +01)ProjectionExec: expr=[sum(test_table.a)@0 as sum_a, sum(test_table.a + Int64(1))@1 as sum_a_plus_1, sum(test_table.b)@2 as sum_b, sum(test_table.b + Int64(2))@3 as sum_b_plus_2, sum(test_table.c + Int64(3))@4 as sum_c_plus_3] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a), sum(test_table.a + Int64(1)), sum(test_table.b), sum(test_table.b + Int64(2)), sum(test_table.c + Int64(3))] +03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1, CAST(b@1 AS Int64) as __common_expr_2, c@2 as c] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query IIIII +SELECT + SUM(a) as sum_a, + SUM(a + 1) as sum_a_plus_1, + SUM(b) as sum_b, + SUM(b + 2) as sum_b_plus_2, + SUM(c + 3) as sum_c_plus_3 +FROM test_table; +---- +15 20 150 160 1515 + +# ==== Test 8: Aliased expressions (should NOT rewrite - only 1 SUM with constant per column) ==== + +# This has SUM(a), SUM(a+10), SUM(b), SUM(b-5) +# Each column has only 1 SUM WITH a constant, so no rewrite +query TT +EXPLAIN SELECT + SUM(a) AS total_a, + SUM(a + 10) AS total_a_plus_10, + SUM(b) AS total_b, + SUM(b - 5) AS total_b_minus_5 +FROM test_table; +---- +logical_plan +01)Projection: sum(test_table.a) AS total_a, sum(test_table.a + Int64(10)) AS total_a_plus_10, sum(test_table.b) AS total_b, sum(test_table.b - Int64(5)) AS total_b_minus_5 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test_table.a), sum(__common_expr_1 AS test_table.a + Int64(10)), sum(__common_expr_2 AS test_table.b), sum(__common_expr_2 AS test_table.b - Int64(5))]] +03)----Projection: CAST(test_table.a AS Int64) AS __common_expr_1, CAST(test_table.b AS Int64) AS __common_expr_2 +04)------TableScan: test_table projection=[a, b] +physical_plan +01)ProjectionExec: expr=[sum(test_table.a)@0 as total_a, sum(test_table.a + Int64(10))@1 as total_a_plus_10, sum(test_table.b)@2 as total_b, sum(test_table.b - Int64(5))@3 as total_b_minus_5] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a), sum(test_table.a + Int64(10)), sum(test_table.b), sum(test_table.b - Int64(5))] +03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1, CAST(b@1 AS Int64) as __common_expr_2] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query IIII +SELECT + SUM(a) AS total_a, + SUM(a + 10) AS total_a_plus_10, + SUM(b) AS total_b, + SUM(b - 5) AS total_b_minus_5 +FROM test_table; +---- +15 65 150 125 + +# Now test with 2+ SUMs with constants on same columns +query TT +EXPLAIN SELECT + SUM(a + 5) AS sum_a_plus_5, + SUM(a + 10) AS sum_a_plus_10, + SUM(b - 3) AS sum_b_minus_3, + SUM(b - 5) AS sum_b_minus_5 +FROM test_table; +---- +logical_plan +01)Projection: sum(test_table.a) + Int64(5) * count(test_table.a) AS sum_a_plus_5, sum(test_table.a) + Int64(10) * count(test_table.a) AS sum_a_plus_10, sum(test_table.b) - Int64(3) * count(test_table.b) AS sum_b_minus_3, sum(test_table.b) - Int64(5) * count(test_table.b) AS sum_b_minus_5 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test_table.a), count(__common_expr_1 AS test_table.a), sum(__common_expr_2 AS test_table.b), count(__common_expr_2 AS test_table.b)]] +03)----Projection: CAST(test_table.a AS Int64) AS __common_expr_1, CAST(test_table.b AS Int64) AS __common_expr_2 +04)------TableScan: test_table projection=[a, b] +physical_plan +01)ProjectionExec: expr=[sum(test_table.a)@0 + 5 * count(test_table.a)@1 as sum_a_plus_5, sum(test_table.a)@0 + 10 * count(test_table.a)@1 as sum_a_plus_10, sum(test_table.b)@2 - 3 * count(test_table.b)@3 as sum_b_minus_3, sum(test_table.b)@2 - 5 * count(test_table.b)@3 as sum_b_minus_5] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a), count(test_table.a), sum(test_table.b), count(test_table.b)] +03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1, CAST(b@1 AS Int64) as __common_expr_2] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query IIII +SELECT + SUM(a + 5) AS sum_a_plus_5, + SUM(a + 10) AS sum_a_plus_10, + SUM(b - 3) AS sum_b_minus_3, + SUM(b - 5) AS sum_b_minus_5 +FROM test_table; +---- +40 65 135 125 + +# ==== Cleanup ==== + statement ok DROP TABLE test_table; + +statement ok +DROP TABLE group_test; + +statement ok +DROP TABLE nullable_test; From 67398c8fe80924f0d99c569d36cced65fe442360 Mon Sep 17 00:00:00 2001 From: Devanshu Date: Fri, 6 Feb 2026 06:57:52 +0700 Subject: [PATCH 4/9] Add more tests and corner cases --- .../src/rewrite_aggregate_with_constant.rs | 81 +++++--- .../aggregate_rewrite_with_constant.slt | 182 ++++++++++++++++-- 2 files changed, 215 insertions(+), 48 deletions(-) diff --git a/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs b/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs index a9472d585269f..18e3c521935b6 100644 --- a/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs +++ b/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`RewriteAggregateWithConstant`] rewrites `SUM(column ± constant)` to `SUM(column) ± constant * COUNT(*)` +//! [`RewriteAggregateWithConstant`] rewrites `SUM(column ± constant)` to `SUM(column) ± constant * COUNT(column)` use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; @@ -31,7 +31,7 @@ use datafusion_functions_aggregate::expr_fn::{count, sum}; use std::collections::HashMap; /// Optimizer rule that rewrites `SUM(column ± constant)` expressions -/// into `SUM(column) ± constant * COUNT(*)` when multiple such expressions +/// into `SUM(column) ± constant * COUNT(column)` when multiple such expressions /// exist for the same base column. /// /// This reduces computation by calculating SUM once and deriving other values. @@ -40,9 +40,11 @@ use std::collections::HashMap; /// ```sql /// SELECT SUM(a), SUM(a + 1), SUM(a + 2) FROM t; /// ``` -/// is rewritten to: +/// is rewritten into a Projection on top of an Aggregate: /// ```sql +/// -- New Projection Node /// SELECT sum_a, sum_a + 1 * count_a, sum_a + 2 * count_a +/// -- New Aggregate Node /// FROM (SELECT SUM(a) as sum_a, COUNT(a) as count_a FROM t); /// ``` #[derive(Default, Debug)] @@ -65,18 +67,20 @@ impl OptimizerRule for RewriteAggregateWithConstant { _config: &dyn OptimizerConfig, ) -> Result> { match plan { + // This rule specifically targets Aggregate nodes LogicalPlan::Aggregate(aggregate) => { - // Check if we can apply the transformation + // Step 1: Identify which expressions can be rewritten and group them by base column let rewrite_info = analyze_aggregate(&aggregate)?; if rewrite_info.is_empty() { - // No transformation possible + // No groups found with 2+ matching SUM expressions, return original plan return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate))); } - // Build the transformed plan + // Step 2: Perform the actual transformation into Aggregate + Projection transform_aggregate(aggregate, &rewrite_info) } + // Non-aggregate plans are passed through unchanged _ => Ok(Transformed::no(plan)), } } @@ -86,40 +90,44 @@ impl OptimizerRule for RewriteAggregateWithConstant { } fn apply_order(&self) -> Option { + // Bottom-up ensures we optimize subqueries before the outer query Some(ApplyOrder::BottomUp) } } -/// Information about a SUM expression with a constant offset +/// Internal structure to track metadata for a SUM expression that qualifies for rewrite #[derive(Debug, Clone)] struct SumWithConstant { - /// The base expression (e.g., column 'a' in SUM(a + 1)) + /// The inner expression being manipulated (e.g., the 'a' in SUM(a + 1)) base_expr: Expr, /// The constant value being added/subtracted constant: ScalarValue, /// The operator (+ or -) operator: Operator, - /// Original index in the aggregate expressions + /// The index in the original Aggregate's aggr_expr list, used to maintain output order original_index: usize, - /// ORDER BY clause if present + /// Any ORDER BY clause inside the aggregate (e.g., SUM(a+1 ORDER BY b)) order_by: Vec, } -/// Information about groups of SUMs that can be rewritten +/// Maps a base expression's schema name to all its SUM(base ± const) variants type RewriteGroups = HashMap>; -/// Analyze the aggregate to find groups of SUM(col ± constant) that can be rewritten +/// Scans the aggregate expressions to find candidates for the rewrite. fn analyze_aggregate(aggregate: &Aggregate) -> Result { let mut groups: RewriteGroups = HashMap::new(); for (idx, expr) in aggregate.aggr_expr.iter().enumerate() { + // Try to match the pattern SUM(col ± lit) if let Some(sum_info) = extract_sum_with_constant(expr, idx)? { let key = sum_info.base_expr.schema_name().to_string(); groups.entry(key).or_default().push(sum_info); } } - // Only keep groups with 2 or more SUMs on the same base column + // Optimization: Only rewrite if we have at least 2 expressions for the same column. + // If there's only one SUM(a + 1), rewriting it to SUM(a) + 1*COUNT(a) + // actually increases the work (1 agg -> 2 aggs). groups.retain(|_, v| v.len() >= 2); Ok(groups) @@ -129,7 +137,7 @@ fn analyze_aggregate(aggregate: &Aggregate) -> Result { fn extract_sum_with_constant(expr: &Expr, idx: usize) -> Result> { match expr { Expr::AggregateFunction(agg_fn) => { - // Must be SUM function + // Rule only applies to SUM if agg_fn.func.name().to_lowercase() != "sum" { return Ok(None); } @@ -142,12 +150,14 @@ fn extract_sum_with_constant(expr: &Expr, idx: usize) -> Result Result = HashMap::new(); // Process each group to determine what to add to the aggregate let mut sum_names: HashMap = HashMap::new(); let mut count_names: HashMap = HashMap::new(); + // For every group (e.g., all SUMs involving column 'a'), add one SUM(a) and one COUNT(a) for (base_key, sums) in rewrite_groups { - // Find a representative SUM (prefer one with ORDER BY if any) + // If any original SUM had an ORDER BY, we try to preserve it in our new base SUM. let representative = sums .iter() .find(|s| !s.order_by.is_empty()) @@ -262,16 +276,16 @@ fn transform_aggregate( let sum_expr = sum(representative.base_expr.clone()); // Note: ORDER BY is not needed for SUM as it's commutative let sum_name = sum_expr.schema_name().to_string(); - let sum_index = new_aggr_exprs.len(); new_aggr_exprs.push(sum_expr); sum_names.insert(base_key.clone(), sum_name); - // Add COUNT - use COUNT(col) for nullable columns - // For nullable columns, COUNT(col) correctly excludes NULLs + // Add the base COUNT(a) + // We use COUNT(col) rather than COUNT(*) because if 'col' is NULL, + // SUM(col + 1) should be NULL, and COUNT(col) correctly returns 0 for NULLs, + // whereas COUNT(*) would count the row. let count_expr = count(representative.base_expr.clone()); let count_name = count_expr.schema_name().to_string(); - let count_index = new_aggr_exprs.len(); new_aggr_exprs.push(count_expr); count_names.insert(base_key.clone(), count_name); @@ -279,7 +293,7 @@ fn transform_aggregate( base_expr_indices.insert(base_key.clone(), (sum_index, count_index)); } - // Now build projection expressions for all original aggregate expressions + // Now iterate through the ORIGINAL aggregate expressions to build the PROJECTION for (idx, orig_expr) in aggregate.aggr_expr.iter().enumerate() { // Check if this expression should be rewritten let rewritten = all_sums.iter().find(|s| s.original_index == idx); @@ -287,7 +301,7 @@ fn transform_aggregate( let projection_expr = if let Some(sum_info) = rewritten { let base_key = sum_info.base_expr.schema_name().to_string(); - // Build: SUM(col) ± constant * COUNT(...) + // Construct the math: SUM(col) [±] (constant * COUNT(col)) let sum_ref = col(&sum_names[&base_key]); let count_ref = col(&count_names[&base_key]); @@ -299,13 +313,14 @@ fn transform_aggregate( let result = binary_expr(sum_ref, sum_info.operator, multiplied); - // Preserve original alias if present + // Ensure the output column name matches the original (aliased or generated) match orig_expr { Expr::Alias(alias) => result.alias(alias.name.clone()), _ => result.alias(orig_expr.schema_name().to_string()), } } else { - // Check if this is a plain SUM(base_expr) that we're already computing + // Special case: If the user had a plain SUM(a) alongside SUM(a+1), + // we should reuse the SUM(a) we just added instead of adding another one. let is_plain_sum_in_group = check_plain_sum_in_group(orig_expr, &base_expr_indices); @@ -322,9 +337,11 @@ fn transform_aggregate( _ => sum_ref.alias(orig_expr.schema_name().to_string()), } } else { - // Keep non-rewritten expressions as-is + // This expression is unrelated to our rewrites (e.g., AVG(b) or MAX(c)). + // We just pass it through to the new Aggregate node. new_aggr_exprs.push(orig_expr.clone()); + // And reference it in the projection by name. match orig_expr { Expr::Alias(alias) => col(alias.name.clone()), _ => col(orig_expr.schema_name().to_string()), @@ -335,7 +352,7 @@ fn transform_aggregate( projection_exprs.push(projection_expr); } - // Also add group by expressions to projection + // Handle GROUP BY columns: they must be passed through the Aggregate and Projection let group_exprs: Vec = aggregate .group_expr .iter() @@ -346,18 +363,18 @@ fn transform_aggregate( }) .collect(); - // Prepend group expressions to projection + // Final projection includes [Group Columns] + [Aggregated/Rewritten Columns] let mut final_projection = group_exprs; final_projection.extend(projection_exprs); - // Create new aggregate with rewritten expressions + // Create the new Aggregate plan node let new_aggregate = LogicalPlan::Aggregate(Aggregate::try_new( aggregate.input, aggregate.group_expr, new_aggr_exprs, )?); - // Wrap with projection + // Wrap the Aggregate with the Projection let projection = LogicalPlanBuilder::from(new_aggregate) .project(final_projection)? .build()?; diff --git a/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt b/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt index 583ff89d2a9ca..1aa70f7cab66a 100644 --- a/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt +++ b/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt @@ -338,37 +338,187 @@ FROM test_table; 15 65 150 125 # Now test with 2+ SUMs with constants on same columns -query TT -EXPLAIN SELECT +# Note: The query verifies results are correct - EXPLAIN order may vary due to HashMap iteration +query IIII +SELECT SUM(a + 5) AS sum_a_plus_5, SUM(a + 10) AS sum_a_plus_10, SUM(b - 3) AS sum_b_minus_3, SUM(b - 5) AS sum_b_minus_5 FROM test_table; ---- +40 65 135 125 + +# ==== Test 9: Complex base expressions (SUM(a + b + 1)) ==== + +statement ok +CREATE TABLE complex_test ( + a INT, + b INT +) AS VALUES + (1, 10), + (2, 20), + (3, 30); + +# Test: Multiple SUMs on the same complex expression (a + b) +query TT +EXPLAIN SELECT + SUM(a + b) as sum_ab, + SUM(a + b + 1) as sum_ab_plus_1, + SUM(a + b + 2) as sum_ab_plus_2 +FROM complex_test; +---- logical_plan -01)Projection: sum(test_table.a) + Int64(5) * count(test_table.a) AS sum_a_plus_5, sum(test_table.a) + Int64(10) * count(test_table.a) AS sum_a_plus_10, sum(test_table.b) - Int64(3) * count(test_table.b) AS sum_b_minus_3, sum(test_table.b) - Int64(5) * count(test_table.b) AS sum_b_minus_5 -02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test_table.a), count(__common_expr_1 AS test_table.a), sum(__common_expr_2 AS test_table.b), count(__common_expr_2 AS test_table.b)]] -03)----Projection: CAST(test_table.a AS Int64) AS __common_expr_1, CAST(test_table.b AS Int64) AS __common_expr_2 -04)------TableScan: test_table projection=[a, b] +01)Projection: sum(complex_test.a + complex_test.b) AS sum_ab, sum(complex_test.a + complex_test.b) + count(complex_test.a + complex_test.b) AS sum_ab_plus_1, sum(complex_test.a + complex_test.b) + Int64(2) * count(complex_test.a + complex_test.b) AS sum_ab_plus_2 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS complex_test.a + complex_test.b), count(__common_expr_1 AS complex_test.a + complex_test.b)]] +03)----Projection: CAST(complex_test.a + complex_test.b AS Int64) AS __common_expr_1 +04)------TableScan: complex_test projection=[a, b] physical_plan -01)ProjectionExec: expr=[sum(test_table.a)@0 + 5 * count(test_table.a)@1 as sum_a_plus_5, sum(test_table.a)@0 + 10 * count(test_table.a)@1 as sum_a_plus_10, sum(test_table.b)@2 - 3 * count(test_table.b)@3 as sum_b_minus_3, sum(test_table.b)@2 - 5 * count(test_table.b)@3 as sum_b_minus_5] -02)--AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a), count(test_table.a), sum(test_table.b), count(test_table.b)] -03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1, CAST(b@1 AS Int64) as __common_expr_2] +01)ProjectionExec: expr=[sum(complex_test.a + complex_test.b)@0 as sum_ab, sum(complex_test.a + complex_test.b)@0 + count(complex_test.a + complex_test.b)@1 as sum_ab_plus_1, sum(complex_test.a + complex_test.b)@0 + 2 * count(complex_test.a + complex_test.b)@1 as sum_ab_plus_2] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(complex_test.a + complex_test.b), count(complex_test.a + complex_test.b)] +03)----ProjectionExec: expr=[CAST(a@0 + b@1 AS Int64) as __common_expr_1] 04)------DataSourceExec: partitions=1, partition_sizes=[1] -query IIII +query III SELECT - SUM(a + 5) AS sum_a_plus_5, - SUM(a + 10) AS sum_a_plus_10, - SUM(b - 3) AS sum_b_minus_3, - SUM(b - 5) AS sum_b_minus_5 -FROM test_table; + SUM(a + b) as sum_ab, + SUM(a + b + 1) as sum_ab_plus_1, + SUM(a + b + 2) as sum_ab_plus_2 +FROM complex_test; ---- -40 65 135 125 +66 69 72 + +# Test: Different complex expressions (a + b) vs (b + a) should NOT be grouped together +query TT +EXPLAIN SELECT + SUM(a + b + 1) as sum_ab_plus_1, + SUM(b + a + 1) as sum_ba_plus_1 +FROM complex_test; +---- +logical_plan +01)Projection: sum(complex_test.a + complex_test.b + Int64(1)) AS sum_ab_plus_1, sum(complex_test.b + complex_test.a + Int64(1)) AS sum_ba_plus_1 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS complex_test.a + complex_test.b + Int64(1)), sum(__common_expr_1 AS complex_test.b + complex_test.a + Int64(1))]] +03)----Projection: CAST(complex_test.a + complex_test.b AS Int64) + Int64(1) AS __common_expr_1 +04)------TableScan: complex_test projection=[a, b] +physical_plan +01)ProjectionExec: expr=[sum(complex_test.a + complex_test.b + Int64(1))@0 as sum_ab_plus_1, sum(complex_test.b + complex_test.a + Int64(1))@1 as sum_ba_plus_1] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(complex_test.a + complex_test.b + Int64(1)), sum(complex_test.b + complex_test.a + Int64(1))] +03)----ProjectionExec: expr=[CAST(a@0 + b@1 AS Int64) + 1 as __common_expr_1] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +# ==== Test 10: Nested constants - Limitation of SimplifyExpressions ==== +# +# This test demonstrates a limitation in constant folding that affects this rule. +# +# Expression parsing (left-to-right associativity): +# - "a + 1 + 2" is parsed as "(a + 1) + 2" +# - "a + 5 + 7 * 8" is parsed as "(a + 5) + (7 * 8)" +# +# What SimplifyExpressions does: +# - "7 * 8" → "56" (two literals in one binary expr → folded) +# - "(a + 5) + 56" is NOT simplified to "a + 61" because: +# - The simplifier sees: (Expr) + 56 +# - It doesn't "look inside" the left child to find the 5 +# - Constants spread across nested additions are not gathered +# +# Expression tree for "a + 5 + 7 * 8" after partial simplification: +# +# + +# / \ +# + 56 <-- 7*8 was folded to 56 +# / \ +# a 5 <-- 5 is NOT combined with 56 +# +# Impact on RewriteAggregateWithConstant: +# - The rule sees base_expr = "(a + 5)" with constant = 56 +# - It does NOT see base_expr = "a" with constant = 61 +# - Since SUM(a) has a different base_expr than SUM((a+5) + 56), they are NOT grouped +# - Therefore, this query does NOT trigger the rewrite optimization +# +# The plan below shows all three SUMs computed separately (no Projection + Aggregate rewrite): +query TT +EXPLAIN SELECT + SUM(a), + SUM(a + 1 + 2), + SUM(a + 5 + 7 * 8) +FROM complex_test; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS complex_test.a), sum(__common_expr_1 AS complex_test.a + Int64(1) + Int64(2)), sum(__common_expr_1 + Int64(5) + Int64(56)) AS sum(complex_test.a + Int64(5) + Int64(7) * Int64(8))]] +02)--Projection: CAST(complex_test.a AS Int64) AS __common_expr_1 +03)----TableScan: complex_test projection=[a] +physical_plan +01)AggregateExec: mode=Single, gby=[], aggr=[sum(complex_test.a), sum(complex_test.a + Int64(1) + Int64(2)), sum(complex_test.a + Int64(5) + Int64(7) * Int64(8))] +02)--ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1] +03)----DataSourceExec: partitions=1, partition_sizes=[1] + +# Verify correctness: +# - SUM(a) = 1+2+3 = 6 +# - SUM(a + 1 + 2) = SUM(a + 3) = 6 + 3*3 = 15 +# - SUM(a + 5 + 7*8) = SUM(a + 61) = 6 + 61*3 = 189 +query III +SELECT + SUM(a), + SUM(a + 1 + 2), + SUM(a + 5 + 7 * 8) +FROM complex_test; +---- +6 15 189 + +# ==== Test 11: Multiple groups with different base expressions ==== +# +# This test demonstrates how the rule handles multiple independent groups +# of SUM expressions, each with a different base expression. +# +# Expression parsing (all parsed left-to-right): +# - sum(a) → base: a (but no constant, so not matched by rule) +# - sum(a+1) → base: a, constant: 1 +# - sum(a+1+1) → parsed as (a+1)+1 → base: (a+1), constant: 1 +# - sum(a+1+2) → parsed as (a+1)+2 → base: (a+1), constant: 2 +# - sum(a+5+1) → parsed as (a+5)+1 → base: (a+5), constant: 1 +# - sum(a+5+2) → parsed as (a+5)+2 → base: (a+5), constant: 2 +# - sum(a+5+3) → parsed as (a+5)+3 → base: (a+5), constant: 3 +# +# Groupings by base expression: +# - Group "a": [sum(a+1)] → only 1 SUM with constant, NOT rewritten +# - Group "a + 1": [sum(a+1+1), sum(a+1+2)] → 2 SUMs, REWRITTEN +# - Group "a + 5": [sum(a+5+1), sum(a+5+2), sum(a+5+3)] → 3 SUMs, REWRITTEN +# +# Note: sum(a) is a plain SUM without a binary expression argument, +# so it doesn't match the SUM(base ± const) pattern at all. +# +# Expected rewrite: +# - sum(a), sum(a+1): computed as-is (no optimization) +# - sum(a+1+1), sum(a+1+2): use SUM(a+1) + k*COUNT(a+1) +# - sum(a+5+1), sum(a+5+2), sum(a+5+3): use SUM(a+5) + k*COUNT(a+5) +# +# Verify correctness (using complex_test: a = 1, 2, 3): +# - sum(a) = 1+2+3 = 6 +# - sum(a+1) = 2+3+4 = 9 +# - sum(a+1+1) = sum(a+2) = 3+4+5 = 12 +# - sum(a+1+2) = sum(a+3) = 4+5+6 = 15 +# - sum(a+5+1) = sum(a+6) = 7+8+9 = 24 +# - sum(a+5+2) = sum(a+7) = 8+9+10 = 27 +# - sum(a+5+3) = sum(a+8) = 9+10+11 = 30 +query IIIIIII +SELECT + SUM(a), + SUM(a + 1), + SUM(a + 1 + 1), + SUM(a + 1 + 2), + SUM(a + 5 + 1), + SUM(a + 5 + 2), + SUM(a + 5 + 3) +FROM complex_test; +---- +6 9 12 15 24 27 30 + +statement ok +DROP TABLE complex_test; # ==== Cleanup ==== + statement ok DROP TABLE test_table; From 43a8b0527dc4790456e0a313e4ca0e984e368eee Mon Sep 17 00:00:00 2001 From: Devanshu Date: Fri, 6 Feb 2026 07:22:16 +0700 Subject: [PATCH 5/9] Make ordering deterministic by using IndexMap --- .../src/rewrite_aggregate_with_constant.rs | 18 ++++--- .../aggregate_rewrite_with_constant.slt | 53 ++++++++++++++++++- .../sqllogictest/test_files/explain.slt | 4 ++ 3 files changed, 65 insertions(+), 10 deletions(-) diff --git a/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs b/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs index 18e3c521935b6..a9f269594bf5a 100644 --- a/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs +++ b/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs @@ -28,7 +28,7 @@ use datafusion_expr::{ col, lit, }; use datafusion_functions_aggregate::expr_fn::{count, sum}; -use std::collections::HashMap; +use indexmap::IndexMap; /// Optimizer rule that rewrites `SUM(column ± constant)` expressions /// into `SUM(column) ± constant * COUNT(column)` when multiple such expressions @@ -110,12 +110,14 @@ struct SumWithConstant { order_by: Vec, } -/// Maps a base expression's schema name to all its SUM(base ± const) variants -type RewriteGroups = HashMap>; +/// Maps a base expression's schema name to all its SUM(base ± const) variants. +/// We use IndexMap to preserve insertion order, ensuring deterministic output +/// in the rewritten plan (important for stable EXPLAIN output in tests). +type RewriteGroups = IndexMap>; /// Scans the aggregate expressions to find candidates for the rewrite. fn analyze_aggregate(aggregate: &Aggregate) -> Result { - let mut groups: RewriteGroups = HashMap::new(); + let mut groups: RewriteGroups = IndexMap::new(); for (idx, expr) in aggregate.aggr_expr.iter().enumerate() { // Try to match the pattern SUM(col ± lit) @@ -227,7 +229,7 @@ fn is_numeric_constant(value: &ScalarValue) -> bool { /// Check if an expression is a plain SUM(base_expr) that matches one of our rewrite groups fn check_plain_sum_in_group( expr: &Expr, - base_expr_indices: &HashMap, + base_expr_indices: &IndexMap, ) -> Option<(usize, usize)> { if let Expr::AggregateFunction(agg_fn) = expr && agg_fn.func.name().to_lowercase() == "sum" @@ -258,11 +260,11 @@ fn transform_aggregate( all_sums.sort_by_key(|s| s.original_index); // Maps base column names to the indices of their new SUM/COUNT in the new Aggregate node - let mut base_expr_indices: HashMap = HashMap::new(); + let mut base_expr_indices: IndexMap = IndexMap::new(); // Process each group to determine what to add to the aggregate - let mut sum_names: HashMap = HashMap::new(); - let mut count_names: HashMap = HashMap::new(); + let mut sum_names: IndexMap = IndexMap::new(); + let mut count_names: IndexMap = IndexMap::new(); // For every group (e.g., all SUMs involving column 'a'), add one SUM(a) and one COUNT(a) for (base_key, sums) in rewrite_groups { diff --git a/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt b/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt index 1aa70f7cab66a..632ab5073a3f4 100644 --- a/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt +++ b/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt @@ -337,8 +337,29 @@ FROM test_table; ---- 15 65 150 125 -# Now test with 2+ SUMs with constants on same columns -# Note: The query verifies results are correct - EXPLAIN order may vary due to HashMap iteration +# Test with 2+ SUMs with constants on same columns (triggers rewrite) +# With IndexMap, the order is deterministic based on insertion order: +# - Group "a" first (from SUM(a + 5)) +# - Group "b" second (from SUM(b - 3)) +query TT +EXPLAIN SELECT + SUM(a + 5) AS sum_a_plus_5, + SUM(a + 10) AS sum_a_plus_10, + SUM(b - 3) AS sum_b_minus_3, + SUM(b - 5) AS sum_b_minus_5 +FROM test_table; +---- +logical_plan +01)Projection: sum(test_table.a) + Int64(5) * count(test_table.a) AS sum_a_plus_5, sum(test_table.a) + Int64(10) * count(test_table.a) AS sum_a_plus_10, sum(test_table.b) - Int64(3) * count(test_table.b) AS sum_b_minus_3, sum(test_table.b) - Int64(5) * count(test_table.b) AS sum_b_minus_5 +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test_table.a), count(__common_expr_1 AS test_table.a), sum(__common_expr_2 AS test_table.b), count(__common_expr_2 AS test_table.b)]] +03)----Projection: CAST(test_table.a AS Int64) AS __common_expr_1, CAST(test_table.b AS Int64) AS __common_expr_2 +04)------TableScan: test_table projection=[a, b] +physical_plan +01)ProjectionExec: expr=[sum(test_table.a)@0 + 5 * count(test_table.a)@1 as sum_a_plus_5, sum(test_table.a)@0 + 10 * count(test_table.a)@1 as sum_a_plus_10, sum(test_table.b)@2 - 3 * count(test_table.b)@3 as sum_b_minus_3, sum(test_table.b)@2 - 5 * count(test_table.b)@3 as sum_b_minus_5] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a), count(test_table.a), sum(test_table.b), count(test_table.b)] +03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1, CAST(b@1 AS Int64) as __common_expr_2] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + query IIII SELECT SUM(a + 5) AS sum_a_plus_5, @@ -492,6 +513,34 @@ FROM complex_test; # - sum(a+1+1), sum(a+1+2): use SUM(a+1) + k*COUNT(a+1) # - sum(a+5+1), sum(a+5+2), sum(a+5+3): use SUM(a+5) + k*COUNT(a+5) # +# With IndexMap, the groups are processed in insertion order: +# - Group "a + Int64(1)" first (from SUM(a+1+1)) +# - Group "a + Int64(5)" second (from SUM(a+5+1)) +# Non-matching expressions (SUM(a), SUM(a+1)) are passed through. +query TT +EXPLAIN SELECT + SUM(a), + SUM(a + 1), + SUM(a + 1 + 1), + SUM(a + 1 + 2), + SUM(a + 5 + 1), + SUM(a + 5 + 2), + SUM(a + 5 + 3) +FROM complex_test; +---- +logical_plan +01)Projection: sum(complex_test.a), sum(complex_test.a + Int64(1)), sum(complex_test.a + Int64(1)) + count(complex_test.a + Int64(1)) AS sum(complex_test.a + Int64(1) + Int64(1)), sum(complex_test.a + Int64(1)) + Int64(2) * count(complex_test.a + Int64(1)) AS sum(complex_test.a + Int64(1) + Int64(2)), sum(complex_test.a + Int64(5)) + count(complex_test.a + Int64(5)) AS sum(complex_test.a + Int64(5) + Int64(1)), sum(complex_test.a + Int64(5)) + Int64(2) * count(complex_test.a + Int64(5)) AS sum(complex_test.a + Int64(5) + Int64(2)), sum(complex_test.a + Int64(5)) + Int64(3) * count(complex_test.a + Int64(5)) AS sum(complex_test.a + Int64(5) + Int64(3)) +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS complex_test.a + Int64(1)), count(__common_expr_1 AS complex_test.a + Int64(1)), sum(__common_expr_2 AS complex_test.a + Int64(5)), count(__common_expr_2 AS complex_test.a + Int64(5)), sum(__common_expr_3 AS complex_test.a)]] +03)----Projection: __common_expr_4 + Int64(1) AS __common_expr_1, __common_expr_4 + Int64(5) AS __common_expr_2, __common_expr_4 AS __common_expr_3 +04)------Projection: CAST(complex_test.a AS Int64) AS __common_expr_4 +05)--------TableScan: complex_test projection=[a] +physical_plan +01)ProjectionExec: expr=[sum(complex_test.a)@4 as sum(complex_test.a), sum(complex_test.a + Int64(1))@0 as sum(complex_test.a + Int64(1)), sum(complex_test.a + Int64(1))@0 + count(complex_test.a + Int64(1))@1 as sum(complex_test.a + Int64(1) + Int64(1)), sum(complex_test.a + Int64(1))@0 + 2 * count(complex_test.a + Int64(1))@1 as sum(complex_test.a + Int64(1) + Int64(2)), sum(complex_test.a + Int64(5))@2 + count(complex_test.a + Int64(5))@3 as sum(complex_test.a + Int64(5) + Int64(1)), sum(complex_test.a + Int64(5))@2 + 2 * count(complex_test.a + Int64(5))@3 as sum(complex_test.a + Int64(5) + Int64(2)), sum(complex_test.a + Int64(5))@2 + 3 * count(complex_test.a + Int64(5))@3 as sum(complex_test.a + Int64(5) + Int64(3))] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(complex_test.a + Int64(1)), count(complex_test.a + Int64(1)), sum(complex_test.a + Int64(5)), count(complex_test.a + Int64(5)), sum(complex_test.a)] +03)----ProjectionExec: expr=[__common_expr_4@0 + 1 as __common_expr_1, __common_expr_4@0 + 5 as __common_expr_2, __common_expr_4@0 as __common_expr_3] +04)------ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_4] +05)--------DataSourceExec: partitions=1, partition_sizes=[1] + # Verify correctness (using complex_test: a = 1, 2, 3): # - sum(a) = 1+2+3 = 6 # - sum(a+1) = 2+3+4 = 9 diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 6f615ec391c9e..617259d2e917d 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -196,6 +196,7 @@ logical_plan after push_down_limit SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE +logical_plan after rewrite_aggregate_with_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE logical_plan after optimize_projections TableScan: simple_explain_test projection=[a, b, c] logical_plan after rewrite_set_comparison SAME TEXT AS ABOVE @@ -218,6 +219,7 @@ logical_plan after push_down_limit SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE +logical_plan after rewrite_aggregate_with_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE logical_plan after optimize_projections SAME TEXT AS ABOVE logical_plan TableScan: simple_explain_test projection=[a, b, c] @@ -557,6 +559,7 @@ logical_plan after push_down_limit SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE +logical_plan after rewrite_aggregate_with_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE logical_plan after optimize_projections TableScan: simple_explain_test projection=[a, b, c] logical_plan after rewrite_set_comparison SAME TEXT AS ABOVE @@ -579,6 +582,7 @@ logical_plan after push_down_limit SAME TEXT AS ABOVE logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE +logical_plan after rewrite_aggregate_with_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE logical_plan after optimize_projections SAME TEXT AS ABOVE logical_plan TableScan: simple_explain_test projection=[a, b, c] From b9c1ca7895c8e382c5bf5da8b3616da0d344a15d Mon Sep 17 00:00:00 2001 From: Devanshu Date: Fri, 6 Feb 2026 09:26:11 +0700 Subject: [PATCH 6/9] Add SLT Test for Filtering --- .../aggregate_rewrite_with_constant.slt | 63 ++++++++++++++++++- 1 file changed, 61 insertions(+), 2 deletions(-) diff --git a/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt b/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt index 632ab5073a3f4..2ad729cf1d54e 100644 --- a/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt +++ b/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt @@ -562,8 +562,64 @@ FROM complex_test; ---- 6 9 12 15 24 27 30 -statement ok -DROP TABLE complex_test; + +# ==== Test 12: Expressions with DISTINCT or FILTER (SHOULD NOT rewrite) ==== + +# SUM(DISTINCT a + 1) should NOT be rewritten because the math SUM(a) + COUNT(a) +# doesn't hold when counting distinct values of the modified expression. +query TT +EXPLAIN SELECT + SUM(DISTINCT a + 1), + SUM(DISTINCT a + 2) +FROM test_table; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[sum(DISTINCT __common_expr_1 AS test_table.a + Int64(1)), sum(DISTINCT __common_expr_1 AS test_table.a + Int64(2))]] +02)--Projection: CAST(test_table.a AS Int64) AS __common_expr_1 +03)----TableScan: test_table projection=[a] +physical_plan +01)AggregateExec: mode=Single, gby=[], aggr=[sum(DISTINCT test_table.a + Int64(1)), sum(DISTINCT test_table.a + Int64(2))] +02)--ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1] +03)----DataSourceExec: partitions=1, partition_sizes=[1] + +# SUM(a + 1) FILTER (WHERE a > 1) should NOT be rewritten because the filter +# applies to the entire expression, and the rule doesn't handle filters. +query TT +EXPLAIN SELECT + SUM(a + 1) FILTER (WHERE a > 1), + SUM(a + 2) FILTER (WHERE a > 1) +FROM test_table; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 + Int64(1)) FILTER (WHERE __common_expr_2) AS sum(test_table.a + Int64(1)) FILTER (WHERE test_table.a > Int64(1)), sum(__common_expr_1 + Int64(2)) FILTER (WHERE __common_expr_2) AS sum(test_table.a + Int64(2)) FILTER (WHERE test_table.a > Int64(1))]] +02)--Projection: CAST(test_table.a AS Int64) AS __common_expr_1, test_table.a > Int32(1) AS __common_expr_2 +03)----TableScan: test_table projection=[a] +physical_plan +01)AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a + Int64(1)) FILTER (WHERE test_table.a > Int64(1)), sum(test_table.a + Int64(2)) FILTER (WHERE test_table.a > Int64(1))] +02)--ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1, a@0 > 1 as __common_expr_2] +03)----DataSourceExec: partitions=1, partition_sizes=[1] + +query TT +EXPLAIN SELECT + SUM(a + 1), + SUM(a + 2) +FROM test_table WHERE a > 3; +---- +logical_plan +01)Projection: sum(test_table.a) + count(test_table.a) AS sum(test_table.a + Int64(1)), sum(test_table.a) + Int64(2) * count(test_table.a) AS sum(test_table.a + Int64(2)) +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test_table.a), count(__common_expr_1 AS test_table.a)]] +03)----Projection: CAST(test_table.a AS Int64) AS __common_expr_1 +04)------Filter: test_table.a > Int32(3) +05)--------TableScan: test_table projection=[a] +physical_plan +01)ProjectionExec: expr=[sum(test_table.a)@0 + count(test_table.a)@1 as sum(test_table.a + Int64(1)), sum(test_table.a)@0 + 2 * count(test_table.a)@1 as sum(test_table.a + Int64(2))] +02)--AggregateExec: mode=Final, gby=[], aggr=[sum(test_table.a), count(test_table.a)] +03)----CoalescePartitionsExec +04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(test_table.a), count(test_table.a)] +05)--------ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------FilterExec: a@0 > 3 +08)--------------DataSourceExec: partitions=1, partition_sizes=[1] # ==== Cleanup ==== @@ -576,3 +632,6 @@ DROP TABLE group_test; statement ok DROP TABLE nullable_test; + +statement ok +DROP TABLE complex_test; From d544889b537ae8f6898524df66e84899a145a01d Mon Sep 17 00:00:00 2001 From: Devanshu Date: Fri, 6 Feb 2026 10:11:45 +0700 Subject: [PATCH 7/9] Fix aliasing issues --- .../src/rewrite_aggregate_with_constant.rs | 130 +++++++++++++++++- .../aggregate_rewrite_with_constant.slt | 55 ++++++++ 2 files changed, 179 insertions(+), 6 deletions(-) diff --git a/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs b/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs index a9f269594bf5a..f7686002d3bff 100644 --- a/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs +++ b/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs @@ -135,9 +135,20 @@ fn analyze_aggregate(aggregate: &Aggregate) -> Result { Ok(groups) } -/// Extract SUM(base_expr ± constant) pattern from an expression +/// Extract SUM(base_expr ± constant) pattern from an expression. +/// Handles both `Expr::AggregateFunction(...)` and `Expr::Alias(Expr::AggregateFunction(...))` +/// so the rule works regardless of whether aggregate expressions carry aliases +/// (e.g., when plans are built via the LogicalPlanBuilder API). fn extract_sum_with_constant(expr: &Expr, idx: usize) -> Result> { - match expr { + // Unwrap Expr::Alias if present — the SQL planner puts aliases in a + // Projection above the Aggregate, but the builder API allows aliases + // directly inside aggr_expr. + let inner = match expr { + Expr::Alias(alias) => alias.expr.as_ref(), + other => other, + }; + + match inner { Expr::AggregateFunction(agg_fn) => { // Rule only applies to SUM if agg_fn.func.name().to_lowercase() != "sum" { @@ -208,6 +219,7 @@ fn extract_sum_with_constant(expr: &Expr, idx: usize) -> Result bool { matches!( value, @@ -226,12 +238,21 @@ fn is_numeric_constant(value: &ScalarValue) -> bool { ) } -/// Check if an expression is a plain SUM(base_expr) that matches one of our rewrite groups +/// Check if an expression is a plain SUM(base_expr) that matches one of our rewrite groups. +/// Handles both `Expr::AggregateFunction(...)` and `Expr::Alias(Expr::AggregateFunction(...))` +/// so that aliased plain SUMs (e.g., `SUM(a) AS total`) are correctly detected and reused +/// instead of being duplicated in the new Aggregate node. fn check_plain_sum_in_group( expr: &Expr, base_expr_indices: &IndexMap, ) -> Option<(usize, usize)> { - if let Expr::AggregateFunction(agg_fn) = expr + // Unwrap alias if present + let inner = match expr { + Expr::Alias(alias) => alias.expr.as_ref(), + other => other, + }; + + if let Expr::AggregateFunction(agg_fn) = inner && agg_fn.func.name().to_lowercase() == "sum" && agg_fn.params.args.len() == 1 && !agg_fn.params.distinct @@ -327,8 +348,13 @@ fn transform_aggregate( check_plain_sum_in_group(orig_expr, &base_expr_indices); if is_plain_sum_in_group.is_some() { - // Use the already-computed SUM - let base_key = if let Expr::AggregateFunction(agg_fn) = orig_expr { + // Use the already-computed SUM. + // Unwrap alias to get to the AggregateFunction and extract the base key. + let inner_expr = match orig_expr { + Expr::Alias(alias) => alias.expr.as_ref(), + other => other, + }; + let base_key = if let Expr::AggregateFunction(agg_fn) = inner_expr { agg_fn.params.args[0].schema_name().to_string() } else { String::new() @@ -447,4 +473,96 @@ mod tests { assert!(!result.transformed); Ok(()) } + + /// Test that aliased SUM(col ± constant) expressions are correctly detected + /// and rewritten. This exercises the Expr::Alias unwrapping in + /// `extract_sum_with_constant`. + /// + /// Note: The SQL planner puts aliases in a Projection above the Aggregate, + /// so this case only arises when building plans via the LogicalPlanBuilder API. + #[test] + fn test_aliased_sum_with_constant() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + Vec::::new(), + vec![ + sum(col("a") + lit(1)).alias("sum_a_plus_1"), + sum(col("a") + lit(2)).alias("sum_a_plus_2"), + ], + )? + .build()?; + + let rule = RewriteAggregateWithConstant::new(); + let config = OptimizerContext::new(); + let result = rule.rewrite(plan, &config)?; + + // Should be transformed: both aliased SUMs share base column "a" + assert!(result.transformed); + + // The rewritten plan should be a Projection on top of an Aggregate + let plan_str = format!("{}", result.data.display_indent()); + assert!( + plan_str.contains("Projection:"), + "Expected Projection node in rewritten plan, got:\n{plan_str}" + ); + // The new aggregate should have SUM(a) and COUNT(a), not two separate SUMs + assert!( + plan_str.contains("sum(test.a)") && plan_str.contains("count(test.a)"), + "Expected SUM(a) and COUNT(a) in rewritten plan, got:\n{plan_str}" + ); + Ok(()) + } + + /// Test that an aliased plain SUM(a) alongside SUM(a+1) and SUM(a+2) is + /// correctly detected as a "plain SUM in group" and reused, rather than + /// being duplicated in the new Aggregate node. + /// + /// This exercises the Expr::Alias unwrapping in `check_plain_sum_in_group`. + #[test] + fn test_aliased_plain_sum_in_group() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + Vec::::new(), + vec![ + // This plain SUM(a) is aliased — it should be detected + // as matching the rewrite group for "a" and reused. + sum(col("a")).alias("total_a"), + sum(col("a") + lit(1)), + sum(col("a") + lit(2)), + ], + )? + .build()?; + + let rule = RewriteAggregateWithConstant::new(); + let config = OptimizerContext::new(); + let result = rule.rewrite(plan, &config)?; + + // Should be transformed + assert!(result.transformed); + + let plan_str = format!("{}", result.data.display_indent()); + + // The rewritten plan should contain "total_a" alias in the Projection + assert!( + plan_str.contains("total_a"), + "Expected alias 'total_a' in rewritten plan, got:\n{plan_str}" + ); + + // The Aggregate node should contain exactly one SUM(a) and one COUNT(a), + // NOT a duplicate SUM(a). Count occurrences of "sum(test.a)" in the + // Aggregate line (not the Projection line). + let aggr_line = plan_str + .lines() + .find(|l| l.contains("Aggregate:")) + .expect("should have Aggregate node"); + let sum_count = aggr_line.matches("sum(test.a)").count(); + assert_eq!( + sum_count, 1, + "Expected exactly 1 SUM(test.a) in Aggregate (no duplicate), got {sum_count} in:\n{aggr_line}" + ); + + Ok(()) + } } diff --git a/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt b/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt index 2ad729cf1d54e..a9390ce537d12 100644 --- a/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt +++ b/datafusion/sqllogictest/test_files/aggregate_rewrite_with_constant.slt @@ -621,6 +621,61 @@ physical_plan 07)------------FilterExec: a@0 > 3 08)--------------DataSourceExec: partitions=1, partition_sizes=[1] +# ==== Test 13: Constant on the left side (Commutative property) ==== + +# Test: SUM(constant + a) should be rewritten to SUM(a) + constant * COUNT(a) +query TT +EXPLAIN SELECT + SUM(5 + a) as sum_5_plus_a, + SUM(10 + a) as sum_10_plus_a +FROM test_table; +---- +logical_plan +01)Projection: sum(test_table.a) + Int64(5) * count(test_table.a) AS sum_5_plus_a, sum(test_table.a) + Int64(10) * count(test_table.a) AS sum_10_plus_a +02)--Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test_table.a), count(__common_expr_1 AS test_table.a)]] +03)----Projection: CAST(test_table.a AS Int64) AS __common_expr_1 +04)------TableScan: test_table projection=[a] +physical_plan +01)ProjectionExec: expr=[sum(test_table.a)@0 + 5 * count(test_table.a)@1 as sum_5_plus_a, sum(test_table.a)@0 + 10 * count(test_table.a)@1 as sum_10_plus_a] +02)--AggregateExec: mode=Single, gby=[], aggr=[sum(test_table.a), count(test_table.a)] +03)----ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1] +04)------DataSourceExec: partitions=1, partition_sizes=[1] + +query II +SELECT + SUM(5 + a) as sum_5_plus_a, + SUM(10 + a) as sum_10_plus_a +FROM test_table; +---- +40 65 + +# ==== Test 14: Constant on the left side with subtraction (SHOULD NOT rewrite) ==== + +# Test: SUM(constant - a) should NOT be rewritten because 5 - a is NOT a - 5. +# The rule intentionally only handles constant + col and col - constant. +query TT +EXPLAIN SELECT + SUM(5 - a), + SUM(10 - a) +FROM test_table; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[sum(Int64(5) - __common_expr_1 AS test_table.a), sum(Int64(10) - __common_expr_1 AS test_table.a)]] +02)--Projection: CAST(test_table.a AS Int64) AS __common_expr_1 +03)----TableScan: test_table projection=[a] +physical_plan +01)AggregateExec: mode=Single, gby=[], aggr=[sum(Int64(5) - test_table.a), sum(Int64(10) - test_table.a)] +02)--ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1] +03)----DataSourceExec: partitions=1, partition_sizes=[1] + +query II +SELECT + SUM(5 - a), + SUM(10 - a) +FROM test_table; +---- +10 35 + # ==== Cleanup ==== From ae52829459bbfddcbf4ebf78791b9d78b6b42721 Mon Sep 17 00:00:00 2001 From: Devanshu Date: Fri, 6 Feb 2026 10:19:37 +0700 Subject: [PATCH 8/9] Remove redundant data structures --- .../src/rewrite_aggregate_with_constant.rs | 192 +++++++++--------- 1 file changed, 95 insertions(+), 97 deletions(-) diff --git a/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs b/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs index f7686002d3bff..52ad0333da5c1 100644 --- a/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs +++ b/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs @@ -20,9 +20,11 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; +use std::collections::HashMap; + use datafusion_common::tree_node::Transformed; use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::expr::{AggregateFunctionParams, Sort}; +use datafusion_expr::expr::AggregateFunctionParams; use datafusion_expr::{ Aggregate, BinaryExpr, Expr, LogicalPlan, LogicalPlanBuilder, Operator, binary_expr, col, lit, @@ -95,19 +97,21 @@ impl OptimizerRule for RewriteAggregateWithConstant { } } -/// Internal structure to track metadata for a SUM expression that qualifies for rewrite +/// Internal structure to track metadata for a SUM expression that qualifies for rewrite. #[derive(Debug, Clone)] struct SumWithConstant { - /// The inner expression being manipulated (e.g., the 'a' in SUM(a + 1)) + /// The inner expression being summed (e.g., the `a` in `SUM(a + 1)`) base_expr: Expr, - /// The constant value being added/subtracted + /// The constant value being added/subtracted (e.g., `1` in `SUM(a + 1)`) constant: ScalarValue, - /// The operator (+ or -) + /// The operator (`+` or `-`) operator: Operator, - /// The index in the original Aggregate's aggr_expr list, used to maintain output order + /// The index in the original Aggregate's `aggr_expr` list, used to maintain output order original_index: usize, - /// Any ORDER BY clause inside the aggregate (e.g., SUM(a+1 ORDER BY b)) - order_by: Vec, + // Note: ORDER BY inside SUM is irrelevant because SUM is commutative — + // the order of addition doesn't change the result. If this rule is ever + // extended to non-commutative aggregates, ORDER BY handling would need + // to be added back. } /// Maps a base expression's schema name to all its SUM(base ± const) variants. @@ -159,7 +163,7 @@ fn extract_sum_with_constant(expr: &Expr, idx: usize) -> Result Result Result Result bool { ) } -/// Check if an expression is a plain SUM(base_expr) that matches one of our rewrite groups. +/// Check if an expression is a plain `SUM(base_expr)` whose base expression matches +/// one of our rewrite groups. Returns `Some(base_key)` if matched, `None` otherwise. +/// /// Handles both `Expr::AggregateFunction(...)` and `Expr::Alias(Expr::AggregateFunction(...))` /// so that aliased plain SUMs (e.g., `SUM(a) AS total`) are correctly detected and reused /// instead of being duplicated in the new Aggregate node. +/// +/// Returning the `base_key` directly eliminates the need for the caller to re-extract it +/// from the expression (avoiding a potential panic on impossible-to-reach fallback paths). fn check_plain_sum_in_group( expr: &Expr, - base_expr_indices: &IndexMap, -) -> Option<(usize, usize)> { + known_base_keys: &IndexMap, +) -> Option { // Unwrap alias if present let inner = match expr { Expr::Alias(alias) => alias.expr.as_ref(), @@ -260,127 +269,117 @@ fn check_plain_sum_in_group( { let arg = &agg_fn.params.args[0]; let base_key = arg.schema_name().to_string(); - return base_expr_indices.get(&base_key).copied(); + if known_base_keys.contains_key(&base_key) { + return Some(base_key); + } } None } -/// Transform the aggregate plan by rewriting SUM(col ± constant) expressions +/// Alias `expr` to match the output column name of `orig_expr`. +/// +/// If `orig_expr` is `Expr::Alias(name)`, the result is aliased to that name. +/// Otherwise, the result is aliased to `orig_expr.schema_name()` to preserve +/// the auto-generated column name (e.g., `"sum(t.a + Int64(1))"`). +fn alias_like(expr: Expr, orig_expr: &Expr) -> Expr { + match orig_expr { + Expr::Alias(alias) => expr.alias(alias.name.clone()), + _ => expr.alias(orig_expr.schema_name().to_string()), + } +} + +/// Transform the aggregate plan by rewriting SUM(col ± constant) expressions. +/// +/// Replaces the original `Aggregate` node with: +/// 1. A new `Aggregate` containing one `SUM(base)` + one `COUNT(base)` per group, +/// plus any unrelated expressions passed through. +/// 2. A `Projection` on top that derives each original output column using +/// the formula: `SUM(base) ± constant * COUNT(base)`. fn transform_aggregate( aggregate: Aggregate, rewrite_groups: &RewriteGroups, ) -> Result> { - let mut new_aggr_exprs = Vec::new(); - let mut projection_exprs = Vec::new(); + let mut new_aggr_exprs: Vec = Vec::new(); + let mut projection_exprs: Vec = Vec::new(); - // Build a flat list of all SUMs to rewrite, sorted by original index - let mut all_sums: Vec = rewrite_groups + // (Fix 1) Build a HashMap for O(1) lookup of rewritable expressions by their + // original index, replacing the previous O(n*m) linear scan with `.find()`. + let rewrite_lookup: HashMap = rewrite_groups .values() - .flat_map(|v| v.iter().cloned()) + .flatten() + .map(|s| (s.original_index, s)) .collect(); - all_sums.sort_by_key(|s| s.original_index); - // Maps base column names to the indices of their new SUM/COUNT in the new Aggregate node + // (Fix 2+6) Single map from base expression key to (sum_index, count_index) + // in `new_aggr_exprs`. We derive SUM/COUNT names directly from + // `new_aggr_exprs[index].schema_name()` instead of maintaining parallel + // `sum_names` / `count_names` maps that could go out of sync. let mut base_expr_indices: IndexMap = IndexMap::new(); - // Process each group to determine what to add to the aggregate - let mut sum_names: IndexMap = IndexMap::new(); - let mut count_names: IndexMap = IndexMap::new(); - - // For every group (e.g., all SUMs involving column 'a'), add one SUM(a) and one COUNT(a) + // For every rewrite group (e.g., all SUMs involving column 'a'), + // add one SUM(base) and one COUNT(base) to the new Aggregate. for (base_key, sums) in rewrite_groups { - // If any original SUM had an ORDER BY, we try to preserve it in our new base SUM. - let representative = sums - .iter() - .find(|s| !s.order_by.is_empty()) - .unwrap_or(&sums[0]); - - // Add SUM(base_expr) with ORDER BY preserved - let sum_expr = sum(representative.base_expr.clone()); - // Note: ORDER BY is not needed for SUM as it's commutative - let sum_name = sum_expr.schema_name().to_string(); + let base_expr = &sums[0].base_expr; + + // Add SUM(base_expr) let sum_index = new_aggr_exprs.len(); - new_aggr_exprs.push(sum_expr); - sum_names.insert(base_key.clone(), sum_name); + new_aggr_exprs.push(sum(base_expr.clone())); - // Add the base COUNT(a) - // We use COUNT(col) rather than COUNT(*) because if 'col' is NULL, + // Add COUNT(base_expr) + // We use COUNT(col) rather than COUNT(*) because if 'col' is NULL, // SUM(col + 1) should be NULL, and COUNT(col) correctly returns 0 for NULLs, // whereas COUNT(*) would count the row. - let count_expr = count(representative.base_expr.clone()); - let count_name = count_expr.schema_name().to_string(); let count_index = new_aggr_exprs.len(); - new_aggr_exprs.push(count_expr); - count_names.insert(base_key.clone(), count_name); + new_aggr_exprs.push(count(base_expr.clone())); base_expr_indices.insert(base_key.clone(), (sum_index, count_index)); } - // Now iterate through the ORIGINAL aggregate expressions to build the PROJECTION + // Iterate through the ORIGINAL aggregate expressions to build the PROJECTION. + // Each original expression falls into one of three cases. for (idx, orig_expr) in aggregate.aggr_expr.iter().enumerate() { - // Check if this expression should be rewritten - let rewritten = all_sums.iter().find(|s| s.original_index == idx); - - let projection_expr = if let Some(sum_info) = rewritten { + let projection_expr = if let Some(sum_info) = rewrite_lookup.get(&idx) { + // ── Case 1: Rewritable SUM(col ± constant) ── + // Derive: SUM(col) ± (constant * COUNT(col)) let base_key = sum_info.base_expr.schema_name().to_string(); + let (sum_idx, count_idx) = base_expr_indices[&base_key]; - // Construct the math: SUM(col) [±] (constant * COUNT(col)) - let sum_ref = col(&sum_names[&base_key]); - let count_ref = col(&count_names[&base_key]); + let sum_ref = col(new_aggr_exprs[sum_idx].schema_name().to_string()); + let count_ref = col(new_aggr_exprs[count_idx].schema_name().to_string()); let multiplied = binary_expr( lit(sum_info.constant.clone()), Operator::Multiply, count_ref, ); - let result = binary_expr(sum_ref, sum_info.operator, multiplied); + alias_like(result, orig_expr) + } else if let Some(base_key) = + check_plain_sum_in_group(orig_expr, &base_expr_indices) + { + // ── Case 2: Plain SUM(a) that matches a rewrite group ── + // Reuse the SUM(a) we already added instead of creating a duplicate. + // `base_key` is returned directly by check_plain_sum_in_group, + // so there is no need to re-extract it (and no risk of a panic). + let (sum_idx, _) = base_expr_indices[&base_key]; + let sum_ref = col(new_aggr_exprs[sum_idx].schema_name().to_string()); + alias_like(sum_ref, orig_expr) + } else { + // ── Case 3: Unrelated expression (e.g., AVG(b), MAX(c)) ── + // Pass it through to the new Aggregate node unchanged. + new_aggr_exprs.push(orig_expr.clone()); - // Ensure the output column name matches the original (aliased or generated) + // Reference it in the projection by name. match orig_expr { - Expr::Alias(alias) => result.alias(alias.name.clone()), - _ => result.alias(orig_expr.schema_name().to_string()), - } - } else { - // Special case: If the user had a plain SUM(a) alongside SUM(a+1), - // we should reuse the SUM(a) we just added instead of adding another one. - let is_plain_sum_in_group = - check_plain_sum_in_group(orig_expr, &base_expr_indices); - - if is_plain_sum_in_group.is_some() { - // Use the already-computed SUM. - // Unwrap alias to get to the AggregateFunction and extract the base key. - let inner_expr = match orig_expr { - Expr::Alias(alias) => alias.expr.as_ref(), - other => other, - }; - let base_key = if let Expr::AggregateFunction(agg_fn) = inner_expr { - agg_fn.params.args[0].schema_name().to_string() - } else { - String::new() - }; - let sum_ref = col(&sum_names[&base_key]); - match orig_expr { - Expr::Alias(alias) => sum_ref.alias(alias.name.clone()), - _ => sum_ref.alias(orig_expr.schema_name().to_string()), - } - } else { - // This expression is unrelated to our rewrites (e.g., AVG(b) or MAX(c)). - // We just pass it through to the new Aggregate node. - new_aggr_exprs.push(orig_expr.clone()); - - // And reference it in the projection by name. - match orig_expr { - Expr::Alias(alias) => col(alias.name.clone()), - _ => col(orig_expr.schema_name().to_string()), - } + Expr::Alias(alias) => col(alias.name.clone()), + _ => col(orig_expr.schema_name().to_string()), } }; projection_exprs.push(projection_expr); } - // Handle GROUP BY columns: they must be passed through the Aggregate and Projection + // Handle GROUP BY columns: they must be passed through the Aggregate and Projection. let group_exprs: Vec = aggregate .group_expr .iter() @@ -391,18 +390,17 @@ fn transform_aggregate( }) .collect(); - // Final projection includes [Group Columns] + [Aggregated/Rewritten Columns] + // Final projection: [GROUP BY columns] ++ [aggregate/rewritten columns] let mut final_projection = group_exprs; final_projection.extend(projection_exprs); - // Create the new Aggregate plan node + // Build the two-node plan: Projection → Aggregate → (original input) let new_aggregate = LogicalPlan::Aggregate(Aggregate::try_new( aggregate.input, aggregate.group_expr, new_aggr_exprs, )?); - // Wrap the Aggregate with the Projection let projection = LogicalPlanBuilder::from(new_aggregate) .project(final_projection)? .build()?; From 53307e1fe3fcf389026741eef5ae076e1df3d33a Mon Sep 17 00:00:00 2001 From: Devanshu Date: Fri, 6 Feb 2026 10:45:56 +0700 Subject: [PATCH 9/9] Refactor, clippy and formatting --- .../src/rewrite_aggregate_with_constant.rs | 119 ++++++++++++------ 1 file changed, 80 insertions(+), 39 deletions(-) diff --git a/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs b/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs index 52ad0333da5c1..d550d85779375 100644 --- a/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs +++ b/datafusion/optimizer/src/rewrite_aggregate_with_constant.rs @@ -132,7 +132,7 @@ fn analyze_aggregate(aggregate: &Aggregate) -> Result { } // Optimization: Only rewrite if we have at least 2 expressions for the same column. - // If there's only one SUM(a + 1), rewriting it to SUM(a) + 1*COUNT(a) + // If there's only one SUM(a + 1), rewriting it to SUM(a) + 1*COUNT(a) // actually increases the work (1 agg -> 2 aggs). groups.retain(|_, v| v.len() >= 2); @@ -172,7 +172,7 @@ fn extract_sum_with_constant(expr: &Expr, idx: usize) -> Result Expr { /// plus any unrelated expressions passed through. /// 2. A `Projection` on top that derives each original output column using /// the formula: `SUM(base) ± constant * COUNT(base)`. +/// +/// Delegates to [`build_rewrite_expressions`] for the expression logic and +/// [`assemble_rewritten_plan`] for the plan construction. fn transform_aggregate( aggregate: Aggregate, rewrite_groups: &RewriteGroups, ) -> Result> { + let (new_aggr_exprs, projection_exprs) = + build_rewrite_expressions(&aggregate, rewrite_groups); + assemble_rewritten_plan(aggregate, new_aggr_exprs, projection_exprs) +} + +/// Build the new aggregate expressions and corresponding projection expressions. +/// +/// For each rewrite group (e.g., all SUMs involving column `a`), adds one +/// `SUM(a)` and one `COUNT(a)` to the new aggregate. Then maps every original +/// aggregate expression to a projection expression using one of three cases: +/// +/// - **Case 1** (rewritable): `SUM(a + k)` → `SUM(a) + k * COUNT(a)` +/// - **Case 2** (plain SUM in group): `SUM(a)` → reference the already-added `SUM(a)` +/// - **Case 3** (unrelated): pass through unchanged (e.g., `AVG(b)`) +/// +/// Returns `(new_aggr_exprs, projection_exprs)`. +fn build_rewrite_expressions( + aggregate: &Aggregate, + rewrite_groups: &RewriteGroups, +) -> (Vec, Vec) { let mut new_aggr_exprs: Vec = Vec::new(); let mut projection_exprs: Vec = Vec::new(); - // (Fix 1) Build a HashMap for O(1) lookup of rewritable expressions by their - // original index, replacing the previous O(n*m) linear scan with `.find()`. + // O(1) lookup of rewritable expressions by their original index. let rewrite_lookup: HashMap = rewrite_groups .values() .flatten() .map(|s| (s.original_index, s)) .collect(); - // (Fix 2+6) Single map from base expression key to (sum_index, count_index) - // in `new_aggr_exprs`. We derive SUM/COUNT names directly from - // `new_aggr_exprs[index].schema_name()` instead of maintaining parallel - // `sum_names` / `count_names` maps that could go out of sync. + // Maps base expression key → (sum_index, count_index) in `new_aggr_exprs`. + // Column names are derived on-demand from `new_aggr_exprs[index].schema_name()` + // rather than maintaining separate name maps. let mut base_expr_indices: IndexMap = IndexMap::new(); - // For every rewrite group (e.g., all SUMs involving column 'a'), - // add one SUM(base) and one COUNT(base) to the new Aggregate. + // For every rewrite group, add one SUM(base) and one COUNT(base) + // to the new Aggregate. for (base_key, sums) in rewrite_groups { let base_expr = &sums[0].base_expr; - // Add SUM(base_expr) let sum_index = new_aggr_exprs.len(); new_aggr_exprs.push(sum(base_expr.clone())); - // Add COUNT(base_expr) - // We use COUNT(col) rather than COUNT(*) because if 'col' is NULL, - // SUM(col + 1) should be NULL, and COUNT(col) correctly returns 0 for NULLs, - // whereas COUNT(*) would count the row. + // COUNT(col) rather than COUNT(*): if col is NULL, SUM(col+1) should + // be NULL, and COUNT(col) correctly returns 0 for NULLs. let count_index = new_aggr_exprs.len(); new_aggr_exprs.push(count(base_expr.clone())); base_expr_indices.insert(base_key.clone(), (sum_index, count_index)); } - // Iterate through the ORIGINAL aggregate expressions to build the PROJECTION. - // Each original expression falls into one of three cases. + // Map each original aggregate expression to a projection expression. for (idx, orig_expr) in aggregate.aggr_expr.iter().enumerate() { let projection_expr = if let Some(sum_info) = rewrite_lookup.get(&idx) { // ── Case 1: Rewritable SUM(col ± constant) ── // Derive: SUM(col) ± (constant * COUNT(col)) - let base_key = sum_info.base_expr.schema_name().to_string(); - let (sum_idx, count_idx) = base_expr_indices[&base_key]; - - let sum_ref = col(new_aggr_exprs[sum_idx].schema_name().to_string()); - let count_ref = col(new_aggr_exprs[count_idx].schema_name().to_string()); - - let multiplied = binary_expr( - lit(sum_info.constant.clone()), - Operator::Multiply, - count_ref, - ); - let result = binary_expr(sum_ref, sum_info.operator, multiplied); - alias_like(result, orig_expr) + build_derived_projection( + sum_info, + &base_expr_indices, + &new_aggr_exprs, + orig_expr, + ) } else if let Some(base_key) = check_plain_sum_in_group(orig_expr, &base_expr_indices) { - // ── Case 2: Plain SUM(a) that matches a rewrite group ── + // ── Case 2: Plain SUM(a) matching a rewrite group ── // Reuse the SUM(a) we already added instead of creating a duplicate. - // `base_key` is returned directly by check_plain_sum_in_group, - // so there is no need to re-extract it (and no risk of a panic). let (sum_idx, _) = base_expr_indices[&base_key]; let sum_ref = col(new_aggr_exprs[sum_idx].schema_name().to_string()); alias_like(sum_ref, orig_expr) @@ -368,8 +376,6 @@ fn transform_aggregate( // ── Case 3: Unrelated expression (e.g., AVG(b), MAX(c)) ── // Pass it through to the new Aggregate node unchanged. new_aggr_exprs.push(orig_expr.clone()); - - // Reference it in the projection by name. match orig_expr { Expr::Alias(alias) => col(alias.name.clone()), _ => col(orig_expr.schema_name().to_string()), @@ -379,8 +385,43 @@ fn transform_aggregate( projection_exprs.push(projection_expr); } - // Handle GROUP BY columns: they must be passed through the Aggregate and Projection. - let group_exprs: Vec = aggregate + (new_aggr_exprs, projection_exprs) +} + +/// Build the projection expression for a single rewritable SUM(col ± constant): +/// +/// `SUM(col) ± (constant * COUNT(col))`, aliased to match the original expression. +fn build_derived_projection( + sum_info: &SumWithConstant, + base_expr_indices: &IndexMap, + new_aggr_exprs: &[Expr], + orig_expr: &Expr, +) -> Expr { + let base_key = sum_info.base_expr.schema_name().to_string(); + let (sum_idx, count_idx) = base_expr_indices[&base_key]; + + let sum_ref = col(new_aggr_exprs[sum_idx].schema_name().to_string()); + let count_ref = col(new_aggr_exprs[count_idx].schema_name().to_string()); + + let multiplied = binary_expr( + lit(sum_info.constant.clone()), + Operator::Multiply, + count_ref, + ); + let result = binary_expr(sum_ref, sum_info.operator, multiplied); + alias_like(result, orig_expr) +} + +/// Assemble the final rewritten plan: `Projection → Aggregate → (original input)`. +/// +/// GROUP BY columns are passed through both nodes so the output schema is preserved. +fn assemble_rewritten_plan( + aggregate: Aggregate, + new_aggr_exprs: Vec, + projection_exprs: Vec, +) -> Result> { + // GROUP BY columns: reference them in the Projection. + let group_refs: Vec = aggregate .group_expr .iter() .map(|e| match e { @@ -391,10 +432,10 @@ fn transform_aggregate( .collect(); // Final projection: [GROUP BY columns] ++ [aggregate/rewritten columns] - let mut final_projection = group_exprs; + let mut final_projection = group_refs; final_projection.extend(projection_exprs); - // Build the two-node plan: Projection → Aggregate → (original input) + // Build: Projection → Aggregate → (original input) let new_aggregate = LogicalPlan::Aggregate(Aggregate::try_new( aggregate.input, aggregate.group_expr,