diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index 19c251c1a149a..bb6587c0b976f 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -16,7 +16,21 @@ // under the License. //! Functionality used both on logical and physical plans - +//! +//! ## About `rehash` +//! Many helpers in this module take a `rehash: bool` argument. +//! +//! Conceptually, `hashes_buffer` is an **accumulator** of per-row hash values. When hashing a +//! *single* column, the hasher should **initialize** each row's hash. When hashing *multiple* +//! columns (e.g. for partitioning or joins), subsequent columns should **mix** their value hash +//! into the existing row hash using [`combine_hashes`]. +//! +//! - `rehash = false`: initialize/overwrite the row hash for this column +//! - `rehash = true`: combine this column into an existing row hash +//! +//! [`create_hashes`] sets `rehash` to `false` for the first column and `true` for all following +//! columns, which avoids an unnecessary `combine_hashes` on the first column for performance. +//! use ahash::RandomState; use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; use arrow::array::*; @@ -132,8 +146,8 @@ where } #[cfg(not(feature = "force_hash_collisions"))] -fn hash_null(random_state: &RandomState, hashes_buffer: &'_ mut [u64], mul_col: bool) { - if mul_col { +fn hash_null(random_state: &RandomState, hashes_buffer: &'_ mut [u64], rehash: bool) { + if rehash { hashes_buffer.iter_mut().for_each(|hash| { // stable hash for null value *hash = combine_hashes(random_state.hash_one(1), *hash); @@ -400,10 +414,10 @@ fn update_hash_for_dict_key( dict_hashes: &[u64], dict_values: &dyn Array, idx: usize, - multi_col: bool, + rehash: bool, ) { if dict_values.is_valid(idx) { - if multi_col { + if rehash { *hash = combine_hashes(dict_hashes[idx], *hash); } else { *hash = dict_hashes[idx]; @@ -418,7 +432,7 @@ fn hash_dictionary( array: &DictionaryArray, random_state: &RandomState, hashes_buffer: &mut [u64], - multi_col: bool, + rehash: bool, ) -> Result<()> { // Hash each dictionary value once, and then use that computed // hash for each key value to avoid a potentially expensive @@ -436,7 +450,7 @@ fn hash_dictionary( &dict_hashes, dict_values.as_ref(), idx, - multi_col, + rehash, ); } // no update for Null key } @@ -448,6 +462,7 @@ fn hash_struct_array( array: &StructArray, random_state: &RandomState, hashes_buffer: &mut [u64], + rehash: bool, ) -> Result<()> { let nulls = array.nulls(); let row_len = array.len(); @@ -462,9 +477,14 @@ fn hash_struct_array( let mut values_hashes = vec![0u64; row_len]; create_hashes(array.columns(), random_state, &mut values_hashes)?; - for i in valid_row_indices { - let hash = &mut hashes_buffer[i]; - *hash = combine_hashes(*hash, values_hashes[i]); + if rehash { + for i in valid_row_indices { + hashes_buffer[i] = combine_hashes(values_hashes[i], hashes_buffer[i]); + } + } else { + for i in valid_row_indices { + hashes_buffer[i] = values_hashes[i]; + } } Ok(()) @@ -476,6 +496,7 @@ fn hash_map_array( array: &MapArray, random_state: &RandomState, hashes_buffer: &mut [u64], + rehash: bool, ) -> Result<()> { let nulls = array.nulls(); let offsets = array.offsets(); @@ -484,22 +505,49 @@ fn hash_map_array( let mut values_hashes = vec![0u64; array.entries().len()]; create_hashes(array.entries().columns(), random_state, &mut values_hashes)?; - // Combine the hashes for entries on each row with each other and previous hash for that row - if let Some(nulls) = nulls { + // Combine the hashes for entries on each row with each other. + // When `rehash=true`, combine the per-row map hash into the existing accumulator. + if rehash { + if let Some(nulls) = nulls { + for (i, (start, stop)) in + offsets.iter().zip(offsets.iter().skip(1)).enumerate() + { + if nulls.is_valid(i) { + let mut row_hash = 0u64; + for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + row_hash = combine_hashes(*values_hash, row_hash); + } + hashes_buffer[i] = combine_hashes(row_hash, hashes_buffer[i]); + } + } + } else { + for (i, (start, stop)) in + offsets.iter().zip(offsets.iter().skip(1)).enumerate() + { + let mut row_hash = 0u64; + for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + row_hash = combine_hashes(*values_hash, row_hash); + } + hashes_buffer[i] = combine_hashes(row_hash, hashes_buffer[i]); + } + } + } else if let Some(nulls) = nulls { for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { if nulls.is_valid(i) { - let hash = &mut hashes_buffer[i]; + let mut row_hash = 0u64; for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { - *hash = combine_hashes(*hash, *values_hash); + row_hash = combine_hashes(*values_hash, row_hash); } + hashes_buffer[i] = row_hash; } } } else { for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { - let hash = &mut hashes_buffer[i]; + let mut row_hash = 0u64; for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { - *hash = combine_hashes(*hash, *values_hash); + row_hash = combine_hashes(*values_hash, row_hash); } + hashes_buffer[i] = row_hash; } } @@ -511,6 +559,7 @@ fn hash_list_array( array: &GenericListArray, random_state: &RandomState, hashes_buffer: &mut [u64], + rehash: bool, ) -> Result<()> where OffsetSize: OffsetSizeTrait, @@ -528,16 +577,50 @@ where &mut values_hashes, )?; - if array.null_count() > 0 { + // Compute per-row list hash (fold element hashes), then either initialize or combine + // once per row depending on `rehash`. + if rehash { + if array.null_count() > 0 { + for (i, (start, stop)) in + array.value_offsets().iter().tuple_windows().enumerate() + { + if array.is_valid(i) { + let mut row_hash = 0u64; + for values_hash in &values_hashes[(*start - first_offset).as_usize() + ..(*stop - first_offset).as_usize()] + { + row_hash = combine_hashes(*values_hash, row_hash); + } + hashes_buffer[i] = combine_hashes(row_hash, hashes_buffer[i]); + } + } + } else { + for ((start, stop), hash) in array + .value_offsets() + .iter() + .tuple_windows() + .zip(hashes_buffer.iter_mut()) + { + let mut row_hash = 0u64; + for values_hash in &values_hashes[(*start - first_offset).as_usize() + ..(*stop - first_offset).as_usize()] + { + row_hash = combine_hashes(*values_hash, row_hash); + } + *hash = combine_hashes(row_hash, *hash); + } + } + } else if array.null_count() > 0 { for (i, (start, stop)) in array.value_offsets().iter().tuple_windows().enumerate() { if array.is_valid(i) { - let hash = &mut hashes_buffer[i]; + let mut row_hash = 0u64; for values_hash in &values_hashes[(*start - first_offset).as_usize() ..(*stop - first_offset).as_usize()] { - *hash = combine_hashes(*hash, *values_hash); + row_hash = combine_hashes(*values_hash, row_hash); } + hashes_buffer[i] = row_hash; } } } else { @@ -547,11 +630,13 @@ where .tuple_windows() .zip(hashes_buffer.iter_mut()) { + let mut row_hash = 0u64; for values_hash in &values_hashes [(*start - first_offset).as_usize()..(*stop - first_offset).as_usize()] { - *hash = combine_hashes(*hash, *values_hash); + row_hash = combine_hashes(*values_hash, row_hash); } + *hash = row_hash; } } Ok(()) @@ -562,6 +647,7 @@ fn hash_list_view_array( array: &GenericListViewArray, random_state: &RandomState, hashes_buffer: &mut [u64], + rehash: bool, ) -> Result<()> where OffsetSize: OffsetSizeTrait, @@ -572,25 +658,51 @@ where let nulls = array.nulls(); let mut values_hashes = vec![0u64; values.len()]; create_hashes([values], random_state, &mut values_hashes)?; - if let Some(nulls) = nulls { + if rehash { + if let Some(nulls) = nulls { + for (i, (offset, size)) in offsets.iter().zip(sizes.iter()).enumerate() { + if nulls.is_valid(i) { + let start = offset.as_usize(); + let end = start + size.as_usize(); + let mut row_hash = 0u64; + for values_hash in &values_hashes[start..end] { + row_hash = combine_hashes(*values_hash, row_hash); + } + hashes_buffer[i] = combine_hashes(row_hash, hashes_buffer[i]); + } + } + } else { + for (i, (offset, size)) in offsets.iter().zip(sizes.iter()).enumerate() { + let start = offset.as_usize(); + let end = start + size.as_usize(); + let mut row_hash = 0u64; + for values_hash in &values_hashes[start..end] { + row_hash = combine_hashes(*values_hash, row_hash); + } + hashes_buffer[i] = combine_hashes(row_hash, hashes_buffer[i]); + } + } + } else if let Some(nulls) = nulls { for (i, (offset, size)) in offsets.iter().zip(sizes.iter()).enumerate() { if nulls.is_valid(i) { - let hash = &mut hashes_buffer[i]; let start = offset.as_usize(); let end = start + size.as_usize(); + let mut row_hash = 0u64; for values_hash in &values_hashes[start..end] { - *hash = combine_hashes(*hash, *values_hash); + row_hash = combine_hashes(*values_hash, row_hash); } + hashes_buffer[i] = row_hash; } } } else { for (i, (offset, size)) in offsets.iter().zip(sizes.iter()).enumerate() { - let hash = &mut hashes_buffer[i]; let start = offset.as_usize(); let end = start + size.as_usize(); + let mut row_hash = 0u64; for values_hash in &values_hashes[start..end] { - *hash = combine_hashes(*hash, *values_hash); + row_hash = combine_hashes(*values_hash, row_hash); } + hashes_buffer[i] = row_hash; } } Ok(()) @@ -601,6 +713,7 @@ fn hash_union_array( array: &UnionArray, random_state: &RandomState, hashes_buffer: &mut [u64], + rehash: bool, ) -> Result<()> { use std::collections::HashMap; @@ -618,13 +731,24 @@ fn hash_union_array( child_hashes.insert(type_id, child_hash_buffer); } - #[expect(clippy::needless_range_loop)] - for i in 0..array.len() { - let type_id = array.type_id(i); - let child_offset = array.value_offset(i); + if rehash { + #[expect(clippy::needless_range_loop)] + for i in 0..array.len() { + let type_id = array.type_id(i); + let child_offset = array.value_offset(i); + + let child_hash = child_hashes.get(&type_id).expect("invalid type_id"); + hashes_buffer[i] = combine_hashes(child_hash[child_offset], hashes_buffer[i]); + } + } else { + #[expect(clippy::needless_range_loop)] + for i in 0..array.len() { + let type_id = array.type_id(i); + let child_offset = array.value_offset(i); - let child_hash = child_hashes.get(&type_id).expect("invalid type_id"); - hashes_buffer[i] = combine_hashes(hashes_buffer[i], child_hash[child_offset]); + let child_hash = child_hashes.get(&type_id).expect("invalid type_id"); + hashes_buffer[i] = child_hash[child_offset]; + } } Ok(()) @@ -635,29 +759,56 @@ fn hash_fixed_list_array( array: &FixedSizeListArray, random_state: &RandomState, hashes_buffer: &mut [u64], + rehash: bool, ) -> Result<()> { let values = array.values(); let value_length = array.value_length() as usize; let nulls = array.nulls(); let mut values_hashes = vec![0u64; values.len()]; create_hashes([values], random_state, &mut values_hashes)?; - if let Some(nulls) = nulls { + if rehash { + if let Some(nulls) = nulls { + for i in 0..array.len() { + if nulls.is_valid(i) { + let mut row_hash = 0u64; + for values_hash in + &values_hashes[i * value_length..(i + 1) * value_length] + { + row_hash = combine_hashes(*values_hash, row_hash); + } + hashes_buffer[i] = combine_hashes(row_hash, hashes_buffer[i]); + } + } + } else { + for i in 0..array.len() { + let mut row_hash = 0u64; + for values_hash in + &values_hashes[i * value_length..(i + 1) * value_length] + { + row_hash = combine_hashes(*values_hash, row_hash); + } + hashes_buffer[i] = combine_hashes(row_hash, hashes_buffer[i]); + } + } + } else if let Some(nulls) = nulls { for i in 0..array.len() { if nulls.is_valid(i) { - let hash = &mut hashes_buffer[i]; + let mut row_hash = 0u64; for values_hash in &values_hashes[i * value_length..(i + 1) * value_length] { - *hash = combine_hashes(*hash, *values_hash); + row_hash = combine_hashes(*values_hash, row_hash); } + hashes_buffer[i] = row_hash; } } } else { for i in 0..array.len() { - let hash = &mut hashes_buffer[i]; + let mut row_hash = 0u64; for values_hash in &values_hashes[i * value_length..(i + 1) * value_length] { - *hash = combine_hashes(*hash, *values_hash); + row_hash = combine_hashes(*values_hash, row_hash); } + hashes_buffer[i] = row_hash; } } Ok(()) @@ -762,35 +913,35 @@ fn hash_single_array( } DataType::Struct(_) => { let array = as_struct_array(array)?; - hash_struct_array(array, random_state, hashes_buffer)?; + hash_struct_array(array, random_state, hashes_buffer, rehash)?; } DataType::List(_) => { let array = as_list_array(array)?; - hash_list_array(array, random_state, hashes_buffer)?; + hash_list_array(array, random_state, hashes_buffer, rehash)?; } DataType::LargeList(_) => { let array = as_large_list_array(array)?; - hash_list_array(array, random_state, hashes_buffer)?; + hash_list_array(array, random_state, hashes_buffer, rehash)?; } DataType::ListView(_) => { let array = as_list_view_array(array)?; - hash_list_view_array(array, random_state, hashes_buffer)?; + hash_list_view_array(array, random_state, hashes_buffer, rehash)?; } DataType::LargeListView(_) => { let array = as_large_list_view_array(array)?; - hash_list_view_array(array, random_state, hashes_buffer)?; + hash_list_view_array(array, random_state, hashes_buffer, rehash)?; } DataType::Map(_, _) => { let array = as_map_array(array)?; - hash_map_array(array, random_state, hashes_buffer)?; + hash_map_array(array, random_state, hashes_buffer, rehash)?; } DataType::FixedSizeList(_,_) => { let array = as_fixed_size_list_array(array)?; - hash_fixed_list_array(array, random_state, hashes_buffer)?; + hash_fixed_list_array(array, random_state, hashes_buffer, rehash)?; } DataType::Union(_, _) => { let array = as_union_array(array)?; - hash_union_array(array, random_state, hashes_buffer)?; + hash_union_array(array, random_state, hashes_buffer, rehash)?; } DataType::RunEndEncoded(_, _) => downcast_run_array! { array => hash_run_array(array, random_state, hashes_buffer, rehash)?, @@ -872,7 +1023,10 @@ where T: AsDynArray, { for (i, array) in arrays.into_iter().enumerate() { - // combine hashes with `combine_hashes` for all columns besides the first + // `hashes_buffer` is a per-row accumulator. + // + // First column: initialize hashes (no need to call `combine_hashes`) + // Subsequent columns: combine with existing per-row hash let rehash = i >= 1; hash_single_array(array.as_dyn_array(), random_state, hashes_buffer, rehash)?; } @@ -1194,6 +1348,50 @@ mod tests { assert_eq!(hashes[1], hashes[6]); // null vs empty list } + #[test] + #[cfg(not(feature = "force_hash_collisions"))] + fn create_multi_column_hash_with_list_array() -> Result<()> { + // Validate that nested types participate in multi-column hashing. + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(3), None, Some(5)]), + None, + Some(vec![Some(0), Some(1), Some(2)]), + Some(vec![]), + ]; + let list_array: ArrayRef = + Arc::new(ListArray::from_iter_primitive::(data)); + let extra_col: ArrayRef = + Arc::new(Int32Array::from(vec![10, 11, 12, 12, 11, 10, 13])); + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + + let mut one_col_hashes = vec![0; list_array.len()]; + create_hashes( + &[Arc::clone(&list_array)], + &random_state, + &mut one_col_hashes, + )?; + + let mut two_col_hashes = vec![0; list_array.len()]; + create_hashes( + &[Arc::clone(&list_array), Arc::clone(&extra_col)], + &random_state, + &mut two_col_hashes, + )?; + + assert_ne!(one_col_hashes, two_col_hashes); + + // Equalities from the underlying list content should still hold when adding a column + assert_eq!(two_col_hashes[0], two_col_hashes[5]); + assert_eq!(two_col_hashes[1], two_col_hashes[4]); + assert_eq!(two_col_hashes[2], two_col_hashes[3]); + + Ok(()) + } + #[test] #[cfg(not(feature = "force_hash_collisions"))] fn create_hashes_for_sliced_list_arrays() {