From 1a25cfe9aa6d7b02ea9447a0b5f46f0504e3a54b Mon Sep 17 00:00:00 2001 From: duanhao-jk Date: Fri, 23 Jan 2026 14:37:04 +0800 Subject: [PATCH 1/5] add case when function --- .../datafusion-ext-functions/src/lib.rs | 2 + .../src/spark_case_when.rs | 318 ++++++++++++++++++ .../spark/sql/auron/NativeConverters.scala | 33 +- 3 files changed, 337 insertions(+), 16 deletions(-) create mode 100644 native-engine/datafusion-ext-functions/src/spark_case_when.rs diff --git a/native-engine/datafusion-ext-functions/src/lib.rs b/native-engine/datafusion-ext-functions/src/lib.rs index a65dc0d44..071329cea 100644 --- a/native-engine/datafusion-ext-functions/src/lib.rs +++ b/native-engine/datafusion-ext-functions/src/lib.rs @@ -19,6 +19,7 @@ use datafusion::{common::Result, logical_expr::ScalarFunctionImplementation}; use datafusion_ext_commons::df_unimplemented_err; mod brickhouse; +mod spark_case_when; mod spark_bround; mod spark_check_overflow; mod spark_crypto; @@ -85,6 +86,7 @@ pub fn create_auron_ext_function( Arc::new(spark_normalize_nan_and_zero::spark_normalize_nan_and_zero) } "Spark_IsNaN" => Arc::new(spark_isnan::spark_isnan), + "Spark_CaseWhen" => Arc::new(spark_case_when::spark_case_when), _ => df_unimplemented_err!("spark ext function not implemented: {name}")?, }) } diff --git a/native-engine/datafusion-ext-functions/src/spark_case_when.rs b/native-engine/datafusion-ext-functions/src/spark_case_when.rs new file mode 100644 index 000000000..22fe0954a --- /dev/null +++ b/native-engine/datafusion-ext-functions/src/spark_case_when.rs @@ -0,0 +1,318 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow::{ + array::*, + compute::kernels::zip::zip, + datatypes::DataType, +}; +use datafusion::{ + common::Result, + physical_plan::ColumnarValue, +}; +use datafusion_ext_commons::df_execution_err; + +/// CASE WHEN function implementation +/// +/// Syntax: case_when(condition1, value1, condition2, value2, ..., else_value) +/// +/// Arguments: +/// - Must have odd number of arguments (at least 3) +/// - Pairs of (condition, value), with optional else_value at the end +/// - If no else_value provided and no conditions match, returns NULL +/// +/// Example: +/// - case_when(x > 10, 'big', x > 5, 'medium', 'small') +/// - case_when(x IS NULL, 0, x) +pub fn spark_case_when(args: &[ColumnarValue]) -> Result { + if args.is_empty() { + return df_execution_err!("case_when requires at least 1 argument (else value)"); + } + + // Special case: only one argument means it's the else value + if args.len() == 1 { + return Ok(args[0].clone()); + } + + // Determine if we have an else value (odd number of args means we do) + let has_else = args.len() % 2 == 1; + let num_conditions = if has_else { + (args.len() - 1) / 2 + } else { + args.len() / 2 + }; + + // Get the batch size + let batch_size = match &args[0] { + ColumnarValue::Array(array) => array.len(), + ColumnarValue::Scalar(_) => { + // If all inputs are scalars, find the first array to determine size + let mut size = 1; + for arg in args { + if let ColumnarValue::Array(array) = arg { + size = array.len(); + break; + } + } + size + } + }; + + // Convert all inputs to arrays + let mut conditions = Vec::with_capacity(num_conditions); + let mut values = Vec::with_capacity(num_conditions); + + for i in 0..num_conditions { + let condition_idx = i * 2; + let value_idx = i * 2 + 1; + + let condition_array = args[condition_idx].clone().into_array(batch_size)?; + let value_array = args[value_idx].clone().into_array(batch_size)?; + + // Verify condition is boolean + if condition_array.data_type() != &DataType::Boolean { + return df_execution_err!( + "case_when condition at position {} must be boolean, got {:?}", + condition_idx, + condition_array.data_type() + ); + } + + conditions.push(as_boolean_array(&condition_array).clone()); + values.push(value_array); + } + + // Get else value if present + let else_array = if has_else { + Some(args[args.len() - 1].clone().into_array(batch_size)?) + } else { + None + }; + + // Determine output data type (from first value) + let output_type = values[0].data_type().clone(); + + // Build the result array + let result = evaluate_case_when(&conditions, &values, else_array.as_ref(), batch_size, &output_type)?; + + // If all inputs were scalars, return a scalar + if batch_size == 1 && args.iter().all(|arg| matches!(arg, ColumnarValue::Scalar(_))) { + let scalar = datafusion::common::ScalarValue::try_from_array(&result, 0)?; + Ok(ColumnarValue::Scalar(scalar)) + } else { + Ok(ColumnarValue::Array(result)) + } +} + +/// Evaluate the case when logic +fn evaluate_case_when( + conditions: &[BooleanArray], + values: &[ArrayRef], + else_value: Option<&ArrayRef>, + batch_size: usize, + output_type: &DataType, +) -> Result { + use arrow::array::new_null_array; + + // Initialize result with nulls or else value + let mut result: ArrayRef = if let Some(else_array) = else_value { + else_array.clone() + } else { + new_null_array(output_type, batch_size) + }; + + // Process conditions in reverse order so earlier conditions take precedence + for i in (0..conditions.len()).rev() { + let condition = &conditions[i]; + let value = &values[i]; + + // Use arrow's zip kernel to select between current result and value based on condition + result = zip(condition, value, &result)?; + } + + Ok(result) +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::{ArrayRef, BooleanArray, Float64Array, Int32Array, StringArray}; + use datafusion::{common::ScalarValue, logical_expr::ColumnarValue}; + + use super::*; + + #[test] + fn test_case_when_simple() -> Result<()> { + // case_when(x > 5, 'big', 'small') + let condition = Arc::new(BooleanArray::from(vec![true, false, true, false])); + let value_true = Arc::new(StringArray::from(vec!["big", "big", "big", "big"])); + let value_else = Arc::new(StringArray::from(vec!["small", "small", "small", "small"])); + + let result = spark_case_when(&[ + ColumnarValue::Array(condition), + ColumnarValue::Array(value_true), + ColumnarValue::Array(value_else), + ])?; + + let expected = StringArray::from(vec!["big", "small", "big", "small"]); + let result_array = result.into_array(4)?; + + assert_eq!( + result_array.as_any().downcast_ref::().unwrap(), + &expected + ); + Ok(()) + } + + #[test] + fn test_case_when_multiple_conditions() -> Result<()> { + // case_when(x > 10, 100, x > 5, 50, 0) + let x = vec![15, 8, 3, 12, 5]; + + let condition1 = Arc::new(BooleanArray::from( + x.iter().map(|&v| v > 10).collect::>() + )); + let value1 = Arc::new(Int32Array::from(vec![100, 100, 100, 100, 100])); + + let condition2 = Arc::new(BooleanArray::from( + x.iter().map(|&v| v > 5).collect::>() + )); + let value2 = Arc::new(Int32Array::from(vec![50, 50, 50, 50, 50])); + + let else_value = Arc::new(Int32Array::from(vec![0, 0, 0, 0, 0])); + + let result = spark_case_when(&[ + ColumnarValue::Array(condition1), + ColumnarValue::Array(value1), + ColumnarValue::Array(condition2), + ColumnarValue::Array(value2), + ColumnarValue::Array(else_value), + ])?; + + let expected = Int32Array::from(vec![100, 50, 0, 100, 0]); + let result_array = result.into_array(5)?; + + assert_eq!( + result_array.as_any().downcast_ref::().unwrap(), + &expected + ); + Ok(()) + } + + #[test] + fn test_case_when_no_else() -> Result<()> { + // case_when(x > 5, 100) - no else, should return NULL for non-matching + let condition = Arc::new(BooleanArray::from(vec![true, false, true, false])); + let value = Arc::new(Int32Array::from(vec![100, 100, 100, 100])); + + let result = spark_case_when(&[ + ColumnarValue::Array(condition), + ColumnarValue::Array(value), + ])?; + + let expected = Int32Array::from(vec![Some(100), None, Some(100), None]); + let result_array = result.into_array(4)?; + + assert_eq!( + result_array.as_any().downcast_ref::().unwrap(), + &expected + ); + Ok(()) + } + + #[test] + fn test_case_when_with_nulls() -> Result<()> { + // Test handling of NULL conditions + let condition = Arc::new(BooleanArray::from(vec![Some(true), None, Some(false), Some(true)])); + let value = Arc::new(Int32Array::from(vec![10, 10, 10, 10])); + let else_value = Arc::new(Int32Array::from(vec![20, 20, 20, 20])); + + let result = spark_case_when(&[ + ColumnarValue::Array(condition), + ColumnarValue::Array(value), + ColumnarValue::Array(else_value), + ])?; + + // NULL conditions should be treated as false + let expected = Int32Array::from(vec![10, 20, 20, 10]); + let result_array = result.into_array(4)?; + + assert_eq!( + result_array.as_any().downcast_ref::().unwrap(), + &expected + ); + Ok(()) + } + + #[test] + fn test_case_when_scalar() -> Result<()> { + // Test with scalar inputs + let condition = ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))); + let value_true = ColumnarValue::Scalar(ScalarValue::Float64(Some(1.5))); + let value_else = ColumnarValue::Scalar(ScalarValue::Float64(Some(2.5))); + + let result = spark_case_when(&[condition, value_true, value_else])?; + + match result { + ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => { + assert_eq!(v, 1.5); + } + _ => panic!("Expected scalar float64"), + } + Ok(()) + } + + #[test] + fn test_case_when_only_else() -> Result<()> { + // Only one argument (else value) + let else_value = ColumnarValue::Scalar(ScalarValue::Int32(Some(42))); + + let result = spark_case_when(&[else_value.clone()])?; + + match result { + ColumnarValue::Scalar(ScalarValue::Int32(Some(v))) => { + assert_eq!(v, 42); + } + _ => panic!("Expected scalar int32"), + } + Ok(()) + } + + #[test] + fn test_case_when_mixed_scalar_array() -> Result<()> { + // Mix of scalar and array inputs + let condition = Arc::new(BooleanArray::from(vec![true, false, true])); + let value_true = ColumnarValue::Scalar(ScalarValue::Int32(Some(100))); + let value_else = Arc::new(Int32Array::from(vec![1, 2, 3])); + + let result = spark_case_when(&[ + ColumnarValue::Array(condition), + value_true, + ColumnarValue::Array(value_else), + ])?; + + let expected = Int32Array::from(vec![100, 2, 100]); + let result_array = result.into_array(3)?; + + assert_eq!( + result_array.as_any().downcast_ref::().unwrap(), + &expected + ); + Ok(()) + } +} diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala index 29b9386ae..56976fb3f 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala @@ -1025,27 +1025,28 @@ object NativeConverters extends Logging { convertExprWithFallback(caseWhen, isPruningExpr, fallback) case e @ CaseWhen(branches, elseValue) => - val caseExpr = pb.PhysicalCaseNode.newBuilder() - val whenThens = branches.map { case (w, t) => - val casted = t match { + // Flatten branches into: condition1, value1, condition2, value2, ..., elseValue + // Cast values to match output data type + val flattenedArgs = branches.flatMap { case (when, then) => + val castedThen = then match { case t if t.dataType != e.dataType => Cast(t, e.dataType) case t => t } - pb.PhysicalWhenThen - .newBuilder() - .setWhenExpr(convertExprWithFallback(w, isPruningExpr, fallback)) - .setThenExpr(convertExprWithFallback(casted, isPruningExpr, fallback)) - .build() + Seq(when, castedThen) } - caseExpr.addAllWhenThenExpr(whenThens.asJava) - elseValue.foreach { el => - val casted = el match { - case el if el.dataType != e.dataType => Cast(el, e.dataType) - case el => el - } - caseExpr.setElseExpr(convertExprWithFallback(casted, isPruningExpr, fallback)) + + // Add else value if present (cast to output type) + val allArgs = elseValue match { + case Some(el) => + val castedElse = el match { + case e1 if e1.dataType != e.dataType => Cast(e1, e.dataType) + case e1 => e1 + } + flattenedArgs :+ castedElse + case None => flattenedArgs } - pb.PhysicalExprNode.newBuilder().setCase(caseExpr).build() + + buildExtScalarFunction("Spark_CaseWhen", allArgs, e.dataType) // expressions for DecimalPrecision rule case UnscaledValue(_1) if decimalArithOpEnabled => From 50d66e33cd75563276ce09723b4f397cf1fbb77a Mon Sep 17 00:00:00 2001 From: duanhao-jk Date: Fri, 23 Jan 2026 14:44:36 +0800 Subject: [PATCH 2/5] checkStyle fix --- .../datafusion-ext-functions/src/lib.rs | 2 +- .../src/spark_case_when.rs | 90 ++++++++++++------- 2 files changed, 59 insertions(+), 33 deletions(-) diff --git a/native-engine/datafusion-ext-functions/src/lib.rs b/native-engine/datafusion-ext-functions/src/lib.rs index 071329cea..8c7e970eb 100644 --- a/native-engine/datafusion-ext-functions/src/lib.rs +++ b/native-engine/datafusion-ext-functions/src/lib.rs @@ -19,8 +19,8 @@ use datafusion::{common::Result, logical_expr::ScalarFunctionImplementation}; use datafusion_ext_commons::df_unimplemented_err; mod brickhouse; -mod spark_case_when; mod spark_bround; +mod spark_case_when; mod spark_check_overflow; mod spark_crypto; mod spark_dates; diff --git a/native-engine/datafusion-ext-functions/src/spark_case_when.rs b/native-engine/datafusion-ext-functions/src/spark_case_when.rs index 22fe0954a..83504d195 100644 --- a/native-engine/datafusion-ext-functions/src/spark_case_when.rs +++ b/native-engine/datafusion-ext-functions/src/spark_case_when.rs @@ -15,26 +15,19 @@ use std::sync::Arc; -use arrow::{ - array::*, - compute::kernels::zip::zip, - datatypes::DataType, -}; -use datafusion::{ - common::Result, - physical_plan::ColumnarValue, -}; +use arrow::{array::*, compute::kernels::zip::zip, datatypes::DataType}; +use datafusion::{common::Result, physical_plan::ColumnarValue}; use datafusion_ext_commons::df_execution_err; /// CASE WHEN function implementation -/// +/// /// Syntax: case_when(condition1, value1, condition2, value2, ..., else_value) -/// +/// /// Arguments: /// - Must have odd number of arguments (at least 3) /// - Pairs of (condition, value), with optional else_value at the end /// - If no else_value provided and no conditions match, returns NULL -/// +/// /// Example: /// - case_when(x > 10, 'big', x > 5, 'medium', 'small') /// - case_when(x IS NULL, 0, x) @@ -107,10 +100,20 @@ pub fn spark_case_when(args: &[ColumnarValue]) -> Result { let output_type = values[0].data_type().clone(); // Build the result array - let result = evaluate_case_when(&conditions, &values, else_array.as_ref(), batch_size, &output_type)?; + let result = evaluate_case_when( + &conditions, + &values, + else_array.as_ref(), + batch_size, + &output_type, + )?; // If all inputs were scalars, return a scalar - if batch_size == 1 && args.iter().all(|arg| matches!(arg, ColumnarValue::Scalar(_))) { + if batch_size == 1 + && args + .iter() + .all(|arg| matches!(arg, ColumnarValue::Scalar(_))) + { let scalar = datafusion::common::ScalarValue::try_from_array(&result, 0)?; Ok(ColumnarValue::Scalar(scalar)) } else { @@ -140,7 +143,8 @@ fn evaluate_case_when( let condition = &conditions[i]; let value = &values[i]; - // Use arrow's zip kernel to select between current result and value based on condition + // Use arrow's zip kernel to select between current result and value based on + // condition result = zip(condition, value, &result)?; } @@ -173,7 +177,10 @@ mod test { let result_array = result.into_array(4)?; assert_eq!( - result_array.as_any().downcast_ref::().unwrap(), + result_array + .as_any() + .downcast_ref::() + .expect("Failed to downcast to StringArray"), &expected ); Ok(()) @@ -183,17 +190,17 @@ mod test { fn test_case_when_multiple_conditions() -> Result<()> { // case_when(x > 10, 100, x > 5, 50, 0) let x = vec![15, 8, 3, 12, 5]; - + let condition1 = Arc::new(BooleanArray::from( - x.iter().map(|&v| v > 10).collect::>() + x.iter().map(|&v| v > 10).collect::>(), )); let value1 = Arc::new(Int32Array::from(vec![100, 100, 100, 100, 100])); - + let condition2 = Arc::new(BooleanArray::from( - x.iter().map(|&v| v > 5).collect::>() + x.iter().map(|&v| v > 5).collect::>(), )); let value2 = Arc::new(Int32Array::from(vec![50, 50, 50, 50, 50])); - + let else_value = Arc::new(Int32Array::from(vec![0, 0, 0, 0, 0])); let result = spark_case_when(&[ @@ -208,7 +215,10 @@ mod test { let result_array = result.into_array(5)?; assert_eq!( - result_array.as_any().downcast_ref::().unwrap(), + result_array + .as_any() + .downcast_ref::() + .expect("Failed to downcast to Int32Array"), &expected ); Ok(()) @@ -220,16 +230,17 @@ mod test { let condition = Arc::new(BooleanArray::from(vec![true, false, true, false])); let value = Arc::new(Int32Array::from(vec![100, 100, 100, 100])); - let result = spark_case_when(&[ - ColumnarValue::Array(condition), - ColumnarValue::Array(value), - ])?; + let result = + spark_case_when(&[ColumnarValue::Array(condition), ColumnarValue::Array(value)])?; let expected = Int32Array::from(vec![Some(100), None, Some(100), None]); let result_array = result.into_array(4)?; assert_eq!( - result_array.as_any().downcast_ref::().unwrap(), + result_array + .as_any() + .downcast_ref::() + .expect("Failed to downcast to Int32Array"), &expected ); Ok(()) @@ -238,7 +249,12 @@ mod test { #[test] fn test_case_when_with_nulls() -> Result<()> { // Test handling of NULL conditions - let condition = Arc::new(BooleanArray::from(vec![Some(true), None, Some(false), Some(true)])); + let condition = Arc::new(BooleanArray::from(vec![ + Some(true), + None, + Some(false), + Some(true), + ])); let value = Arc::new(Int32Array::from(vec![10, 10, 10, 10])); let else_value = Arc::new(Int32Array::from(vec![20, 20, 20, 20])); @@ -253,7 +269,10 @@ mod test { let result_array = result.into_array(4)?; assert_eq!( - result_array.as_any().downcast_ref::().unwrap(), + result_array + .as_any() + .downcast_ref::() + .expect("Failed to downcast to Int32Array"), &expected ); Ok(()) @@ -272,7 +291,9 @@ mod test { ColumnarValue::Scalar(ScalarValue::Float64(Some(v))) => { assert_eq!(v, 1.5); } - _ => panic!("Expected scalar float64"), + _ => { + return df_execution_err!("Expected scalar float64"); + } } Ok(()) } @@ -288,7 +309,9 @@ mod test { ColumnarValue::Scalar(ScalarValue::Int32(Some(v))) => { assert_eq!(v, 42); } - _ => panic!("Expected scalar int32"), + _ => { + return df_execution_err!("Expected scalar int32"); + } } Ok(()) } @@ -310,7 +333,10 @@ mod test { let result_array = result.into_array(3)?; assert_eq!( - result_array.as_any().downcast_ref::().unwrap(), + result_array + .as_any() + .downcast_ref::() + .expect("Failed to downcast to Int32Array"), &expected ); Ok(()) From 7705293299fa2c68128d3510bfb9505f8e11313b Mon Sep 17 00:00:00 2001 From: duanhao-jk Date: Fri, 23 Jan 2026 16:51:50 +0800 Subject: [PATCH 3/5] cc --- native-engine/datafusion-ext-functions/src/spark_case_when.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/native-engine/datafusion-ext-functions/src/spark_case_when.rs b/native-engine/datafusion-ext-functions/src/spark_case_when.rs index 83504d195..c917f6a6b 100644 --- a/native-engine/datafusion-ext-functions/src/spark_case_when.rs +++ b/native-engine/datafusion-ext-functions/src/spark_case_when.rs @@ -30,7 +30,7 @@ use datafusion_ext_commons::df_execution_err; /// /// Example: /// - case_when(x > 10, 'big', x > 5, 'medium', 'small') -/// - case_when(x IS NULL, 0, x) +/// - case_when(x IS NULL, 0, x ) pub fn spark_case_when(args: &[ColumnarValue]) -> Result { if args.is_empty() { return df_execution_err!("case_when requires at least 1 argument (else value)"); From 18e71a675b6086dc65ea7c4318149dd7e629aec2 Mon Sep 17 00:00:00 2001 From: duanhao-jk Date: Fri, 23 Jan 2026 17:42:58 +0800 Subject: [PATCH 4/5] cc --- dev/auron-it/src/main/resources/tpcds-queries/q99.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/auron-it/src/main/resources/tpcds-queries/q99.sql b/dev/auron-it/src/main/resources/tpcds-queries/q99.sql index f1a3d4d2b..81824622e 100755 --- a/dev/auron-it/src/main/resources/tpcds-queries/q99.sql +++ b/dev/auron-it/src/main/resources/tpcds-queries/q99.sql @@ -23,7 +23,7 @@ SELECT FROM catalog_sales, warehouse, ship_mode, call_center, date_dim WHERE - d_month_seq BETWEEN 1200 AND 1200 + 11 + d_month_seq BETWEEN 1200 AND 1200 + 11 and cs_sold_date_sk>0 AND cs_ship_date_sk = d_date_sk AND cs_warehouse_sk = w_warehouse_sk AND cs_ship_mode_sk = sm_ship_mode_sk From 877f63e7a8ca702c8c571f3f5f2e6e5f594ca757 Mon Sep 17 00:00:00 2001 From: duanhao-jk Date: Fri, 23 Jan 2026 18:34:43 +0800 Subject: [PATCH 5/5] cc --- dev/auron-it/src/main/resources/tpcds-queries/q99.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/auron-it/src/main/resources/tpcds-queries/q99.sql b/dev/auron-it/src/main/resources/tpcds-queries/q99.sql index 81824622e..ef96290e0 100755 --- a/dev/auron-it/src/main/resources/tpcds-queries/q99.sql +++ b/dev/auron-it/src/main/resources/tpcds-queries/q99.sql @@ -23,7 +23,7 @@ SELECT FROM catalog_sales, warehouse, ship_mode, call_center, date_dim WHERE - d_month_seq BETWEEN 1200 AND 1200 + 11 and cs_sold_date_sk>0 + d_month_seq BETWEEN 1200 AND 1200 + 11 and cs_sold_date_sk AND cs_ship_date_sk = d_date_sk AND cs_warehouse_sk = w_warehouse_sk AND cs_ship_mode_sk = sm_ship_mode_sk