diff --git a/datafusion/spark/src/function/map/mod.rs b/datafusion/spark/src/function/map/mod.rs index 2f596b19b422f..c9ebed6f612e1 100644 --- a/datafusion/spark/src/function/map/mod.rs +++ b/datafusion/spark/src/function/map/mod.rs @@ -17,6 +17,7 @@ pub mod map_from_arrays; pub mod map_from_entries; +pub mod str_to_map; mod utils; use datafusion_expr::ScalarUDF; @@ -25,6 +26,7 @@ use std::sync::Arc; make_udf_function!(map_from_arrays::MapFromArrays, map_from_arrays); make_udf_function!(map_from_entries::MapFromEntries, map_from_entries); +make_udf_function!(str_to_map::SparkStrToMap, str_to_map); pub mod expr_fn { use datafusion_functions::export_functions; @@ -40,8 +42,14 @@ pub mod expr_fn { "Creates a map from array>.", arg1 )); + + export_functions!(( + str_to_map, + "Creates a map after splitting the text into key/value pairs using delimiters.", + text pair_delim key_value_delim + )); } pub fn functions() -> Vec> { - vec![map_from_arrays(), map_from_entries()] + vec![map_from_arrays(), map_from_entries(), str_to_map()] } diff --git a/datafusion/spark/src/function/map/str_to_map.rs b/datafusion/spark/src/function/map/str_to_map.rs new file mode 100644 index 0000000000000..b722fb7abd6b2 --- /dev/null +++ b/datafusion/spark/src/function/map/str_to_map.rs @@ -0,0 +1,266 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::collections::HashSet; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayRef, MapBuilder, MapFieldNames, StringArrayType, StringBuilder, +}; +use arrow::buffer::NullBuffer; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion_common::cast::{ + as_large_string_array, as_string_array, as_string_view_array, +}; +use datafusion_common::{Result, exec_err, internal_err}; +use datafusion_expr::{ + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, Volatility, +}; + +use crate::function::map::utils::map_type_from_key_value_types; + +const DEFAULT_PAIR_DELIM: &str = ","; +const DEFAULT_KV_DELIM: &str = ":"; + +/// Spark-compatible `str_to_map` expression +/// +/// +/// Creates a map from a string by splitting on delimiters. +/// str_to_map(text[, pairDelim[, keyValueDelim]]) -> Map +/// +/// - text: The input string +/// - pairDelim: Delimiter between key-value pairs (default: ',') +/// - keyValueDelim: Delimiter between key and value (default: ':') +/// +/// # Duplicate Key Handling +/// Uses EXCEPTION behavior (Spark 3.0+ default): errors on duplicate keys. +/// See `spark.sql.mapKeyDedupPolicy`: +/// +/// +/// TODO: Support configurable `spark.sql.mapKeyDedupPolicy` (LAST_WIN) in a follow-up PR. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkStrToMap { + signature: Signature, +} + +impl Default for SparkStrToMap { + fn default() -> Self { + Self::new() + } +} + +impl SparkStrToMap { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + // str_to_map(text) + TypeSignature::String(1), + // str_to_map(text, pairDelim) + TypeSignature::String(2), + // str_to_map(text, pairDelim, keyValueDelim) + TypeSignature::String(3), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkStrToMap { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "str_to_map" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); + let map_type = map_type_from_key_value_types(&DataType::Utf8, &DataType::Utf8); + Ok(Arc::new(Field::new(self.name(), map_type, nullable))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let arrays: Vec = ColumnarValue::values_to_arrays(&args.args)?; + let result = str_to_map_inner(&arrays)?; + Ok(ColumnarValue::Array(result)) + } +} + +fn str_to_map_inner(args: &[ArrayRef]) -> Result { + match args.len() { + 1 => match args[0].data_type() { + DataType::Utf8 => str_to_map_impl(as_string_array(&args[0])?, None, None), + DataType::LargeUtf8 => { + str_to_map_impl(as_large_string_array(&args[0])?, None, None) + } + DataType::Utf8View => { + str_to_map_impl(as_string_view_array(&args[0])?, None, None) + } + other => exec_err!( + "Unsupported data type {other:?} for str_to_map, \ + expected Utf8, LargeUtf8, or Utf8View" + ), + }, + 2 => match (args[0].data_type(), args[1].data_type()) { + (DataType::Utf8, DataType::Utf8) => str_to_map_impl( + as_string_array(&args[0])?, + Some(as_string_array(&args[1])?), + None, + ), + (DataType::LargeUtf8, DataType::LargeUtf8) => str_to_map_impl( + as_large_string_array(&args[0])?, + Some(as_large_string_array(&args[1])?), + None, + ), + (DataType::Utf8View, DataType::Utf8View) => str_to_map_impl( + as_string_view_array(&args[0])?, + Some(as_string_view_array(&args[1])?), + None, + ), + (t1, t2) => exec_err!( + "Unsupported data types ({t1:?}, {t2:?}) for str_to_map, \ + expected matching Utf8, LargeUtf8, or Utf8View" + ), + }, + 3 => match ( + args[0].data_type(), + args[1].data_type(), + args[2].data_type(), + ) { + (DataType::Utf8, DataType::Utf8, DataType::Utf8) => str_to_map_impl( + as_string_array(&args[0])?, + Some(as_string_array(&args[1])?), + Some(as_string_array(&args[2])?), + ), + (DataType::LargeUtf8, DataType::LargeUtf8, DataType::LargeUtf8) => { + str_to_map_impl( + as_large_string_array(&args[0])?, + Some(as_large_string_array(&args[1])?), + Some(as_large_string_array(&args[2])?), + ) + } + (DataType::Utf8View, DataType::Utf8View, DataType::Utf8View) => { + str_to_map_impl( + as_string_view_array(&args[0])?, + Some(as_string_view_array(&args[1])?), + Some(as_string_view_array(&args[2])?), + ) + } + (t1, t2, t3) => exec_err!( + "Unsupported data types ({t1:?}, {t2:?}, {t3:?}) for str_to_map, \ + expected matching Utf8, LargeUtf8, or Utf8View" + ), + }, + n => exec_err!("str_to_map expects 1-3 arguments, got {n}"), + } +} + +fn str_to_map_impl<'a, V: StringArrayType<'a> + Copy>( + text_array: V, + pair_delim_array: Option, + kv_delim_array: Option, +) -> Result { + let num_rows = text_array.len(); + + // Precompute combined null buffer from all input arrays. + // NullBuffer::union performs a bitmap-level AND, which is more efficient + // than checking per-row nullability inline. + let text_nulls = text_array.nulls().cloned(); + let pair_nulls = pair_delim_array.and_then(|a| a.nulls().cloned()); + let kv_nulls = kv_delim_array.and_then(|a| a.nulls().cloned()); + let combined_nulls = [text_nulls.as_ref(), pair_nulls.as_ref(), kv_nulls.as_ref()] + .into_iter() + .fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls)); + + // Use field names matching map_type_from_key_value_types: "key" and "value" + let field_names = MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }; + let mut map_builder = MapBuilder::new( + Some(field_names), + StringBuilder::new(), + StringBuilder::new(), + ); + + let mut seen_keys = HashSet::new(); + for row_idx in 0..num_rows { + if combined_nulls.as_ref().is_some_and(|n| n.is_null(row_idx)) { + map_builder.append(false)?; + continue; + } + + // Per-row delimiter extraction + let pair_delim = + pair_delim_array.map_or(DEFAULT_PAIR_DELIM, |a| a.value(row_idx)); + let kv_delim = kv_delim_array.map_or(DEFAULT_KV_DELIM, |a| a.value(row_idx)); + + let text = text_array.value(row_idx); + if text.is_empty() { + // Empty string -> map with empty key and NULL value (Spark behavior) + map_builder.keys().append_value(""); + map_builder.values().append_null(); + map_builder.append(true)?; + continue; + } + + seen_keys.clear(); + for pair in text.split(pair_delim) { + if pair.is_empty() { + continue; + } + + let mut kv_iter = pair.splitn(2, kv_delim); + let key = kv_iter.next().unwrap_or(""); + let value = kv_iter.next(); + + // TODO: Support LAST_WIN policy via spark.sql.mapKeyDedupPolicy config + // EXCEPTION policy: error on duplicate keys (Spark 3.0+ default) + if !seen_keys.insert(key) { + return exec_err!( + "Duplicate map key '{key}' was found, please check the input data. \ + If you want to remove the duplicated keys, you can set \ + spark.sql.mapKeyDedupPolicy to \"LAST_WIN\" so that the key \ + inserted at last takes precedence." + ); + } + + map_builder.keys().append_value(key); + match value { + Some(v) => map_builder.values().append_value(v), + None => map_builder.values().append_null(), + } + } + map_builder.append(true)?; + } + + Ok(Arc::new(map_builder.finish())) +} diff --git a/datafusion/sqllogictest/test_files/spark/map/str_to_map.slt b/datafusion/sqllogictest/test_files/spark/map/str_to_map.slt new file mode 100644 index 0000000000000..30d1672aef0ae --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/map/str_to_map.slt @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Tests for Spark-compatible str_to_map function +# https://spark.apache.org/docs/latest/api/sql/index.html#str_to_map +# +# Test cases derived from Spark test("StringToMap"): +# https://github.com/apache/spark/blob/v4.0.0/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala#L525-L618 + +# s0: Basic test with default delimiters +query ? +SELECT str_to_map('a:1,b:2,c:3'); +---- +{a: 1, b: 2, c: 3} + +# s1: Preserve spaces in values +query ? +SELECT str_to_map('a: ,b:2'); +---- +{a: , b: 2} + +# s2: Custom key-value delimiter '=' +query ? +SELECT str_to_map('a=1,b=2,c=3', ',', '='); +---- +{a: 1, b: 2, c: 3} + +# s3: Empty string returns map with empty key and NULL value +query ? +SELECT str_to_map('', ',', '='); +---- +{: NULL} + +# s4: Custom pair delimiter '_' +query ? +SELECT str_to_map('a:1_b:2_c:3', '_', ':'); +---- +{a: 1, b: 2, c: 3} + +# s5: Single key without value returns NULL value +query ? +SELECT str_to_map('a'); +---- +{a: NULL} + +# s6: Custom delimiters '&' and '=' +query ? +SELECT str_to_map('a=1&b=2&c=3', '&', '='); +---- +{a: 1, b: 2, c: 3} + +# Duplicate keys: EXCEPTION policy (Spark 3.0+ default) +# TODO: Add LAST_WIN policy tests when spark.sql.mapKeyDedupPolicy config is supported +statement error +Duplicate map key +SELECT str_to_map('a:1,b:2,a:3'); + +# Additional tests (DataFusion-specific) + +# NULL input returns NULL +query ? +SELECT str_to_map(NULL, ',', ':'); +---- +NULL + +# Explicit 3-arg form +query ? +SELECT str_to_map('a:1,b:2,c:3', ',', ':'); +---- +{a: 1, b: 2, c: 3} + +# Missing key-value delimiter results in NULL value +query ? +SELECT str_to_map('a,b:2', ',', ':'); +---- +{a: NULL, b: 2} + +# Multi-row test +query ? +SELECT str_to_map(col) FROM (VALUES ('a:1,b:2'), ('x:9'), (NULL)) AS t(col); +---- +{a: 1, b: 2} +{x: 9} +NULL + +# Multi-row with custom delimiter +query ? +SELECT str_to_map(col, ',', '=') FROM (VALUES ('a=1,b=2'), ('x=9'), (NULL)) AS t(col); +---- +{a: 1, b: 2} +{x: 9} +NULL + +# Per-row delimiters: each row can have different delimiters +query ? +SELECT str_to_map(col1, col2, col3) FROM (VALUES ('a=1,b=2', ',', '='), ('x#9', ',', '#'), (NULL, ',', '=')) AS t(col1, col2, col3); +---- +{a: 1, b: 2} +{x: 9} +NULL \ No newline at end of file