diff --git a/encodings/alp/src/alp/compute/between.rs b/encodings/alp/src/alp/compute/between.rs index 8721dd6a8dd..ca9b36ffce3 100644 --- a/encodings/alp/src/alp/compute/between.rs +++ b/encodings/alp/src/alp/compute/between.rs @@ -45,8 +45,8 @@ impl BetweenKernel for ALPVTable { match_each_alp_float_ptype!(array.ptype(), |F| { between_impl::( array, - F::try_from(lower)?, - F::try_from(upper)?, + F::try_from(&lower)?, + F::try_from(&upper)?, nullability, options, ) diff --git a/encodings/alp/src/alp/compute/compare.rs b/encodings/alp/src/alp/compute/compare.rs index 352a31899ea..6f04b58841b 100644 --- a/encodings/alp/src/alp/compute/compare.rs +++ b/encodings/alp/src/alp/compute/compare.rs @@ -15,7 +15,7 @@ use vortex_array::register_kernel; use vortex_dtype::NativePType; use vortex_error::VortexResult; use vortex_error::vortex_bail; -use vortex_scalar::PrimitiveScalar; +use vortex_error::vortex_err; use vortex_scalar::Scalar; use crate::ALPArray; @@ -42,7 +42,13 @@ impl CompareKernel for ALPVTable { } if let Some(const_scalar) = rhs.as_constant() { - let pscalar = PrimitiveScalar::try_from(&const_scalar)?; + let pscalar = const_scalar.as_primitive_opt().ok_or_else(|| { + vortex_err!( + "ALP Compare RHS had the wrong type {}, expected {}", + const_scalar, + const_scalar.dtype() + ) + })?; match_each_alp_float_ptype!(pscalar.ptype(), |T| { match pscalar.typed_value::() { diff --git a/encodings/alp/src/alp/ops.rs b/encodings/alp/src/alp/ops.rs index d17b0dddf02..af1a3a72bb6 100644 --- a/encodings/alp/src/alp/ops.rs +++ b/encodings/alp/src/alp/ops.rs @@ -22,10 +22,8 @@ impl OperationsVTable for ALPVTable { let encoded_val = array.encoded().scalar_at(index)?; Ok(match_each_alp_float_ptype!(array.ptype(), |T| { - let encoded_val: ::ALPInt = encoded_val - .as_ref() - .try_into() - .vortex_expect("invalid ALPInt"); + let encoded_val: ::ALPInt = + (&encoded_val).try_into().vortex_expect("invalid ALPInt"); Scalar::primitive( ::decode_single(encoded_val, array.exponents()), array.dtype().nullability(), diff --git a/encodings/alp/src/alp_rd/compute/take.rs b/encodings/alp/src/alp_rd/compute/take.rs index 7b123f19eca..ade967e9971 100644 --- a/encodings/alp/src/alp_rd/compute/take.rs +++ b/encodings/alp/src/alp_rd/compute/take.rs @@ -9,7 +9,6 @@ use vortex_array::arrays::TakeExecute; use vortex_array::compute::fill_null; use vortex_error::VortexResult; use vortex_scalar::Scalar; -use vortex_scalar::ScalarValue; use crate::ALPRDArray; use crate::ALPRDVTable; @@ -36,7 +35,7 @@ impl TakeExecute for ALPRDVTable { .transpose()?; let right_parts = fill_null( &array.right_parts().take(indices.to_array())?, - &Scalar::new(array.right_parts().dtype().clone(), ScalarValue::from(0)), + &Scalar::zero_value(array.right_parts().dtype()), )?; Ok(Some( diff --git a/encodings/datetime-parts/src/compute/rules.rs b/encodings/datetime-parts/src/compute/rules.rs index 9702927c9ad..aa31b57bedc 100644 --- a/encodings/datetime-parts/src/compute/rules.rs +++ b/encodings/datetime-parts/src/compute/rules.rs @@ -171,7 +171,7 @@ fn try_extract_days_constant(array: &ArrayRef) -> Option { fn is_constant_zero(array: &ArrayRef) -> bool { array .as_opt::() - .is_some_and(|c| c.scalar().is_zero()) + .is_some_and(|c| c.scalar().is_zero() == Some(true)) } #[cfg(test)] diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/compare.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/compare.rs index ca7ff934be7..2ff776c36b3 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/compare.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/compute/compare.rs @@ -46,9 +46,9 @@ impl CompareKernel for DecimalBytePartsVTable { .vortex_expect("checked for null in entry func"); match decimal_value_wrapper_to_primitive(rhs_decimal, lhs.msp.as_primitive_typed().ptype()) - .map(|value| Scalar::new(scalar_type.clone(), value)) { - Ok(encoded_scalar) => { + Ok(value) => { + let encoded_scalar = Scalar::try_new(scalar_type, Some(value))?; let encoded_const = ConstantArray::new(encoded_scalar, rhs.len()); compare(&lhs.msp, &encoded_const.to_array(), operator).map(Some) } @@ -165,7 +165,10 @@ mod tests { ) .unwrap() .to_array(); - let rhs = ConstantArray::new(Scalar::new(dtype, DecimalValue::I64(400).into()), lhs.len()); + let rhs = ConstantArray::new( + Scalar::try_new(dtype, Some(DecimalValue::I64(400).into())).unwrap(), + lhs.len(), + ); let res = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap(); @@ -215,10 +218,11 @@ mod tests { .to_array(); // This cannot be converted to a i32. let rhs = ConstantArray::new( - Scalar::new( + Scalar::try_new( dtype.clone(), - DecimalValue::I128(-9999999999999965304).into(), - ), + Some(DecimalValue::I128(-9999999999999965304).into()), + ) + .unwrap(), lhs.len(), ); @@ -236,7 +240,7 @@ mod tests { // This cannot be converted to a i32. let rhs = ConstantArray::new( - Scalar::new(dtype, DecimalValue::I128(9999999999999965304).into()), + Scalar::try_new(dtype, Some(DecimalValue::I128(9999999999999965304).into())).unwrap(), lhs.len(), ); diff --git a/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs b/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs index 0e4de81842d..4f6c61268ca 100644 --- a/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs +++ b/encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs @@ -44,6 +44,7 @@ use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_scalar::DecimalValue; use vortex_scalar::Scalar; +use vortex_scalar::ScalarValue; use vortex_session::VortexSession; use crate::decimal_byte_parts::compute::kernel::PARENT_KERNELS; @@ -285,10 +286,10 @@ impl OperationsVTable for DecimalBytePartsVTable { let primitive_scalar = scalar.as_primitive(); // TODO(joe): extend this to support multiple parts. let value = primitive_scalar.as_::().vortex_expect("non-null"); - Ok(Scalar::new( + Scalar::try_new( array.dtype.clone(), - DecimalValue::I64(value).into(), - )) + Some(ScalarValue::Decimal(DecimalValue::I64(value))), + ) } } @@ -319,6 +320,7 @@ mod tests { use vortex_dtype::Nullability; use vortex_scalar::DecimalValue; use vortex_scalar::Scalar; + use vortex_scalar::ScalarValue; use crate::DecimalBytePartsArray; @@ -339,11 +341,15 @@ mod tests { assert_eq!(Scalar::null(dtype.clone()), array.scalar_at(0).unwrap()); assert_eq!( - Scalar::new(dtype.clone(), DecimalValue::I64(200).into()), + Scalar::try_new( + dtype.clone(), + Some(ScalarValue::Decimal(DecimalValue::I64(200))) + ) + .unwrap(), array.scalar_at(1).unwrap() ); assert_eq!( - Scalar::new(dtype, DecimalValue::I64(400).into()), + Scalar::try_new(dtype, Some(ScalarValue::Decimal(DecimalValue::I64(400)))).unwrap(), array.scalar_at(2).unwrap() ); } diff --git a/encodings/fastlanes/src/bitpacking/array/bitpack_decompress.rs b/encodings/fastlanes/src/bitpacking/array/bitpack_decompress.rs index cfb010a46e3..ff375b5aa76 100644 --- a/encodings/fastlanes/src/bitpacking/array/bitpack_decompress.rs +++ b/encodings/fastlanes/src/bitpacking/array/bitpack_decompress.rs @@ -262,7 +262,7 @@ mod tests { .iter() .enumerate() .for_each(|(i, v)| { - let scalar: u16 = unpack_single(&compressed, i).try_into().unwrap(); + let scalar: u16 = (&unpack_single(&compressed, i)).try_into().unwrap(); assert_eq!(scalar, *v); }); } diff --git a/encodings/fastlanes/src/for/array/for_compress.rs b/encodings/fastlanes/src/for/array/for_compress.rs index 4b20acc5b77..4a6e263123b 100644 --- a/encodings/fastlanes/src/for/array/for_compress.rs +++ b/encodings/fastlanes/src/for/array/for_compress.rs @@ -175,10 +175,7 @@ mod test { .iter() .enumerate() .for_each(|(i, v)| { - assert_eq!( - *v, - i8::try_from(compressed.scalar_at(i).unwrap().as_ref()).unwrap() - ); + assert_eq!(*v, i8::try_from(&compressed.scalar_at(i).unwrap()).unwrap()); }); assert_arrays_eq!(decompressed, array); Ok(()) diff --git a/encodings/fastlanes/src/for/compute/compare.rs b/encodings/fastlanes/src/for/compute/compare.rs index ee65ab27fc7..6ae51867ad8 100644 --- a/encodings/fastlanes/src/for/compute/compare.rs +++ b/encodings/fastlanes/src/for/compute/compare.rs @@ -19,7 +19,6 @@ use vortex_error::VortexError; use vortex_error::VortexExpect as _; use vortex_error::VortexResult; use vortex_scalar::PValue; -use vortex_scalar::PrimitiveScalar; use vortex_scalar::Scalar; use crate::FoRArray; @@ -33,7 +32,7 @@ impl CompareKernel for FoRVTable { operator: Operator, ) -> VortexResult> { if let Some(constant) = rhs.as_constant() - && let Ok(constant) = PrimitiveScalar::try_from(&constant) + && let Some(constant) = constant.as_primitive_opt() { match_each_integer_ptype!(constant.ptype(), |T| { return compare_constant( diff --git a/encodings/fastlanes/src/for/vtable/mod.rs b/encodings/fastlanes/src/for/vtable/mod.rs index 8808b34589f..1f9824fd43d 100644 --- a/encodings/fastlanes/src/for/vtable/mod.rs +++ b/encodings/fastlanes/src/for/vtable/mod.rs @@ -2,13 +2,10 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use std::fmt::Debug; -use std::fmt::Formatter; use vortex_array::ArrayRef; -use vortex_array::DeserializeMetadata; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; -use vortex_array::SerializeMetadata; use vortex_array::buffer::BufferHandle; use vortex_array::serde::ArrayChildren; use vortex_array::vtable; @@ -41,7 +38,7 @@ vtable!(FoR); impl VTable for FoRVTable { type Array = FoRArray; - type Metadata = ScalarValueMetadata; + type Metadata = Scalar; type ArrayVTable = Self; type OperationsVTable = Self; @@ -68,22 +65,21 @@ impl VTable for FoRVTable { } fn metadata(array: &FoRArray) -> VortexResult { - Ok(ScalarValueMetadata( - array.reference_scalar().value().clone(), - )) + Ok(array.reference_scalar().clone()) } fn serialize(metadata: Self::Metadata) -> VortexResult>> { - Ok(Some(metadata.serialize())) + Ok(Some(ScalarValue::to_proto_bytes(metadata.value()))) } fn deserialize( bytes: &[u8], - _dtype: &DType, + dtype: &DType, _len: usize, _session: &VortexSession, ) -> VortexResult { - ScalarValueMetadata::deserialize(bytes) + let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype)?; + Scalar::try_new(dtype.clone(), scalar_value) } fn build( @@ -101,9 +97,8 @@ impl VTable for FoRVTable { } let encoded = children.get(0, dtype, len)?; - let reference = Scalar::new(dtype.clone(), metadata.0.clone()); - FoRArray::try_new(encoded, reference) + FoRArray::try_new(encoded, metadata.clone()) } fn reduce_parent( @@ -134,27 +129,3 @@ pub struct FoRVTable; impl FoRVTable { pub const ID: ArrayId = ArrayId::new_ref("fastlanes.for"); } - -#[derive(Clone)] -pub struct ScalarValueMetadata(pub ScalarValue); - -impl SerializeMetadata for ScalarValueMetadata { - fn serialize(self) -> Vec { - self.0.to_protobytes() - } -} - -impl DeserializeMetadata for ScalarValueMetadata { - type Output = ScalarValueMetadata; - - fn deserialize(metadata: &[u8]) -> VortexResult { - let scalar_value = ScalarValue::from_protobytes(metadata)?; - Ok(ScalarValueMetadata(scalar_value)) - } -} - -impl Debug for ScalarValueMetadata { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", &self.0) - } -} diff --git a/encodings/fastlanes/src/rle/vtable/operations.rs b/encodings/fastlanes/src/rle/vtable/operations.rs index 10af0cb6f00..2ea20a3b8a1 100644 --- a/encodings/fastlanes/src/rle/vtable/operations.rs +++ b/encodings/fastlanes/src/rle/vtable/operations.rs @@ -27,7 +27,7 @@ impl OperationsVTable for RLEVTable { .values() .scalar_at(value_idx_offset + chunk_relative_idx)?; - Ok(Scalar::new(array.dtype().clone(), scalar.into_value())) + Scalar::try_new(array.dtype().clone(), scalar.into_value()) } } diff --git a/encodings/fsst/src/compute/compare.rs b/encodings/fsst/src/compute/compare.rs index 9119725c655..24ff9e9f3a1 100644 --- a/encodings/fsst/src/compute/compare.rs +++ b/encodings/fsst/src/compute/compare.rs @@ -111,9 +111,9 @@ fn compare_fsst_constant( _ => unreachable!("FSSTArray can only have string or binary data type"), }; - let encoded_scalar = Scalar::new( - DType::Binary(left.dtype().nullability() | right.dtype().nullability()), - encoded_buffer.into(), + let encoded_scalar = Scalar::binary( + encoded_buffer, + left.dtype().nullability() | right.dtype().nullability(), ); let rhs = ConstantArray::new(encoded_scalar, left.len()); diff --git a/encodings/fsst/src/compute/mod.rs b/encodings/fsst/src/compute/mod.rs index 53d835e3306..17b840c1170 100644 --- a/encodings/fsst/src/compute/mod.rs +++ b/encodings/fsst/src/compute/mod.rs @@ -16,7 +16,6 @@ use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_err; use vortex_scalar::Scalar; -use vortex_scalar::ScalarValue; use crate::FSSTArray; use crate::FSSTVTable; @@ -41,10 +40,7 @@ impl TakeExecute for FSSTVTable { .map_err(|_| vortex_err!("take for codes must return varbin array"))?, fill_null( &array.uncompressed_lengths().take(indices.to_array())?, - &Scalar::new( - array.uncompressed_lengths_dtype().clone(), - ScalarValue::from(0), - ), + &Scalar::zero_value(&array.uncompressed_lengths_dtype().clone()), )?, )? .into_array(), diff --git a/encodings/fsst/src/ops.rs b/encodings/fsst/src/ops.rs index 9b69b743a36..c6d247441b6 100644 --- a/encodings/fsst/src/ops.rs +++ b/encodings/fsst/src/ops.rs @@ -16,8 +16,7 @@ impl OperationsVTable for FSSTVTable { let compressed = array.codes().scalar_at(index)?; let binary_datum = compressed.as_binary().value().vortex_expect("non-null"); - let decoded_buffer = - ByteBuffer::from(array.decompressor().decompress(binary_datum.as_slice())); + let decoded_buffer = ByteBuffer::from(array.decompressor().decompress(binary_datum)); Ok(varbin_scalar(decoded_buffer, array.dtype())) } } diff --git a/encodings/runend/src/array.rs b/encodings/runend/src/array.rs index 242a3d12c9b..90f271cdc7d 100644 --- a/encodings/runend/src/array.rs +++ b/encodings/runend/src/array.rs @@ -222,13 +222,13 @@ impl RunEndArray { // Validate the offset and length are valid for the given ends and values if offset != 0 && length != 0 { - let first_run_end: usize = ends.scalar_at(0)?.as_ref().try_into()?; + let first_run_end = usize::try_from(&ends.scalar_at(0)?)?; if first_run_end <= offset { vortex_bail!("First run end {first_run_end} must be bigger than offset {offset}"); } } - let last_run_end: usize = ends.scalar_at(ends.len() - 1)?.as_ref().try_into()?; + let last_run_end = usize::try_from(&ends.scalar_at(ends.len() - 1)?)?; let min_required_end = offset + length; if last_run_end < min_required_end { vortex_bail!("Last run end {last_run_end} must be >= offset+length {min_required_end}"); @@ -302,7 +302,7 @@ impl RunEndArray { let length: usize = if ends.is_empty() { 0 } else { - ends.scalar_at(ends.len() - 1)?.as_ref().try_into()? + usize::try_from(&ends.scalar_at(ends.len() - 1)?)? }; Self::try_new_offset_length(ends, values, 0, length) diff --git a/encodings/runend/src/compute/cast.rs b/encodings/runend/src/compute/cast.rs index 35316316972..ccf30cca9d9 100644 --- a/encodings/runend/src/compute/cast.rs +++ b/encodings/runend/src/compute/cast.rs @@ -76,27 +76,19 @@ mod tests { // RunEnd encoding should expand to [100, 100, 100, 200, 200, 100, 100, 100, 300, 300] assert_eq!(decoded.len(), 10); assert_eq!( - TryInto::::try_into(decoded.scalar_at(0).unwrap().as_ref()) - .ok() - .unwrap(), + TryInto::::try_into(&decoded.scalar_at(0).unwrap()).unwrap(), 100i64 ); assert_eq!( - TryInto::::try_into(decoded.scalar_at(3).unwrap().as_ref()) - .ok() - .unwrap(), + TryInto::::try_into(&decoded.scalar_at(3).unwrap()).unwrap(), 200i64 ); assert_eq!( - TryInto::::try_into(decoded.scalar_at(5).unwrap().as_ref()) - .ok() - .unwrap(), + TryInto::::try_into(&decoded.scalar_at(5).unwrap()).unwrap(), 100i64 ); assert_eq!( - TryInto::::try_into(decoded.scalar_at(8).unwrap().as_ref()) - .ok() - .unwrap(), + TryInto::::try_into(&decoded.scalar_at(8).unwrap()).unwrap(), 300i64 ); } diff --git a/encodings/sequence/src/array.rs b/encodings/sequence/src/array.rs index 2a9d0983c6f..28e83e10ccc 100644 --- a/encodings/sequence/src/array.rs +++ b/encodings/sequence/src/array.rs @@ -245,28 +245,26 @@ impl VTable for SequenceVTable { let ptype = dtype.as_ptype(); // We go via scalar to cast the scalar values into the correct PType - let base = Scalar::new( - DType::Primitive(ptype, NonNullable), + let base = Scalar::from_proto_value( metadata .0 .base .as_ref() - .ok_or_else(|| vortex_err!("base required"))? - .try_into()?, - ) + .ok_or_else(|| vortex_err!("base required"))?, + &DType::Primitive(ptype, NonNullable), + )? .as_primitive() .pvalue() .vortex_expect("non-nullable primitive"); - let multiplier = Scalar::new( - DType::Primitive(ptype, NonNullable), + let multiplier = Scalar::from_proto_value( metadata .0 .multiplier .as_ref() - .ok_or_else(|| vortex_err!("base required"))? - .try_into()?, - ) + .ok_or_else(|| vortex_err!("multiplier required"))?, + &DType::Primitive(ptype, NonNullable), + )? .as_primitive() .pvalue() .vortex_expect("non-nullable primitive"); @@ -355,10 +353,10 @@ impl BaseArrayVTable for SequenceVTable { impl OperationsVTable for SequenceVTable { fn scalar_at(array: &SequenceArray, index: usize) -> VortexResult { - Ok(Scalar::new( + Scalar::try_new( array.dtype().clone(), - ScalarValue::from(array.index_value(index)), - )) + Some(ScalarValue::Primitive(array.index_value(index))), + ) } } @@ -423,7 +421,7 @@ mod tests { assert_eq!( scalar, - Scalar::new(scalar.dtype().clone(), ScalarValue::from(8i64)) + Scalar::try_new(scalar.dtype().clone(), Some(ScalarValue::from(8i64))).unwrap() ) } diff --git a/encodings/sequence/src/compute/cast.rs b/encodings/sequence/src/compute/cast.rs index 68725d9e6a3..20cfebcf46f 100644 --- a/encodings/sequence/src/compute/cast.rs +++ b/encodings/sequence/src/compute/cast.rs @@ -48,14 +48,14 @@ impl CastKernel for SequenceVTable { // For type changes, we need to cast the base and multiplier if array.ptype() != *target_ptype { // Create scalars from PValues and cast them - let base_scalar = Scalar::new( + let base_scalar = Scalar::try_new( DType::Primitive(array.ptype(), Nullability::NonNullable), - ScalarValue::from(array.base()), - ); - let multiplier_scalar = Scalar::new( + Some(ScalarValue::Primitive(array.base())), + )?; + let multiplier_scalar = Scalar::try_new( DType::Primitive(array.ptype(), Nullability::NonNullable), - ScalarValue::from(array.multiplier()), - ); + Some(ScalarValue::Primitive(array.multiplier())), + )?; let new_base_scalar = base_scalar.cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?; diff --git a/encodings/sequence/src/compute/compare.rs b/encodings/sequence/src/compute/compare.rs index 6fa11d534a5..fc652216a15 100644 --- a/encodings/sequence/src/compute/compare.rs +++ b/encodings/sequence/src/compute/compare.rs @@ -9,7 +9,6 @@ use vortex_array::compute::CompareKernel; use vortex_array::compute::Operator; use vortex_array::validity::Validity; use vortex_buffer::BitBuffer; -use vortex_dtype::DType; use vortex_dtype::NativePType; use vortex_dtype::Nullability; use vortex_dtype::match_each_integer_ptype; @@ -58,11 +57,7 @@ impl CompareKernel for SequenceVTable { Ok(Some(BoolArray::new(buffer, validity).to_array())) } else { Ok(Some( - ConstantArray::new( - Scalar::new(DType::Bool(nullability), false.into()), - lhs.len(), - ) - .to_array(), + ConstantArray::new(Scalar::bool(false, nullability), lhs.len()).to_array(), )) } } diff --git a/encodings/sequence/src/kernel.rs b/encodings/sequence/src/kernel.rs index 982f33cf51b..6dacc4d01ca 100644 --- a/encodings/sequence/src/kernel.rs +++ b/encodings/sequence/src/kernel.rs @@ -129,22 +129,15 @@ fn compare_eq_neq( find_intersection_scalar(array.base(), array.multiplier(), array.len, constant) else { return Ok(Some( - ConstantArray::new( - Scalar::new(DType::Bool(nullability), not_match_val.into()), - array.len, - ) - .into_array(), + ConstantArray::new(Scalar::bool(not_match_val, nullability), array.len).into_array(), )); }; let idx = set_idx as u64; let len = array.len as u64; if len == 1 && set_idx == 0 { - let result_array = ConstantArray::new( - Scalar::new(DType::Bool(nullability), match_val.into()), - array.len, - ) - .to_array(); + let result_array = + ConstantArray::new(Scalar::bool(match_val, nullability), array.len).to_array(); return Ok(Some(result_array)); } @@ -186,16 +179,12 @@ fn compare_ordering( ); let result_array = match transition { - Transition::AllTrue => ConstantArray::new( - Scalar::new(DType::Bool(nullability), true.into()), - array.len, - ) - .to_array(), - Transition::AllFalse => ConstantArray::new( - Scalar::new(DType::Bool(nullability), false.into()), - array.len, - ) - .to_array(), + Transition::AllTrue => { + ConstantArray::new(Scalar::bool(true, nullability), array.len).to_array() + } + Transition::AllFalse => { + ConstantArray::new(Scalar::bool(false, nullability), array.len).to_array() + } Transition::FalseToTrue(idx) => { // [0..idx) is false, [idx..len) is true let ends = buffer![idx as u64, array.len as u64].into_array(); @@ -362,10 +351,11 @@ mod tests { fn test_sequence_gte_constant() -> VortexResult<()> { let seq = SequenceArray::typed_new(0i64, 1, NonNullable, 10)?.to_array(); let constant = ConstantArray::new( - Scalar::new( + Scalar::try_new( DType::Primitive(PType::I64, Nullability::Nullable), - 5i64.into(), - ), + Some(5i64.into()), + ) + .unwrap(), 10, ) .to_array(); diff --git a/encodings/sparse/src/canonical.rs b/encodings/sparse/src/canonical.rs index 222c4befeb0..162fb1d0515 100644 --- a/encodings/sparse/src/canonical.rs +++ b/encodings/sparse/src/canonical.rs @@ -97,12 +97,12 @@ pub(super) fn execute_sparse(array: &SparseArray) -> VortexResult { }) } dtype @ DType::Utf8(..) => { - let fill_value = array.fill_scalar().as_utf8().value(); + let fill_value = array.fill_scalar().as_utf8().value().cloned(); let fill_value = fill_value.map(BufferString::into_inner); execute_varbin(array, dtype.clone(), fill_value)? } dtype @ DType::Binary(..) => { - let fill_value = array.fill_scalar().as_binary().value(); + let fill_value = array.fill_scalar().as_binary().value().cloned(); execute_varbin(array, dtype.clone(), fill_value)? } DType::List(values_dtype, nullability) => { @@ -369,12 +369,12 @@ fn execute_sparse_struct( unresolved_patches: &Patches, len: usize, ) -> VortexResult { - let (fill_values, top_level_fill_validity) = match fill_struct.fields() { + let (fill_values, top_level_fill_validity) = match fill_struct.fields_iter() { Some(fill_values) => (fill_values.collect::>(), Validity::AllValid), None => ( struct_fields .fields() - .map(Scalar::default_value) + .map(|f| Scalar::default_value(&f)) .collect::>(), Validity::AllInvalid, ), diff --git a/encodings/sparse/src/compute/cast.rs b/encodings/sparse/src/compute/cast.rs index 2db305e80b9..c38ac614ded 100644 --- a/encodings/sparse/src/compute/cast.rs +++ b/encodings/sparse/src/compute/cast.rs @@ -77,7 +77,7 @@ mod tests { buffer![1u64, 3, 5].into_array(), PrimitiveArray::from_option_iter([Some(42i32), Some(84), Some(126)]).into_array(), 8, - Scalar::null_typed::(), + Scalar::null_native::(), ) .unwrap(); @@ -109,7 +109,7 @@ mod tests { buffer![1u64, 3, 7].into_array(), PrimitiveArray::from_option_iter([Some(100i32), None, Some(300)]).into_array(), 10, - Scalar::null_typed::() + Scalar::null_native::() ).unwrap())] #[case(SparseArray::try_new( buffer![5u64].into_array(), diff --git a/encodings/sparse/src/compute/filter.rs b/encodings/sparse/src/compute/filter.rs index 4b0046631e4..47f8c805156 100644 --- a/encodings/sparse/src/compute/filter.rs +++ b/encodings/sparse/src/compute/filter.rs @@ -60,7 +60,7 @@ mod tests { buffer![2u64, 9, 15].into_array(), PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(), 20, - Scalar::null_typed::(), + Scalar::null_native::(), ) .unwrap() .into_array() @@ -80,7 +80,7 @@ mod tests { buffer![0u64].into_array(), PrimitiveArray::new(buffer![33_i32], Validity::AllValid).into_array(), 1, - Scalar::null_typed::(), + Scalar::null_native::(), ) .unwrap(); @@ -94,7 +94,7 @@ mod tests { buffer![0_u64, 3, 6].into_array(), PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(), 7, - Scalar::null_typed::(), + Scalar::null_native::(), ) .unwrap() .into_array(); @@ -109,7 +109,7 @@ mod tests { buffer![1u64, 3].into_array(), PrimitiveArray::new(buffer![44_i32, 55], Validity::AllValid).into_array(), 4, - Scalar::null_typed::(), + Scalar::null_native::(), ) .unwrap(); diff --git a/encodings/sparse/src/compute/mod.rs b/encodings/sparse/src/compute/mod.rs index 38bdb41b68a..c3b6868964b 100644 --- a/encodings/sparse/src/compute/mod.rs +++ b/encodings/sparse/src/compute/mod.rs @@ -14,6 +14,7 @@ mod test { use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::arrays::PrimitiveArray; + use vortex_array::assert_arrays_eq; use vortex_array::compute::cast; use vortex_array::compute::conformance::binary_numeric::test_binary_numeric_array; use vortex_array::compute::conformance::mask::test_mask_conformance; @@ -22,6 +23,7 @@ mod test { use vortex_dtype::DType; use vortex_dtype::Nullability; use vortex_dtype::PType; + use vortex_mask::Mask; use vortex_scalar::Scalar; use crate::SparseArray; @@ -32,12 +34,62 @@ mod test { buffer![2u64, 9, 15].into_array(), PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(), 20, - Scalar::null_typed::(), + Scalar::null_native::(), ) .unwrap() .into_array() } + #[rstest] + fn test_filter(array: ArrayRef) { + let mut predicate = vec![false, false, true]; + predicate.extend_from_slice(&[false; 17]); + let mask = Mask::from_iter(predicate); + + let filtered_array = array.filter(mask).unwrap(); + + // Construct expected SparseArray: index 2 was kept, which had value 33. + // The new index is 0 (since it's the only element). + let expected = SparseArray::try_new( + buffer![0u64].into_array(), + PrimitiveArray::new(buffer![33_i32], Validity::AllValid).into_array(), + 1, + Scalar::null_native::(), + ) + .unwrap(); + + assert_arrays_eq!(filtered_array, expected); + } + + #[test] + fn true_fill_value() { + let mask = Mask::from_iter([false, true, false, true, false, true, true]); + let array = SparseArray::try_new( + buffer![0_u64, 3, 6].into_array(), + PrimitiveArray::new(buffer![33_i32, 44, 55], Validity::AllValid).into_array(), + 7, + Scalar::null_native::(), + ) + .unwrap() + .into_array(); + + let filtered_array = array.filter(mask).unwrap(); + + // Original indices 0, 3, 6 with values 33, 44, 55. + // Mask keeps indices 1, 3, 5, 6 -> new indices 0, 1, 2, 3. + // Index 3 (value 44) maps to new index 1. + // Index 6 (value 55) maps to new index 3. + let expected = SparseArray::try_new( + buffer![1u64, 3].into_array(), + PrimitiveArray::new(buffer![44_i32, 55], Validity::AllValid).into_array(), + 4, + Scalar::null_native::(), + ) + .unwrap(); + + assert_arrays_eq!(filtered_array, expected); + } + #[rstest] fn test_sparse_binary_numeric(array: ArrayRef) { test_binary_numeric_array(array) @@ -97,7 +149,7 @@ mod tests { buffer![2u64, 5, 8].into_array(), PrimitiveArray::from_option_iter([Some(100i32), Some(200), Some(300)]).into_array(), 10, - Scalar::null_typed::() + Scalar::null_native::() ).unwrap())] #[case::sparse_i32_value_fill(SparseArray::try_new( buffer![1u64, 3, 7].into_array(), @@ -129,7 +181,7 @@ mod tests { buffer![0u64, 1, 2, 3, 4].into_array(), PrimitiveArray::from_option_iter([Some(10i32), Some(20), Some(30), Some(40), Some(50)]).into_array(), 5, - Scalar::null_typed::() + Scalar::null_native::() ).unwrap())] // Large sparse arrays #[case::sparse_large(SparseArray::try_new( diff --git a/encodings/sparse/src/compute/take.rs b/encodings/sparse/src/compute/take.rs index 05c88ed787d..ff5f4a0c274 100644 --- a/encodings/sparse/src/compute/take.rs +++ b/encodings/sparse/src/compute/take.rs @@ -72,7 +72,7 @@ mod test { fn test_array_fill_value() -> Scalar { // making this const is annoying - Scalar::null_typed::() + Scalar::null_native::() } fn sparse_array() -> ArrayRef { @@ -174,7 +174,7 @@ mod test { buffer![0u64, 37, 47, 99].into_array(), PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(), 100, - Scalar::null_typed::(), + Scalar::null_native::(), ).unwrap())] #[case(SparseArray::try_new( buffer![1u32, 3, 7, 8, 9].into_array(), @@ -188,7 +188,7 @@ mod test { buffer![2u64, 4, 6].into_array(), nullable_values.into_array(), 10, - Scalar::null_typed::(), + Scalar::null_native::(), ).unwrap() })] #[case(SparseArray::try_new( diff --git a/encodings/sparse/src/lib.rs b/encodings/sparse/src/lib.rs index fbfd003c3d1..300bbf65cd7 100644 --- a/encodings/sparse/src/lib.rs +++ b/encodings/sparse/src/lib.rs @@ -128,10 +128,11 @@ impl VTable for SparseVTable { if buffers.len() != 1 { vortex_bail!("Expected 1 buffer, got {}", buffers.len()); } - let fill_value = Scalar::new( - dtype.clone(), - ScalarValue::from_protobytes(&buffers[0].clone().try_to_host_sync()?)?, - ); + + let bytes: &[u8] = &buffers[0].clone().try_to_host_sync()?; + let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype)?; + + let fill_value = Scalar::try_new(dtype.clone(), scalar_value)?; SparseArray::try_new(patch_indices, patch_values, len, fill_value) } @@ -418,11 +419,8 @@ impl ValidityVTable for SparseVTable { impl VisitorVTable for SparseVTable { fn visit_buffers(array: &SparseArray, visitor: &mut dyn ArrayBufferVisitor) { - let fill_value_buffer = array - .fill_value - .value() - .to_protobytes::() - .freeze(); + let fill_value_buffer = + ScalarValue::to_proto_bytes::(array.fill_value.value()).freeze(); visitor.visit_buffer_handle("fill_value", &BufferHandle::new_host(fill_value_buffer)); } @@ -445,7 +443,6 @@ mod test { use vortex_dtype::Nullability; use vortex_dtype::PType; use vortex_error::VortexExpect; - use vortex_scalar::PrimitiveScalar; use vortex_scalar::Scalar; use super::*; @@ -495,8 +492,9 @@ mod test { .unwrap(); assert_eq!( - PrimitiveScalar::try_from(&arr.scalar_at(10).unwrap()) + arr.scalar_at(10) .unwrap() + .as_primitive() .typed_value::(), Some(1234) ); @@ -619,7 +617,8 @@ mod test { let indices = buffer![0u8, 2, 4, 6, 8].into_array(); let values = PrimitiveArray::from_option_iter([Some(0i16), Some(1), None, None, Some(4)]) .into_array(); - let array = SparseArray::try_new(indices, values, 10, Scalar::null_typed::()).unwrap(); + let array = + SparseArray::try_new(indices, values, 10, Scalar::null_native::()).unwrap(); let actual = array.validity_mask().unwrap(); let expected = Mask::from_iter([ true, false, true, false, false, false, false, false, true, false, diff --git a/encodings/zigzag/src/array.rs b/encodings/zigzag/src/array.rs index 9b28770c1aa..823af4c66e2 100644 --- a/encodings/zigzag/src/array.rs +++ b/encodings/zigzag/src/array.rs @@ -195,7 +195,7 @@ impl OperationsVTable for ZigZagVTable { fn scalar_at(array: &ZigZagArray, index: usize) -> VortexResult { let scalar = array.encoded().scalar_at(index)?; if scalar.is_null() { - return Ok(scalar.reinterpret_cast(array.ptype())); + return scalar.primitive_reinterpret_cast(array.ptype()); } let pscalar = scalar.as_primitive(); diff --git a/encodings/zstd/src/test.rs b/encodings/zstd/src/test.rs index f498218e2e1..05344b24e07 100644 --- a/encodings/zstd/src/test.rs +++ b/encodings/zstd/src/test.rs @@ -84,9 +84,7 @@ fn test_zstd_with_validity_and_multi_frame() { let slice = compressed.slice(176..179).unwrap(); let primitive = slice.to_primitive(); assert_eq!( - TryInto::::try_into(primitive.scalar_at(1).unwrap().as_ref()) - .ok() - .unwrap(), + i32::try_from(&primitive.scalar_at(1).unwrap()).unwrap(), 177 ); assert_eq!( diff --git a/fuzz/src/array/compare.rs b/fuzz/src/array/compare.rs index d5fde3bfd52..bb2c699905c 100644 --- a/fuzz/src/array/compare.rs +++ b/fuzz/src/array/compare.rs @@ -107,27 +107,21 @@ pub fn compare_canonical_array(array: &dyn Array, value: &Scalar, operator: Oper }) } DType::Utf8(_) => array.to_varbinview().with_iterator(|iter| { - let utf8_value = value - .as_utf8() - .value() - .vortex_expect("nulls handled before"); + let utf8_value = value.as_utf8(); compare_to( iter.map(|v| v.map(|b| unsafe { str::from_utf8_unchecked(b) })), - &utf8_value, + utf8_value.value().vortex_expect("nulls handled before"), operator, result_nullability, ) }), DType::Binary(_) => array.to_varbinview().with_iterator(|iter| { - let binary_value = value - .as_binary() - .value() - .vortex_expect("nulls handled before"); + let binary_value = value.as_binary(); compare_to( // Don't understand the lifetime problem here but identity map makes it go away #[allow(clippy::map_identity)] iter.map(|v| v), - &binary_value, + binary_value.value().vortex_expect("nulls handled before"), operator, result_nullability, ) diff --git a/vortex-array/src/array/mod.rs b/vortex-array/src/array/mod.rs index 7af207736ad..95a66056102 100644 --- a/vortex-array/src/array/mod.rs +++ b/vortex-array/src/array/mod.rs @@ -452,7 +452,8 @@ impl Array for ArrayAdapter { stat, Stat::IsConstant | Stat::IsSorted | Stat::IsStrictSorted ) && value.as_ref().as_exact().is_some_and(|v| { - Scalar::new(DType::Bool(Nullability::NonNullable), v.clone()) + Scalar::try_new(DType::Bool(Nullability::NonNullable), Some(v.clone())) + .vortex_expect("A stat that was expected to be a boolean stat was not") .as_bool() .value() .unwrap_or_default() diff --git a/vortex-array/src/arrays/bool/compute/sum.rs b/vortex-array/src/arrays/bool/compute/sum.rs index abcb89f1733..19c435c9d2f 100644 --- a/vortex-array/src/arrays/bool/compute/sum.rs +++ b/vortex-array/src/arrays/bool/compute/sum.rs @@ -3,6 +3,7 @@ use std::ops::BitAnd; +use vortex_dtype::Nullability; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_mask::AllOr; @@ -30,13 +31,15 @@ impl SumKernel for BoolVTable { } }; - let accumulator = accumulator + let acc_value = accumulator .as_primitive() .as_::() .vortex_expect("cannot be null"); - Ok(Scalar::from( - true_count.and_then(|tc| accumulator.checked_add(tc)), - )) + let result = true_count.and_then(|tc| acc_value.checked_add(tc)); + Ok(match result { + Some(v) => Scalar::primitive(v, Nullability::Nullable), + None => Scalar::null_native::(), + }) } } diff --git a/vortex-array/src/arrays/chunked/compute/sum.rs b/vortex-array/src/arrays/chunked/compute/sum.rs index c9bb9389781..6b80c2e2fa8 100644 --- a/vortex-array/src/arrays/chunked/compute/sum.rs +++ b/vortex-array/src/arrays/chunked/compute/sum.rs @@ -30,9 +30,9 @@ mod tests { use vortex_dtype::DType; use vortex_dtype::DecimalDType; use vortex_dtype::Nullability; + use vortex_dtype::i256; use vortex_scalar::DecimalValue; use vortex_scalar::Scalar; - use vortex_scalar::i256; use crate::array::IntoArray; use crate::arrays::ChunkedArray; diff --git a/vortex-array/src/arrays/constant/array.rs b/vortex-array/src/arrays/constant/array.rs index 9e99b73aef8..5722baa7ab8 100644 --- a/vortex-array/src/arrays/constant/array.rs +++ b/vortex-array/src/arrays/constant/array.rs @@ -56,7 +56,7 @@ mod tests { #[cfg_attr(miri, ignore)] #[test] fn test_constant_metadata() { - let scalar_bytes: Vec = ScalarValue::from(i32::MAX).to_protobytes(); + let scalar_bytes: Vec = ScalarValue::to_proto_bytes(Some(&ScalarValue::from(i32::MAX))); check_metadata( "constant.metadata", ProstMetadata(ConstantMetadata { diff --git a/vortex-array/src/arrays/constant/compute/cast.rs b/vortex-array/src/arrays/constant/compute/cast.rs index 36401af8a36..1875c52f39d 100644 --- a/vortex-array/src/arrays/constant/compute/cast.rs +++ b/vortex-array/src/arrays/constant/compute/cast.rs @@ -37,7 +37,7 @@ mod tests { #[case(ConstantArray::new(Scalar::from(-100i32), 10).into_array())] #[case(ConstantArray::new(Scalar::from(3.5f32), 3).into_array())] #[case(ConstantArray::new(Scalar::from(true), 7).into_array())] - #[case(ConstantArray::new(Scalar::null_typed::(), 4).into_array())] + #[case(ConstantArray::new(Scalar::null_native::(), 4).into_array())] #[case(ConstantArray::new(Scalar::from(255u8), 1).into_array())] fn test_cast_constant_conformance(#[case] array: crate::ArrayRef) { test_cast_conformance(array.as_ref()); diff --git a/vortex-array/src/arrays/constant/compute/fill_null.rs b/vortex-array/src/arrays/constant/compute/fill_null.rs index 9e6fd16cda6..cde536b7a35 100644 --- a/vortex-array/src/arrays/constant/compute/fill_null.rs +++ b/vortex-array/src/arrays/constant/compute/fill_null.rs @@ -37,7 +37,7 @@ mod test { #[test] fn test_null() { let actual = fill_null( - &ConstantArray::new(Scalar::from(None::), 3).into_array(), + &ConstantArray::new(Scalar::null_native::(), 3).into_array(), &Scalar::from(1), ) .unwrap(); diff --git a/vortex-array/src/arrays/constant/compute/mod.rs b/vortex-array/src/arrays/constant/compute/mod.rs index 7e31b9f326a..f9e3eb3e6b7 100644 --- a/vortex-array/src/arrays/constant/compute/mod.rs +++ b/vortex-array/src/arrays/constant/compute/mod.rs @@ -29,7 +29,7 @@ mod test { #[test] fn test_mask_constant() { - test_mask_conformance(&ConstantArray::new(Scalar::null_typed::(), 5).into_array()); + test_mask_conformance(&ConstantArray::new(Scalar::null_native::(), 5).into_array()); test_mask_conformance(&ConstantArray::new(Scalar::from(3u16), 5).into_array()); test_mask_conformance(&ConstantArray::new(Scalar::from(1.0f32 / 0.0f32), 5).into_array()); test_mask_conformance( @@ -39,7 +39,7 @@ mod test { #[test] fn test_filter_constant() { - test_filter_conformance(&ConstantArray::new(Scalar::null_typed::(), 5).into_array()); + test_filter_conformance(&ConstantArray::new(Scalar::null_native::(), 5).into_array()); test_filter_conformance(&ConstantArray::new(Scalar::from(3u16), 5).into_array()); test_filter_conformance(&ConstantArray::new(Scalar::from(1.0f32 / 0.0f32), 5).into_array()); test_filter_conformance( diff --git a/vortex-array/src/arrays/constant/compute/sum.rs b/vortex-array/src/arrays/constant/compute/sum.rs index 75a412d50c3..3137af22990 100644 --- a/vortex-array/src/arrays/constant/compute/sum.rs +++ b/vortex-array/src/arrays/constant/compute/sum.rs @@ -35,11 +35,15 @@ impl SumKernel for ConstantVTable { .ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?; let sum_value = sum_scalar(array.scalar(), array.len(), accumulator)?; - Ok(Scalar::new(sum_dtype, sum_value)) + Scalar::try_new(sum_dtype, sum_value) } } -fn sum_scalar(scalar: &Scalar, len: usize, accumulator: &Scalar) -> VortexResult { +fn sum_scalar( + scalar: &Scalar, + len: usize, + accumulator: &Scalar, +) -> VortexResult> { match scalar.dtype() { DType::Bool(_) => { let count = match scalar.as_bool().value() { @@ -51,14 +55,16 @@ fn sum_scalar(scalar: &Scalar, len: usize, accumulator: &Scalar) -> VortexResult .as_primitive() .as_::() .vortex_expect("cannot be null"); - Ok(ScalarValue::from(accumulator.checked_add(count))) + Ok(accumulator + .checked_add(count) + .map(|v| ScalarValue::Primitive(v.into()))) } DType::Primitive(ptype, _) => { let result = match_each_native_ptype!( ptype, - unsigned: |T| { sum_integral::(scalar.as_primitive(), len, accumulator)?.into() }, - signed: |T| { sum_integral::(scalar.as_primitive(), len, accumulator)?.into() }, - floating: |T| { sum_float(scalar.as_primitive(), len, accumulator)?.into() } + unsigned: |T| { sum_integral::(scalar.as_primitive(), len, accumulator)?.map(|v| ScalarValue::Primitive(v.into())) }, + signed: |T| { sum_integral::(scalar.as_primitive(), len, accumulator)?.map(|v| ScalarValue::Primitive(v.into())) }, + floating: |T| { sum_float(scalar.as_primitive(), len, accumulator)?.map(|v| ScalarValue::Primitive(v.into())) } ); Ok(result) } @@ -75,7 +81,7 @@ fn sum_decimal( array_len: usize, decimal_dtype: DecimalDType, accumulator: &Scalar, -) -> VortexResult { +) -> VortexResult> { let result_dtype = Stat::Sum .dtype(&DType::Decimal(decimal_dtype, Nullability::Nullable)) .vortex_expect("decimal supports sum"); @@ -85,43 +91,35 @@ fn sum_decimal( let Some(value) = decimal_scalar.decimal_value() else { // Null value: return null - return Ok(ScalarValue::null()); + return Ok(None); }; - // Convert array_len to DecimalValue for multiplication + // Convert array_len to DecimalValue for multiplication. let len_value = DecimalValue::I256(i256::from_i128(array_len as i128)); - // Multiply value * len - let array_sum = value.checked_mul(&len_value).and_then(|result| { - // Check if result fits in the precision - result - .fits_in_precision(*result_decimal_type) - .unwrap_or(false) - .then_some(result) - }); - - // Add accumulator to array_sum - let initial_decimal = DecimalScalar::try_from(accumulator)?; + let Some(array_sum) = value + .checked_mul(&len_value) + .filter(|d| d.fits_in_precision(*result_decimal_type)) + else { + return Ok(None); + }; + + // Add accumulator to array_sum. + let initial_decimal = accumulator.as_decimal(); let initial_dec_value = initial_decimal .decimal_value() .unwrap_or(DecimalValue::I256(i256::ZERO)); - match array_sum { - Some(array_sum_value) => { - let total = array_sum_value - .checked_add(&initial_dec_value) - .and_then(|result| { - result - .fits_in_precision(*result_decimal_type) - .unwrap_or(false) - .then_some(result) - }); - match total { - Some(result_value) => Ok(ScalarValue::from(result_value)), - None => Ok(ScalarValue::null()), // Overflow - } - } - None => Ok(ScalarValue::null()), // Overflow + let total = array_sum + .checked_add(&initial_dec_value) + .and_then(|result| { + result + .fits_in_precision(*result_decimal_type) + .then_some(result) + }); + match total { + Some(result_value) => Ok(Some(ScalarValue::from(result_value))), + None => Ok(None), // Overflow } } @@ -132,7 +130,6 @@ fn sum_integral( ) -> VortexResult> where T: NativePType + CheckedMul + CheckedAdd, - Scalar: From>, { let v = primitive_scalar.as_::(); let array_len = @@ -295,7 +292,7 @@ mod tests { let sum = sum_with_accumulator(array.as_ref(), &Scalar::primitive(acc, Nullable)) .vortex_expect("operation should succeed in test"); assert_eq!( - f64::try_from(sum).vortex_expect("operation should succeed in test"), + f64::try_from(&sum).vortex_expect("operation should succeed in test"), -2048669274505644600000000000f64 ); } diff --git a/vortex-array/src/arrays/constant/compute/take.rs b/vortex-array/src/arrays/constant/compute/take.rs index 3f565a77729..209fc5feb5c 100644 --- a/vortex-array/src/arrays/constant/compute/take.rs +++ b/vortex-array/src/arrays/constant/compute/take.rs @@ -20,13 +20,13 @@ impl TakeReduce for ConstantVTable { fn take(array: &ConstantArray, indices: &dyn Array) -> VortexResult> { let result = match indices.validity_mask()?.bit_buffer() { AllOr::All => { - let scalar = Scalar::new( + let scalar = Scalar::try_new( array .scalar() .dtype() .union_nullability(indices.dtype().nullability()), - array.scalar().value().clone(), - ); + array.scalar().value().cloned(), + )?; ConstantArray::new(scalar, indices.len()).into_array() } AllOr::None => ConstantArray::new( @@ -125,7 +125,7 @@ mod tests { #[case(ConstantArray::new(42i32, 5))] #[case(ConstantArray::new(std::f64::consts::PI, 10))] #[case(ConstantArray::new(Scalar::from("hello"), 3))] - #[case(ConstantArray::new(Scalar::null_typed::(), 5))] + #[case(ConstantArray::new(Scalar::null_native::(), 5))] #[case(ConstantArray::new(true, 1))] fn test_take_constant_conformance(#[case] array: ConstantArray) { test_take_conformance(array.as_ref()); diff --git a/vortex-array/src/arrays/constant/vtable/canonical.rs b/vortex-array/src/arrays/constant/vtable/canonical.rs index 42c3391f6fc..f251c372b1e 100644 --- a/vortex-array/src/arrays/constant/vtable/canonical.rs +++ b/vortex-array/src/arrays/constant/vtable/canonical.rs @@ -14,14 +14,8 @@ use vortex_dtype::match_each_decimal_value_type; use vortex_dtype::match_each_native_ptype; use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_scalar::BinaryScalar; -use vortex_scalar::BoolScalar; use vortex_scalar::DecimalValue; -use vortex_scalar::ExtScalar; -use vortex_scalar::ListScalar; use vortex_scalar::Scalar; -use vortex_scalar::StructScalar; -use vortex_scalar::Utf8Scalar; use vortex_vector::binaryview::BinaryView; use crate::Canonical; @@ -54,11 +48,7 @@ pub(crate) fn constant_canonicalize(array: &ConstantArray) -> VortexResult Canonical::Null(NullArray::new(array.len())), DType::Bool(..) => Canonical::Bool(BoolArray::new( - if BoolScalar::try_from(scalar) - .vortex_expect("must be bool") - .value() - .unwrap_or_default() - { + if scalar.as_bool().value().unwrap_or_default() { BitBuffer::new_set(array.len()) } else { BitBuffer::new_unset(array.len()) @@ -111,9 +101,7 @@ pub(crate) fn constant_canonicalize(array: &ConstantArray) -> VortexResult { - let value = Utf8Scalar::try_from(scalar) - .vortex_expect("Must be a utf8 scalar") - .value(); + let value = scalar.as_utf8().value(); let const_value = value.as_ref().map(|v| v.as_bytes()); Canonical::VarBinView(constant_canonical_byte_view( const_value, @@ -122,9 +110,7 @@ pub(crate) fn constant_canonicalize(array: &ConstantArray) -> VortexResult { - let value = BinaryScalar::try_from(scalar) - .vortex_expect("must be a binary scalar") - .value(); + let value = scalar.as_binary().value().cloned(); let const_value = value.as_ref().map(|v| v.as_slice()); Canonical::VarBinView(constant_canonical_byte_view( const_value, @@ -133,18 +119,21 @@ pub(crate) fn constant_canonicalize(array: &ConstantArray) -> VortexResult { - let value = StructScalar::try_from(scalar).vortex_expect("must be struct"); - let fields: Vec<_> = match value.fields() { + let value = scalar.as_struct(); + let fields: Vec<_> = match value.fields_iter() { Some(fields) => fields .into_iter() .map(|s| ConstantArray::new(s, array.len()).into_array()) .collect(), None => { assert!(validity.all_invalid(array.len())?); + // The struct is entirely null, so fields just need placeholder values with the + // correct dtype. We use `default_value` which returns a zero for non-nullable + // dtypes and null for nullable dtypes, preserving each field's nullability. struct_dtype .fields() .map(|dt| { - let scalar = Scalar::default_value(dt); + let scalar = Scalar::default_value(&dt); ConstantArray::new(scalar, array.len()).into_array() }) .collect() @@ -158,7 +147,7 @@ pub(crate) fn constant_canonicalize(array: &ConstantArray) -> VortexResult Canonical::List(constant_canonical_list_array(scalar, array.len())), DType::FixedSizeList(element_dtype, list_size, _) => { - let value = ListScalar::try_from(scalar).vortex_expect("must be list"); + let value = scalar.as_list(); Canonical::FixedSizeList(constant_canonical_fixed_size_list_array( value.elements(), @@ -169,7 +158,7 @@ pub(crate) fn constant_canonicalize(array: &ConstantArray) -> VortexResult { - let s = ExtScalar::try_from(scalar).vortex_expect("must be an extension scalar"); + let s = scalar.as_extension(); let storage_scalar = s.storage(); let storage_self = ConstantArray::new(storage_scalar, array.len()).into_array(); @@ -227,7 +216,7 @@ fn constant_canonical_byte_view( /// We basically just project the list scalar value into list view components. If the caller wants /// a fully decompressed and non-overlapping array, they can rebuild the array. fn constant_canonical_list_array(scalar: &Scalar, len: usize) -> ListViewArray { - let list = ListScalar::try_from(scalar).vortex_expect("must be list"); + let list = scalar.as_list(); // Since "canonicalize" only applies to the top level array, we can simply have 1 scalar in our // child `elements` and have all list views point to that scalar. diff --git a/vortex-array/src/arrays/constant/vtable/mod.rs b/vortex-array/src/arrays/constant/vtable/mod.rs index 53594a5431f..ca4a36d1fcb 100644 --- a/vortex-array/src/arrays/constant/vtable/mod.rs +++ b/vortex-array/src/arrays/constant/vtable/mod.rs @@ -61,7 +61,8 @@ impl VTable for ConstantVTable { } fn metadata(array: &ConstantArray) -> VortexResult { - let proto_bytes: Vec = array.scalar().value().to_protobytes(); + let constant = &array.scalar(); + let proto_bytes: Vec = ScalarValue::to_proto_bytes(constant.value()); let scalar_value = (proto_bytes.len() <= CONSTANT_INLINE_THRESHOLD).then_some(proto_bytes); Ok(ProstMetadata(ConstantMetadata { scalar_value })) } @@ -92,16 +93,23 @@ impl VTable for ConstantVTable { _children: &dyn ArrayChildren, ) -> VortexResult { // Prefer reading the scalar from inlined metadata to avoid device-to-host copies. - let sv = if let Some(ref proto_bytes) = metadata.scalar_value { - ScalarValue::from_protobytes(proto_bytes)? + let scalar = if let Some(proto_bytes) = &metadata.scalar_value { + let scalar_value = ScalarValue::from_proto_bytes(proto_bytes, dtype)?; + + Scalar::try_new(dtype.clone(), scalar_value) } else { if buffers.len() != 1 { vortex_bail!("Expected 1 buffer, got {}", buffers.len()); } + let buffer = buffers[0].clone().try_to_host_sync()?; - ScalarValue::from_protobytes(&buffer)? - }; - let scalar = Scalar::new(dtype.clone(), sv); + let bytes: &[u8] = buffer.as_ref(); + + let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype)?; + + Scalar::try_new(dtype.clone(), scalar_value) + }?; + Ok(ConstantArray::new(scalar, len)) } diff --git a/vortex-array/src/arrays/constant/vtable/visitor.rs b/vortex-array/src/arrays/constant/vtable/visitor.rs index 28e613ba121..9f1e74553ca 100644 --- a/vortex-array/src/arrays/constant/vtable/visitor.rs +++ b/vortex-array/src/arrays/constant/vtable/visitor.rs @@ -2,6 +2,7 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_buffer::ByteBufferMut; +use vortex_scalar::ScalarValue; use crate::ArrayBufferVisitor; use crate::ArrayChildVisitor; @@ -12,11 +13,7 @@ use crate::vtable::VisitorVTable; impl VisitorVTable for ConstantVTable { fn visit_buffers(array: &ConstantArray, visitor: &mut dyn ArrayBufferVisitor) { - let buffer = array - .scalar - .value() - .to_protobytes::() - .freeze(); + let buffer = ScalarValue::to_proto_bytes::(array.scalar.value()).freeze(); visitor.visit_buffer_handle("scalar", &BufferHandle::new_host(buffer)); } diff --git a/vortex-array/src/arrays/decimal/compute/min_max.rs b/vortex-array/src/arrays/decimal/compute/min_max.rs index da4483a3a69..68450cf7f9d 100644 --- a/vortex-array/src/arrays/decimal/compute/min_max.rs +++ b/vortex-array/src/arrays/decimal/compute/min_max.rs @@ -95,14 +95,16 @@ mod tests { let non_nullable_dtype = decimal.dtype().as_nonnullable(); let expected = MinMaxResult { - min: Scalar::new( + min: Scalar::try_new( non_nullable_dtype.clone(), - ScalarValue::from(DecimalValue::from(100i32)), - ), - max: Scalar::new( + Some(ScalarValue::from(DecimalValue::from(100i32))), + ) + .unwrap(), + max: Scalar::try_new( non_nullable_dtype, - ScalarValue::from(DecimalValue::from(200i32)), - ), + Some(ScalarValue::from(DecimalValue::from(200i32))), + ) + .unwrap(), }; assert_eq!(Some(expected), min_max) diff --git a/vortex-array/src/arrays/decimal/compute/sum.rs b/vortex-array/src/arrays/decimal/compute/sum.rs index 5b5a192fc58..e091162653e 100644 --- a/vortex-array/src/arrays/decimal/compute/sum.rs +++ b/vortex-array/src/arrays/decimal/compute/sum.rs @@ -13,7 +13,6 @@ use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_mask::Mask; -use vortex_scalar::DecimalScalar; use vortex_scalar::DecimalValue; use vortex_scalar::Scalar; @@ -38,8 +37,8 @@ impl SumKernel for DecimalVTable { .vortex_expect("must be decimal"); // Extract the initial value as a DecimalValue - let initial_decimal = DecimalScalar::try_from(accumulator) - .vortex_expect("must be a decimal") + let initial_decimal = accumulator + .as_decimal() .decimal_value() .vortex_expect("cannot be null"); @@ -129,11 +128,11 @@ mod tests { use vortex_dtype::DType; use vortex_dtype::DecimalDType; use vortex_dtype::Nullability; + use vortex_dtype::i256; use vortex_error::VortexExpect; use vortex_scalar::DecimalValue; use vortex_scalar::Scalar; use vortex_scalar::ScalarValue; - use vortex_scalar::i256; use crate::arrays::DecimalArray; use crate::compute::sum; @@ -149,10 +148,11 @@ mod tests { let result = sum(decimal.as_ref()).unwrap(); - let expected = Scalar::new( + let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable), - ScalarValue::from(DecimalValue::from(600i32)), - ); + Some(ScalarValue::from(DecimalValue::from(600i32))), + ) + .unwrap(); assert_eq!(result, expected); } @@ -167,10 +167,11 @@ mod tests { let result = sum(decimal.as_ref()).unwrap(); - let expected = Scalar::new( + let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable), - ScalarValue::from(DecimalValue::from(800i32)), - ); + Some(ScalarValue::from(DecimalValue::from(800i32))), + ) + .unwrap(); assert_eq!(result, expected); } @@ -185,10 +186,11 @@ mod tests { let result = sum(decimal.as_ref()).unwrap(); - let expected = Scalar::new( + let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(14, 2), Nullability::NonNullable), - ScalarValue::from(DecimalValue::from(150i32)), - ); + Some(ScalarValue::from(DecimalValue::from(150i32))), + ) + .unwrap(); assert_eq!(result, expected); } @@ -207,10 +209,11 @@ mod tests { // Should use i64 for accumulation since precision increases let expected_sum = near_max as i64 + 500 + 400; - let expected = Scalar::new( + let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(20, 2), Nullability::NonNullable), - ScalarValue::from(DecimalValue::from(expected_sum)), - ); + Some(ScalarValue::from(DecimalValue::from(expected_sum))), + ) + .unwrap(); assert_eq!(result, expected); } @@ -228,17 +231,18 @@ mod tests { let result = sum(decimal.as_ref()).unwrap(); let expected_sum = (large_val as i128) * 4 + 1; - let expected = Scalar::new( + let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(29, 0), Nullability::NonNullable), - ScalarValue::from(DecimalValue::from(expected_sum)), - ); + Some(ScalarValue::from(DecimalValue::from(expected_sum))), + ) + .unwrap(); assert_eq!(result, expected); } #[test] fn test_sum_overflow_detection() { - use vortex_scalar::i256; + use vortex_dtype::i256; // Create values that will overflow when summed // Use maximum i128 values that will overflow when added @@ -254,10 +258,11 @@ mod tests { // Should use i256 for accumulation let expected_sum = i256::from_i128(max_val) + i256::from_i128(max_val) + i256::from_i128(max_val); - let expected = Scalar::new( + let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable), - ScalarValue::from(DecimalValue::from(expected_sum)), - ); + Some(ScalarValue::from(DecimalValue::from(expected_sum))), + ) + .unwrap(); assert_eq!(result, expected); } @@ -276,10 +281,11 @@ mod tests { let result = sum(decimal.as_ref()).unwrap(); let expected_sum = (large_pos as i128) + (large_neg as i128) + (large_pos as i128) + 1000; - let expected = Scalar::new( + let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(29, 3), Nullability::NonNullable), - ScalarValue::from(DecimalValue::from(expected_sum)), - ); + Some(ScalarValue::from(DecimalValue::from(expected_sum))), + ) + .unwrap(); assert_eq!(result, expected); } @@ -295,10 +301,11 @@ mod tests { let result = sum(decimal.as_ref()).unwrap(); // Scale should be preserved, precision increased by 10 - let expected = Scalar::new( + let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(16, 4), Nullability::NonNullable), - ScalarValue::from(DecimalValue::from(91346i32)), - ); + Some(ScalarValue::from(DecimalValue::from(91346i32))), + ) + .unwrap(); assert_eq!(result, expected); } @@ -310,10 +317,11 @@ mod tests { let result = sum(decimal.as_ref()).unwrap(); - let expected = Scalar::new( + let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(13, 1), Nullability::NonNullable), - ScalarValue::from(DecimalValue::from(42i32)), - ); + Some(ScalarValue::from(DecimalValue::from(42i32))), + ) + .unwrap(); assert_eq!(result, expected); } @@ -328,10 +336,11 @@ mod tests { let result = sum(decimal.as_ref()).unwrap(); - let expected = Scalar::new( + let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(14, 2), Nullability::Nullable), - ScalarValue::from(DecimalValue::from(300i32)), - ); + Some(ScalarValue::from(DecimalValue::from(300i32))), + ) + .unwrap(); assert_eq!(result, expected); } @@ -353,10 +362,11 @@ mod tests { // Should use i256 for accumulation since 9 * (i128::MAX / 10) fits in i128 but we increase precision let expected_sum = i256::from_i128(large_i128).wrapping_pow(1) * i256::from_i128(9); - let expected = Scalar::new( + let expected = Scalar::try_new( DType::Decimal(DecimalDType::new(48, 0), Nullability::NonNullable), - ScalarValue::from(DecimalValue::from(expected_sum)), - ); + Some(ScalarValue::from(DecimalValue::from(expected_sum))), + ) + .unwrap(); assert_eq!(result, expected); } diff --git a/vortex-array/src/arrays/decimal/utils.rs b/vortex-array/src/arrays/decimal/utils.rs index 00bd6a86fb3..1b148b130a9 100644 --- a/vortex-array/src/arrays/decimal/utils.rs +++ b/vortex-array/src/arrays/decimal/utils.rs @@ -3,9 +3,9 @@ use itertools::Itertools; use itertools::MinMaxResult; +use vortex_dtype::DecimalType; +use vortex_dtype::i256; use vortex_error::VortexExpect; -use vortex_scalar::DecimalType; -use vortex_scalar::i256; use crate::arrays::DecimalArray; use crate::vtable::ValidityHelper; diff --git a/vortex-array/src/arrays/decimal/vtable/array.rs b/vortex-array/src/arrays/decimal/vtable/array.rs index 23013e066ec..f5997dd0b95 100644 --- a/vortex-array/src/arrays/decimal/vtable/array.rs +++ b/vortex-array/src/arrays/decimal/vtable/array.rs @@ -4,7 +4,7 @@ use std::hash::Hash; use vortex_dtype::DType; -use vortex_scalar::DecimalType; +use vortex_dtype::DecimalType; use crate::Precision; use crate::arrays::DecimalArray; diff --git a/vortex-array/src/arrays/decimal/vtable/mod.rs b/vortex-array/src/arrays/decimal/vtable/mod.rs index 96683f051bd..70ac0c3fdad 100644 --- a/vortex-array/src/arrays/decimal/vtable/mod.rs +++ b/vortex-array/src/arrays/decimal/vtable/mod.rs @@ -4,13 +4,13 @@ use kernel::PARENT_KERNELS; use vortex_buffer::Alignment; use vortex_dtype::DType; +use vortex_dtype::DecimalType; use vortex_dtype::NativeDecimalType; use vortex_dtype::match_each_decimal_value_type; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; -use vortex_scalar::DecimalType; use vortex_session::VortexSession; use crate::ArrayRef; diff --git a/vortex-array/src/arrays/dict/compute/fill_null.rs b/vortex-array/src/arrays/dict/compute/fill_null.rs index 4a7ab763d60..b676ddceda8 100644 --- a/vortex-array/src/arrays/dict/compute/fill_null.rs +++ b/vortex-array/src/arrays/dict/compute/fill_null.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use vortex_dtype::match_each_integer_ptype; use vortex_error::VortexResult; use vortex_scalar::Scalar; use vortex_scalar::ScalarValue; @@ -21,8 +22,8 @@ use crate::register_kernel; impl FillNullKernel for DictVTable { fn fill_null(&self, array: &DictArray, fill_value: &Scalar) -> VortexResult { - // If the fill value exists in the dictionary, we can simply rewrite the null codes to - // point to the value. + // If the fill value already exists in the dictionary, we can simply rewrite the null codes + // to point to the value. let found_fill_values = compare( array.values(), ConstantArray::new(fill_value.clone(), array.values().len()).as_ref(), @@ -30,7 +31,10 @@ impl FillNullKernel for DictVTable { )? .to_bool(); - let Some(first_fill_value) = found_fill_values.to_bit_buffer().set_indices().next() else { + // We found the fill value already in the values at this given index. + let Some(existing_fill_value_index) = + found_fill_values.to_bit_buffer().set_indices().next() + else { // No fill values found, so we must canonicalize and fill_null. // TODO(ngates): compute kernels should all return Option to support this // fall back. @@ -38,20 +42,28 @@ impl FillNullKernel for DictVTable { }; // Now we rewrite the nullable codes to point at the fill value. + let codes = array.codes(); + + // Cast the index to the correct unsigned integer type matching the codes' ptype. + let codes_ptype = codes.dtype().as_ptype(); + + #[expect( + clippy::cast_possible_truncation, + reason = "The existing index must be representable by the existing ptype" + )] + let fill_scalar_value = match_each_integer_ptype!(codes_ptype, |P| { + ScalarValue::from(existing_fill_value_index as P) + }); + + // Fill nulls in both the codes and the values. Note that the precondition of this function + // states that the fill value is non-null, so we do not have to worry about the nullability. let codes = fill_null( - array.codes(), - &Scalar::new( - array - .codes() - .dtype() - .with_nullability(fill_value.dtype().nullability()), - ScalarValue::from(first_fill_value), - ), + codes, + &Scalar::try_new(codes.dtype().as_nonnullable(), Some(fill_scalar_value))?, )?; - // And fill nulls in the values let values = fill_null(array.values(), fill_value)?; - // SAFETY: invariants are still satisfied after patching nulls + // SAFETY: invariants are still satisfied after patching nulls. unsafe { Ok(DictArray::new_unchecked(codes, values) .set_all_values_referenced(array.has_all_values_referenced()) diff --git a/vortex-array/src/arrays/dict/compute/min_max.rs b/vortex-array/src/arrays/dict/compute/min_max.rs index 11851dbef0f..32a041784fc 100644 --- a/vortex-array/src/arrays/dict/compute/min_max.rs +++ b/vortex-array/src/arrays/dict/compute/min_max.rs @@ -49,16 +49,16 @@ mod tests { fn assert_min_max(array: &dyn Array, expected: Option<(i32, i32)>) { match (min_max(array).unwrap(), expected) { (Some(result), Some((expected_min, expected_max))) => { - assert_eq!(i32::try_from(result.min).unwrap(), expected_min); - assert_eq!(i32::try_from(result.max).unwrap(), expected_max); + assert_eq!(i32::try_from(&result.min).unwrap(), expected_min); + assert_eq!(i32::try_from(&result.max).unwrap(), expected_max); } (None, None) => {} (got, expected) => panic!( "min_max mismatch: expected {:?}, got {:?}", expected, got.as_ref().map(|r| ( - i32::try_from(r.min.clone()).ok(), - i32::try_from(r.max.clone()).ok() + i32::try_from(&r.min.clone()).ok(), + i32::try_from(&r.max.clone()).ok() )) ), } diff --git a/vortex-array/src/arrays/dict/take.rs b/vortex-array/src/arrays/dict/take.rs index 30367f0f4b5..46e67a1b2ae 100644 --- a/vortex-array/src/arrays/dict/take.rs +++ b/vortex-array/src/arrays/dict/take.rs @@ -156,7 +156,8 @@ pub(crate) fn propagate_take_stats( source .statistics() .get(stat) - .map(|v| (stat, v.map(|s| s.into_value()).into_inexact())) + .and_then(|v| v.map(|s| s.into_value()).into_inexact().transpose()) + .map(|sv| (stat, sv)) }) .collect::>(); st.combine_sets( diff --git a/vortex-array/src/arrays/masked/compute/take.rs b/vortex-array/src/arrays/masked/compute/take.rs index 162d6e5699d..c9026c43ea9 100644 --- a/vortex-array/src/arrays/masked/compute/take.rs +++ b/vortex-array/src/arrays/masked/compute/take.rs @@ -21,12 +21,14 @@ impl TakeExecute for MaskedVTable { _ctx: &mut ExecutionCtx, ) -> VortexResult> { let taken_child = if !indices.all_valid()? { - // This is safe because we'll mask out these positions in the validity - let filled_take = fill_null( - indices, - &Scalar::default_value(indices.dtype().clone().as_nonnullable()), - )?; - array.child.take(filled_take)?.to_canonical()?.into_array() + // This is safe because we'll mask out these positions in the validity. + let fill_scalar = Scalar::zero_value(indices.dtype()); + let filled_take_indices = fill_null(indices, &fill_scalar)?; + array + .child + .take(filled_take_indices)? + .to_canonical()? + .into_array() } else { array .child diff --git a/vortex-array/src/arrays/null/compute/cast.rs b/vortex-array/src/arrays/null/compute/cast.rs index ceb35504e63..32bc1db0107 100644 --- a/vortex-array/src/arrays/null/compute/cast.rs +++ b/vortex-array/src/arrays/null/compute/cast.rs @@ -5,7 +5,6 @@ use vortex_dtype::DType; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_scalar::Scalar; -use vortex_scalar::ScalarValue; use crate::ArrayRef; use crate::IntoArray; @@ -25,7 +24,7 @@ impl CastKernel for NullVTable { return Ok(Some(array.to_array())); } - let scalar = Scalar::new(dtype.clone(), ScalarValue::null()); + let scalar = Scalar::null(dtype.clone()); Ok(Some(ConstantArray::new(scalar, array.len()).into_array())) } } diff --git a/vortex-array/src/arrays/primitive/array/cast.rs b/vortex-array/src/arrays/primitive/array/cast.rs index 74051542c79..9594381b350 100644 --- a/vortex-array/src/arrays/primitive/array/cast.rs +++ b/vortex-array/src/arrays/primitive/array/cast.rs @@ -75,10 +75,18 @@ impl PrimitiveArray { // If we can't cast to i64, then leave the array as its original type. // It's too big to downcast anyway. - let Ok(min) = min_max.min.cast(&PType::I64.into()).and_then(i64::try_from) else { + let Ok(min) = min_max + .min + .cast(&PType::I64.into()) + .and_then(|s| i64::try_from(&s)) + else { return Ok(self.clone()); }; - let Ok(max) = min_max.max.cast(&PType::I64.into()).and_then(i64::try_from) else { + let Ok(max) = min_max + .max + .cast(&PType::I64.into()) + .and_then(|s| i64::try_from(&s)) + else { return Ok(self.clone()); }; diff --git a/vortex-array/src/arrays/primitive/compute/between.rs b/vortex-array/src/arrays/primitive/compute/between.rs index 844a41a5296..60417906058 100644 --- a/vortex-array/src/arrays/primitive/compute/between.rs +++ b/vortex-array/src/arrays/primitive/compute/between.rs @@ -41,8 +41,8 @@ impl BetweenKernel for PrimitiveVTable { Ok(Some(match_each_native_ptype!(arr.ptype(), |P| { between_impl::

( arr, - P::try_from(lower)?, - P::try_from(upper)?, + P::try_from(&lower)?, + P::try_from(&upper)?, nullability, options, ) diff --git a/vortex-array/src/arrays/primitive/compute/min_max.rs b/vortex-array/src/arrays/primitive/compute/min_max.rs index 1c3445c983a..7ac1cb584ca 100644 --- a/vortex-array/src/arrays/primitive/compute/min_max.rs +++ b/vortex-array/src/arrays/primitive/compute/min_max.rs @@ -87,8 +87,8 @@ mod tests { Validity::NonNullable, ); let min_max = min_max(array.as_ref()).unwrap().unwrap(); - assert_eq!(f32::try_from(min_max.min).unwrap(), -1.0); - assert_eq!(f32::try_from(min_max.max).unwrap(), 1.0); + assert_eq!(f32::try_from(&min_max.min).unwrap(), -1.0); + assert_eq!(f32::try_from(&min_max.max).unwrap(), 1.0); } #[test] @@ -98,7 +98,7 @@ mod tests { Validity::NonNullable, ); let min_max = min_max(array.as_ref()).unwrap().unwrap(); - assert_eq!(f32::try_from(min_max.min).unwrap(), f32::NEG_INFINITY); - assert_eq!(f32::try_from(min_max.max).unwrap(), f32::INFINITY); + assert_eq!(f32::try_from(&min_max.min).unwrap(), f32::NEG_INFINITY); + assert_eq!(f32::try_from(&min_max.max).unwrap(), f32::INFINITY); } } diff --git a/vortex-array/src/arrays/primitive/compute/sum.rs b/vortex-array/src/arrays/primitive/compute/sum.rs index bd3d861d4eb..2d39578a99b 100644 --- a/vortex-array/src/arrays/primitive/compute/sum.rs +++ b/vortex-array/src/arrays/primitive/compute/sum.rs @@ -7,6 +7,7 @@ use num_traits::Float; use num_traits::ToPrimitive; use vortex_buffer::BitBuffer; use vortex_dtype::NativePType; +use vortex_dtype::Nullability; use vortex_dtype::match_each_native_ptype; use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -26,9 +27,27 @@ impl SumKernel for PrimitiveVTable { // All-valid match_each_native_ptype!( array.ptype(), - unsigned: |T| { sum_integer::<_, u64>(array.as_slice::(), accumulator.as_primitive().as_::().vortex_expect("cannot be null")).into() }, - signed: |T| { sum_integer::<_, i64>(array.as_slice::(), accumulator.as_primitive().as_::().vortex_expect("cannot be null")).into() }, - floating: |T| { Some(sum_float(array.as_slice::(), accumulator.as_primitive().as_::().vortex_expect("cannot be null"))).into() } + unsigned: |T| { + Scalar::from(sum_integer::<_, u64>( + array.as_slice::(), + accumulator.as_primitive().as_::().vortex_expect("cannot be null"), + )) + }, + signed: |T| { + Scalar::from(sum_integer::<_, i64>( + array.as_slice::(), + accumulator.as_primitive().as_::().vortex_expect("cannot be null"), + )) + }, + floating: |T| { + Scalar::primitive( + sum_float( + array.as_slice::(), + accumulator.as_primitive().as_::().vortex_expect("cannot be null"), + ), + Nullability::Nullable, + ) + } ) } AllOr::None => { @@ -40,13 +59,28 @@ impl SumKernel for PrimitiveVTable { match_each_native_ptype!( array.ptype(), unsigned: |T| { - sum_integer_with_validity::<_, u64>(array.as_slice::(), validity_mask, accumulator.as_primitive().as_::().vortex_expect("cannot be null")).into() + Scalar::from(sum_integer_with_validity::<_, u64>( + array.as_slice::(), + validity_mask, + accumulator.as_primitive().as_::().vortex_expect("cannot be null"), + )) }, signed: |T| { - sum_integer_with_validity::<_, i64>(array.as_slice::(), validity_mask, accumulator.as_primitive().as_::().vortex_expect("cannot be null")).into() + Scalar::from(sum_integer_with_validity::<_, i64>( + array.as_slice::(), + validity_mask, + accumulator.as_primitive().as_::().vortex_expect("cannot be null"), + )) }, floating: |T| { - Some(sum_float_with_validity(array.as_slice::(), validity_mask, accumulator.as_primitive().as_::().vortex_expect("cannot be null"))).into() + Scalar::primitive( + sum_float_with_validity( + array.as_slice::(), + validity_mask, + accumulator.as_primitive().as_::().vortex_expect("cannot be null"), + ), + Nullability::Nullable, + ) } ) } diff --git a/vortex-array/src/arrays/primitive/compute/take/mod.rs b/vortex-array/src/arrays/primitive/compute/take/mod.rs index 26457aa1af6..5584773804a 100644 --- a/vortex-array/src/arrays/primitive/compute/take/mod.rs +++ b/vortex-array/src/arrays/primitive/compute/take/mod.rs @@ -153,12 +153,12 @@ mod test { // position 3 is null assert_eq!( actual.scalar_at(1).vortex_expect("no fail"), - Scalar::null_typed::() + Scalar::null_native::() ); // the third index is null assert_eq!( actual.scalar_at(2).vortex_expect("no fail"), - Scalar::null_typed::() + Scalar::null_native::() ); } diff --git a/vortex-array/src/arrays/struct_/compute/take.rs b/vortex-array/src/arrays/struct_/compute/take.rs index f7e7a72867c..68141211416 100644 --- a/vortex-array/src/arrays/struct_/compute/take.rs +++ b/vortex-array/src/arrays/struct_/compute/take.rs @@ -1,7 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use vortex_dtype::Nullability; use vortex_error::VortexResult; use vortex_scalar::Scalar; @@ -13,6 +12,7 @@ use crate::arrays::StructArray; use crate::arrays::StructVTable; use crate::arrays::TakeExecute; use crate::compute; +use crate::validity::Validity; use crate::vtable::ValidityHelper; impl TakeExecute for StructVTable { @@ -21,11 +21,27 @@ impl TakeExecute for StructVTable { indices: &dyn Array, _ctx: &mut ExecutionCtx, ) -> VortexResult> { - // The validity is applied to the struct validity, - let inner_indices = &compute::fill_null( - indices, - &Scalar::default_value(indices.dtype().with_nullability(Nullability::NonNullable)), - )?; + // If the struct array is empty then the indices must be all null, otherwise it will access + // an out of bounds element. + if array.is_empty() { + return StructArray::try_new_with_dtype( + array.unmasked_fields().clone(), + array.struct_fields().clone(), + indices.len(), + Validity::AllInvalid, + ) + .map(StructArray::into_array) + .map(Some); + } + + // TODO(connor): This could be bad for cache locality... + + // Fill null indices with zero so they point at a valid row. + // Note that we strip nullability so that `Take::return_dtype` doesn't union nullable into + // each field's dtype (the struct-level validity already captures which rows are null). + let fill_scalar = Scalar::zero_value(&indices.dtype().as_nonnullable()); + let inner_indices = &compute::fill_null(indices, &fill_scalar)?; + StructArray::try_new_with_dtype( array .unmasked_fields() diff --git a/vortex-array/src/arrays/varbin/array.rs b/vortex-array/src/arrays/varbin/array.rs index bb4d2250ecb..c28e59db115 100644 --- a/vortex-array/src/arrays/varbin/array.rs +++ b/vortex-array/src/arrays/varbin/array.rs @@ -355,10 +355,10 @@ impl VarBinArray { self.len() ); - self.offsets() + (&self + .offsets() .scalar_at(index) - .vortex_expect("offsets must support scalar_at") - .as_ref() + .vortex_expect("offsets must support scalar_at")) .try_into() .vortex_expect("Failed to convert offset to usize") } diff --git a/vortex-array/src/arrays/varbin/compute/min_max.rs b/vortex-array/src/arrays/varbin/compute/min_max.rs index 5cb603f380f..0b89c4f817a 100644 --- a/vortex-array/src/arrays/varbin/compute/min_max.rs +++ b/vortex-array/src/arrays/varbin/compute/min_max.rs @@ -88,17 +88,19 @@ mod tests { assert_eq!( min, - Scalar::new( + Scalar::try_new( Utf8(NonNullable), - BufferString::from("hello world".to_string()).into(), + Some(BufferString::from("hello world".to_string()).into()), ) + .unwrap() ); assert_eq!( max, - Scalar::new( + Scalar::try_new( Utf8(NonNullable), - BufferString::from("hello world this is a long string".to_string()).into() + Some(BufferString::from("hello world this is a long string".to_string()).into()), ) + .unwrap() ); } diff --git a/vortex-array/src/arrow/convert.rs b/vortex-array/src/arrow/convert.rs index f57bf91def8..12849414ddc 100644 --- a/vortex-array/src/arrow/convert.rs +++ b/vortex-array/src/arrow/convert.rs @@ -68,11 +68,11 @@ use vortex_dtype::IntegerPType; use vortex_dtype::NativePType; use vortex_dtype::PType; use vortex_dtype::datetime::TimeUnit; +use vortex_dtype::i256; use vortex_error::VortexExpect as _; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_panic; -use vortex_scalar::i256; use crate::ArrayRef; use crate::IntoArray; diff --git a/vortex-array/src/arrow/executor/decimal.rs b/vortex-array/src/arrow/executor/decimal.rs index f88362813e8..9c06292911f 100644 --- a/vortex-array/src/arrow/executor/decimal.rs +++ b/vortex-array/src/arrow/executor/decimal.rs @@ -67,7 +67,7 @@ fn to_arrow_decimal32(array: DecimalArray) -> VortexResult { }) .process_results(|iter| Buffer::from_trusted_len_iter(iter))?, DecimalType::I256 => array - .buffer::() + .buffer::() .into_iter() .map(|x| { x.to_i32() @@ -106,7 +106,7 @@ fn to_arrow_decimal64(array: DecimalArray) -> VortexResult { }) .process_results(|iter| Buffer::from_trusted_len_iter(iter))?, DecimalType::I256 => array - .buffer::() + .buffer::() .into_iter() .map(|x| { x.to_i64() @@ -140,7 +140,7 @@ fn to_arrow_decimal128(array: DecimalArray) -> VortexResult { } DecimalType::I128 => array.buffer::(), DecimalType::I256 => array - .buffer::() + .buffer::() .into_iter() .map(|x| { x.to_i128() @@ -176,7 +176,7 @@ fn to_arrow_decimal256(array: DecimalArray) -> VortexResult { array .buffer::() .into_iter() - .map(|x| vortex_scalar::i256::from_i128(x).into()), + .map(|x| vortex_dtype::i256::from_i128(x).into()), ), DecimalType::I256 => { Buffer::::from_byte_buffer(array.buffer_handle().clone().into_host_sync()) @@ -241,7 +241,7 @@ mod tests { #[case(0i32)] #[case(0i64)] #[case(0i128)] - #[case(vortex_scalar::i256::ZERO)] + #[case(vortex_dtype::i256::ZERO)] fn test_to_arrow_decimal128( #[case] _decimal_type: T, ) -> VortexResult<()> { @@ -268,7 +268,7 @@ mod tests { #[case(0i32)] #[case(0i64)] #[case(0i128)] - #[case(vortex_scalar::i256::ZERO)] + #[case(vortex_dtype::i256::ZERO)] fn test_to_arrow_decimal32(#[case] _decimal_type: T) -> VortexResult<()> { use arrow_array::Decimal32Array; @@ -295,7 +295,7 @@ mod tests { #[case(0i32)] #[case(0i64)] #[case(0i128)] - #[case(vortex_scalar::i256::ZERO)] + #[case(vortex_dtype::i256::ZERO)] fn test_to_arrow_decimal64(#[case] _decimal_type: T) -> VortexResult<()> { use arrow_array::Decimal64Array; @@ -322,7 +322,7 @@ mod tests { #[case(0i32)] #[case(0i64)] #[case(0i128)] - #[case(vortex_scalar::i256::ZERO)] + #[case(vortex_dtype::i256::ZERO)] fn test_to_arrow_decimal256( #[case] _decimal_type: T, ) -> VortexResult<()> { diff --git a/vortex-array/src/builders/bool.rs b/vortex-array/src/builders/bool.rs index 0af165eaa79..3298a6d147f 100644 --- a/vortex-array/src/builders/bool.rs +++ b/vortex-array/src/builders/bool.rs @@ -11,7 +11,6 @@ use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_ensure; use vortex_mask::Mask; -use vortex_scalar::BoolScalar; use vortex_scalar::Scalar; use crate::Array; @@ -105,8 +104,7 @@ impl ArrayBuilder for BoolBuilder { scalar.dtype() ); - let bool_scalar = BoolScalar::try_from(scalar)?; - match bool_scalar.value() { + match scalar.as_bool().value() { Some(value) => self.append_value(value), None => self.append_null(), } diff --git a/vortex-array/src/builders/decimal.rs b/vortex-array/src/builders/decimal.rs index 24767753538..0c9331ca211 100644 --- a/vortex-array/src/builders/decimal.rs +++ b/vortex-array/src/builders/decimal.rs @@ -9,6 +9,7 @@ use vortex_dtype::DType; use vortex_dtype::DecimalDType; use vortex_dtype::NativeDecimalType; use vortex_dtype::Nullability; +use vortex_dtype::i256; use vortex_dtype::match_each_decimal_value; use vortex_dtype::match_each_decimal_value_type; use vortex_error::VortexExpect; @@ -19,7 +20,6 @@ use vortex_error::vortex_panic; use vortex_mask::Mask; use vortex_scalar::DecimalValue; use vortex_scalar::Scalar; -use vortex_scalar::i256; use crate::Array; use crate::ArrayRef; diff --git a/vortex-array/src/builders/extension.rs b/vortex-array/src/builders/extension.rs index 25c090392c3..4bb20998bb2 100644 --- a/vortex-array/src/builders/extension.rs +++ b/vortex-array/src/builders/extension.rs @@ -95,8 +95,7 @@ impl ArrayBuilder for ExtensionBuilder { scalar.dtype() ); - let ext_scalar = ExtScalar::try_from(scalar)?; - self.append_value(ext_scalar) + self.append_value(scalar.as_extension()) } unsafe fn extend_from_array_unchecked(&mut self, array: &dyn Array) { diff --git a/vortex-array/src/builders/list.rs b/vortex-array/src/builders/list.rs index 462d3ef0ba4..8ec20ce4d25 100644 --- a/vortex-array/src/builders/list.rs +++ b/vortex-array/src/builders/list.rs @@ -210,8 +210,7 @@ impl ArrayBuilder for ListBuilder { scalar.dtype() ); - let list_scalar = ListScalar::try_from(scalar)?; - self.append_value(list_scalar) + self.append_value(scalar.as_list()) } unsafe fn extend_from_array_unchecked(&mut self, array: &dyn Array) { diff --git a/vortex-array/src/builders/primitive.rs b/vortex-array/src/builders/primitive.rs index d942491555f..dd2017fb6de 100644 --- a/vortex-array/src/builders/primitive.rs +++ b/vortex-array/src/builders/primitive.rs @@ -12,7 +12,6 @@ use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_ensure; use vortex_mask::Mask; -use vortex_scalar::PrimitiveScalar; use vortex_scalar::Scalar; use crate::Array; @@ -156,8 +155,7 @@ impl ArrayBuilder for PrimitiveBuilder { scalar.dtype() ); - let primitive_scalar = PrimitiveScalar::try_from(scalar)?; - match primitive_scalar.pvalue() { + match scalar.as_primitive().pvalue() { Some(pv) => self.append_value(pv.cast::()), None => self.append_null(), } diff --git a/vortex-array/src/builders/struct_.rs b/vortex-array/src/builders/struct_.rs index ff4d5f527cf..0df28002c20 100644 --- a/vortex-array/src/builders/struct_.rs +++ b/vortex-array/src/builders/struct_.rs @@ -73,7 +73,7 @@ impl StructBuilder { ); } - if let Some(fields) = struct_scalar.fields() { + if let Some(fields) = struct_scalar.fields_iter() { for (builder, field) in self.builders.iter_mut().zip_eq(fields) { builder.append_scalar(&field)?; } @@ -162,8 +162,7 @@ impl ArrayBuilder for StructBuilder { scalar.dtype() ); - let struct_scalar = StructScalar::try_from(scalar)?; - self.append_value(struct_scalar) + self.append_value(scalar.as_struct()) } unsafe fn extend_from_array_unchecked(&mut self, array: &dyn Array) { diff --git a/vortex-array/src/builders/tests.rs b/vortex-array/src/builders/tests.rs index c6cb8b30e91..31b785f575b 100644 --- a/vortex-array/src/builders/tests.rs +++ b/vortex-array/src/builders/tests.rs @@ -78,7 +78,7 @@ fn test_append_zeros_matches_default_value(#[case] dtype: DType) { // Builder 2: Manually append default values. let mut builder_manual = builder_with_capacity(&dtype, num_elements); - let default_scalar = Scalar::default_value(dtype.clone()); + let default_scalar = Scalar::zero_value(&dtype); for _ in 0..num_elements { builder_manual.append_scalar(&default_scalar).unwrap(); } @@ -198,7 +198,7 @@ fn test_append_defaults_behavior(#[case] dtype: DType, #[case] should_be_null: b i ); // For non-nullable, it should match the default value. - let expected = Scalar::default_value(dtype.clone()); + let expected = Scalar::default_value(&dtype); // Skip list comparison due to known bug. if !matches!(dtype, DType::List(..)) { assert_eq!( @@ -359,7 +359,7 @@ fn test_to_canonical_struct() { ); compare_to_canonical_methods(&dtype, |builder| { for _ in 0..3 { - let value = Scalar::default_value(dtype.clone()); + let value = Scalar::default_value(&dtype); builder.append_scalar(&value).unwrap(); } }); @@ -395,7 +395,7 @@ fn test_to_canonical_decimal() { let dtype = DType::Decimal(DecimalDType::new(10, 2), Nullability::NonNullable); compare_to_canonical_methods(&dtype, |builder| { for _ in 0..5 { - let value = Scalar::default_value(dtype.clone()); + let value = Scalar::default_value(&dtype); builder.append_scalar(&value).unwrap(); } }); @@ -592,7 +592,7 @@ fn create_test_scalars_for_dtype(dtype: &DType, count: usize) -> Vec { Scalar::primitive((i + j) as f64, *n) } DType::Utf8(n) => Scalar::utf8(format!("field_{}", i + j), *n), - _ => Scalar::default_value(field_dtype), + _ => Scalar::default_value(&field_dtype), } }) .collect(); @@ -605,7 +605,7 @@ fn create_test_scalars_for_dtype(dtype: &DType, count: usize) -> Vec { DType::Primitive(PType::I32, n) => { Scalar::primitive(j.min(i32::MAX as usize) as i32, *n) } - _ => Scalar::default_value(element_dtype.as_ref().clone()), + _ => Scalar::default_value(element_dtype.as_ref()), }) .collect(); Scalar::list(element_dtype.clone(), elements, *n) @@ -617,7 +617,7 @@ fn create_test_scalars_for_dtype(dtype: &DType, count: usize) -> Vec { DType::Primitive(PType::I32, n) => { Scalar::primitive((i as i32).saturating_add(j as i32), *n) } - _ => Scalar::default_value(element_dtype.as_ref().clone()), + _ => Scalar::default_value(element_dtype.as_ref()), }) .collect(); Scalar::fixed_size_list(element_dtype.clone(), elements, *n) @@ -626,7 +626,7 @@ fn create_test_scalars_for_dtype(dtype: &DType, count: usize) -> Vec { // Create extension scalars with storage values. let storage_scalar = match ext_dtype.storage_dtype() { DType::Primitive(PType::I64, n) => Scalar::primitive(i as i64, *n), - _ => Scalar::default_value(ext_dtype.storage_dtype().clone()), + _ => Scalar::default_value(ext_dtype.storage_dtype()), }; Scalar::extension_ref(ext_dtype.clone(), storage_scalar) } diff --git a/vortex-array/src/builders/varbinview.rs b/vortex-array/src/builders/varbinview.rs index 4525da8ec92..e8317f730f8 100644 --- a/vortex-array/src/builders/varbinview.rs +++ b/vortex-array/src/builders/varbinview.rs @@ -16,9 +16,7 @@ use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_mask::Mask; -use vortex_scalar::BinaryScalar; use vortex_scalar::Scalar; -use vortex_scalar::Utf8Scalar; use vortex_utils::aliases::hash_map::Entry; use vortex_utils::aliases::hash_map::HashMap; use vortex_vector::binaryview::BinaryView; @@ -252,20 +250,14 @@ impl ArrayBuilder for VarBinViewBuilder { ); match self.dtype() { - DType::Utf8(_) => { - let utf8_scalar = Utf8Scalar::try_from(scalar)?; - match utf8_scalar.value() { - Some(value) => self.append_value(value), - None => self.append_null(), - } - } - DType::Binary(_) => { - let binary_scalar = BinaryScalar::try_from(scalar)?; - match binary_scalar.value() { - Some(value) => self.append_value(value), - None => self.append_null(), - } - } + DType::Utf8(_) => match scalar.as_utf8().value() { + Some(value) => self.append_value(value), + None => self.append_null(), + }, + DType::Binary(_) => match scalar.as_binary().value() { + Some(value) => self.append_value(value), + None => self.append_null(), + }, _ => vortex_bail!( "VarBinViewBuilder can only handle Utf8 or Binary scalars, got {:?}", scalar.dtype() @@ -1028,7 +1020,13 @@ mod tests { assert_eq!(array.len(), 1); // Verify the value was stored correctly - let retrieved = array.scalar_at(0).unwrap().as_binary().value().unwrap(); + let retrieved = array + .scalar_at(0) + .unwrap() + .as_binary() + .value() + .cloned() + .unwrap(); assert_eq!(retrieved.len(), 8192); assert_eq!(retrieved.as_slice(), &large_value); } diff --git a/vortex-array/src/compute/is_constant.rs b/vortex-array/src/compute/is_constant.rs index 31bf507c907..5df78d83471 100644 --- a/vortex-array/src/compute/is_constant.rs +++ b/vortex-array/src/compute/is_constant.rs @@ -85,7 +85,8 @@ impl ComputeFnVTable for IsConstant { // We try and rely on some easy-to-get stats if let Some(Precision::Exact(value)) = array.statistics().get_as::(Stat::IsConstant) { - return Ok(Scalar::from(Some(value)).into()); + let scalar: Scalar = Some(value).into(); + return Ok(scalar.into()); } let value = is_constant_impl(array, options, kernels)?; @@ -105,7 +106,8 @@ impl ComputeFnVTable for IsConstant { .set(Stat::IsConstant, Precision::Exact(value.into())); } - Ok(Scalar::from(value).into()) + let scalar: Scalar = value.into(); + Ok(scalar.into()) } fn return_dtype(&self, _args: &InvocationArgs) -> VortexResult { @@ -227,7 +229,8 @@ impl Kernel for IsConstantKernelAdapter { return Ok(None); }; let is_constant = V::is_constant(&self.0, array, args.options)?; - Ok(Some(Scalar::from(is_constant).into())) + let scalar: Scalar = is_constant.into(); + Ok(Some(scalar.into())) } } diff --git a/vortex-array/src/compute/is_sorted.rs b/vortex-array/src/compute/is_sorted.rs index 2a8818b5d99..9189fa53b96 100644 --- a/vortex-array/src/compute/is_sorted.rs +++ b/vortex-array/src/compute/is_sorted.rs @@ -69,14 +69,16 @@ impl ComputeFnVTable for IsSorted { // We currently don't support sorting struct arrays. if array.dtype().is_struct() { - return Ok(Scalar::from(Some(false)).into()); + let scalar: Scalar = Some(false).into(); + return Ok(scalar.into()); } let is_sorted = if strict { if let Some(Precision::Exact(value)) = array.statistics().get_as::(Stat::IsStrictSorted) { - return Ok(Scalar::from(Some(value)).into()); + let scalar: Scalar = Some(value).into(); + return Ok(scalar.into()); } let is_strict_sorted = is_sorted_impl(array, kernels, true)?; @@ -95,7 +97,8 @@ impl ComputeFnVTable for IsSorted { } else { if let Some(Precision::Exact(value)) = array.statistics().get_as::(Stat::IsSorted) { - return Ok(Scalar::from(Some(value)).into()); + let scalar: Scalar = Some(value).into(); + return Ok(scalar.into()); } let is_sorted = is_sorted_impl(array, kernels, false)?; @@ -113,7 +116,8 @@ impl ComputeFnVTable for IsSorted { is_sorted }; - Ok(Scalar::from(is_sorted).into()) + let scalar: Scalar = is_sorted.into(); + Ok(scalar.into()) } fn return_dtype(&self, _args: &InvocationArgs) -> VortexResult { @@ -198,7 +202,8 @@ impl Kernel for IsSortedKernelAdapter { V::is_sorted(&self.0, array)? }; - Ok(Some(Scalar::from(is_sorted).into())) + let scalar: Scalar = is_sorted.into(); + Ok(Some(scalar.into())) } } diff --git a/vortex-array/src/compute/min_max.rs b/vortex-array/src/compute/min_max.rs index ced0fcfeadf..931c2880cf7 100644 --- a/vortex-array/src/compute/min_max.rs +++ b/vortex-array/src/compute/min_max.rs @@ -102,13 +102,17 @@ impl ComputeFnVTable for MinMax { array.encoding_id() ); - // Update the stats set with the computed min/max - array - .statistics() - .set(Stat::Min, Precision::Exact(min.value().clone())); - array - .statistics() - .set(Stat::Max, Precision::Exact(max.value().clone())); + // Update the stats set with the computed min/max. + if let Some(min_value) = min.value() { + array + .statistics() + .set(Stat::Min, Precision::Exact(min_value.clone())); + } + if let Some(max_value) = max.value() { + array + .statistics() + .set(Stat::Max, Precision::Exact(max_value.clone())); + } // Return the min/max as a struct scalar Ok(Scalar::struct_(return_dtype, vec![min, max]).into()) diff --git a/vortex-array/src/compute/sum.rs b/vortex-array/src/compute/sum.rs index d1a2f3eff22..9d36a009408 100644 --- a/vortex-array/src/compute/sum.rs +++ b/vortex-array/src/compute/sum.rs @@ -66,7 +66,7 @@ pub fn sum(array: &dyn Array) -> VortexResult { let sum_dtype = Stat::Sum .dtype(array.dtype()) .ok_or_else(|| vortex_err!("Sum not supported for dtype: {}", array.dtype()))?; - let zero = Scalar::zero_value(sum_dtype); + let zero = Scalar::zero_value(&sum_dtype); sum_with_accumulator(array, &zero) } @@ -113,16 +113,17 @@ impl ComputeFnVTable for Sum { ); // Short-circuit using array statistics. - if let Some(Precision::Exact(sum)) = array.statistics().get(Stat::Sum) { - // For floats only use stats if accumulator is zero. otherwise we might have numerical stability issues. - match sum_dtype { + if let Some(Precision::Exact(sum_scalar)) = array.statistics().get(Stat::Sum) { + // For floats only use stats if accumulator is zero. otherwise we might have numerical + // stability issues. + match &sum_dtype { DType::Primitive(p, _) => { - if p.is_float() && accumulator.is_zero() { - return Ok(sum.into()); + if p.is_float() && accumulator.is_zero() == Some(true) { + return Ok(sum_scalar.into()); } else if p.is_int() { let sum_from_stat = accumulator .as_primitive() - .checked_add(&sum.as_primitive()) + .checked_add(&sum_scalar.as_primitive()) .map(Scalar::from); return Ok(sum_from_stat .unwrap_or_else(|| Scalar::null(sum_dtype)) @@ -132,7 +133,7 @@ impl ComputeFnVTable for Sum { DType::Decimal(..) => { let sum_from_stat = accumulator .as_decimal() - .checked_binary_numeric(&sum.as_decimal(), NumericOperator::Add) + .checked_binary_numeric(&sum_scalar.as_decimal(), NumericOperator::Add) .map(Scalar::from); return Ok(sum_from_stat .unwrap_or_else(|| Scalar::null(sum_dtype)) @@ -147,30 +148,29 @@ impl ComputeFnVTable for Sum { // Update the statistics with the computed sum. Stored statistic shouldn't include the accumulator. match sum_dtype { DType::Primitive(p, _) => { - if p.is_float() && accumulator.is_zero() { + if p.is_float() + && accumulator.is_zero() == Some(true) + && let Some(sum_value) = sum_scalar.value().cloned() + { array .statistics() - .set(Stat::Sum, Precision::Exact(sum_scalar.value().clone())); + .set(Stat::Sum, Precision::Exact(sum_value)); } else if p.is_int() && let Some(less_accumulator) = sum_scalar .as_primitive() .checked_sub(&accumulator.as_primitive()) + && let Some(val) = Scalar::from(less_accumulator).into_value() { - array.statistics().set( - Stat::Sum, - Precision::Exact(Scalar::from(less_accumulator).value().clone()), - ); + array.statistics().set(Stat::Sum, Precision::Exact(val)); } } DType::Decimal(..) => { if let Some(less_accumulator) = sum_scalar .as_decimal() .checked_binary_numeric(&accumulator.as_decimal(), NumericOperator::Sub) + && let Some(val) = Scalar::from(less_accumulator).into_value() { - array.statistics().set( - Stat::Sum, - Precision::Exact(Scalar::from(less_accumulator).value().clone()), - ) + array.statistics().set(Stat::Sum, Precision::Exact(val)); } } _ => unreachable!("Sum will always be a decimal or a primitive dtype"), diff --git a/vortex-array/src/expr/exprs/dynamic.rs b/vortex-array/src/expr/exprs/dynamic.rs index a1717483a0f..7636e79098e 100644 --- a/vortex-array/src/expr/exprs/dynamic.rs +++ b/vortex-array/src/expr/exprs/dynamic.rs @@ -107,10 +107,11 @@ impl VTable for DynamicComparison { let ret_dtype = DType::Bool(args.inputs[0].dtype().nullability() | data.rhs.dtype.nullability()); - Ok( - ConstantArray::new(Scalar::new(ret_dtype, data.default.into()), args.row_count) - .into_array(), + Ok(ConstantArray::new( + Scalar::try_new(ret_dtype, Some(data.default.into()))?, + args.row_count, ) + .into_array()) } fn stat_falsification( @@ -193,7 +194,10 @@ pub struct DynamicComparisonExpr { impl DynamicComparisonExpr { pub fn scalar(&self) -> Option { - (self.rhs.value)().map(|v| Scalar::new(self.rhs.dtype.clone(), v)) + (self.rhs.value)().map(|v| { + Scalar::try_new(self.rhs.dtype.clone(), Some(v)) + .vortex_expect("`DynamicComparisonExpr` was invalid") + }) } } @@ -237,7 +241,9 @@ struct Rhs { impl Rhs { pub fn scalar(&self) -> Option { - (self.value)().map(|v| Scalar::new(self.dtype.clone(), v)) + (self.value)().map(|v| { + Scalar::try_new(self.dtype.clone(), Some(v)).vortex_expect("`Rhs` was invalid") + }) } } @@ -283,7 +289,12 @@ impl DynamicExprUpdates { let exprs = visitor.0.into_boxed_slice(); let prev_versions = exprs .iter() - .map(|expr| (expr.rhs.value)().map(|v| Scalar::new(expr.rhs.dtype.clone(), v))) + .map(|expr| { + (expr.rhs.value)().map(|v| { + Scalar::try_new(expr.rhs.dtype.clone(), Some(v)) + .vortex_expect("`DynamicExprUpdates` was invalid") + }) + }) .collect(); Some(Self { diff --git a/vortex-array/src/expr/exprs/like.rs b/vortex-array/src/expr/exprs/like.rs index 8ac15b9bce4..e4cc85a3314 100644 --- a/vortex-array/src/expr/exprs/like.rs +++ b/vortex-array/src/expr/exprs/like.rs @@ -159,7 +159,7 @@ impl VTable for Like { let src_min = src.stat_min(catalog)?; let src_max = src.stat_max(catalog)?; - match LikeVariant::from_str(&pat_str)? { + match LikeVariant::from_str(pat_str)? { LikeVariant::Exact(text) => { // col LIKE 'exact' ==> col.min > 'exact' || col.max < 'exact' Some(or(gt(src_min, lit(text)), lt(src_max, lit(text)))) diff --git a/vortex-array/src/expr/exprs/literal.rs b/vortex-array/src/expr/exprs/literal.rs index 1465ec23ddd..b610359b699 100644 --- a/vortex-array/src/expr/exprs/literal.rs +++ b/vortex-array/src/expr/exprs/literal.rs @@ -38,7 +38,7 @@ impl VTable for Literal { fn serialize(&self, instance: &Self::Options) -> VortexResult>> { Ok(Some( pb::LiteralOpts { - value: Some(instance.as_ref().into()), + value: Some(instance.into()), } .encode_to_vec(), )) diff --git a/vortex-array/src/expr/stats/precision.rs b/vortex-array/src/expr/stats/precision.rs index 3c2ed9332d3..7f3e13b3c48 100644 --- a/vortex-array/src/expr/stats/precision.rs +++ b/vortex-array/src/expr/stats/precision.rs @@ -6,6 +6,7 @@ use std::fmt::Display; use std::fmt::Formatter; use vortex_dtype::DType; +use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_scalar::Scalar; use vortex_scalar::ScalarValue; @@ -20,26 +21,14 @@ use crate::expr::stats::precision::Precision::Inexact; /// This is statistic specific, for max this will be an upper bound. Meaning that the actual max /// in an array is guaranteed to be less than or equal to the inexact value, but equal to the exact /// value. -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum Precision { Exact(T), Inexact(T), } -impl Clone for Precision -where - T: Clone, -{ - fn clone(&self) -> Self { - match self { - Exact(e) => Exact(e.clone()), - Inexact(ie) => Inexact(ie.clone()), - } - } -} - impl Precision> { - /// Transpose the `Option>` into `Option>`. + /// Transpose the `Precision>` into `Option>`. pub fn transpose(self) -> Option> { match self { Exact(Some(x)) => Some(Exact(x)), @@ -167,13 +156,22 @@ impl PartialEq for Precision { } impl Precision { + /// Convert this [`Precision`] into a [`Precision`] with the given + /// [`DType`]. pub fn into_scalar(self, dtype: DType) -> Precision { - self.map(|v| Scalar::new(dtype, v)) + self.map(|v| { + Scalar::try_new(dtype, Some(v)).vortex_expect("`Precision` was invalid") + }) } } impl Precision<&ScalarValue> { + /// Convert this [`Precision<&ScalarValue>`] into a [`Precision`] with the given + /// [`DType`]. pub fn into_scalar(self, dtype: DType) -> Precision { - self.map(|v| Scalar::new(dtype, v.clone())) + self.map(|v| { + Scalar::try_new(dtype, Some(v.clone())) + .vortex_expect("`Precision` was invalid") + }) } } diff --git a/vortex-array/src/serde.rs b/vortex-array/src/serde.rs index be064f9652e..373e174360e 100644 --- a/vortex-array/src/serde.rs +++ b/vortex-array/src/serde.rs @@ -22,7 +22,6 @@ use vortex_error::vortex_bail; use vortex_error::vortex_err; use vortex_error::vortex_panic; use vortex_flatbuffers::FlatBuffer; -use vortex_flatbuffers::ReadFlatBuffer; use vortex_flatbuffers::WriteFlatBuffer; use vortex_flatbuffers::array as fba; use vortex_flatbuffers::array::Compression; @@ -378,7 +377,7 @@ impl ArrayParts { // Populate statistics from the serialized array. if let Some(stats) = self.flatbuffer().stats() { let decoded_statistics = decoded.statistics(); - StatsSet::read_flatbuffer(&stats)? + StatsSet::from_flatbuffer(&stats, dtype)? .into_iter() .for_each(|(stat, val)| decoded_statistics.set(stat, val)); } diff --git a/vortex-array/src/stats/array.rs b/vortex-array/src/stats/array.rs index 9df19542f15..873d5ee635f 100644 --- a/vortex-array/src/stats/array.rs +++ b/vortex-array/src/stats/array.rs @@ -197,8 +197,10 @@ impl StatsSetRef<'_> { pub fn compute_all(&self, stats: &[Stat]) -> VortexResult { let mut stats_set = StatsSet::default(); for &stat in stats { - if let Some(s) = self.compute_stat(stat)? { - stats_set.set(stat, Precision::exact(s.into_value())) + if let Some(s) = self.compute_stat(stat)? + && let Some(value) = s.into_value() + { + stats_set.set(stat, Precision::exact(value)); } } Ok(stats_set) diff --git a/vortex-array/src/stats/flatbuffers.rs b/vortex-array/src/stats/flatbuffers.rs index d7f3b23172b..67f82fcb7c3 100644 --- a/vortex-array/src/stats/flatbuffers.rs +++ b/vortex-array/src/stats/flatbuffers.rs @@ -2,15 +2,12 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use flatbuffers::FlatBufferBuilder; -use flatbuffers::Follow; use flatbuffers::WIPOffset; use vortex_dtype::DType; use vortex_dtype::Nullability; use vortex_dtype::PType; -use vortex_error::VortexError; use vortex_error::VortexResult; use vortex_error::vortex_bail; -use vortex_flatbuffers::ReadFlatBuffer; use vortex_flatbuffers::WriteFlatBuffer; use vortex_flatbuffers::array as fba; use vortex_scalar::ScalarValue; @@ -49,7 +46,11 @@ impl WriteFlatBuffer for StatsSet { } else { fba::Precision::Inexact }, - Some(fbb.create_vector(&min.into_inner().to_protobytes::>())), + Some( + fbb.create_vector(&ScalarValue::to_proto_bytes::>(Some( + &min.into_inner(), + ))), + ), ) }) .unwrap_or_else(|| (fba::Precision::Inexact, None)); @@ -63,7 +64,11 @@ impl WriteFlatBuffer for StatsSet { } else { fba::Precision::Inexact }, - Some(fbb.create_vector(&max.into_inner().to_protobytes::>())), + Some( + fbb.create_vector(&ScalarValue::to_proto_bytes::>(Some( + &max.into_inner(), + ))), + ), ) }) .unwrap_or_else(|| (fba::Precision::Inexact, None)); @@ -71,7 +76,7 @@ impl WriteFlatBuffer for StatsSet { let sum = self .get(Stat::Sum) .and_then(Precision::as_exact) - .map(|sum| fbb.create_vector(&sum.to_protobytes::>())); + .map(|sum| fbb.create_vector(&ScalarValue::to_proto_bytes::>(Some(&sum)))); let stat_args = &fba::ArrayStatsArgs { min, @@ -103,16 +108,17 @@ impl WriteFlatBuffer for StatsSet { } } -impl ReadFlatBuffer for StatsSet { - type Source<'a> = fba::ArrayStats<'a>; - type Error = VortexError; - - fn read_flatbuffer<'buf>( - fb: & as Follow<'buf>>::Inner, - ) -> Result { +impl StatsSet { + /// Creates a [`StatsSet`] from a flatbuffers array [`fba::ArrayStats<'a>`]. + pub fn from_flatbuffer<'a>( + fb: &fba::ArrayStats<'a>, + array_dtype: &DType, + ) -> VortexResult { let mut stats_set = StatsSet::default(); for stat in Stat::all() { + let stat_dtype = stat.dtype(array_dtype); + match stat { Stat::IsConstant => { if let Some(is_constant) = fb.is_constant() { @@ -133,8 +139,14 @@ impl ReadFlatBuffer for StatsSet { } } Stat::Max => { - if let Some(max) = fb.max() { - let value = ScalarValue::from_protobytes(max.bytes())?; + if let Some(max) = fb.max() + && let Some(stat_dtype) = stat_dtype + { + let value = ScalarValue::from_proto_bytes(max.bytes(), &stat_dtype)?; + let Some(value) = value else { + continue; + }; + stats_set.set( Stat::Max, match fb.max_precision() { @@ -146,8 +158,14 @@ impl ReadFlatBuffer for StatsSet { } } Stat::Min => { - if let Some(min) = fb.min() { - let value = ScalarValue::from_protobytes(min.bytes())?; + if let Some(min) = fb.min() + && let Some(stat_dtype) = stat_dtype + { + let value = ScalarValue::from_proto_bytes(min.bytes(), &stat_dtype)?; + let Some(value) = value else { + continue; + }; + stats_set.set( Stat::Min, match fb.min_precision() { @@ -172,11 +190,15 @@ impl ReadFlatBuffer for StatsSet { } } Stat::Sum => { - if let Some(sum) = fb.sum() { - stats_set.set( - Stat::Sum, - Precision::Exact(ScalarValue::from_protobytes(sum.bytes())?), - ); + if let Some(sum) = fb.sum() + && let Some(stat_dtype) = stat_dtype + { + let value = ScalarValue::from_proto_bytes(sum.bytes(), &stat_dtype)?; + let Some(value) = value else { + continue; + }; + + stats_set.set(Stat::Sum, Precision::Exact(value)); } } Stat::NaNCount => { diff --git a/vortex-array/src/stats/stats_set.rs b/vortex-array/src/stats/stats_set.rs index af977e6753c..aa0ed782b90 100644 --- a/vortex-array/src/stats/stats_set.rs +++ b/vortex-array/src/stats/stats_set.rs @@ -132,7 +132,11 @@ impl StatsSet { ) -> Option> { self.get(stat).map(|v| { v.map(|v| { - T::try_from(&Scalar::new(dtype.clone(), v)).unwrap_or_else(|err| { + T::try_from( + &Scalar::try_new(dtype.clone(), Some(v)) + .vortex_expect("failed to construct a scalar statistic"), + ) + .unwrap_or_else(|err| { vortex_panic!( err, "Failed to get stat {} as {}", @@ -225,11 +229,12 @@ impl StatsProvider for TypedStatsSetRef<'_, '_> { fn get(&self, stat: Stat) -> Option> { self.values.get(stat).map(|p| { p.map(|sv| { - Scalar::new( + Scalar::try_new( stat.dtype(self.dtype) .vortex_expect("Must have valid dtype if value is present"), - sv, + Some(sv), ) + .vortex_expect("failed to construct a scalar statistic") }) }) } @@ -260,11 +265,12 @@ impl StatsProvider for MutTypedStatsSetRef<'_, '_> { fn get(&self, stat: Stat) -> Option> { self.values.get(stat).map(|p| { p.map(|sv| { - Scalar::new( + Scalar::try_new( stat.dtype(self.dtype) .vortex_expect("Must have valid dtype if value is present"), - sv, + Some(sv), ) + .vortex_expect("failed to construct a scalar statistic") }) }) } @@ -356,10 +362,22 @@ impl MutTypedStatsSetRef<'_, '_> { vortex_err!("{:?} bounds ({m1:?}, {m2:?}) do not overlap", S::STAT) })?; if meet != m1 { - self.set(S::STAT, meet.into_value().map(Scalar::into_value)); + self.set( + S::STAT, + meet.into_value().map(|s| { + s.into_value() + .vortex_expect("stat scalar value cannot be null") + }), + ); } } - (None, Some(m)) => self.set(S::STAT, m.into_value().map(Scalar::into_value)), + (None, Some(m)) => self.set( + S::STAT, + m.into_value().map(|s| { + s.into_value() + .vortex_expect("stat scalar value cannot be null") + }), + ), (Some(_), _) => (), (None, None) => self.clear(S::STAT), } @@ -400,7 +418,13 @@ impl MutTypedStatsSetRef<'_, '_> { (Some(m1), Some(m2)) => { let meet = m1.union(&m2).vortex_expect("can compare scalar"); if meet != m1 { - self.set(Stat::Min, meet.into_value().map(Scalar::into_value)); + self.set( + Stat::Min, + meet.into_value().map(|s| { + s.into_value() + .vortex_expect("stat scalar value cannot be null") + }), + ); } } _ => self.clear(Stat::Min), @@ -415,7 +439,13 @@ impl MutTypedStatsSetRef<'_, '_> { (Some(m1), Some(m2)) => { let meet = m1.union(&m2).vortex_expect("can compare scalar"); if meet != m1 { - self.set(Stat::Max, meet.into_value().map(Scalar::into_value)); + self.set( + Stat::Max, + meet.into_value().map(|s| { + s.into_value() + .vortex_expect("stat scalar value cannot be null") + }), + ); } } _ => self.clear(Stat::Max), @@ -432,19 +462,7 @@ impl MutTypedStatsSetRef<'_, '_> { if let Some(scalar_value) = m1.zip(m2).as_exact().and_then(|(s1, s2)| { s1.as_primitive() .checked_add(&s2.as_primitive()) - .map(|pscalar| { - pscalar - .pvalue() - .map(|pvalue| { - Scalar::primitive_value( - pvalue, - pscalar.ptype(), - pscalar.dtype().nullability(), - ) - .into_value() - }) - .unwrap_or_else(ScalarValue::null) - }) + .and_then(|pscalar| pscalar.pvalue().map(ScalarValue::Primitive)) }) { self.set(Stat::Sum, Precision::Exact(scalar_value)); } @@ -565,17 +583,18 @@ mod test { let first = iter.next().unwrap().clone(); assert_eq!(first.0, Stat::Max); assert_eq!( - first - .1 - .map(|f| i32::try_from(&Scalar::new(PType::I32.into(), f)).unwrap()), + first.1.map( + |f| i32::try_from(&Scalar::try_new(PType::I32.into(), Some(f)).unwrap()).unwrap() + ), Precision::exact(100) ); let snd = iter.next().unwrap().clone(); assert_eq!(snd.0, Stat::Min); assert_eq!( - snd.1 - .map(|s| i32::try_from(&Scalar::new(PType::I32.into(), s)).unwrap()), - 42 + snd.1.map( + |s| i32::try_from(&Scalar::try_new(PType::I32.into(), Some(s)).unwrap()).unwrap() + ), + Precision::exact(42) ); } @@ -592,14 +611,17 @@ mod test { let (stat, first) = set.next().unwrap(); assert_eq!(stat, Stat::Max); assert_eq!( - first.map(|f| i32::try_from(&Scalar::new(PType::I32.into(), f)).unwrap()), + first.map( + |f| i32::try_from(&Scalar::try_new(PType::I32.into(), Some(f)).unwrap()).unwrap() + ), Precision::exact(100) ); let snd = set.next().unwrap(); assert_eq!(snd.0, Stat::Min); assert_eq!( - snd.1 - .map(|s| i32::try_from(&Scalar::new(PType::I32.into(), s)).unwrap()), + snd.1.map( + |s| i32::try_from(&Scalar::try_new(PType::I32.into(), Some(s)).unwrap()).unwrap() + ), Precision::exact(42) ); } @@ -710,7 +732,9 @@ mod test { #[test] fn merge_into_scalar() { - let first = StatsSet::of(Stat::Sum, Precision::exact(42)).merge_ordered( + // Sum stats for primitive types are always the 64-bit version (i64 for signed, u64 + // for unsigned, f64 for floats). + let first = StatsSet::of(Stat::Sum, Precision::exact(42i64)).merge_ordered( &StatsSet::default(), &DType::Primitive(PType::I32, Nullability::NonNullable), ); @@ -720,8 +744,10 @@ mod test { #[test] fn merge_from_scalar() { + // Sum stats for primitive types are always the 64-bit version (i64 for signed, u64 + // for unsigned, f64 for floats). let first = StatsSet::default().merge_ordered( - &StatsSet::of(Stat::Sum, Precision::exact(42)), + &StatsSet::of(Stat::Sum, Precision::exact(42i64)), &DType::Primitive(PType::I32, Nullability::NonNullable), ); let first_ref = first.as_typed_ref(&DType::Primitive(PType::I32, Nullability::NonNullable)); @@ -730,14 +756,16 @@ mod test { #[test] fn merge_scalars() { - let first = StatsSet::of(Stat::Sum, Precision::exact(37)).merge_ordered( - &StatsSet::of(Stat::Sum, Precision::exact(42)), + // Sum stats for primitive types are always the 64-bit version (i64 for signed, u64 + // for unsigned, f64 for floats). + let first = StatsSet::of(Stat::Sum, Precision::exact(37i64)).merge_ordered( + &StatsSet::of(Stat::Sum, Precision::exact(42i64)), &DType::Primitive(PType::I32, Nullability::NonNullable), ); let first_ref = first.as_typed_ref(&DType::Primitive(PType::I32, Nullability::NonNullable)); assert_eq!( - first_ref.get_as::(Stat::Sum), - Some(Precision::exact(79usize)) + first_ref.get_as::(Stat::Sum), + Some(Precision::exact(79i64)) ); } diff --git a/vortex-array/src/variants.rs b/vortex-array/src/variants.rs index 1ee704ff777..7a3e248dadf 100644 --- a/vortex-array/src/variants.rs +++ b/vortex-array/src/variants.rs @@ -122,7 +122,7 @@ impl PrimitiveTyped<'_> { .scalar_at(idx)? .as_primitive() .pvalue() - .unwrap_or_else(|| PValue::zero(self.ptype()))) + .unwrap_or_else(|| PValue::zero(&self.ptype()))) } } diff --git a/vortex-btrblocks/src/compressor/decimal.rs b/vortex-btrblocks/src/compressor/decimal.rs index 5170405d10c..4a3f6e5475a 100644 --- a/vortex-btrblocks/src/compressor/decimal.rs +++ b/vortex-btrblocks/src/compressor/decimal.rs @@ -8,8 +8,8 @@ use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::narrowed_decimal; use vortex_array::vtable::ValidityHelper; use vortex_decimal_byte_parts::DecimalBytePartsArray; +use vortex_dtype::DecimalType; use vortex_error::VortexResult; -use vortex_scalar::DecimalType; use crate::BtrBlocksCompressor; use crate::CanonicalCompressor; diff --git a/vortex-buffer/src/string.rs b/vortex-buffer/src/string.rs index 182b2fe6989..bcaf8d7c3d2 100644 --- a/vortex-buffer/src/string.rs +++ b/vortex-buffer/src/string.rs @@ -84,10 +84,29 @@ impl TryFrom for BufferString { let err = simdutf8::compat::from_utf8(value.as_ref()).unwrap_err(); vortex_err!("invalid utf-8: {err}") })?; + Ok(Self(value)) } } +impl TryFrom<&[u8]> for BufferString { + type Error = VortexError; + + fn try_from(value: &[u8]) -> Result { + simdutf8::basic::from_utf8(value).map_err(|_| { + #[expect( + clippy::unwrap_used, + reason = "unwrap is intentional - the error was already detected" + )] + // run validation using `compat` package to get more detailed error message + let err = simdutf8::compat::from_utf8(value).unwrap_err(); + vortex_err!("invalid utf-8: {err}") + })?; + + Ok(Self(ByteBuffer::from(value.to_vec()))) + } +} + impl Deref for BufferString { type Target = str; diff --git a/vortex-datafusion/src/convert/scalars.rs b/vortex-datafusion/src/convert/scalars.rs index 5031614b0b7..7ac9c377f96 100644 --- a/vortex-datafusion/src/convert/scalars.rs +++ b/vortex-datafusion/src/convert/scalars.rs @@ -13,11 +13,12 @@ use vortex::dtype::datetime::AnyTemporal; use vortex::dtype::datetime::TemporalMetadata; use vortex::dtype::datetime::TimeUnit; use vortex::dtype::half::f16; +use vortex::dtype::i256; +use vortex::error::VortexExpect; use vortex::error::VortexResult; use vortex::error::vortex_bail; use vortex::scalar::DecimalValue; use vortex::scalar::Scalar; -use vortex::scalar::i256; use crate::convert::FromDataFusion; use crate::convert::TryToDataFusion; @@ -101,12 +102,13 @@ impl TryToDataFusion for Scalar { } } // SAFETY: By construction Utf8 scalar values are utf8 - DType::Utf8(_) => ScalarValue::Utf8(self.as_utf8().value().map(|s| unsafe { + DType::Utf8(_) => ScalarValue::Utf8(self.as_utf8().value().cloned().map(|s| unsafe { String::from_utf8_unchecked(Vec::::from(s.into_inner().into_inner())) })), DType::Binary(_) => ScalarValue::Binary( self.as_binary() .value() + .cloned() .map(|b| Vec::::from(b.into_inner())), ), DType::Struct(..) => todo!("struct scalar conversion"), @@ -217,11 +219,8 @@ impl FromDataFusion for Scalar { | ScalarValue::Time32Second(v) | ScalarValue::Time32Millisecond(v) => { let dtype = DType::from_arrow((&value.data_type(), Nullability::Nullable)); - Scalar::new( - dtype, - v.map(vortex::scalar::ScalarValue::from) - .unwrap_or_else(vortex::scalar::ScalarValue::null), - ) + Scalar::try_new(dtype, v.map(vortex::scalar::ScalarValue::from)) + .vortex_expect("unable to create a time `Scalar`") } ScalarValue::Date64(v) | ScalarValue::Time64Microsecond(v) @@ -231,11 +230,8 @@ impl FromDataFusion for Scalar { | ScalarValue::TimestampMicrosecond(v, _) | ScalarValue::TimestampNanosecond(v, _) => { let dtype = DType::from_arrow((&value.data_type(), Nullability::Nullable)); - Scalar::new( - dtype, - v.map(vortex::scalar::ScalarValue::from) - .unwrap_or_else(vortex::scalar::ScalarValue::null), - ) + Scalar::try_new(dtype, v.map(vortex::scalar::ScalarValue::from)) + .vortex_expect("unable to create a time `Scalar`") } ScalarValue::Decimal32(decimal, precision, scale) => { let decimal_dtype = DecimalDType::new(*precision, *scale); @@ -305,9 +301,9 @@ mod tests { use vortex::dtype::DecimalDType; use vortex::dtype::Nullability; use vortex::dtype::PType; + use vortex::dtype::i256; use vortex::scalar::DecimalValue; use vortex::scalar::Scalar; - use vortex::scalar::i256; use super::*; @@ -684,7 +680,13 @@ mod tests { #[case::fixed_size_binary(ScalarValue::FixedSizeBinary(5, Some(vec![1u8, 2, 3, 4, 5])))] fn test_binary_variants(#[case] variant: ScalarValue) { let result = Scalar::from_df(&variant); - let result_bytes: Vec = result.as_binary().value().unwrap().into_inner().into(); + let result_bytes: Vec = result + .as_binary() + .value() + .cloned() + .unwrap() + .into_inner() + .into(); assert_eq!(result_bytes, vec![1u8, 2, 3, 4, 5]); } } diff --git a/vortex-datafusion/src/persistent/cache.rs b/vortex-datafusion/src/persistent/cache.rs index b28cf972c2b..4bbd1e79f56 100644 --- a/vortex-datafusion/src/persistent/cache.rs +++ b/vortex-datafusion/src/persistent/cache.rs @@ -56,8 +56,9 @@ fn estimate_footer_size(footer: &Footer) -> usize { let segments_size = footer.segment_map().len() * size_of::(); let stats_size = footer .statistics() - .map(|stats| { - stats + .map(|file_statistics| { + file_statistics + .stats_sets() .iter() .map(|s| { s.iter().count() * (size_of::() + size_of::>()) diff --git a/vortex-datafusion/src/persistent/format.rs b/vortex-datafusion/src/persistent/format.rs index e598e2e401c..7a64d0048d4 100644 --- a/vortex-datafusion/src/persistent/format.rs +++ b/vortex-datafusion/src/persistent/format.rs @@ -39,11 +39,9 @@ use futures::FutureExt; use futures::StreamExt as _; use futures::TryStreamExt as _; use futures::stream; -use itertools::Itertools; use object_store::ObjectMeta; use object_store::ObjectStore; use vortex::VortexSessionDefault; -use vortex::array::stats::StatsSet; use vortex::dtype::DType; use vortex::dtype::Nullability; use vortex::dtype::PType; @@ -373,81 +371,87 @@ impl FileFormat for VortexFormat { }); }; - let stats = table_schema - .fields() - .iter() - .map(|field| struct_dtype.find(field.name())) - .map(|idx| match idx { - None => StatsSet::default(), - Some(id) => file_stats[id].clone(), - }) - .collect_vec(); - - let total_byte_size = stats - .iter() - .map(|stats_set| { - stats_set - .get_as::(Stat::UncompressedSizeInBytes, &PType::U64.into()) - .unwrap_or_else(|| stats::Precision::inexact(0_usize)) - }) - .fold(stats::Precision::exact(0_usize), |acc, stats_set| { - acc.zip(stats_set).map(|(acc, stats_set)| acc + stats_set) - }); - - // Sum up the total byte size across all the columns. - let total_byte_size = total_byte_size.to_df(); - - let column_statistics = stats - .into_iter() - .zip(table_schema.fields().iter()) - .map(|(stats_set, field)| { - let null_count = stats_set.get_as::(Stat::NullCount, &PType::U64.into()); - let min = stats_set.get(Stat::Min).and_then(|n| { - n.map(|n| { - Scalar::new( + let mut sum_of_column_byte_sizes = stats::Precision::exact(0_usize); + let mut column_statistics = Vec::with_capacity(table_schema.fields().len()); + + for field in table_schema.fields().iter() { + // If the column does not exist, continue. This can happen if the schema has evolved + // but we have not yet updated the Vortex file. + let Some(col_idx) = struct_dtype.find(field.name()) else { + // The default sets all statistics to `Precision`. + column_statistics.push(ColumnStatistics::default()); + continue; + }; + let (stats_set, stats_dtype) = file_stats.get(col_idx); + + // Update the total size in bytes. + let column_size = stats_set + .get_as::(Stat::UncompressedSizeInBytes, &PType::U64.into()) + .unwrap_or_else(|| stats::Precision::inexact(0_usize)); + sum_of_column_byte_sizes = sum_of_column_byte_sizes + .zip(column_size) + .map(|(acc, size)| acc + size); + + // TODO(connor): There's a lot that can go wrong here, should probably handle this + // more gracefully... + // Find the min statistic. + let min = stats_set.get(Stat::Min).and_then(|pstat_val| { + pstat_val + .map(|stat_val| { + // Because of DataFusion's Schema evolution, it is possible that the + // type of the min/max stat has changed. Thus we construct the stat as + // the file datatype first and only then do we cast accordingly. + Scalar::try_new( Stat::Min - .dtype(&DType::from_arrow(field.as_ref())) + .dtype(stats_dtype) .vortex_expect("must have a valid dtype"), - n, + Some(stat_val), ) + .vortex_expect("`Stat::Min` somehow had an incompatible `DType`") + .cast(&DType::from_arrow(field.as_ref())) + .vortex_expect("Unable to cast to target type that DataFusion wants") .try_to_df() .ok() }) .transpose() - }); + }); - let max = stats_set.get(Stat::Max).and_then(|n| { - n.map(|n| { - Scalar::new( + // Find the max statistic. + let max = stats_set.get(Stat::Max).and_then(|pstat_val| { + pstat_val + .map(|stat_val| { + Scalar::try_new( Stat::Max - .dtype(&DType::from_arrow(field.as_ref())) + .dtype(stats_dtype) .vortex_expect("must have a valid dtype"), - n, + Some(stat_val), ) + .vortex_expect("`Stat::Max` somehow had an incompatible `DType`") + .cast(&DType::from_arrow(field.as_ref())) + .vortex_expect("Unable to cast to target type that DataFusion wants") .try_to_df() .ok() }) .transpose() - }); - - ColumnStatistics { - null_count: null_count.to_df(), - max_value: max.to_df(), - min_value: min.to_df(), - sum_value: Precision::Absent, - distinct_count: stats_set - .get_as::( - Stat::IsConstant, - &DType::Bool(Nullability::NonNullable), - ) - .and_then(|is_constant| { - is_constant.as_exact().map(|_| Precision::Exact(1)) - }) - .unwrap_or(Precision::Absent), - byte_size: Precision::Absent, - } + }); + + let null_count = stats_set.get_as::(Stat::NullCount, &PType::U64.into()); + + column_statistics.push(ColumnStatistics { + null_count: null_count.to_df(), + min_value: min.to_df(), + max_value: max.to_df(), + sum_value: Precision::Absent, + distinct_count: stats_set + .get_as::(Stat::IsConstant, &DType::Bool(Nullability::NonNullable)) + .and_then(|is_constant| is_constant.as_exact().map(|_| Precision::Exact(1))) + .unwrap_or(Precision::Absent), + // TODO(connor): Is this correct? + byte_size: column_size.to_df(), }) - .collect::>(); + } + + let total_byte_size = sum_of_column_byte_sizes.to_df(); Ok(Statistics { num_rows: Precision::Exact( diff --git a/vortex-dtype/src/dtype.rs b/vortex-dtype/src/dtype.rs index 07a54a2bcdd..1970788e622 100644 --- a/vortex-dtype/src/dtype.rs +++ b/vortex-dtype/src/dtype.rs @@ -252,7 +252,7 @@ impl DType { if let Primitive(ptype, _) = self { *ptype } else { - vortex_panic!("DType is not a primitive type") + vortex_panic!("DType {self} is not a primitive type") } } @@ -492,6 +492,15 @@ impl DType { ext } + /// Get the `ExtDTypeRef` if `self` is an `Extension` type, otherwise `None` + pub fn as_extension_opt(&self) -> Option<&ExtDTypeRef> { + if let Extension(ext) = self { + Some(ext) + } else { + None + } + } + /// Convenience method for creating a [`DType::List`]. pub fn list(dtype: impl Into, nullability: Nullability) -> Self { List(Arc::new(dtype.into()), nullability) diff --git a/vortex-duckdb/src/convert/scalar.rs b/vortex-duckdb/src/convert/scalar.rs index ac82f21ffd4..3f073d56c55 100644 --- a/vortex-duckdb/src/convert/scalar.rs +++ b/vortex-duckdb/src/convert/scalar.rs @@ -261,39 +261,57 @@ impl<'a> TryFrom> for Scalar { ExtractedValue::Blob(b) => Ok(Scalar::binary(b, Nullable)), ExtractedValue::Date(days) => Ok(Scalar::extension::( TimeUnit::Days, - Scalar::new(DType::Primitive(I32, Nullable), ScalarValue::from(days)), + Scalar::try_new( + DType::Primitive(I32, Nullable), + Some(ScalarValue::from(days)), + )?, )), ExtractedValue::Time(micros) => Ok(Scalar::extension::

(other, result_dtype, ptype, op) - }, - floating: |P| { - let lhs = self.typed_value::

(); - let rhs = other.typed_value::

(); - let value_or_null = match (lhs, rhs) { - (_, None) | (None, _) => None, - (Some(lhs), Some(rhs)) => match op { - NumericOperator::Add => Some(lhs + rhs), - NumericOperator::Sub => Some(lhs - rhs), - NumericOperator::RSub => Some(rhs - lhs), - NumericOperator::Mul => Some(lhs * rhs), - NumericOperator::Div => Some(lhs / rhs), - NumericOperator::RDiv => Some(rhs / lhs), - } - }; - Some(Self { dtype: result_dtype, ptype, pvalue: value_or_null.map(PValue::from) }) - } - ) - } - - fn checked_integral_numeric_operator< - P: NativePType - + TryFrom - + CheckedSub - + CheckedAdd - + CheckedMul - + CheckedDiv, - >( - &self, - other: &PrimitiveScalar<'a>, - result_dtype: &'a DType, - ptype: PType, - op: NumericOperator, - ) -> Option> - where - PValue: From

, - { - let lhs = self.typed_value::

(); - let rhs = other.typed_value::

(); - let value_or_null_or_overflow = match (lhs, rhs) { - (_, None) | (None, _) => Some(None), - (Some(lhs), Some(rhs)) => match op { - NumericOperator::Add => lhs.checked_add(&rhs).map(Some), - NumericOperator::Sub => lhs.checked_sub(&rhs).map(Some), - NumericOperator::RSub => rhs.checked_sub(&lhs).map(Some), - NumericOperator::Mul => lhs.checked_mul(&rhs).map(Some), - NumericOperator::Div => lhs.checked_div(&rhs).map(Some), - NumericOperator::RDiv => rhs.checked_div(&lhs).map(Some), - }, - }; - - value_or_null_or_overflow.map(|value_or_null| Self { - dtype: result_dtype, - ptype, - pvalue: value_or_null.map(PValue::from), - }) - } -} - -#[cfg(test)] -mod tests { - use num_traits::CheckedSub; - use rstest::rstest; - use vortex_dtype::DType; - use vortex_dtype::Nullability; - use vortex_dtype::PType; - use vortex_error::VortexExpect; - - use crate::InnerScalarValue; - use crate::PValue; - use crate::PrimitiveScalar; - use crate::ScalarValue; - - #[test] - fn test_integer_subtract() { - let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let p_scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))), - ) - .unwrap(); - let p_scalar2 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))), - ) - .unwrap(); - let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2); - let value_or_null_or_type_error = pscalar_or_overflow.unwrap().as_::(); - assert_eq!(value_or_null_or_type_error.unwrap(), 1); - - assert_eq!((p_scalar1 - p_scalar2).as_::().unwrap(), 1); - } - - #[test] - #[should_panic(expected = "PrimitiveScalar subtract: overflow or underflow")] - fn test_integer_subtract_overflow() { - let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let p_scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MIN))), - ) - .unwrap(); - let p_scalar2 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))), - ) - .unwrap(); - let _ = p_scalar1 - p_scalar2; - } - - #[test] - fn test_float_subtract() { - let dtype = DType::Primitive(PType::F32, Nullability::NonNullable); - let p_scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.99f32))), - ) - .unwrap(); - let p_scalar2 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::F32(1.0f32))), - ) - .unwrap(); - let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2).unwrap(); - let value_or_null_or_type_error = pscalar_or_overflow.as_::(); - assert_eq!(value_or_null_or_type_error.unwrap(), 0.99f32); - - assert_eq!((p_scalar1 - p_scalar2).as_::().unwrap(), 0.99f32); - } - - #[test] - fn test_primitive_scalar_equality() { - let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))), - ) - .unwrap(); - let scalar2 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))), - ) - .unwrap(); - let scalar3 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(43))), - ) - .unwrap(); - - assert_eq!(scalar1, scalar2); - assert_ne!(scalar1, scalar3); - } - - #[test] - fn test_primitive_scalar_partial_ord() { - let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))), - ) - .unwrap(); - let scalar2 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))), - ) - .unwrap(); - - assert!(scalar1 < scalar2); - assert!(scalar2 > scalar1); - assert_eq!( - scalar1.partial_cmp(&scalar1), - Some(std::cmp::Ordering::Equal) - ); - } - - #[test] - fn test_primitive_scalar_null_handling() { - let dtype = DType::Primitive(PType::I32, Nullability::Nullable); - let null_scalar = - PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap(); - - assert_eq!(null_scalar.pvalue(), None); - assert_eq!(null_scalar.typed_value::(), None); - } - - #[test] - fn test_typed_value_correct_type() { - let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); - let scalar = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::F64(3.5))), - ) - .unwrap(); - - assert_eq!(scalar.typed_value::(), Some(3.5)); - } - - #[test] - #[should_panic(expected = "Attempting to read")] - fn test_typed_value_wrong_type() { - let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); - let scalar = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::F64(3.5))), - ) - .unwrap(); - - let _ = scalar.typed_value::(); - } - - #[rstest] - #[case(PType::I8, 127i32, PType::I16, true)] - #[case(PType::I8, 127i32, PType::I32, true)] - #[case(PType::I8, 127i32, PType::I64, true)] - #[case(PType::U8, 255i32, PType::U16, true)] - #[case(PType::U8, 255i32, PType::U32, true)] - #[case(PType::I32, 42i32, PType::F32, true)] - #[case(PType::I32, 42i32, PType::F64, true)] - // Overflow cases - #[case(PType::I32, 300i32, PType::U8, false)] - #[case(PType::I32, -1i32, PType::U32, false)] - #[case(PType::I32, 256i32, PType::I8, false)] - #[case(PType::U16, 65535i32, PType::I8, false)] - fn test_primitive_cast( - #[case] source_type: PType, - #[case] source_value: i32, - #[case] target_type: PType, - #[case] should_succeed: bool, - ) { - let source_pvalue = match source_type { - PType::I8 => PValue::I8(i8::try_from(source_value).vortex_expect("cannot cast")), - PType::U8 => PValue::U8(u8::try_from(source_value).vortex_expect("cannot cast")), - PType::U16 => PValue::U16(u16::try_from(source_value).vortex_expect("cannot cast")), - PType::I32 => PValue::I32(source_value), - _ => unreachable!("Test case uses unexpected source type"), - }; - - let dtype = DType::Primitive(source_type, Nullability::NonNullable); - let scalar = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(source_pvalue)), - ) - .unwrap(); - - let target_dtype = DType::Primitive(target_type, Nullability::NonNullable); - let result = scalar.cast(&target_dtype); - - if should_succeed { - assert!( - result.is_ok(), - "Cast from {:?} to {:?} should succeed", - source_type, - target_type - ); - } else { - assert!( - result.is_err(), - "Cast from {:?} to {:?} should fail due to overflow", - source_type, - target_type - ); - } - } - - #[test] - fn test_as_conversion_success() { - let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let scalar = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))), - ) - .unwrap(); - - assert_eq!(scalar.as_::(), Some(42i64)); - assert_eq!(scalar.as_::(), Some(42.0)); - } - - #[test] - fn test_as_conversion_overflow() { - let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let scalar = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(-1))), - ) - .unwrap(); - - // Converting -1 to u32 should fail - let result = scalar.as_opt::(); - assert!(result.is_none()); - } - - #[test] - fn test_as_conversion_null() { - let dtype = DType::Primitive(PType::I32, Nullability::Nullable); - let scalar = - PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap(); - - assert_eq!(scalar.as_::(), None); - assert_eq!(scalar.as_::(), None); - } - - #[test] - fn test_numeric_operator_swap() { - use crate::primitive::NumericOperator; - - assert_eq!(NumericOperator::Add.swap(), NumericOperator::Add); - assert_eq!(NumericOperator::Sub.swap(), NumericOperator::RSub); - assert_eq!(NumericOperator::RSub.swap(), NumericOperator::Sub); - assert_eq!(NumericOperator::Mul.swap(), NumericOperator::Mul); - assert_eq!(NumericOperator::Div.swap(), NumericOperator::RDiv); - assert_eq!(NumericOperator::RDiv.swap(), NumericOperator::Div); - } - - #[test] - fn test_checked_binary_numeric_add() { - use crate::primitive::NumericOperator; - - let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))), - ) - .unwrap(); - let scalar2 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))), - ) - .unwrap(); - - let result = scalar1 - .checked_binary_numeric(&scalar2, NumericOperator::Add) - .unwrap(); - assert_eq!(result.typed_value::(), Some(30)); - } - - #[test] - fn test_checked_binary_numeric_overflow() { - use crate::primitive::NumericOperator; - - let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32::MAX))), - ) - .unwrap(); - let scalar2 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(1))), - ) - .unwrap(); - - // Add should overflow and return None - let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::Add); - assert!(result.is_none()); - } - - #[test] - fn test_checked_binary_numeric_with_null() { - use crate::primitive::NumericOperator; - - let dtype = DType::Primitive(PType::I32, Nullability::Nullable); - let scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))), - ) - .unwrap(); - let null_scalar = - PrimitiveScalar::try_new(&dtype, &ScalarValue(InnerScalarValue::Null)).unwrap(); - - // Operation with null should return null - let result = scalar1 - .checked_binary_numeric(&null_scalar, NumericOperator::Add) - .unwrap(); - assert_eq!(result.pvalue(), None); - } - - #[test] - fn test_checked_binary_numeric_mul() { - use crate::primitive::NumericOperator; - - let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(5))), - ) - .unwrap(); - let scalar2 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(6))), - ) - .unwrap(); - - let result = scalar1 - .checked_binary_numeric(&scalar2, NumericOperator::Mul) - .unwrap(); - assert_eq!(result.typed_value::(), Some(30)); - } - - #[test] - fn test_checked_binary_numeric_div() { - use crate::primitive::NumericOperator; - - let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))), - ) - .unwrap(); - let scalar2 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))), - ) - .unwrap(); - - let result = scalar1 - .checked_binary_numeric(&scalar2, NumericOperator::Div) - .unwrap(); - assert_eq!(result.typed_value::(), Some(5)); - } - - #[test] - fn test_checked_binary_numeric_rdiv() { - use crate::primitive::NumericOperator; - - let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(4))), - ) - .unwrap(); - let scalar2 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(20))), - ) - .unwrap(); - - // RDiv means right / left, so 20 / 4 = 5 - let result = scalar1 - .checked_binary_numeric(&scalar2, NumericOperator::RDiv) - .unwrap(); - assert_eq!(result.typed_value::(), Some(5)); - } - - #[test] - fn test_checked_binary_numeric_div_by_zero() { - use crate::primitive::NumericOperator; - - let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))), - ) - .unwrap(); - let scalar2 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(0))), - ) - .unwrap(); - - // Division by zero should return None for integers - let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::Div); - assert!(result.is_none()); - } - - #[test] - fn test_checked_binary_numeric_float_ops() { - use crate::primitive::NumericOperator; - - let dtype = DType::Primitive(PType::F32, Nullability::NonNullable); - let scalar1 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::F32(10.0))), - ) - .unwrap(); - let scalar2 = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::F32(2.5))), - ) - .unwrap(); - - // Test all operations with floats - let add_result = scalar1 - .checked_binary_numeric(&scalar2, NumericOperator::Add) - .unwrap(); - assert_eq!(add_result.typed_value::(), Some(12.5)); - - let sub_result = scalar1 - .checked_binary_numeric(&scalar2, NumericOperator::Sub) - .unwrap(); - assert_eq!(sub_result.typed_value::(), Some(7.5)); - - let mul_result = scalar1 - .checked_binary_numeric(&scalar2, NumericOperator::Mul) - .unwrap(); - assert_eq!(mul_result.typed_value::(), Some(25.0)); - - let div_result = scalar1 - .checked_binary_numeric(&scalar2, NumericOperator::Div) - .unwrap(); - assert_eq!(div_result.typed_value::(), Some(4.0)); - } - - #[test] - fn test_from_primitive_or_f16() { - use vortex_dtype::half::f16; - - use crate::primitive::FromPrimitiveOrF16; - - // Test f16 to f32 conversion - let f16_val = f16::from_f32(3.5); - assert!(f32::from_f16(f16_val).is_some()); - - // Test f16 to f64 conversion - assert!(f64::from_f16(f16_val).is_some()); - - // Test PValue::F16(f16) to integer conversion (should fail) - assert!(i32::try_from(PValue::from(f16_val)).is_err()); - assert!(u32::try_from(PValue::from(f16_val)).is_err()); - } - - #[test] - fn test_partial_ord_different_types() { - let dtype1 = DType::Primitive(PType::I32, Nullability::NonNullable); - let dtype2 = DType::Primitive(PType::F32, Nullability::NonNullable); - - let scalar1 = PrimitiveScalar::try_new( - &dtype1, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(10))), - ) - .unwrap(); - let scalar2 = PrimitiveScalar::try_new( - &dtype2, - &ScalarValue(InnerScalarValue::Primitive(PValue::F32(10.0))), - ) - .unwrap(); - - // Different types should not be comparable - assert_eq!(scalar1.partial_cmp(&scalar2), None); - } - - #[test] - fn test_scalar_value_from_usize() { - let value: ScalarValue = 42usize.into(); - assert!(matches!( - value.0, - InnerScalarValue::Primitive(PValue::U64(42)) - )); - } - - #[test] - fn test_getters() { - let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let scalar = PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))), - ) - .unwrap(); - - assert_eq!(scalar.dtype(), &dtype); - assert_eq!(scalar.ptype(), PType::I32); - assert_eq!(scalar.pvalue(), Some(PValue::I32(42))); - } -} diff --git a/vortex-scalar/src/proto.rs b/vortex-scalar/src/proto.rs index b243005141f..1b41b0af85b 100644 --- a/vortex-scalar/src/proto.rs +++ b/vortex-scalar/src/proto.rs @@ -1,16 +1,21 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use std::sync::Arc; +//! Protobuf serialization and deserialization for scalars. use num_traits::ToBytes; +use num_traits::ToPrimitive; +use prost::Message; use vortex_buffer::BufferString; use vortex_buffer::ByteBuffer; use vortex_dtype::DType; +use vortex_dtype::PType; use vortex_dtype::half::f16; -use vortex_error::VortexError; +use vortex_dtype::i256; use vortex_error::VortexExpect; use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; use vortex_error::vortex_err; use vortex_proto::scalar as pb; use vortex_proto::scalar::ListValue; @@ -18,10 +23,13 @@ use vortex_proto::scalar::scalar_value::Kind; use vortex_session::VortexSession; use crate::DecimalValue; -use crate::InnerScalarValue; +use crate::PValue; use crate::Scalar; use crate::ScalarValue; -use crate::pvalue::PValue; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Serialize INTO proto. +//////////////////////////////////////////////////////////////////////////////////////////////////// impl From<&Scalar> for pb::Scalar { fn from(value: &Scalar) -> Self { @@ -31,22 +39,49 @@ impl From<&Scalar> for pb::Scalar { .try_into() .vortex_expect("Failed to convert DType to proto"), ), - value: Some((value.value()).into()), + value: Some(ScalarValue::to_proto(value.value())), + } + } +} + +impl ScalarValue { + /// Ideally, we would not have this function and instead implement this `From` implementation: + /// + /// ```ignore + /// impl From> for pb::ScalarValue { ... } + /// ``` + /// + /// However, we are not allowed to do this because of the Orphan rule (`Option` and + /// `pb::ScalarValue` are not types defined in this crate). So we must make this a method on + /// `vortex_scalar::ScalarValue` directly. + pub fn to_proto(this: Option<&Self>) -> pb::ScalarValue { + match this { + None => pb::ScalarValue { + kind: Some(Kind::NullValue(0)), + }, + Some(this) => pb::ScalarValue::from(this), } } + + /// Serialize an optional [`ScalarValue`] to protobuf bytes (handles null values). + pub fn to_proto_bytes(value: Option<&ScalarValue>) -> B { + let proto = Self::to_proto(value); + let mut buf = B::default(); + proto + .encode(&mut buf) + .vortex_expect("Failed to encode scalar value"); + buf + } } impl From<&ScalarValue> for pb::ScalarValue { fn from(value: &ScalarValue) -> Self { match value { - ScalarValue(InnerScalarValue::Null) => pb::ScalarValue { - kind: Some(Kind::NullValue(0)), - }, - ScalarValue(InnerScalarValue::Bool(v)) => pb::ScalarValue { + ScalarValue::Bool(v) => pb::ScalarValue { kind: Some(Kind::BoolValue(*v)), }, - ScalarValue(InnerScalarValue::Primitive(v)) => v.into(), - ScalarValue(InnerScalarValue::Decimal(v)) => { + ScalarValue::Primitive(v) => pb::ScalarValue::from(v), + ScalarValue::Decimal(v) => { let inner_value = match v { DecimalValue::I8(v) => v.to_le_bytes().to_vec(), DecimalValue::I16(v) => v.to_le_bytes().to_vec(), @@ -60,16 +95,16 @@ impl From<&ScalarValue> for pb::ScalarValue { kind: Some(Kind::BytesValue(inner_value)), } } - ScalarValue(InnerScalarValue::Buffer(v)) => pb::ScalarValue { - kind: Some(Kind::BytesValue(v.as_slice().to_vec())), + ScalarValue::Utf8(v) => pb::ScalarValue { + kind: Some(Kind::StringValue(v.to_string())), }, - ScalarValue(InnerScalarValue::BufferString(v)) => pb::ScalarValue { - kind: Some(Kind::StringValue(v.as_str().to_string())), + ScalarValue::Binary(v) => pb::ScalarValue { + kind: Some(Kind::BytesValue(v.to_vec())), }, - ScalarValue(InnerScalarValue::List(v)) => { + ScalarValue::List(v) => { let mut values = Vec::with_capacity(v.len()); for elem in v.iter() { - values.push(pb::ScalarValue::from(elem)); + values.push(ScalarValue::to_proto(elem.as_ref())); } pb::ScalarValue { kind: Some(Kind::ListValue(ListValue { values })), @@ -119,8 +154,30 @@ impl From<&PValue> for pb::ScalarValue { } } +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Serialize FROM proto. +//////////////////////////////////////////////////////////////////////////////////////////////////// + impl Scalar { - /// Creates a Scalar from its protobuf representation. + /// Creates a [`Scalar`] from a [protobuf `ScalarValue`](pb::ScalarValue) representation. + /// + /// Note that we need to provide a [`DType`] since protobuf serialization only supports 64-bit + /// integers, and serializing _into_ protobuf loses that type information. + /// + /// # Errors + /// + /// Returns an error if type validation fails. + pub fn from_proto_value(value: &pb::ScalarValue, dtype: &DType) -> VortexResult { + let scalar_value = ScalarValue::from_proto(value, dtype)?; + + Scalar::try_new(dtype.clone(), scalar_value) + } + + /// Creates a [`Scalar`] from its [protobuf](pb::Scalar) representation. + /// + /// # Errors + /// + /// Returns an error if the protobuf is missing required fields or if type validation fails. pub fn from_proto(value: &pb::Scalar, session: &VortexSession) -> VortexResult { let dtype = DType::from_proto( value @@ -130,82 +187,238 @@ impl Scalar { session, )?; - let value = ScalarValue::try_from( - value - .value - .as_ref() - .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing value"))?, - )?; + let pb_scalar_value: &pb::ScalarValue = value + .value + .as_ref() + .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing value"))?; - Ok(Scalar::new(dtype, value)) + let value: Option = ScalarValue::from_proto(pb_scalar_value, &dtype)?; + + Scalar::try_new(dtype, value) } } -impl TryFrom<&pb::ScalarValue> for ScalarValue { - type Error = VortexError; +impl ScalarValue { + /// Deserialize a [`ScalarValue`] from protobuf bytes. + /// + /// Note that we need to provide a [`DType`] since protobuf serialization only supports 64-bit + /// integers, and serializing _into_ protobuf loses that type information. + /// + /// # Errors + /// + /// Returns an error if decoding or type validation fails. + pub fn from_proto_bytes(bytes: &[u8], dtype: &DType) -> VortexResult> { + let proto = pb::ScalarValue::decode(bytes)?; + Self::from_proto(&proto, dtype) + } - fn try_from(value: &pb::ScalarValue) -> Result { + /// Creates a [`ScalarValue`] from its [protobuf](pb::ScalarValue) representation. + /// + /// Note that we need to provide a [`DType`] since protobuf serialization only supports 64-bit + /// integers, and serializing _into_ protobuf loses that type information. + /// + /// # Errors + /// + /// Returns an error if the protobuf value cannot be converted to the given [`DType`]. + pub fn from_proto(value: &pb::ScalarValue, dtype: &DType) -> VortexResult> { let kind = value .kind .as_ref() - .ok_or_else(|| vortex_err!(InvalidSerde: "ScalarValue missing kind"))?; - - match kind { - Kind::NullValue(_) => Ok(ScalarValue(InnerScalarValue::Null)), - Kind::BoolValue(v) => Ok(ScalarValue(InnerScalarValue::Bool(*v))), - Kind::Int64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::I64(*v)))), - Kind::Uint64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::U64(*v)))), - Kind::F16Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F16( - f16::from_bits(u16::try_from(*v).map_err(|_| { - vortex_err!("f16 bitwise representation has more than 16 bits: {}", v) - })?), - )))), - Kind::F32Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F32(*v)))), - Kind::F64Value(v) => Ok(ScalarValue(InnerScalarValue::Primitive(PValue::F64(*v)))), - Kind::StringValue(v) => Ok(ScalarValue(InnerScalarValue::BufferString(Arc::new( - BufferString::from(v.clone()), - )))), - Kind::BytesValue(v) => Ok(ScalarValue(InnerScalarValue::Buffer(Arc::new( - ByteBuffer::from(v.clone()), - )))), - Kind::ListValue(v) => { - let mut values = Vec::with_capacity(v.values.len()); - for elem in v.values.iter() { - values.push(elem.try_into()?); - } - Ok(ScalarValue(InnerScalarValue::List(values.into()))) - } - } + .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar value missing kind"))?; + + // `DType::Extension` store their serialized values using the storage `DType`. + let dtype = match dtype { + DType::Extension(ext) => ext.storage_dtype(), + _ => dtype, + }; + + Ok(Some(match kind { + Kind::NullValue(_) => return Ok(None), + Kind::BoolValue(v) => bool_from_proto(*v, dtype)?, + Kind::Int64Value(v) => int64_from_proto(*v, dtype)?, + Kind::Uint64Value(v) => uint64_from_proto(*v, dtype)?, + Kind::F16Value(v) => f16_from_proto(*v, dtype)?, + Kind::F32Value(v) => f32_from_proto(*v, dtype)?, + Kind::F64Value(v) => f64_from_proto(*v, dtype)?, + Kind::StringValue(s) => string_from_proto(s, dtype)?, + Kind::BytesValue(b) => bytes_from_proto(b, dtype)?, + Kind::ListValue(v) => list_from_proto(v, dtype)?, + })) + } +} + +/// Deserialize a [`ScalarValue::Bool`] from a protobuf `BoolValue`. +fn bool_from_proto(v: bool, dtype: &DType) -> VortexResult { + vortex_ensure!( + dtype.is_boolean(), + InvalidSerde: "expected Bool dtype for BoolValue, got {dtype}" + ); + + Ok(ScalarValue::Bool(v)) +} + +/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `Int64Value`. +/// +/// Protobuf consolidates all signed integers into `i64`, so we narrow back to the original +/// type using the provided [`DType`]. +fn int64_from_proto(v: i64, dtype: &DType) -> VortexResult { + vortex_ensure!( + dtype.is_primitive(), + InvalidSerde: "expected Primitive dtype for Int64Value, got {dtype}" + ); + + let pvalue = match dtype.as_ptype() { + PType::I8 => v.to_i8().map(PValue::I8), + PType::I16 => v.to_i16().map(PValue::I16), + PType::I32 => v.to_i32().map(PValue::I32), + PType::I64 => Some(PValue::I64(v)), + ptype => vortex_bail!( + InvalidSerde: "expected signed integer ptype for Int64Value, got {ptype}" + ), } + .ok_or_else(|| vortex_err!(InvalidSerde: "Int64 value {v} out of range for dtype {dtype}"))?; + + Ok(ScalarValue::Primitive(pvalue)) +} + +/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `Uint64Value`. +/// +/// Protobuf consolidates all unsigned integers into `u64`, so we narrow back to the original +/// type using the provided [`DType`]. Also handles the backwards-compatible case where `f16` +/// values were serialized as `u64` (via `f16::to_bits() as u64`). +fn uint64_from_proto(v: u64, dtype: &DType) -> VortexResult { + vortex_ensure!( + dtype.is_primitive(), + InvalidSerde: "expected Primitive dtype for Uint64Value, got {dtype}" + ); + + let pvalue = match dtype.as_ptype() { + PType::U8 => v.to_u8().map(PValue::U8), + PType::U16 => v.to_u16().map(PValue::U16), + PType::U32 => v.to_u32().map(PValue::U32), + PType::U64 => Some(PValue::U64(v)), + // Backwards compatibility: f16 values were previously serialized as u64. + PType::F16 => v.to_u16().map(f16::from_bits).map(PValue::F16), + ptype => vortex_bail!( + InvalidSerde: "expected unsigned integer ptype for Uint64Value, got {ptype}" + ), + } + .ok_or_else(|| vortex_err!(InvalidSerde: "Uint64 value {v} out of range for dtype {dtype}"))?; + + Ok(ScalarValue::Primitive(pvalue)) +} + +/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `F16Value`. +fn f16_from_proto(v: u64, dtype: &DType) -> VortexResult { + vortex_ensure!( + matches!(dtype, DType::Primitive(PType::F16, _)), + InvalidSerde: "expected F16 dtype for F16Value, got {dtype}" + ); + + let bits = u16::try_from(v).map_err( + |_| vortex_err!(InvalidSerde: "f16 bitwise representation has more than 16 bits: {v}"), + )?; + + Ok(ScalarValue::Primitive(PValue::F16(f16::from_bits(bits)))) +} + +/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `F32Value`. +fn f32_from_proto(v: f32, dtype: &DType) -> VortexResult { + vortex_ensure!( + matches!(dtype, DType::Primitive(PType::F32, _)), + InvalidSerde: "expected F32 dtype for F32Value, got {dtype}" + ); + + Ok(ScalarValue::Primitive(PValue::F32(v))) +} + +/// Deserialize a [`ScalarValue::Primitive`] from a protobuf `F64Value`. +fn f64_from_proto(v: f64, dtype: &DType) -> VortexResult { + vortex_ensure!( + matches!(dtype, DType::Primitive(PType::F64, _)), + InvalidSerde: "expected F64 dtype for F64Value, got {dtype}" + ); + + Ok(ScalarValue::Primitive(PValue::F64(v))) +} + +/// Deserialize a [`ScalarValue::Utf8`] or [`ScalarValue::Binary`] from a protobuf +/// `StringValue`. +fn string_from_proto(s: &str, dtype: &DType) -> VortexResult { + match dtype { + DType::Utf8(_) => Ok(ScalarValue::Utf8(BufferString::from(s))), + DType::Binary(_) => Ok(ScalarValue::Binary(ByteBuffer::copy_from(s.as_bytes()))), + _ => vortex_bail!( + InvalidSerde: "expected Utf8 or Binary dtype for StringValue, got {dtype}" + ), + } +} + +/// Deserialize a [`ScalarValue`] from a protobuf bytes and a `DType`. +/// +/// Handles [`Utf8`](ScalarValue::Utf8), [`Binary`](ScalarValue::Binary), and +/// [`Decimal`](ScalarValue::Decimal) dtypes. +fn bytes_from_proto(bytes: &[u8], dtype: &DType) -> VortexResult { + match dtype { + DType::Utf8(_) => Ok(ScalarValue::Utf8(BufferString::try_from(bytes)?)), + DType::Binary(_) => Ok(ScalarValue::Binary(ByteBuffer::copy_from(bytes))), + // TODO(connor): This is incorrect, we need to verify this matches the `dtype`. + DType::Decimal(..) => Ok(ScalarValue::Decimal(match bytes.len() { + 1 => DecimalValue::I8(bytes[0] as i8), + 2 => DecimalValue::I16(i16::from_le_bytes(bytes.try_into()?)), + 4 => DecimalValue::I32(i32::from_le_bytes(bytes.try_into()?)), + 8 => DecimalValue::I64(i64::from_le_bytes(bytes.try_into()?)), + 16 => DecimalValue::I128(i128::from_le_bytes(bytes.try_into()?)), + 32 => DecimalValue::I256(i256::from_le_bytes(bytes.try_into()?)), + l => vortex_bail!(InvalidSerde: "invalid decimal byte length: {l}"), + })), + _ => vortex_bail!( + InvalidSerde: "expected Utf8, Binary, or Decimal dtype for BytesValue, got {dtype}" + ), + } +} + +/// Deserialize a [`ScalarValue::List`] from a protobuf `ListValue`. +fn list_from_proto(v: &ListValue, dtype: &DType) -> VortexResult { + let element_dtype = dtype.as_list_element_opt().ok_or_else( + || vortex_err!(InvalidSerde: "expected List dtype for ListValue, got {dtype}"), + )?; + + let mut values = Vec::with_capacity(v.values.len()); + for elem in v.values.iter() { + values.push(ScalarValue::from_proto(elem, element_dtype.as_ref())?); + } + + Ok(ScalarValue::List(values)) } #[cfg(test)] mod tests { use std::sync::Arc; - use rstest::rstest; use vortex_buffer::BufferString; use vortex_dtype::DType; use vortex_dtype::DecimalDType; - use vortex_dtype::FieldDType; use vortex_dtype::Nullability; use vortex_dtype::PType; - use vortex_dtype::StructFields; use vortex_dtype::half::f16; - use vortex_dtype::i256; use vortex_error::vortex_panic; use vortex_proto::scalar as pb; + use vortex_session::VortexSession; use super::*; - use crate::InnerScalarValue; + use crate::DecimalValue; use crate::Scalar; use crate::ScalarValue; - use crate::tests::SESSION; + + fn session() -> VortexSession { + VortexSession::empty() + } fn round_trip(scalar: Scalar) { assert_eq!( scalar, - Scalar::from_proto(&pb::Scalar::from(&scalar), &SESSION).unwrap(), + Scalar::from_proto(&pb::Scalar::from(&scalar), &session()).unwrap(), ); } @@ -218,7 +431,7 @@ mod tests { fn test_bool() { round_trip(Scalar::new( DType::Bool(Nullability::Nullable), - ScalarValue(InnerScalarValue::Bool(true)), + Some(ScalarValue::Bool(true)), )); } @@ -226,7 +439,7 @@ mod tests { fn test_primitive() { round_trip(Scalar::new( DType::Primitive(PType::I32, Nullability::Nullable), - ScalarValue(InnerScalarValue::Primitive(42i32.into())), + Some(ScalarValue::Primitive(42i32.into())), )); } @@ -234,7 +447,7 @@ mod tests { fn test_buffer() { round_trip(Scalar::new( DType::Binary(Nullability::Nullable), - ScalarValue(InnerScalarValue::Buffer(Arc::new(vec![1, 2, 3].into()))), + Some(ScalarValue::Binary(vec![1, 2, 3].into())), )); } @@ -242,9 +455,7 @@ mod tests { fn test_buffer_string() { round_trip(Scalar::new( DType::Utf8(Nullability::Nullable), - ScalarValue(InnerScalarValue::BufferString(Arc::new( - BufferString::from("hello".to_string()), - ))), + Some(ScalarValue::Utf8(BufferString::from("hello".to_string()))), )); } @@ -255,13 +466,10 @@ mod tests { Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)), Nullability::Nullable, ), - ScalarValue(InnerScalarValue::List( - vec![ - ScalarValue(InnerScalarValue::Primitive(42i32.into())), - ScalarValue(InnerScalarValue::Primitive(43i32.into())), - ] - .into(), - )), + Some(ScalarValue::List(vec![ + Some(ScalarValue::Primitive(42i32.into())), + Some(ScalarValue::Primitive(43i32.into())), + ])), )); } @@ -277,100 +485,118 @@ mod tests { fn test_i8() { round_trip(Scalar::new( DType::Primitive(PType::I8, Nullability::Nullable), - ScalarValue(InnerScalarValue::Primitive(i8::MIN.into())), + Some(ScalarValue::Primitive(i8::MIN.into())), )); round_trip(Scalar::new( DType::Primitive(PType::I8, Nullability::Nullable), - ScalarValue(InnerScalarValue::Primitive(0i8.into())), + Some(ScalarValue::Primitive(0i8.into())), )); round_trip(Scalar::new( DType::Primitive(PType::I8, Nullability::Nullable), - ScalarValue(InnerScalarValue::Primitive(i8::MAX.into())), + Some(ScalarValue::Primitive(i8::MAX.into())), )); } - #[rstest] - #[case(Scalar::binary(ByteBuffer::copy_from(b"hello"), Nullability::NonNullable))] - #[case(Scalar::utf8("hello", Nullability::NonNullable))] - #[case(Scalar::primitive(1u8, Nullability::NonNullable))] - #[case(Scalar::primitive( - f32::from_bits(u32::from_le_bytes([0xFFu8, 0x8A, 0xF9, 0xFF])), - Nullability::NonNullable - ))] - #[case(Scalar::list(Arc::new(PType::U8.into()), vec![Scalar::primitive(1u8, Nullability::NonNullable)], Nullability::NonNullable - ))] - #[case(Scalar::struct_(DType::Struct( - StructFields::from_iter([ - ("a", FieldDType::from(DType::Primitive(PType::U32, Nullability::NonNullable))), - ("b", FieldDType::from(DType::Primitive(PType::F16, Nullability::NonNullable))), - ]), - Nullability::NonNullable), - vec![ - Scalar::primitive(23592960u32, Nullability::NonNullable), - Scalar::primitive(f16::from_f32(2.6584664e36f32), Nullability::NonNullable), - ], - ))] - #[case(Scalar::struct_(DType::Struct( - StructFields::from_iter([ - ("a", FieldDType::from(DType::Primitive(PType::U64, Nullability::NonNullable))), - ("b", FieldDType::from(DType::Primitive(PType::F32, Nullability::NonNullable))), - ("c", FieldDType::from(DType::Primitive(PType::F16, Nullability::NonNullable))), - ]), - Nullability::NonNullable), - vec![ - Scalar::primitive(415118687234u64, Nullability::NonNullable), - Scalar::primitive(2.6584664e36f32, Nullability::NonNullable), - Scalar::primitive(f16::from_f32(2.6584664e36f32), Nullability::NonNullable), - ], - ))] - #[case(Scalar::decimal( - DecimalValue::I256(i256::from_i128(12345643673471)), - DecimalDType::new(10, 2), - Nullability::NonNullable - ))] - #[case(Scalar::decimal( - DecimalValue::I16(23412), - DecimalDType::new(3, 2), - Nullability::NonNullable - ))] - fn test_scalar_value_serde_roundtrip(#[case] scalar: Scalar) { - let written = scalar.value().to_protobytes::>(); - let scalar_read_back = ScalarValue::from_protobytes(&written).unwrap(); - assert_eq!( - Scalar::new(scalar.dtype().clone(), scalar_read_back), - scalar - ); + #[test] + fn test_decimal_i32_roundtrip() { + // A typical decimal with moderate precision and scale. + round_trip(Scalar::decimal( + DecimalValue::I32(123_456), + DecimalDType::new(10, 2), + Nullability::NonNullable, + )); + } + + #[test] + fn test_decimal_i128_roundtrip() { + // A large decimal value that requires i128 storage. + round_trip(Scalar::decimal( + DecimalValue::I128(99_999_999_999_999_999_999), + DecimalDType::new(38, 6), + Nullability::Nullable, + )); + } + + #[test] + fn test_decimal_null_roundtrip() { + round_trip(Scalar::null(DType::Decimal( + DecimalDType::new(10, 2), + Nullability::Nullable, + ))); + } + + #[test] + fn test_scalar_value_serde_roundtrip_binary() { + round_trip(Scalar::binary( + ByteBuffer::copy_from(b"hello"), + Nullability::NonNullable, + )); + } + + #[test] + fn test_scalar_value_serde_roundtrip_utf8() { + round_trip(Scalar::utf8("hello", Nullability::NonNullable)); } #[test] fn test_backcompat_f16_serialized_as_u64() { - // Note that this is a backwards compatibility test for poor design in the previous implementation. - // Previously, f16 ScalarValues were serialized as `pb::ScalarValue::Uint64Value(v.to_bits() as u64)`. + // Backwards compatibility test for the legacy f16 serialization format. + // + // Previously, f16 ScalarValues were serialized as `Uint64Value(v.to_bits() as u64)` because + // the proto schema only had 64-bit integer types, and f16's underlying representation is + // u16 which got widened to u64. + // + // The current implementation uses a dedicated `F16Value` proto field, but we must still be + // able to deserialize the old format. This test verifies that: + // + // 1. A `Uint64Value` containing f16 bits can be read as a U64 primitive (the raw bits). + // 2. When wrapped in a Scalar with F16 dtype, the value is correctly interpreted as f16. + // + // This ensures data written with the old serialization format remains readable. + + // Simulate the old serialization: f16(0.42) stored as Uint64Value with its bit pattern. + let f16_value = f16::from_f32(0.42); + let f16_bits_as_u64 = f16_value.to_bits() as u64; // 14008 + let pb_scalar_value = pb::ScalarValue { - kind: Some(Kind::Uint64Value(f16::from_f32(0.42).to_bits() as u64)), + kind: Some(Kind::Uint64Value(f16_bits_as_u64)), }; - let scalar_value = ScalarValue::try_from(&pb_scalar_value).unwrap(); + + // Step 1: Verify the normal U64 scalar. + let scalar_value = ScalarValue::from_proto( + &pb_scalar_value, + &DType::Primitive(PType::U64, Nullability::NonNullable), + ) + .unwrap(); assert_eq!( - scalar_value.as_pvalue().unwrap(), - Some(PValue::U64(14008u64)) + scalar_value.as_ref().map(|v| v.as_primitive()), + Some(&PValue::U64(14008u64)), ); + // Step 2: Verify that when we use F16 dtype, the Uint64Value is correctly interpreted. + let scalar_value_f16 = ScalarValue::from_proto( + &pb_scalar_value, + &DType::Primitive(PType::F16, Nullability::Nullable), + ) + .unwrap(); + let scalar = Scalar::new( DType::Primitive(PType::F16, Nullability::Nullable), - scalar_value, + scalar_value_f16, ); assert_eq!( scalar.as_primitive().pvalue().unwrap(), - PValue::F16(f16::from_f32(0.42)) + PValue::F16(f16::from_f32(0.42)), + "Uint64Value should be correctly interpreted as f16 when dtype is F16" ); } #[test] fn test_scalar_value_direct_roundtrip_f16() { - // Test that ScalarValue with f16 roundtrips correctly without going through Scalar + // Test that ScalarValue with f16 roundtrips correctly without going through Scalar. let f16_values = vec![ f16::from_f32(0.0), f16::from_f32(1.0), @@ -384,17 +610,21 @@ mod tests { ]; for f16_val in f16_values { - let scalar_value = ScalarValue(InnerScalarValue::Primitive(PValue::F16(f16_val))); - let written = scalar_value.to_protobytes::>(); - let read_back = ScalarValue::from_protobytes(&written).unwrap(); - - match (&scalar_value.0, &read_back.0) { + let scalar_value = ScalarValue::Primitive(PValue::F16(f16_val)); + let pb_value = ScalarValue::to_proto(Some(&scalar_value)); + let read_back = ScalarValue::from_proto( + &pb_value, + &DType::Primitive(PType::F16, Nullability::NonNullable), + ) + .unwrap(); + + match (&scalar_value, read_back.as_ref()) { ( - InnerScalarValue::Primitive(PValue::F16(original)), - InnerScalarValue::Primitive(PValue::F16(roundtripped)), + ScalarValue::Primitive(PValue::F16(original)), + Some(ScalarValue::Primitive(PValue::F16(roundtripped))), ) => { if original.is_nan() && roundtripped.is_nan() { - // NaN values are equal for our purposes + // NaN values are equal for our purposes. continue; } assert_eq!( @@ -413,55 +643,57 @@ mod tests { #[test] fn test_scalar_value_direct_roundtrip_preserves_values() { - // Test that ScalarValue roundtripping preserves values (but not necessarily exact types) - // Note: Proto encoding consolidates integer types (u8/u16/u32 → u64, i8/i16/i32 → i64) - - // Test cases that should roundtrip exactly - let exact_roundtrip_cases = vec![ - ("null", ScalarValue(InnerScalarValue::Null)), - ("bool_true", ScalarValue(InnerScalarValue::Bool(true))), - ("bool_false", ScalarValue(InnerScalarValue::Bool(false))), + // Test that ScalarValue roundtripping preserves values (but not necessarily exact types). + // Note: Proto encoding consolidates integer types (u8/u16/u32 → u64, i8/i16/i32 → i64). + + // Test cases that should roundtrip exactly. + let exact_roundtrip_cases: Vec<(&str, Option, DType)> = vec![ + ("null", None, DType::Null), + ( + "bool_true", + Some(ScalarValue::Bool(true)), + DType::Bool(Nullability::Nullable), + ), + ( + "bool_false", + Some(ScalarValue::Bool(false)), + DType::Bool(Nullability::Nullable), + ), ( "u64", - ScalarValue(InnerScalarValue::Primitive(PValue::U64( - 18446744073709551615, - ))), + Some(ScalarValue::Primitive(PValue::U64(18446744073709551615))), + DType::Primitive(PType::U64, Nullability::Nullable), ), ( "i64", - ScalarValue(InnerScalarValue::Primitive(PValue::I64( - -9223372036854775808, - ))), + Some(ScalarValue::Primitive(PValue::I64(-9223372036854775808))), + DType::Primitive(PType::I64, Nullability::Nullable), ), ( "f32", - ScalarValue(InnerScalarValue::Primitive(PValue::F32( - std::f32::consts::E, - ))), + Some(ScalarValue::Primitive(PValue::F32(std::f32::consts::E))), + DType::Primitive(PType::F32, Nullability::Nullable), ), ( "f64", - ScalarValue(InnerScalarValue::Primitive(PValue::F64( - std::f64::consts::PI, - ))), + Some(ScalarValue::Primitive(PValue::F64(std::f64::consts::PI))), + DType::Primitive(PType::F64, Nullability::Nullable), ), ( "string", - ScalarValue(InnerScalarValue::BufferString(Arc::new( - BufferString::from("test"), - ))), + Some(ScalarValue::Utf8(BufferString::from("test"))), + DType::Utf8(Nullability::Nullable), ), ( "bytes", - ScalarValue(InnerScalarValue::Buffer(Arc::new( - vec![1, 2, 3, 4, 5].into(), - ))), + Some(ScalarValue::Binary(vec![1, 2, 3, 4, 5].into())), + DType::Binary(Nullability::Nullable), ), ]; - for (name, value) in exact_roundtrip_cases { - let written = value.to_protobytes::>(); - let read_back = ScalarValue::from_protobytes(&written).unwrap(); + for (name, value, dtype) in exact_roundtrip_cases { + let pb_value = ScalarValue::to_proto(value.as_ref()); + let read_back = ScalarValue::from_proto(&pb_value, &dtype).unwrap(); let original_debug = format!("{value:?}"); let roundtrip_debug = format!("{read_back:?}"); @@ -471,34 +703,44 @@ mod tests { ); } - // Test cases where type changes but value is preserved - // Unsigned integers consolidate to U64 + // Test cases where type changes but value is preserved. + // Unsigned integers consolidate to U64. let unsigned_cases = vec![ ( "u8", - ScalarValue(InnerScalarValue::Primitive(PValue::U8(255))), + ScalarValue::Primitive(PValue::U8(255)), + DType::Primitive(PType::U8, Nullability::Nullable), 255u64, ), ( "u16", - ScalarValue(InnerScalarValue::Primitive(PValue::U16(65535))), + ScalarValue::Primitive(PValue::U16(65535)), + DType::Primitive(PType::U16, Nullability::Nullable), 65535u64, ), ( "u32", - ScalarValue(InnerScalarValue::Primitive(PValue::U32(4294967295))), + ScalarValue::Primitive(PValue::U32(4294967295)), + DType::Primitive(PType::U32, Nullability::Nullable), 4294967295u64, ), ]; - for (name, value, expected) in unsigned_cases { - let written = value.to_protobytes::>(); - let read_back = ScalarValue::from_protobytes(&written).unwrap(); - - match &read_back.0 { - InnerScalarValue::Primitive(PValue::U64(v)) => { + for (name, value, dtype, expected) in unsigned_cases { + let pb_value = ScalarValue::to_proto(Some(&value)); + let read_back = ScalarValue::from_proto(&pb_value, &dtype).unwrap(); + + match read_back.as_ref() { + Some(ScalarValue::Primitive(pv)) => { + let v = match pv { + PValue::U8(v) => *v as u64, + PValue::U16(v) => *v as u64, + PValue::U32(v) => *v as u64, + PValue::U64(v) => *v, + _ => vortex_panic!("Unexpected primitive type for {name}: {pv:?}"), + }; assert_eq!( - *v, expected, + v, expected, "ScalarValue {name} value not preserved: expected {expected}, got {v}" ); } @@ -506,33 +748,43 @@ mod tests { } } - // Signed integers consolidate to I64 + // Signed integers consolidate to I64. let signed_cases = vec![ ( "i8", - ScalarValue(InnerScalarValue::Primitive(PValue::I8(-128))), + ScalarValue::Primitive(PValue::I8(-128)), + DType::Primitive(PType::I8, Nullability::Nullable), -128i64, ), ( "i16", - ScalarValue(InnerScalarValue::Primitive(PValue::I16(-32768))), + ScalarValue::Primitive(PValue::I16(-32768)), + DType::Primitive(PType::I16, Nullability::Nullable), -32768i64, ), ( "i32", - ScalarValue(InnerScalarValue::Primitive(PValue::I32(-2147483648))), + ScalarValue::Primitive(PValue::I32(-2147483648)), + DType::Primitive(PType::I32, Nullability::Nullable), -2147483648i64, ), ]; - for (name, value, expected) in signed_cases { - let written = value.to_protobytes::>(); - let read_back = ScalarValue::from_protobytes(&written).unwrap(); - - match &read_back.0 { - InnerScalarValue::Primitive(PValue::I64(v)) => { + for (name, value, dtype, expected) in signed_cases { + let pb_value = ScalarValue::to_proto(Some(&value)); + let read_back = ScalarValue::from_proto(&pb_value, &dtype).unwrap(); + + match read_back.as_ref() { + Some(ScalarValue::Primitive(pv)) => { + let v = match pv { + PValue::I8(v) => *v as i64, + PValue::I16(v) => *v as i64, + PValue::I32(v) => *v as i64, + PValue::I64(v) => *v, + _ => vortex_panic!("Unexpected primitive type for {name}: {pv:?}"), + }; assert_eq!( - *v, expected, + v, expected, "ScalarValue {name} value not preserved: expected {expected}, got {v}" ); } diff --git a/vortex-scalar/src/scalar.rs b/vortex-scalar/src/scalar.rs index ec9d0ca4580..72c1857673c 100644 --- a/vortex-scalar/src/scalar.rs +++ b/vortex-scalar/src/scalar.rs @@ -1,552 +1,359 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +//! Core [`Scalar`] type definition. + use std::cmp::Ordering; use std::hash::Hash; -use std::sync::Arc; +use std::hash::Hasher; -use vortex_buffer::Buffer; use vortex_dtype::DType; use vortex_dtype::NativeDType; -use vortex_dtype::NativeDecimalType; -use vortex_dtype::Nullability; -use vortex_dtype::i256; -use vortex_error::VortexError; -use vortex_error::VortexExpect; +use vortex_dtype::PType; use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_err; +use vortex_error::vortex_ensure; +use vortex_error::vortex_ensure_eq; +use vortex_error::vortex_panic; -use super::*; +use crate::PValue; +use crate::ScalarValue; -/// A single logical item, composed of both a [`ScalarValue`] and a logical [`DType`]. -/// -/// A [`ScalarValue`] is opaque, and should be accessed via one of the type-specific scalar wrappers -/// for example [`BoolScalar`], [`PrimitiveScalar`], etc. +/// A typed scalar value. /// -/// Note that [`PartialOrd`] is implemented only for an exact match of the scalar's dtype, -/// including nullability. When the DType does match, ordering is nulls first (lowest), then the -/// natural ordering of the scalar value. -#[derive(Debug, Clone)] +/// Scalars represent a single value with an associated [`DType`]. The value can be null, in which +/// case the [`value`][Scalar::value] method returns `None`. +#[derive(Clone, Debug, Eq)] pub struct Scalar { /// The type of the scalar. dtype: DType, - /// The value of the scalar. + /// The value of the scalar. This is [`None`] if the value is null, otherwise it is [`Some`]. /// - /// Invariant: If the `dtype` is non-nullable, then this value _cannot_ be equal to - /// [`ScalarValue::null()`](ScalarValue::null). - value: ScalarValue, + /// Invariant: If the [`DType`] is non-nullable, then this value _cannot_ be [`None`]. + value: Option, } -impl Scalar { - /// Creates a new scalar with the given data type and value. - pub fn new(dtype: DType, value: ScalarValue) -> Self { - if !dtype.is_nullable() { - assert!( - !value.is_null(), - "Tried to construct a null scalar when the `DType` is non-nullable: {dtype}", - ); - } - - Self { dtype, value } - } - - /// Returns a reference to the scalar's data type. - #[inline] - pub fn dtype(&self) -> &DType { - &self.dtype - } - - /// Returns a reference to the scalar's underlying value. - #[inline] - pub fn value(&self) -> &ScalarValue { - &self.value - } - - /// Consumes the scalar and returns its data type and value as a tuple. - #[inline] - pub fn into_parts(self) -> (DType, ScalarValue) { - (self.dtype, self.value) - } - - /// Consumes the scalar and returns its underlying [`DType`]. - #[inline] - pub fn into_dtype(self) -> DType { - self.dtype - } - - /// Consumes the scalar and returns its underlying [`ScalarValue`]. - #[inline] - pub fn into_value(self) -> ScalarValue { - self.value +/// We implement `PartialEq` manually because we want to ignore nullability when comparing scalars. +/// Two scalars with the same value but different nullability should be considered equal. +impl PartialEq for Scalar { + fn eq(&self, other: &Self) -> bool { + self.dtype.eq_ignore_nullability(&other.dtype) && self.value == other.value } +} - /// Returns true if the scalar is not null. - pub fn is_valid(&self) -> bool { - !self.value.is_null() +/// We implement `Hash` manually to be consistent with `PartialEq`. Since we ignore nullability +/// in equality comparisons, we must also ignore it when hashing to maintain the invariant that +/// equal values have equal hashes. +impl Hash for Scalar { + fn hash(&self, state: &mut H) { + self.dtype.as_nonnullable().hash(state); + self.value.hash(state); } +} - /// Returns true if the scalar is null. - pub fn is_null(&self) -> bool { - self.value.is_null() - } +impl Scalar { + // Constructors for null scalars. - /// Creates a null scalar with the given nullable data type. + /// Creates a new null [`Scalar`] with the given [`DType`]. /// /// # Panics /// - /// Panics if the data type is not nullable. + /// Panics if the given [`DType`] is non-nullable. pub fn null(dtype: DType) -> Self { assert!( dtype.is_nullable(), - "Tried to construct a null scalar when the `DType` is non-nullable: {dtype}" + "Cannot create null scalar with non-nullable dtype {dtype}" ); - Self { - dtype, - value: ScalarValue(InnerScalarValue::Null), - } + Self { dtype, value: None } } - /// Creates a null scalar for the given scalar type. + // TODO(connor): This method arguably shouldn't exist... + /// Creates a new null [`Scalar`] for the given scalar type. /// /// The resulting scalar will have a nullable version of the type's data type. - pub fn null_typed() -> Self { + pub fn null_native() -> Self { Self { dtype: T::dtype().as_nullable(), - value: ScalarValue(InnerScalarValue::Null), + value: None, } } - /// Casts the scalar to the target data type. + // Constructors for potentially null scalars. + + /// Creates a new [`Scalar`] with the given [`DType`] and potentially null [`ScalarValue`]. /// - /// Returns an error if the cast is not supported or if the value cannot be represented - /// in the target type. - pub fn cast(&self, target: &DType) -> VortexResult { - if let DType::Extension(ext_dtype) = target { - let storage_scalar = self.cast_to_non_extension(ext_dtype.storage_dtype())?; - Ok(Scalar::extension_ref(ext_dtype.clone(), storage_scalar)) - } else { - self.cast_to_non_extension(target) - } + /// This is just a helper function for tests. + /// + /// # Panics + /// + /// Panics if the given [`DType`] and [`ScalarValue`] are incompatible. + #[cfg(test)] + pub fn new(dtype: DType, value: Option) -> Self { + use vortex_error::VortexExpect; + + Self::try_new(dtype, value).vortex_expect("Failed to create Scalar") } - fn cast_to_non_extension(&self, target: &DType) -> VortexResult { - assert!( - !matches!(target, DType::Extension(..)), - "cast_to_non_extension must not be called with an Extension dtype (got {target})", + /// Attempts to create a new [`Scalar`] with the given [`DType`] and potentially null + /// [`ScalarValue`]. + /// + /// # Errors + /// + /// Returns an error if the given [`DType`] and [`ScalarValue`] are incompatible. + pub fn try_new(dtype: DType, value: Option) -> VortexResult { + vortex_ensure!( + Self::is_compatible(&dtype, value.as_ref()), + "Incompatible dtype {dtype} with value {}", + value.map(|v| format!("{}", v)).unwrap_or_default() ); - if self.is_null() { - if target.is_nullable() { - return Ok(Scalar::new(target.clone(), self.value.clone())); - } - - vortex_bail!("Cannot cast null to {target}: target type is non-nullable") - } - - match &self.dtype { - DType::Null => unreachable!(), // Handled by `if self.is_null()` case. - DType::Bool(_) => self.as_bool().cast(target), - DType::Primitive(..) => self.as_primitive().cast(target), - DType::Decimal(..) => self.as_decimal().cast(target), - DType::Utf8(_) => self.as_utf8().cast(target), - DType::Binary(_) => self.as_binary().cast(target), - DType::Struct(..) => self.as_struct().cast(target), - DType::List(..) | DType::FixedSizeList(..) => self.as_list().cast(target), - DType::Extension(..) => self.as_extension().cast(target), - } + Ok(Self { dtype, value }) } - /// Converts the scalar to have a nullable version of its data type. - pub fn into_nullable(self) -> Self { - Self { - dtype: self.dtype.as_nullable(), - value: self.value, - } - } + /// Creates a new [`Scalar`] with the given [`DType`] and potentially null [`ScalarValue`] + /// without checking compatibility. + /// + /// # Safety + /// + /// The caller must ensure that the given [`DType`] and [`ScalarValue`] are compatible per the + /// rules defined in [`Self::is_compatible`]. + pub unsafe fn new_unchecked(dtype: DType, value: Option) -> Self { + debug_assert!( + Self::is_compatible(&dtype, value.as_ref()), + "Incompatible dtype {dtype} with value {}", + value.map(|v| format!("{}", v)).unwrap_or_default() + ); - /// Returns the size of the scalar in bytes, uncompressed. - pub fn nbytes(&self) -> usize { - match self.dtype() { - DType::Null => 0, - DType::Bool(_) => 1, - DType::Primitive(ptype, _) => ptype.byte_width(), - DType::Decimal(dt, _) => { - if dt.precision() <= i128::MAX_PRECISION { - size_of::() - } else { - size_of::() - } - } - DType::Binary(_) | DType::Utf8(_) => self - .value() - .as_buffer() - .ok() - .flatten() - .map_or(0, |s| s.len()), - DType::Struct(..) => self - .as_struct() - .fields() - .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::()) - .unwrap_or_default(), - DType::List(..) | DType::FixedSizeList(..) => self - .as_list() - .elements() - .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::()) - .unwrap_or_default(), - DType::Extension(_) => self.as_extension().storage().nbytes(), - } + Self { dtype, value } } - /// Creates a "zero"-value scalar value for the given data type. - /// - /// For nullable types the zero value is the underlying `DType`'s zero value. + /// Returns a default value for the given [`DType`]. /// - /// # Zero Values + /// For nullable types, this returns a null scalar. For non-nullable and non-nested types, this + /// returns the zero value for the type. /// - /// Here is the list of zero values for each [`DType`] (when the [`DType`] is non-nullable): - /// - `Bool`: `false` - /// - `Primitive`: `0` - /// - `Decimal`: `0` - /// - `Utf8`: `""` - /// - `Binary`: An empty buffer - /// - `List`: An empty list - /// - `FixedSizeList`: A list (with correct size) of zero values, which is determined by the - /// element [`DType`] - /// - `Struct`: A struct where each field has a zero value, which is determined by the field - /// [`DType`] - /// - `Extension`: The zero value of the storage [`DType`] + /// For non-nullable and nested types that may need null values in their children (as of right + /// now, that is _only_ `FixedSizeList` and `Struct`), this function will provide null default + /// children. /// - /// This is similar to `default_value` except in its handling of nullability. - pub fn zero_value(dtype: DType) -> Self { - match dtype { - DType::Null => Self::null(dtype), - DType::Bool(nullability) => Self::bool(false, nullability), - DType::Primitive(pt, nullability) => { - Self::primitive_value(PValue::zero(pt), pt, nullability) - } - DType::Decimal(dt, nullability) => { - Self::decimal(DecimalValue::from(0i8), dt, nullability) - } - DType::Utf8(nullability) => Self::utf8("", nullability), - DType::Binary(nullability) => Self::binary(Buffer::empty(), nullability), - DType::List(edt, nullability) => Self::list(edt, vec![], nullability), - DType::FixedSizeList(edt, size, nullability) => { - let elements = (0..size) - .map(|_| Scalar::zero_value(edt.as_ref().clone())) - .collect(); - Self::fixed_size_list(edt, elements, nullability) - } - DType::Struct(sf, nullability) => { - let fields: Vec<_> = sf.fields().map(Scalar::zero_value).collect(); - Self::struct_(DType::Struct(sf, nullability), fields) - } - DType::Extension(dt) => { - let scalar = Self::zero_value(dt.storage_dtype().clone()); - Self::extension_ref(dt, scalar) - } - } + /// See [`ScalarValue::zero_value`] for more details about "zero" values. + pub fn default_value(dtype: &DType) -> Self { + let value = ScalarValue::default_value(dtype); + // SAFETY: We assume that `default_value` creates a valid `ScalarValue` for the `DType`. + unsafe { Self::new_unchecked(dtype.clone(), value) } } - /// Returns true if the scalar is a zero value i.e., equal to a scalar returned from the ` zero_value ` method. - pub fn is_zero(&self) -> bool { - match self.dtype() { - DType::Null => true, - DType::Bool(_) => self.as_bool().value() == Some(false), - DType::Primitive(pt, _) => self.as_primitive().pvalue() == Some(PValue::zero(*pt)), - DType::Decimal(..) => { - self.as_decimal().decimal_value() == Some(DecimalValue::from(0i8)) - } - DType::Utf8(_) => self - .as_utf8() - .value() - .map(|v| v.is_empty()) - .unwrap_or(false), - DType::Binary(_) => self - .as_binary() - .value() - .map(|v| v.is_empty()) - .unwrap_or(false), - DType::Struct(..) => self - .as_struct() - .fields() - .map(|mut sf| sf.all(|f| f.is_zero())) - .unwrap_or(false), - DType::List(..) => self - .as_list() - .elements() - .map(|vals| vals.is_empty()) - .unwrap_or(false), - DType::FixedSizeList(..) => self - .as_list() - .elements() - .map(|vals| vals.iter().all(|f| f.is_zero())) - .unwrap_or(false), - DType::Extension(..) => self.as_extension().storage().is_zero(), - } + /// Returns a non-null zero / identity value for the given [`DType`]. + /// + /// See [`ScalarValue::zero_value`] for more details about "zero" values. + pub fn zero_value(dtype: &DType) -> Self { + let value = ScalarValue::zero_value(dtype); + // SAFETY: We assume that `zero_value` creates a valid `ScalarValue` for the `DType`. + unsafe { Self::new_unchecked(dtype.clone(), Some(value)) } } - /// Creates a "default" scalar value for the given data type. - /// - /// For nullable types, returns null. For non-nullable types, returns an appropriate zero/empty - /// value. - /// - /// # Default Values - /// - /// Here is the list of default values for each [`DType`] (when the [`DType`] is non-nullable): - /// - /// - `Null`: `null` - /// - `Bool`: `false` - /// - `Primitive`: `0` - /// - `Decimal`: `0` - /// - `Utf8`: `""` - /// - `Binary`: An empty buffer - /// - `List`: An empty list - /// - `FixedSizeList`: A list (with correct size) of default values, which is determined by the - /// element [`DType`] - /// - `Struct`: A struct where each field has a default value, which is determined by the field - /// [`DType`] - /// - `Extension`: The default value of the storage [`DType`] - pub fn default_value(dtype: DType) -> Self { - if dtype.is_nullable() { - return Self::null(dtype); - } + // Other methods. + + /// Check if the given [`ScalarValue`] is compatible with the given [`DType`]. + pub fn is_compatible(dtype: &DType, value: Option<&ScalarValue>) -> bool { + let Some(value) = value else { + return dtype.is_nullable(); + }; + // From here on, we know that the value is not null. match dtype { - DType::Null => Self::null(dtype), - DType::Bool(nullability) => Self::bool(false, nullability), - DType::Primitive(pt, nullability) => { - Self::primitive_value(PValue::zero(pt), pt, nullability) + DType::Null => false, + DType::Bool(_) => matches!(value, ScalarValue::Bool(_)), + DType::Primitive(ptype, _) => { + if let ScalarValue::Primitive(pvalue) = value { + // Note that this is a backwards compatibility check for poor design in the + // previous implementation. `f16` `ScalarValue`s used to be serialized as + // `pb::ScalarValue::Uint64Value(v.to_bits() as u64)`, so we need to ensure that + // we can still represent them as such. + let f16_backcompat_still_works = + matches!(ptype, &PType::F16) && matches!(pvalue, PValue::U64(_)); + + f16_backcompat_still_works || pvalue.ptype() == *ptype + } else { + false + } + } + DType::Decimal(dec_dtype, _) => { + if let ScalarValue::Decimal(dvalue) = value { + dvalue.fits_in_precision(*dec_dtype) + } else { + false + } } - DType::Decimal(dt, nullability) => { - Self::decimal(DecimalValue::from(0i8), dt, nullability) + DType::Utf8(_) => matches!(value, ScalarValue::Utf8(_)), + DType::Binary(_) => matches!(value, ScalarValue::Binary(_)), + DType::List(elem_dtype, _) => { + if let ScalarValue::List(elements) = value { + elements + .iter() + .all(|element| Self::is_compatible(elem_dtype.as_ref(), element.as_ref())) + } else { + false + } } - DType::Utf8(nullability) => Self::utf8("", nullability), - DType::Binary(nullability) => Self::binary(Buffer::empty(), nullability), - DType::List(edt, nullability) => Self::list(edt, vec![], nullability), - DType::FixedSizeList(edt, size, nullability) => { - let elements = (0..size) - .map(|_| Scalar::default_value(edt.as_ref().clone())) - .collect(); - Self::fixed_size_list(edt, elements, nullability) + DType::FixedSizeList(elem_dtype, size, _) => { + if let ScalarValue::List(elements) = value { + if elements.len() != *size as usize { + return false; + } + elements + .iter() + .all(|element| Self::is_compatible(elem_dtype.as_ref(), element.as_ref())) + } else { + false + } } - DType::Struct(sf, nullability) => { - let fields: Vec<_> = sf.fields().map(Scalar::default_value).collect(); - Self::struct_(DType::Struct(sf, nullability), fields) + DType::Struct(fields, _) => { + if let ScalarValue::List(values) = value { + if values.len() != fields.nfields() { + return false; + } + for (field, field_value) in fields.fields().zip(values.iter()) { + if !Self::is_compatible(&field, field_value.as_ref()) { + return false; + } + } + true + } else { + false + } } - DType::Extension(dt) => { - let scalar = Self::default_value(dt.storage_dtype().clone()); - Self::extension_ref(dt, scalar) + DType::Extension(ext_dtype) => { + // TODO(connor): Fix this when adding the correct extension scalars! + Self::is_compatible(ext_dtype.storage_dtype(), Some(value)) } } } -} - -/// This implementation block contains only `TryFrom` and `From` wrappers (`as_something`). -impl Scalar { - /// Returns a view of the scalar as a boolean scalar. - /// - /// # Panics - /// - /// Panics if the scalar is not a boolean type. - pub fn as_bool(&self) -> BoolScalar<'_> { - BoolScalar::try_from(self).vortex_expect("Failed to convert scalar to bool") - } - - /// Returns a view of the scalar as a boolean scalar if it has a boolean type. - pub fn as_bool_opt(&self) -> Option> { - matches!(self.dtype, DType::Bool(..)).then(|| self.as_bool()) - } - /// Returns a view of the scalar as a primitive scalar. - /// - /// # Panics - /// - /// Panics if the scalar is not a primitive type. - pub fn as_primitive(&self) -> PrimitiveScalar<'_> { - PrimitiveScalar::try_from(self).vortex_expect("Failed to convert scalar to primitive") + /// Check if two scalars are equal, ignoring nullability of the [`DType`]. + pub fn eq_ignore_nullability(&self, other: &Self) -> bool { + self.dtype.eq_ignore_nullability(&other.dtype) && self.value == other.value } - /// Returns a view of the scalar as a primitive scalar if it has a primitive type. - pub fn as_primitive_opt(&self) -> Option> { - matches!(self.dtype, DType::Primitive(..)).then(|| self.as_primitive()) + /// Returns the parts of the [`Scalar`]. + pub fn into_parts(self) -> (DType, Option) { + (self.dtype, self.value) } - /// Returns a view of the scalar as a decimal scalar. - /// - /// # Panics - /// - /// Panics if the scalar is not a decimal type. - pub fn as_decimal(&self) -> DecimalScalar<'_> { - DecimalScalar::try_from(self).vortex_expect("Failed to convert scalar to decimal") + /// Returns the [`DType`] of the [`Scalar`]. + pub fn dtype(&self) -> &DType { + &self.dtype } - /// Returns a view of the scalar as a decimal scalar if it has a decimal type. - pub fn as_decimal_opt(&self) -> Option> { - matches!(self.dtype, DType::Decimal(..)).then(|| self.as_decimal()) + /// Returns an optional [`ScalarValue`] of the [`Scalar`], where `None` means the value is null. + pub fn value(&self) -> Option<&ScalarValue> { + self.value.as_ref() } - /// Returns a view of the scalar as a UTF-8 string scalar. - /// - /// # Panics - /// - /// Panics if the scalar is not a UTF-8 type. - pub fn as_utf8(&self) -> Utf8Scalar<'_> { - Utf8Scalar::try_from(self).vortex_expect("Failed to convert scalar to utf8") - } - - /// Returns a view of the scalar as a UTF-8 string scalar if it has a UTF-8 type. - pub fn as_utf8_opt(&self) -> Option> { - matches!(self.dtype, DType::Utf8(..)).then(|| self.as_utf8()) + /// Returns the internal optional [`ScalarValue`], where `None` means the value is null, + /// consuming the [`Scalar`]. + pub fn into_value(self) -> Option { + self.value } - /// Returns a view of the scalar as a binary scalar. - /// - /// # Panics - /// - /// Panics if the scalar is not a binary type. - pub fn as_binary(&self) -> BinaryScalar<'_> { - BinaryScalar::try_from(self).vortex_expect("Failed to convert scalar to binary") + /// Returns `true` if the [`Scalar`] has a non-null value. + pub fn is_valid(&self) -> bool { + self.value.is_some() } - /// Returns a view of the scalar as a binary scalar if it has a binary type. - pub fn as_binary_opt(&self) -> Option> { - matches!(self.dtype, DType::Binary(..)).then(|| self.as_binary()) + /// Returns `true` if the [`Scalar`] is null. + pub fn is_null(&self) -> bool { + self.value.is_none() } - /// Returns a view of the scalar as a struct scalar. + /// Returns `true` if the [`Scalar`] has a non-null zero value. /// - /// # Panics - /// - /// Panics if the scalar is not a struct type. - pub fn as_struct(&self) -> StructScalar<'_> { - StructScalar::try_from(self).vortex_expect("Failed to convert scalar to struct") - } + /// Returns `None` if the scalar is null, otherwise returns `Some(true)` if the value is zero + /// and `Some(false)` otherwise. + pub fn is_zero(&self) -> Option { + let value = self.value()?; - /// Returns a view of the scalar as a struct scalar if it has a struct type. - pub fn as_struct_opt(&self) -> Option> { - matches!(self.dtype, DType::Struct(..)).then(|| self.as_struct()) - } + let is_zero = match self.dtype() { + DType::Null => vortex_panic!("non-null value somehow had `DType::Null`"), + DType::Bool(_) => !value.as_bool(), + DType::Primitive(..) => value.as_primitive().is_zero(), + DType::Decimal(..) => value.as_decimal().is_zero(), + DType::Utf8(_) => value.as_utf8().is_empty(), + DType::Binary(_) => value.as_binary().is_empty(), + DType::List(..) => value.as_list().is_empty(), + DType::FixedSizeList(_, list_size, _) => value.as_list().len() == *list_size as usize, + DType::Struct(struct_fields, _) => value.as_list().len() == struct_fields.nfields(), + DType::Extension(_) => self.as_extension().storage().is_zero()?, + }; - /// Returns a view of the scalar as a list scalar. - /// - /// Note that we use [`ListScalar`] to represent **both** [`DType::List`] and - /// [`DType::FixedSizeList`]. - /// - /// # Panics - /// - /// Panics if the scalar is not a list type. - pub fn as_list(&self) -> ListScalar<'_> { - ListScalar::try_from(self).vortex_expect("Failed to convert scalar to list") + Some(is_zero) } - /// Returns a view of the scalar as a list scalar if it has a list type. + /// Reinterprets the bytes of this scalar as a different primitive type. /// - /// Note that we use [`ListScalar`] to represent **both** [`DType::List`] and - /// [`DType::FixedSizeList`]. - pub fn as_list_opt(&self) -> Option> { - matches!(self.dtype, DType::List(..) | DType::FixedSizeList(..)).then(|| self.as_list()) - } - - /// Returns a view of the scalar as an extension scalar. + /// # Errors /// - /// # Panics - /// - /// Panics if the scalar is not an extension type. - pub fn as_extension(&self) -> ExtScalar<'_> { - ExtScalar::try_from(self).vortex_expect("Failed to convert scalar to extension") - } - - /// Returns a view of the scalar as an extension scalar if it has an extension type. - pub fn as_extension_opt(&self) -> Option> { - matches!(self.dtype, DType::Extension(..)).then(|| self.as_extension()) - } -} - -/// It is common to represent a nullable type `T` as an `Option`, so we implement a blanket -/// implementation for all `Option` to simply be a nullable `T`. -impl From> for Scalar -where - T: NativeDType, - Scalar: From, -{ - /// A blanket implementation for all `Option`. - fn from(value: Option) -> Self { - value - .map(Scalar::from) - .map(|x| x.into_nullable()) - .unwrap_or_else(|| Scalar { - dtype: T::dtype().as_nullable(), - value: ScalarValue(InnerScalarValue::Null), - }) - } -} - -impl From> for Scalar -where - T: NativeDType, - Scalar: From, -{ - /// Converts a vector into a `Scalar` (where the value is a `ListScalar`). - fn from(vec: Vec) -> Self { - Scalar { - dtype: DType::List(Arc::from(T::dtype()), Nullability::NonNullable), - value: ScalarValue::from(vec), + /// Panics if the scalar is not a primitive type or if the types have different byte widths. + pub fn primitive_reinterpret_cast(&self, ptype: PType) -> VortexResult { + let primitive = self.as_primitive(); + if primitive.ptype() == ptype { + return Ok(self.clone()); } - } -} -impl TryFrom for Vec -where - T: for<'b> TryFrom<&'b Scalar, Error = VortexError>, -{ - type Error = VortexError; - - fn try_from(value: Scalar) -> Result { - Vec::try_from(&value) - } -} + vortex_ensure_eq!( + primitive.ptype().byte_width(), + ptype.byte_width(), + "can't reinterpret cast between integers of two different widths" + ); -impl<'a, T> TryFrom<&'a Scalar> for Vec -where - T: for<'b> TryFrom<&'b Scalar, Error = VortexError>, -{ - type Error = VortexError; - - fn try_from(value: &'a Scalar) -> Result { - ListScalar::try_from(value)? - .elements() - .ok_or_else(|| vortex_err!("Expected non-null list"))? - .into_iter() - .map(|e| T::try_from(&e)) - .collect::>>() + Scalar::try_new( + DType::Primitive(ptype, self.dtype().nullability()), + primitive + .pvalue() + .map(|p| p.reinterpret_cast(ptype)) + .map(ScalarValue::Primitive), + ) } -} -impl PartialEq for Scalar { - fn eq(&self, other: &Self) -> bool { - if !self.dtype.eq_ignore_nullability(&other.dtype) { - return false; - } + /// Returns the size of the scalar in bytes, uncompressed. + #[cfg(test)] + pub fn nbytes(&self) -> usize { + use vortex_dtype::NativeDecimalType; + use vortex_dtype::i256; match self.dtype() { - DType::Null => true, - DType::Bool(_) => self.as_bool() == other.as_bool(), - DType::Primitive(..) => self.as_primitive() == other.as_primitive(), - DType::Decimal(..) => self.as_decimal() == other.as_decimal(), - DType::Utf8(_) => self.as_utf8() == other.as_utf8(), - DType::Binary(_) => self.as_binary() == other.as_binary(), - DType::Struct(..) => self.as_struct() == other.as_struct(), - DType::List(..) | DType::FixedSizeList(..) => self.as_list() == other.as_list(), - DType::Extension(_) => self.as_extension() == other.as_extension(), + DType::Null => 0, + DType::Bool(_) => 1, + DType::Primitive(ptype, _) => ptype.byte_width(), + DType::Decimal(dt, _) => { + if dt.precision() <= i128::MAX_PRECISION { + size_of::() + } else { + size_of::() + } + } + DType::Utf8(_) => self + .value() + .map_or_else(|| 0, |value| value.as_utf8().len()), + DType::Binary(_) => self + .value() + .map_or_else(|| 0, |value| value.as_binary().len()), + DType::Struct(..) => self + .as_struct() + .fields_iter() + .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::()) + .unwrap_or_default(), + DType::List(..) | DType::FixedSizeList(..) => self + .as_list() + .elements() + .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::()) + .unwrap_or_default(), + DType::Extension(_) => self.as_extension().storage().nbytes(), } } } -impl Eq for Scalar {} - impl PartialOrd for Scalar { /// Compares two scalar values for ordering. /// @@ -580,62 +387,6 @@ impl PartialOrd for Scalar { if !self.dtype().eq_ignore_nullability(other.dtype()) { return None; } - match self.dtype() { - DType::Null => Some(Ordering::Equal), - DType::Bool(_) => self.as_bool().partial_cmp(&other.as_bool()), - DType::Primitive(..) => self.as_primitive().partial_cmp(&other.as_primitive()), - DType::Decimal(..) => self.as_decimal().partial_cmp(&other.as_decimal()), - DType::Utf8(_) => self.as_utf8().partial_cmp(&other.as_utf8()), - DType::Binary(_) => self.as_binary().partial_cmp(&other.as_binary()), - DType::Struct(..) => self.as_struct().partial_cmp(&other.as_struct()), - DType::List(..) | DType::FixedSizeList(..) => { - self.as_list().partial_cmp(&other.as_list()) - } - DType::Extension(_) => self.as_extension().partial_cmp(&other.as_extension()), - } - } -} - -impl Hash for Scalar { - fn hash(&self, state: &mut H) { - match self.dtype() { - DType::Null => self.dtype().hash(state), // Hash the dtype instead of the value - DType::Bool(_) => self.as_bool().hash(state), - DType::Primitive(..) => self.as_primitive().hash(state), - DType::Decimal(..) => self.as_decimal().hash(state), - DType::Utf8(_) => self.as_utf8().hash(state), - DType::Binary(_) => self.as_binary().hash(state), - DType::Struct(..) => self.as_struct().hash(state), - DType::List(..) | DType::FixedSizeList(..) => self.as_list().hash(state), - DType::Extension(_) => self.as_extension().hash(state), - } - } -} - -impl AsRef for Scalar { - fn as_ref(&self) -> &Self { - self - } -} - -impl From> for Scalar { - fn from(pscalar: PrimitiveScalar<'_>) -> Self { - let dtype = pscalar.dtype().clone(); - let value = pscalar - .pvalue() - .map(|pvalue| ScalarValue(InnerScalarValue::Primitive(pvalue))) - .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)); - Self::new(dtype, value) - } -} - -impl From> for Scalar { - fn from(decimal_scalar: DecimalScalar<'_>) -> Self { - let dtype = decimal_scalar.dtype().clone(); - let value = decimal_scalar - .decimal_value() - .map(|value| ScalarValue(InnerScalarValue::Decimal(value))) - .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)); - Self::new(dtype, value) + self.value().partial_cmp(&other.value()) } } diff --git a/vortex-scalar/src/scalar_value.rs b/vortex-scalar/src/scalar_value.rs index ffa75fd80d7..b11bfa329ab 100644 --- a/vortex-scalar/src/scalar_value.rs +++ b/vortex-scalar/src/scalar_value.rs @@ -1,288 +1,270 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +//! Core [`ScalarValue`] type definition. + +use std::cmp::Ordering; use std::fmt::Display; -use std::sync::Arc; +use std::fmt::Formatter; -use bytes::BufMut; use itertools::Itertools; -use prost::Message; use vortex_buffer::BufferString; use vortex_buffer::ByteBuffer; -use vortex_dtype::NativeDType; -use vortex_dtype::i256; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_err; -use vortex_proto::scalar as pb; +use vortex_dtype::DType; +use vortex_error::vortex_panic; -use crate::Scalar; -use crate::decimal::DecimalValue; -use crate::pvalue::PValue; +use crate::DecimalValue; +// use crate::ExtScalarRef; +use crate::PValue; -/// Represents the internal data of a scalar value. Must be interpreted by wrapping up with a -/// [`vortex_dtype::DType`] to make a [`super::Scalar`]. +/// The value stored in a [`Scalar`][crate::Scalar]. /// -/// Note that these values can be deserialized from JSON or other formats. So a [`PValue`] may not -/// have the correct width for what the [`vortex_dtype::DType`] expects. Primitive values should therefore be -/// read using [`super::PrimitiveScalar`] which will handle the conversion. -#[derive(Debug, Clone)] -pub struct ScalarValue(pub(crate) InnerScalarValue); - -/// It is common to represent a nullable type `T` as an `Option`, so we implement a blanket -/// implementation for all `Option` to simply be a nullable `T`. -impl From> for ScalarValue -where - T: NativeDType, - ScalarValue: From, -{ - fn from(value: Option) -> Self { - value - .map(ScalarValue::from) - .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)) - } -} - -impl From> for ScalarValue -where - T: NativeDType, - Scalar: From, -{ - /// Converts a vector into a `ScalarValue` (specifically a `ListScalar`). - fn from(value: Vec) -> Self { - ScalarValue(InnerScalarValue::List( - value - .into_iter() - .map(|x| { - let scalar: Scalar = T::into(x); - scalar.into_value() - }) - .collect::>(), - )) - } -} - -#[derive(Debug, Clone)] -pub(crate) enum InnerScalarValue { - Null, +/// This enum represents the possible non-null values that can be stored in a scalar. When the +/// scalar is null, the value is represented as `None` in the `Option` field. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum ScalarValue { + /// A boolean value. Bool(bool), + /// A primitive numeric value. Primitive(PValue), + /// A decimal value. Decimal(DecimalValue), - Buffer(Arc), - BufferString(Arc), - List(Arc<[ScalarValue]>), + /// A UTF-8 encoded string value. + Utf8(BufferString), + /// A binary (byte array) value. + Binary(ByteBuffer), + /// A list of potentially null scalar values. + List(Vec>), + // Extension(ExtScalarRef), } +// TODO(connor): Docs can be improved (in combination with the associated `Scalar` methods). impl ScalarValue { - /// Serializes the scalar value to Protocol Buffers format. - pub fn to_protobytes(&self) -> B { - let pb_scalar = pb::ScalarValue::from(self); - - let mut buf = B::default(); - pb_scalar - .encode(&mut buf) - .vortex_expect("protobuf encoding should succeed"); - buf - } - - /// Deserializes a scalar value from Protocol Buffers format. - pub fn from_protobytes(buf: &[u8]) -> VortexResult { - ScalarValue::try_from(&pb::ScalarValue::decode(buf)?) - } -} - -fn to_hex(slice: &[u8]) -> String { - slice - .iter() - .format_with("", |f, b| b(&format_args!("{f:02x}"))) - .to_string() -} - -impl Display for ScalarValue { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -impl Display for InnerScalarValue { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Bool(b) => write!(f, "{b}"), - Self::Primitive(pvalue) => write!(f, "{pvalue}"), - Self::Decimal(value) => write!(f, "{value}"), - Self::Buffer(buf) => { - if buf.len() > 10 { - write!( - f, - "{}..{}", - to_hex(&buf[0..5]), - to_hex(&buf[buf.len() - 5..buf.len()]), - ) - } else { - write!(f, "{}", to_hex(buf)) - } + // TODO(connor): There is an inconsistency here w.r.t. `FixedSizeList` and `Struct` types, since + // we say that the zero value for those are **not** empty lists. But here we say that a list is + // a "zero" value if it is empty. So depending on the dtype this might just be incorrect! + // /// Returns true if the scalar represents the zero / identity value for its [`DType`]. + // /// + // /// Returns false if the scalar is null. + // /// + // /// See [`Scalar::zero_value()`] for more details about "zero" values. + // /// + // /// [`Scalar::zero_value()`]: crate::Scalar::zero_value + // pub fn is_zero(&self) -> bool { + // // TODO(connor): Is it better to just do == Self::zero_value()? + // match self { + // ScalarValue::Bool(b) => !*b, + // ScalarValue::Primitive(p) => p.is_zero(), + // ScalarValue::Decimal(d) => d.is_zero(), + // ScalarValue::Utf8(s) => s.is_empty(), + // ScalarValue::Binary(b) => b.is_empty(), + // ScalarValue::List(elems) => elems.is_empty(), + // } + // } + + /// Returns the zero / identity value for the given [`DType`]. + /// + /// # Zero Values + /// + /// Here is the list of zero values for each [`DType`] (when the [`DType`] is non-nullable): + /// + /// - `Null`: Does not have a "zero" value + /// - `Bool`: `false` + /// - `Primitive`: `0` + /// - `Decimal`: `0` + /// - `Utf8`: `""` + /// - `Binary`: An empty buffer + /// - `List`: An empty list + /// - `FixedSizeList`: A list (with correct size) of zero values, which is determined by the + /// element [`DType`] + /// - `Struct`: A struct where each field has a zero value, which is determined by the field + /// [`DType`] + /// + /// - `Extension`: TODO(connor): Is this right? + /// The zero value of the storage [`DType`] + pub fn zero_value(dtype: &DType) -> Self { + match dtype { + DType::Null => vortex_panic!("Null dtype has no zero value"), + DType::Bool(_) => Self::Bool(false), + DType::Primitive(ptype, _) => Self::Primitive(PValue::zero(ptype)), + DType::Decimal(dt, ..) => Self::Decimal(DecimalValue::zero(dt)), + DType::Utf8(_) => Self::Utf8(BufferString::empty()), + DType::Binary(_) => Self::Binary(ByteBuffer::empty()), + DType::List(..) => Self::List(vec![]), + DType::FixedSizeList(edt, size, _) => { + let elements = (0..*size).map(|_| Some(Self::zero_value(edt))).collect(); + Self::List(elements) } - Self::BufferString(bufstr) => { - let bufstr = bufstr.as_str(); - let str_len = bufstr.chars().count(); + DType::Struct(fields, _) => { + let field_values = fields + .fields() + .map(|f| Some(Self::zero_value(&f))) + .collect(); + Self::List(field_values) + } + DType::Extension(ext_dtype) => Self::zero_value(ext_dtype.storage_dtype()), // TODO(connor): Fix this! + } + } - if str_len > 10 { - let prefix = String::from_iter(bufstr.chars().take(5)); - let suffix = String::from_iter(bufstr.chars().skip(str_len - 5)); + /// A similar function to [`ScalarValue::zero_value`], but for nullable [`DType`]s, this returns + /// `None` instead. + /// + /// For non-nullable and nested types that may need null values in their children (as of right + /// now, that is _only_ `FixedSizeList` and `Struct`), this function will provide `None` as the + /// default child values (whereas [`ScalarValue::zero_value`] would provide `Some(_)`). + pub fn default_value(dtype: &DType) -> Option { + if dtype.is_nullable() { + return None; + } - write!(f, "\"{prefix}..{suffix}\"") - } else { - write!(f, "\"{bufstr}\"") - } + Some(match dtype { + DType::Null => vortex_panic!("Null dtype has no zero value"), + DType::Bool(_) => Self::Bool(false), + DType::Primitive(ptype, _) => Self::Primitive(PValue::zero(ptype)), + DType::Decimal(dt, ..) => Self::Decimal(DecimalValue::zero(dt)), + DType::Utf8(_) => Self::Utf8(BufferString::empty()), + DType::Binary(_) => Self::Binary(ByteBuffer::empty()), + DType::List(..) => Self::List(vec![]), + DType::FixedSizeList(edt, size, _) => { + let elements = (0..*size).map(|_| Self::default_value(edt)).collect(); + Self::List(elements) } - Self::List(elems) => { - write!(f, "[{}]", elems.iter().format(",")) + DType::Struct(fields, _) => { + let field_values = fields.fields().map(|f| Self::default_value(&f)).collect(); + Self::List(field_values) } - Self::Null => write!(f, "null"), - } + DType::Extension(ext_dtype) => Self::default_value(ext_dtype.storage_dtype())?, // TODO(connor): Fix this! + }) } } impl ScalarValue { - /// Creates a null scalar value. - pub const fn null() -> Self { - ScalarValue(InnerScalarValue::Null) - } - - /// Returns true if this is a null value. - #[inline] - pub fn is_null(&self) -> bool { - self.0.is_null() - } - - /// Returns scalar as a null value - #[inline] - pub(crate) fn as_null(&self) -> VortexResult<()> { - self.0.as_null() - } - - /// Returns scalar as a boolean value - #[inline] - pub(crate) fn as_bool(&self) -> VortexResult> { - self.0.as_bool() - } - - /// Return scalar as a primitive value. PValues don't match dtypes but will be castable to the scalars dtype - #[inline] - pub(crate) fn as_pvalue(&self) -> VortexResult> { - self.0.as_pvalue() - } - - /// Returns scalar as a decimal value - #[inline] - pub(crate) fn as_decimal(&self) -> VortexResult> { - self.0.as_decimal() - } - - /// Returns scalar as a binary buffer - #[inline] - pub(crate) fn as_buffer(&self) -> VortexResult>> { - self.0.as_buffer() - } - - /// Returns scalar as a string buffer - #[inline] - pub(crate) fn as_buffer_string(&self) -> VortexResult>> { - self.0.as_buffer_string() + /// Returns the boolean value, panicking if the value is not a [`Bool`][ScalarValue::Bool]. + pub fn as_bool(&self) -> bool { + match self { + ScalarValue::Bool(b) => *b, + _ => vortex_panic!("ScalarValue is not a Bool"), + } } - /// Returns scalar as a list value - #[inline] - pub(crate) fn as_list(&self) -> VortexResult>> { - self.0.as_list() - } -} - -impl InnerScalarValue { - #[inline] - pub(crate) fn is_null(&self) -> bool { - matches!(self, InnerScalarValue::Null) + /// Returns the primitive value, panicking if the value is not a + /// [`Primitive`][ScalarValue::Primitive]. + pub fn as_primitive(&self) -> &PValue { + match self { + ScalarValue::Primitive(p) => p, + _ => vortex_panic!("ScalarValue is not a Primitive"), + } } - #[inline] - pub(crate) fn as_null(&self) -> VortexResult<()> { - if matches!(self, InnerScalarValue::Null) { - Ok(()) - } else { - Err(vortex_err!("Expected a Null scalar, found {self}")) + /// Returns the decimal value, panicking if the value is not a + /// [`Decimal`][ScalarValue::Decimal]. + pub fn as_decimal(&self) -> &DecimalValue { + match self { + ScalarValue::Decimal(d) => d, + _ => vortex_panic!("ScalarValue is not a Decimal"), } } - #[inline] - pub(crate) fn as_bool(&self) -> VortexResult> { + /// Returns the UTF-8 string value, panicking if the value is not a [`Utf8`][ScalarValue::Utf8]. + pub fn as_utf8(&self) -> &BufferString { match self { - InnerScalarValue::Null => Ok(None), - InnerScalarValue::Bool(b) => Ok(Some(*b)), - other => Err(vortex_err!("Expected a bool scalar, found {other}",)), + ScalarValue::Utf8(s) => s, + _ => vortex_panic!("ScalarValue is not a Utf8"), } } - /// FIXME(ngates): PValues are such a footgun... we should probably remove this. - /// But the other accessors can sometimes be useful? e.g. as_buffer. But maybe we just force - /// the user to switch over Utf8 and Binary and use the correct Scalar wrapper? - #[inline] - pub(crate) fn as_pvalue(&self) -> VortexResult> { + /// Returns the binary value, panicking if the value is not a [`Binary`][ScalarValue::Binary]. + pub fn as_binary(&self) -> &ByteBuffer { match self { - InnerScalarValue::Null => Ok(None), - InnerScalarValue::Primitive(pvalue) => Ok(Some(*pvalue)), - other => Err(vortex_err!("Expected a primitive scalar, found {other}")), + ScalarValue::Binary(b) => b, + _ => vortex_panic!("ScalarValue is not a Binary"), } } - #[inline] - pub(crate) fn as_decimal(&self) -> VortexResult> { + /// Returns the list elements, panicking if the value is not a [`List`][ScalarValue::List]. + pub fn as_list(&self) -> &[Option] { match self { - InnerScalarValue::Null => Ok(None), - InnerScalarValue::Decimal(v) => Ok(Some(*v)), - InnerScalarValue::Buffer(b) => Ok(Some(match b.len() { - 1 => DecimalValue::I8(b[0] as i8), - 2 => DecimalValue::I16(i16::from_le_bytes(b.as_slice().try_into()?)), - 4 => DecimalValue::I32(i32::from_le_bytes(b.as_slice().try_into()?)), - 8 => DecimalValue::I64(i64::from_le_bytes(b.as_slice().try_into()?)), - 16 => DecimalValue::I128(i128::from_le_bytes(b.as_slice().try_into()?)), - 32 => DecimalValue::I256(i256::from_le_bytes(b.as_slice().try_into()?)), - l => vortex_bail!("Buffer is not a decimal value length {l}"), - })), - _ => vortex_bail!("Expected a decimal scalar, found {:?}", self), + ScalarValue::List(elements) => elements, + _ => vortex_panic!("ScalarValue is not a List"), } } - #[inline] - pub(crate) fn as_buffer(&self) -> VortexResult>> { - match &self { - InnerScalarValue::Null => Ok(None), - InnerScalarValue::Buffer(b) => Ok(Some(b.clone())), - InnerScalarValue::BufferString(b) => { - Ok(Some(Arc::new(b.as_ref().clone().into_inner()))) - } - _ => Err(vortex_err!("Expected a binary scalar, found {:?}", self)), + // pub fn as_extension(&self) -> &ExtScalarRef { + // match self { + // ScalarValue::Extension(e) => e, + // _ => vortex_panic!("ScalarValue is not an Extension"), + // } + // } +} + +impl PartialOrd for ScalarValue { + fn partial_cmp(&self, other: &Self) -> Option { + match (self, other) { + (ScalarValue::Bool(a), ScalarValue::Bool(b)) => a.partial_cmp(b), + (ScalarValue::Primitive(a), ScalarValue::Primitive(b)) => a.partial_cmp(b), + (ScalarValue::Decimal(a), ScalarValue::Decimal(b)) => a.partial_cmp(b), + (ScalarValue::Utf8(a), ScalarValue::Utf8(b)) => a.partial_cmp(b), + (ScalarValue::Binary(a), ScalarValue::Binary(b)) => a.partial_cmp(b), + (ScalarValue::List(a), ScalarValue::List(b)) => a.partial_cmp(b), + // (ScalarValue::Extension(a), ScalarValue::Extension(b)) => a.partial_cmp(b), + _ => None, } } +} + +impl Display for ScalarValue { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + ScalarValue::Bool(b) => write!(f, "{}", b), + ScalarValue::Primitive(p) => write!(f, "{}", p), + ScalarValue::Decimal(d) => write!(f, "{}", d), + ScalarValue::Utf8(s) => { + let bufstr = s.as_str(); + let str_len = bufstr.chars().count(); - #[inline] - pub(crate) fn as_buffer_string(&self) -> VortexResult>> { - match &self { - InnerScalarValue::Null => Ok(None), - InnerScalarValue::Buffer(b) => { - Ok(Some(Arc::new(BufferString::try_from(b.as_ref().clone())?))) + if str_len > 10 { + let prefix = String::from_iter(bufstr.chars().take(5)); + let suffix = String::from_iter(bufstr.chars().skip(str_len - 5)); + + write!(f, "\"{prefix}..{suffix}\"") + } else { + write!(f, "\"{bufstr}\"") + } + } + ScalarValue::Binary(b) => { + if b.len() > 10 { + write!( + f, + "{}..{}", + to_hex(&b[0..5]), + to_hex(&b[b.len() - 5..b.len()]), + ) + } else { + write!(f, "{}", to_hex(b)) + } } - InnerScalarValue::BufferString(b) => Ok(Some(b.clone())), - _ => Err(vortex_err!("Expected a string scalar, found {:?}", self)), + ScalarValue::List(elements) => { + write!(f, "[")?; + for (i, element) in elements.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + match element { + None => write!(f, "null")?, + Some(e) => write!(f, "{}", e)?, + } + } + write!(f, "]") + } // + // ScalarValue::Extension(e) => write!(f, "{}", e), } } +} - #[inline] - pub(crate) fn as_list(&self) -> VortexResult>> { - match &self { - InnerScalarValue::Null => Ok(None), - InnerScalarValue::List(l) => Ok(Some(l)), - _ => Err(vortex_err!("Expected a list scalar, found {:?}", self)), - } - } +/// Formats a byte slice as a hexadecimal string. +fn to_hex(slice: &[u8]) -> String { + slice + .iter() + .format_with("", |f, b| b(&format_args!("{f:02x}"))) + .to_string() } diff --git a/vortex-scalar/src/tests/casting.rs b/vortex-scalar/src/tests/casting.rs index 3d38699f426..910e52186b1 100644 --- a/vortex-scalar/src/tests/casting.rs +++ b/vortex-scalar/src/tests/casting.rs @@ -19,13 +19,13 @@ mod tests { use vortex_error::VortexExpect; use vortex_error::VortexResult; - use crate::InnerScalarValue; use crate::PValue; use crate::Scalar; use crate::ScalarValue; #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] struct Apples; + impl ExtDTypeVTable for Apples { type Metadata = usize; @@ -52,11 +52,16 @@ mod tests { #[test] fn cast_to_from_extension_types() { let apples = Apples::new(); + let ext_dtype = DType::Extension(apples.clone().erased()); - let ext_scalar = Scalar::new(ext_dtype.clone(), ScalarValue(InnerScalarValue::Bool(true))); + let ext_scalar = Scalar::new( + ext_dtype.clone(), + Some(ScalarValue::Primitive(PValue::U16(1000))), + ); + let storage_scalar = Scalar::new( DType::clone(apples.storage_dtype()), - ScalarValue(InnerScalarValue::Primitive(PValue::U16(1000))), + Some(ScalarValue::Primitive(PValue::U16(1000))), ); // to self @@ -89,15 +94,6 @@ mod tests { let actual = storage_scalar.cast(expected_dtype).unwrap(); assert_eq!(actual.dtype(), expected_dtype); - // cast from *compatible* storage type to extension - let storage_scalar_u64 = Scalar::new( - DType::clone(apples.storage_dtype()), - ScalarValue(InnerScalarValue::Primitive(PValue::U64(1000))), - ); - let expected_dtype = &ext_dtype; - let actual = storage_scalar_u64.cast(expected_dtype).unwrap(); - assert_eq!(actual.dtype(), expected_dtype); - // cast from *incompatible* storage type to extension let apples_u8 = ExtDType::::try_new(0, DType::Primitive(PType::U8, Nullability::NonNullable)) @@ -112,86 +108,6 @@ mod tests { ); } - #[test] - fn test_f16_coercion_from_u64() { - let f16_value = f16::from_f32(5.722046e-6); - let u64_bits = f16_value.to_bits() as u64; - - let scalar = Scalar::new( - DType::Primitive(PType::F16, Nullability::NonNullable), - ScalarValue(InnerScalarValue::Primitive(PValue::U64(u64_bits))), - ); - - assert_eq!( - scalar.as_primitive().pvalue().unwrap(), - PValue::F16(f16_value) - ); - } - - #[test] - fn test_f16_coercion_from_u32() { - let f16_value = f16::from_f32(0.42); - let u32_bits = f16_value.to_bits() as u32; - - let scalar = Scalar::new( - DType::Primitive(PType::F16, Nullability::NonNullable), - ScalarValue(InnerScalarValue::Primitive(PValue::U32(u32_bits))), - ); - - assert_eq!( - scalar.as_primitive().pvalue().unwrap(), - PValue::F16(f16_value) - ); - } - - #[test] - fn test_f16_coercion_from_u16() { - let f16_value = f16::from_f32(1.5); - let u16_bits = f16_value.to_bits(); - - let scalar = Scalar::new( - DType::Primitive(PType::F16, Nullability::NonNullable), - ScalarValue(InnerScalarValue::Primitive(PValue::U16(u16_bits))), - ); - - assert_eq!( - scalar.as_primitive().pvalue().unwrap(), - PValue::F16(f16_value) - ); - } - - #[test] - fn test_f32_coercion_from_u32() { - let f32_value = std::f32::consts::PI; - let u32_bits = f32_value.to_bits(); - - let scalar = Scalar::new( - DType::Primitive(PType::F32, Nullability::NonNullable), - ScalarValue(InnerScalarValue::Primitive(PValue::U32(u32_bits))), - ); - - assert_eq!( - scalar.as_primitive().pvalue().unwrap(), - PValue::F32(f32_value) - ); - } - - #[test] - fn test_f64_coercion_from_u64() { - let f64_value = std::f64::consts::E; - let u64_bits = f64_value.to_bits(); - - let scalar = Scalar::new( - DType::Primitive(PType::F64, Nullability::NonNullable), - ScalarValue(InnerScalarValue::Primitive(PValue::U64(u64_bits))), - ); - - assert_eq!( - scalar.as_primitive().pvalue().unwrap(), - PValue::F64(f64_value) - ); - } - #[test] fn test_struct_field_coercion() { let f16_value = f16::from_f32(0.42); @@ -216,20 +132,19 @@ mod tests { ); let field_values = vec![ - ScalarValue(InnerScalarValue::Primitive(PValue::U32(42))), - ScalarValue(InnerScalarValue::Primitive(PValue::U64( - f16_value.to_bits() as u64, + Some(ScalarValue::Primitive(PValue::U32(42))), + Some(ScalarValue::Primitive(PValue::U64( + f16_value.to_bits() as u64 ))), - ScalarValue(InnerScalarValue::Primitive(PValue::F32(f32_value))), + Some(ScalarValue::Primitive(PValue::F32(f32_value))), ]; - let scalar = Scalar::new( - struct_dtype, - ScalarValue(InnerScalarValue::List(field_values.into())), - ); + let scalar = Scalar::new(struct_dtype, Some(ScalarValue::List(field_values))); let struct_scalar = scalar.as_struct(); - let fields = struct_scalar.fields().unwrap().collect::>(); + let fields: Vec<_> = (0..3) + .map(|i| struct_scalar.field_by_idx(i).unwrap()) + .collect(); // Check first field (no coercion needed) assert_eq!(fields[0].as_primitive().pvalue().unwrap(), PValue::U32(42)); @@ -249,11 +164,11 @@ mod tests { #[test] fn test_fake_coercion_for_matching_type() { - // Test that when types already match, no coercion happens + // Test that when types already match, no coercion happens. let i32_value = 42i32; let scalar = Scalar::new( DType::Primitive(PType::I32, Nullability::NonNullable), - ScalarValue(InnerScalarValue::Primitive(PValue::I32(i32_value))), + Some(ScalarValue::Primitive(PValue::I32(i32_value))), ); assert_eq!( @@ -273,18 +188,15 @@ mod tests { ); let elements = vec![ - ScalarValue(InnerScalarValue::Primitive(PValue::U64( - f16_value1.to_bits() as u64, + Some(ScalarValue::Primitive(PValue::U64( + f16_value1.to_bits() as u64 ))), - ScalarValue(InnerScalarValue::Primitive(PValue::U64( - f16_value2.to_bits() as u64, + Some(ScalarValue::Primitive(PValue::U64( + f16_value2.to_bits() as u64 ))), ]; - let scalar = Scalar::new( - list_dtype, - ScalarValue(InnerScalarValue::List(elements.into())), - ); + let scalar = Scalar::new(list_dtype, Some(ScalarValue::List(elements))); let list_scalar = scalar.as_list(); let elements = list_scalar.elements().unwrap(); @@ -300,13 +212,13 @@ mod tests { #[test] #[should_panic] fn test_coercion_with_overflow_protection() { - // Test that values too large for target type are not coerced + // Test that values too large for target type are not coerced. let large_u64 = u64::MAX; - // This should NOT be coerced to F16 because it's too large + // This should NOT be coerced to F16 because it's too large. let scalar = Scalar::new( DType::Primitive(PType::F16, Nullability::NonNullable), - ScalarValue(InnerScalarValue::Primitive(PValue::U64(large_u64))), + Some(ScalarValue::Primitive(PValue::U64(large_u64))), ); let _ = scalar.as_primitive(); // Should panic @@ -342,7 +254,7 @@ mod tests { let scalar = Scalar::new( DType::Extension(ext_dtype.erased()), - ScalarValue(InnerScalarValue::Primitive(PValue::U64(u64_bits))), + Some(ScalarValue::Primitive(PValue::U64(u64_bits))), ); // Verify the value was coerced to f16 @@ -396,15 +308,15 @@ mod tests { // Create struct value with f16 stored as u64 let f16_value = f16::from_f32(1.5); let field_values = vec![ - ScalarValue(InnerScalarValue::Primitive(PValue::U32(123))), - ScalarValue(InnerScalarValue::Primitive(PValue::U64( - f16_value.to_bits() as u64, + Some(ScalarValue::Primitive(PValue::U32(123))), + Some(ScalarValue::Primitive(PValue::U64( + f16_value.to_bits() as u64 ))), ]; let scalar = Scalar::new( DType::Extension(ext_dtype.erased()), - ScalarValue(InnerScalarValue::List(field_values.into())), + Some(ScalarValue::List(field_values)), ); // Verify the struct field was coerced @@ -412,7 +324,7 @@ mod tests { .as_extension() .storage() .as_struct() - .fields() + .fields_iter() .vortex_expect("non null") .collect::>(); assert_eq!( diff --git a/vortex-scalar/src/tests/consistency.rs b/vortex-scalar/src/tests/consistency.rs index 391d585b55b..af52ceb2f49 100644 --- a/vortex-scalar/src/tests/consistency.rs +++ b/vortex-scalar/src/tests/consistency.rs @@ -8,19 +8,17 @@ mod tests { use vortex_dtype::Nullability; - use crate::BoolScalar; - use crate::PrimitiveScalar; use crate::Scalar; // Demonstrates inconsistent null comparison behavior #[test] fn test_null_comparison_inconsistency() { // Test with primitive scalars - let null_i32 = Scalar::null_typed::(); - let null_i64 = Scalar::null_typed::(); + let null_i32 = Scalar::null_native::(); + let null_i64 = Scalar::null_native::(); - let prim_i32 = PrimitiveScalar::try_from(&null_i32).unwrap(); - let prim_i64 = PrimitiveScalar::try_from(&null_i64).unwrap(); + let prim_i32 = null_i32.as_primitive(); + let prim_i64 = null_i64.as_primitive(); // Primitive scalars check dtype compatibility first assert_eq!(prim_i32.partial_cmp(&prim_i64), None); // Different types => None @@ -30,30 +28,14 @@ mod tests { let bool_nullable = Scalar::bool(true, Nullability::Nullable); let bool_non_nullable = Scalar::bool(true, Nullability::NonNullable); - let bool1 = BoolScalar::try_from(&bool_nullable).unwrap(); - let bool2 = BoolScalar::try_from(&bool_non_nullable).unwrap(); + let bool1 = bool_nullable.as_bool(); + let bool2 = bool_non_nullable.as_bool(); // Bool scalars should now check dtype compatibility but ignore nullability // So they should still compare as they have the same base type assert!(bool1.partial_cmp(&bool2).is_some()); // Same base type, different nullability -> Some } - // Test that demonstrates potential issues with typed null conversions - #[test] - fn test_typed_null_unit_conversion_surprising() { - // This behavior is documented but potentially surprising - let typed_null = Scalar::null_typed::(); - - // A typed null (i32 null) successfully converts to unit type - let unit_result = <()>::try_from(&typed_null); - assert!(unit_result.is_ok()); // This might be unexpected! - - // But a non-null value correctly fails - let non_null = Scalar::primitive(42i32, Nullability::NonNullable); - let unit_result = <()>::try_from(&non_null); - assert!(unit_result.is_err()); // Expected - } - // Demonstrates that equality checking doesn't always consider nullability #[test] fn test_nullability_in_equality() { diff --git a/vortex-scalar/src/tests/nested.rs b/vortex-scalar/src/tests/nested.rs index c3d12b12486..5585602e77d 100644 --- a/vortex-scalar/src/tests/nested.rs +++ b/vortex-scalar/src/tests/nested.rs @@ -11,12 +11,9 @@ mod tests { use vortex_dtype::Nullability; use vortex_dtype::PType; - use crate::InnerScalarValue; - use crate::ListScalar; use crate::PValue; use crate::Scalar; use crate::ScalarValue; - use crate::StructScalar; #[test] fn test_fixed_size_list_of_fixed_size_list() { @@ -217,7 +214,7 @@ mod tests { // Access struct fields through the list. let first_struct = list.element(0).unwrap(); - let first = StructScalar::try_from(&first_struct).unwrap(); + let first = first_struct.as_struct(); assert_eq!( first .field("a") @@ -227,7 +224,13 @@ mod tests { Some(100) ); assert_eq!( - first.field("b").unwrap().as_utf8().value().unwrap(), + first + .field("b") + .unwrap() + .as_utf8() + .value() + .cloned() + .unwrap(), "first".into() ); } @@ -309,10 +312,10 @@ mod tests { Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)), Nullability::Nullable, ), - ScalarValue(InnerScalarValue::List(Arc::from([ - ScalarValue(InnerScalarValue::Primitive(PValue::U16(6))), - ScalarValue(InnerScalarValue::Primitive(PValue::U16(100))), - ]))), + Some(ScalarValue::List(vec![ + Some(ScalarValue::Primitive(PValue::U16(6))), + Some(ScalarValue::Primitive(PValue::U16(100))), + ])), ); // Cast U16 -> U32. @@ -348,11 +351,11 @@ mod tests { Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)), Nullability::Nullable, ), - ScalarValue(InnerScalarValue::List(Arc::from([ - ScalarValue(InnerScalarValue::Primitive(PValue::U16(100))), - ScalarValue(InnerScalarValue::Primitive(PValue::U16(256))), // Too large for U8 - ScalarValue(InnerScalarValue::Primitive(PValue::U16(1000))), // Too large for U8 - ]))), + Some(ScalarValue::List(vec![ + Some(ScalarValue::Primitive(PValue::U16(100))), + Some(ScalarValue::Primitive(PValue::U16(256))), // Too large for U8 + Some(ScalarValue::Primitive(PValue::U16(1000))), // Too large for U8 + ])), ); let target_u8 = DType::List( @@ -901,7 +904,7 @@ mod tests { assert!(matches!(fixed_list.dtype(), DType::FixedSizeList(_, 3, _))); - let list = ListScalar::try_from(&fixed_list).unwrap(); + let list = fixed_list.as_list(); assert_eq!(list.len(), 3); assert!(!list.is_null()); @@ -932,7 +935,7 @@ mod tests { DType::FixedSizeList(_, 0, _) )); - let list = ListScalar::try_from(&empty_fixed_list).unwrap(); + let list = empty_fixed_list.as_list(); assert_eq!(list.len(), 0); assert!(list.is_empty()); assert!(!list.is_null()); diff --git a/vortex-scalar/src/tests/nullability.rs b/vortex-scalar/src/tests/nullability.rs index 385b14d54bf..8bb18ee6435 100644 --- a/vortex-scalar/src/tests/nullability.rs +++ b/vortex-scalar/src/tests/nullability.rs @@ -15,7 +15,6 @@ mod tests { use vortex_dtype::datetime::TimeUnit; use vortex_dtype::datetime::Timestamp; - use crate::InnerScalarValue; use crate::PValue; use crate::Scalar; use crate::ScalarValue; @@ -55,9 +54,9 @@ mod tests { Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)), Nullability::Nullable, ), - ScalarValue(InnerScalarValue::List(Arc::from([ScalarValue( - InnerScalarValue::Primitive(PValue::U16(6)), - )]))), + Some(ScalarValue::List(vec![Some(ScalarValue::Primitive( + PValue::U16(6), + ))])), ); // Change element nullability from Nullable to NonNullable. @@ -98,11 +97,11 @@ mod tests { Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)), Nullability::Nullable, ), - ScalarValue(InnerScalarValue::List(Arc::from([ - ScalarValue(InnerScalarValue::Primitive(PValue::U16(6))), - ScalarValue(InnerScalarValue::Null), - ScalarValue(InnerScalarValue::Primitive(PValue::U16(10))), - ]))), + Some(ScalarValue::List(vec![ + Some(ScalarValue::Primitive(PValue::U16(6))), + None, + Some(ScalarValue::Primitive(PValue::U16(10))), + ])), ); // Cast to different element type with nullable elements - should succeed. @@ -201,10 +200,10 @@ mod tests { Arc::from(DType::Primitive(PType::U16, Nullability::Nullable)), Nullability::Nullable, ), - ScalarValue(InnerScalarValue::List(Arc::from([ - ScalarValue(InnerScalarValue::Primitive(PValue::U16(6))), - ScalarValue(InnerScalarValue::Null), - ]))), + Some(ScalarValue::List(vec![ + Some(ScalarValue::Primitive(PValue::U16(6))), + None, + ])), ); // Casting to non-nullable element type should fail. @@ -427,7 +426,6 @@ mod tests { let result = fixed_list_with_nulls.cast(&target_nonnull_elems); assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("non-nullable")); // Null FixedSizeList can't cast to non-nullable container. let null_fixed_list = Scalar::null(DType::FixedSizeList( @@ -528,7 +526,7 @@ mod tests { 3, Nullability::Nullable, ); - let default_nullable_list = Scalar::default_value(nullable_fixed_list_dtype.clone()); + let default_nullable_list = Scalar::default_value(&nullable_fixed_list_dtype); assert!(default_nullable_list.is_null()); assert_eq!(default_nullable_list.dtype(), &nullable_fixed_list_dtype); @@ -538,7 +536,7 @@ mod tests { 2, Nullability::NonNullable, ); - let default_nonnull_list = Scalar::default_value(nonnull_fixed_list_dtype); + let default_nonnull_list = Scalar::default_value(&nonnull_fixed_list_dtype); assert!(!default_nonnull_list.is_null()); assert_eq!(default_nonnull_list.as_list().len(), 2); // Elements should be default values (0 for I32). @@ -567,7 +565,7 @@ mod tests { ], Nullability::NonNullable, ); - let default_struct = Scalar::default_value(struct_dtype); + let default_struct = Scalar::default_value(&struct_dtype); let struct_view = default_struct.as_struct(); assert_eq!( struct_view diff --git a/vortex-scalar/src/tests/primitives.rs b/vortex-scalar/src/tests/primitives.rs index 382327dd027..4ef99246286 100644 --- a/vortex-scalar/src/tests/primitives.rs +++ b/vortex-scalar/src/tests/primitives.rs @@ -9,6 +9,7 @@ mod tests { use vortex_buffer::ByteBuffer; use vortex_dtype::DType; + use vortex_dtype::DecimalDType; use vortex_dtype::NativeDecimalType; use vortex_dtype::Nullability; use vortex_dtype::PType; @@ -16,8 +17,10 @@ mod tests { use vortex_dtype::datetime::TimeUnit; use vortex_utils::aliases::hash_set::HashSet; - use crate::InnerScalarValue; + use crate::DecimalScalar; + use crate::DecimalValue; use crate::PValue; + use crate::PrimitiveScalar; use crate::Scalar; use crate::ScalarValue; @@ -38,7 +41,7 @@ mod tests { Nullability::NonNullable, ); - let scalar = Scalar::default_value(struct_dtype.clone()); + let scalar = Scalar::default_value(&struct_dtype); assert_eq!(scalar.dtype(), &struct_dtype); let scalar = scalar.as_struct(); @@ -134,10 +137,6 @@ mod tests { #[test] fn test_decimal_nbytes() { - use vortex_dtype::DecimalDType; - - use crate::decimal::DecimalValue; - // Test decimal with precision <= 38 (should use i128 = 16 bytes) let decimal_low_precision = Scalar::decimal( DecimalValue::I128(123456789), @@ -284,7 +283,7 @@ mod tests { DType::Primitive(PType::I32, Nullability::NonNullable) ); match value { - ScalarValue(InnerScalarValue::Primitive(PValue::I32(v))) => { + Some(ScalarValue::Primitive(PValue::I32(v))) => { assert_eq!(v, 42); } _ => panic!("Expected I32 primitive value"), @@ -297,7 +296,7 @@ mod tests { let value = scalar.into_value(); match value { - ScalarValue(InnerScalarValue::Primitive(PValue::I32(v))) => { + Some(ScalarValue::Primitive(PValue::I32(v))) => { assert_eq!(v, 42); } _ => panic!("Expected I32 primitive value"), @@ -315,13 +314,6 @@ mod tests { assert!(null_scalar.is_null()); } - #[test] - fn test_scalar_as_ref() { - let scalar = Scalar::primitive(42i32, Nullability::NonNullable); - let scalar_ref: &Scalar = scalar.as_ref(); - assert_eq!(scalar_ref, &scalar); - } - #[test] fn test_scalar_from_option() { // Test Some value @@ -346,11 +338,9 @@ mod tests { #[test] fn test_scalar_from_primitive_scalar() { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let pscalar = crate::PrimitiveScalar::try_new( - &dtype, - &ScalarValue(InnerScalarValue::Primitive(PValue::I32(42))), - ) - .unwrap(); + let pscalar = + PrimitiveScalar::try_new(&dtype, Some(&ScalarValue::Primitive(PValue::I32(42)))) + .unwrap(); let scalar = Scalar::from(pscalar); assert_eq!(scalar.dtype(), &dtype); @@ -359,16 +349,11 @@ mod tests { #[test] fn test_scalar_from_decimal_scalar() { - use vortex_dtype::DecimalDType; - - use crate::decimal::DecimalScalar; - use crate::decimal::DecimalValue; - let decimal_dtype = DecimalDType::new(10, 2); let dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable); let dscalar = DecimalScalar::try_new( &dtype, - &ScalarValue(InnerScalarValue::Decimal(DecimalValue::I32(12345))), + Some(&ScalarValue::Decimal(DecimalValue::I32(12345))), ) .unwrap(); diff --git a/vortex-scalar/src/tests/round_trip.rs b/vortex-scalar/src/tests/round_trip.rs index 07c2fe5d260..69c28ba0220 100644 --- a/vortex-scalar/src/tests/round_trip.rs +++ b/vortex-scalar/src/tests/round_trip.rs @@ -19,14 +19,8 @@ mod tests { use vortex_dtype::i256; use vortex_proto::scalar as pb; - use crate::BinaryScalar; - use crate::BoolScalar; - use crate::DecimalScalar; use crate::DecimalValue; - use crate::ListScalar; - use crate::PrimitiveScalar; use crate::Scalar; - use crate::Utf8Scalar; use crate::tests::SESSION; // Test that primitive scalars round-trip through ScalarValue @@ -46,7 +40,7 @@ mod tests { ]; for scalar in values { - let value = scalar.value().clone(); + let value = scalar.value().cloned(); let dtype = scalar.dtype().clone(); let reconstructed = Scalar::new(dtype, value); assert_eq!(scalar, reconstructed); @@ -57,24 +51,24 @@ mod tests { #[test] fn test_null_scalar_type_preservation() { let null_scalars = vec![ - Scalar::null_typed::(), - Scalar::null_typed::(), - Scalar::null_typed::(), - Scalar::null_typed::(), - Scalar::null_typed::(), - Scalar::null_typed::(), - Scalar::null_typed::(), - Scalar::null_typed::(), - Scalar::null_typed::(), - Scalar::null_typed::(), - Scalar::null_typed::(), - Scalar::null_typed::(), + Scalar::null_native::(), + Scalar::null_native::(), + Scalar::null_native::(), + Scalar::null_native::(), + Scalar::null_native::(), + Scalar::null_native::(), + Scalar::null_native::(), + Scalar::null_native::(), + Scalar::null_native::(), + Scalar::null_native::(), + Scalar::null_native::(), + Scalar::null_native::(), ]; for scalar in null_scalars { assert!(scalar.is_null()); let dtype = scalar.dtype().clone(); - let value = scalar.value().clone(); + let value = scalar.value().cloned(); let reconstructed = Scalar::new(dtype.clone(), value); assert_eq!(scalar, reconstructed); assert_eq!(scalar.dtype(), reconstructed.dtype()); @@ -86,24 +80,24 @@ mod tests { fn test_specialized_scalar_conversions() { // Test PrimitiveScalar let int_scalar = Scalar::primitive(42i32, Nullability::NonNullable); - let primitive_scalar = PrimitiveScalar::try_from(&int_scalar).unwrap(); + let primitive_scalar = int_scalar.as_primitive(); assert_eq!(primitive_scalar.typed_value::().unwrap(), 42); let reconstructed = Scalar::from(primitive_scalar); assert_eq!(int_scalar, reconstructed); // Test BoolScalar let bool_scalar = Scalar::bool(true, Nullability::NonNullable); - let bool_specialized = BoolScalar::try_from(&bool_scalar).unwrap(); + let bool_specialized = bool_scalar.as_bool(); assert!(bool_specialized.value().unwrap()); // Test Utf8Scalar let utf8_scalar = Scalar::utf8("hello".to_string(), Nullability::NonNullable); - let utf8_specialized = Utf8Scalar::try_from(&utf8_scalar).unwrap(); + let utf8_specialized = utf8_scalar.as_utf8(); assert_eq!(utf8_specialized.value().unwrap().as_str(), "hello"); // Test BinaryScalar let binary_scalar = Scalar::binary(vec![1, 2, 3, 4], Nullability::NonNullable); - let binary_specialized = BinaryScalar::try_from(&binary_scalar).unwrap(); + let binary_specialized = binary_scalar.as_binary(); assert_eq!( binary_specialized.value().unwrap().as_slice(), &[1, 2, 3, 4] @@ -163,7 +157,7 @@ mod tests { let list_scalar = Scalar::list(element_dtype, children.clone(), Nullability::NonNullable); // Extract as ListScalar - let list_specialized = ListScalar::try_from(&list_scalar).unwrap(); + let list_specialized = list_scalar.as_list(); assert_eq!(list_specialized.len(), 3); // Extract as Vec @@ -182,19 +176,19 @@ mod tests { fn test_decimal_scalar_round_trip() { let decimal_dtype = DecimalDType::new(10, 2); - // Test various decimal value types + // Test various decimal value types. let decimal_values = vec![ DecimalValue::I8(100), DecimalValue::I16(10000), DecimalValue::I32(1000000), - DecimalValue::I64(100000000000), - DecimalValue::I128(123456789012345678901234567890i128), - DecimalValue::I256(i256::from_i128(987654321098765432109876543210i128)), + DecimalValue::I64(10000000), + DecimalValue::I128(100000000), + DecimalValue::I256(i256::from_i128(1000000000)), ]; for value in decimal_values { let scalar = Scalar::decimal(value, decimal_dtype, Nullability::NonNullable); - let decimal_specialized = DecimalScalar::try_from(&scalar).unwrap(); + let decimal_specialized = scalar.as_decimal(); match decimal_specialized.decimal_value() { Some(extracted) => assert_eq!(extracted, value), @@ -202,7 +196,7 @@ mod tests { } // Test round-trip through ScalarValue - let scalar_value = scalar.value().clone(); + let scalar_value = scalar.value().cloned(); let dtype = scalar.dtype().clone(); let reconstructed = Scalar::new(dtype, scalar_value); assert_eq!(scalar, reconstructed); @@ -290,14 +284,12 @@ mod tests { let result: Result = i32::try_from(&string_scalar); assert!(result.is_err()); - // Try to convert an integer scalar to a list + // Try to convert an integer scalar to a list. let int_scalar = Scalar::primitive(42i32, Nullability::NonNullable); - let result = ListScalar::try_from(&int_scalar); - assert!(result.is_err()); + assert!(int_scalar.as_list_opt().is_none()); - // Try to convert a boolean to a decimal + // Try to convert a boolean to a decimal. let bool_scalar = Scalar::bool(true, Nullability::NonNullable); - let result = DecimalScalar::try_from(&bool_scalar); - assert!(result.is_err()); + assert!(bool_scalar.as_decimal_opt().is_none()); } } diff --git a/vortex-scalar/src/binary.rs b/vortex-scalar/src/typed_view/binary.rs similarity index 54% rename from vortex-scalar/src/binary.rs rename to vortex-scalar/src/typed_view/binary.rs index 79d8b352ca7..b44a0621808 100644 --- a/vortex-scalar/src/binary.rs +++ b/vortex-scalar/src/typed_view/binary.rs @@ -1,21 +1,18 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +//! [`BinaryScalar`] typed view implementation. + use std::fmt::Display; use std::fmt::Formatter; -use std::sync::Arc; use itertools::Itertools; use vortex_buffer::ByteBuffer; use vortex_dtype::DType; -use vortex_dtype::Nullability; -use vortex_error::VortexError; -use vortex_error::VortexExpect as _; +use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; -use vortex_error::vortex_err; -use crate::InnerScalarValue; use crate::Scalar; use crate::ScalarValue; @@ -25,8 +22,10 @@ use crate::ScalarValue; /// a valid byte buffer or null. #[derive(Debug, Clone, Hash)] pub struct BinaryScalar<'a> { + /// The data type of this scalar. dtype: &'a DType, - value: Option>, + /// The binary value, or [`None`] if null. + value: Option<&'a ByteBuffer>, } impl Display for BinaryScalar<'_> { @@ -68,13 +67,14 @@ impl<'a> BinaryScalar<'a> { /// # Errors /// /// Returns an error if the data type is not a binary type. - pub fn from_scalar_value(dtype: &'a DType, value: ScalarValue) -> VortexResult { + pub fn try_new(dtype: &'a DType, value: Option<&'a ScalarValue>) -> VortexResult { if !matches!(dtype, DType::Binary(..)) { vortex_bail!("Can only construct binary scalar from binary dtype, found {dtype}") } + Ok(Self { dtype, - value: value.as_buffer()?, + value: value.map(|value| value.as_binary()), }) } @@ -84,83 +84,59 @@ impl<'a> BinaryScalar<'a> { self.dtype } - /// Returns the binary value as a byte buffer, or None if null. - pub fn value(&self) -> Option { - self.value.as_ref().map(|v| v.as_ref().clone()) - } - /// Returns a reference to the binary value, or None if null. /// This avoids cloning the underlying ByteBuffer. - pub fn value_ref(&self) -> Option<&ByteBuffer> { - self.value.as_ref().map(|v| v.as_ref()) + pub fn value(&self) -> Option<&'a ByteBuffer> { + self.value } - /// Constructs the next scalar at most `max_length` bytes that's lexicographically greater than - /// this. + /// Constructs the next [`Scalar`] at most `max_length` bytes that's lexicographically greater + /// than this. /// - /// Returns None if constructing a greater value would overflow. - pub fn upper_bound(self, max_length: usize) -> Option { - if let Some(value) = self.value { - if value.len() > max_length { - let sliced = value.slice(0..max_length); - drop(value); - let mut sliced_mut = sliced.into_mut(); - for b in sliced_mut.iter_mut().rev() { - let (incr, overflow) = b.overflowing_add(1); - *b = incr; - if !overflow { - return Some(Self { - dtype: self.dtype, - value: Some(Arc::new(sliced_mut.freeze())), - }); - } - } - None - } else { - Some(Self { - dtype: self.dtype, - value: Some(value), - }) + /// Returns `None` if the value is null or if constructing a greater value would overflow. + pub fn upper_bound(&self, max_length: usize) -> Option { + let value = self.value()?; + let sliced = value.slice(0..max_length); + let mut sliced_mut = sliced.into_mut(); + for b in sliced_mut.iter_mut().rev() { + let (incr, overflow) = b.overflowing_add(1); + *b = incr; + if !overflow { + return Some(Scalar::binary( + sliced_mut.freeze(), + self.dtype().nullability(), + )); } - } else { - Some(self) } + None } - /// Construct a value at most `max_length` in size that's less than ourselves. - pub fn lower_bound(self, max_length: usize) -> Self { - if let Some(value) = self.value { - if value.len() > max_length { - Self { - dtype: self.dtype, - value: Some(Arc::new(value.slice(0..max_length))), - } - } else { - Self { - dtype: self.dtype, - value: Some(value), - } - } - } else { - self + /// Construct a [`Scalar`] at most `max_length` in size that's less than or equal to + /// ourselves. + /// + /// Returns a null [`Scalar`] if the value is null. + pub fn lower_bound(&self, max_length: usize) -> Scalar { + match self.value() { + Some(value) => Scalar::binary(value.slice(0..max_length), self.dtype().nullability()), + None => Scalar::null(self.dtype().clone()), } } + /// Casts this scalar to the given `dtype`. pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { if !matches!(dtype, DType::Binary(..)) { vortex_bail!( "Cannot cast binary to {dtype}: binary scalars can only be cast to binary types with different nullability" ) } - Ok(Scalar::new( + Scalar::try_new( dtype.clone(), - ScalarValue(InnerScalarValue::Buffer( - self.value - .as_ref() - .vortex_expect("nullness handled in Scalar::cast") - .clone(), + Some(ScalarValue::Binary( + self.value() + .cloned() + .vortex_expect("nullness handled in Scalar::cast"), )), - )) + ) } /// Length of the scalar value or None if value is null @@ -174,104 +150,6 @@ impl<'a> BinaryScalar<'a> { } } -impl Scalar { - /// Creates a new binary scalar from a byte buffer. - pub fn binary(buffer: impl Into, nullability: Nullability) -> Self { - Self::new( - DType::Binary(nullability), - ScalarValue(InnerScalarValue::Buffer(Arc::new(buffer.into()))), - ) - } -} - -impl<'a> TryFrom<&'a Scalar> for BinaryScalar<'a> { - type Error = VortexError; - - fn try_from(value: &'a Scalar) -> Result { - if !matches!(value.dtype(), DType::Binary(_)) { - vortex_bail!("Expected binary scalar, found {}", value.dtype()) - } - Ok(Self { - dtype: value.dtype(), - value: value.value().as_buffer()?, - }) - } -} - -impl<'a> TryFrom<&'a Scalar> for ByteBuffer { - type Error = VortexError; - - fn try_from(scalar: &'a Scalar) -> VortexResult { - let binary = scalar - .as_binary_opt() - .ok_or_else(|| vortex_err!("Cannot extract buffer from non-buffer scalar"))?; - - binary - .value() - .ok_or_else(|| vortex_err!("Cannot extract present value from null scalar")) - } -} - -impl<'a> TryFrom<&'a Scalar> for Option { - type Error = VortexError; - - fn try_from(scalar: &'a Scalar) -> VortexResult { - Ok(scalar - .as_binary_opt() - .ok_or_else(|| vortex_err!("Cannot extract buffer from non-buffer scalar"))? - .value()) - } -} - -impl TryFrom for ByteBuffer { - type Error = VortexError; - - fn try_from(scalar: Scalar) -> VortexResult { - Self::try_from(&scalar) - } -} - -impl TryFrom for Option { - type Error = VortexError; - - fn try_from(scalar: Scalar) -> VortexResult { - Self::try_from(&scalar) - } -} - -impl From<&[u8]> for Scalar { - fn from(value: &[u8]) -> Self { - Scalar::from(ByteBuffer::from(value.to_vec())) - } -} - -impl From for Scalar { - fn from(value: ByteBuffer) -> Self { - Self::new(DType::Binary(Nullability::NonNullable), value.into()) - } -} - -impl From> for Scalar { - fn from(value: Arc) -> Self { - Self::new( - DType::Binary(Nullability::NonNullable), - ScalarValue(InnerScalarValue::Buffer(value)), - ) - } -} - -impl From<&[u8]> for ScalarValue { - fn from(value: &[u8]) -> Self { - ScalarValue::from(ByteBuffer::from(value.to_vec())) - } -} - -impl From for ScalarValue { - fn from(value: ByteBuffer) -> Self { - ScalarValue(InnerScalarValue::Buffer(Arc::new(value))) - } -} - #[cfg(test)] mod tests { use std::cmp::Ordering; @@ -279,47 +157,30 @@ mod tests { use rstest::rstest; use vortex_buffer::buffer; use vortex_dtype::Nullability; - use vortex_error::VortexExpect; use crate::BinaryScalar; + use crate::PValue; use crate::Scalar; + use crate::ScalarValue; #[test] fn lower_bound() { let binary = Scalar::binary(buffer![0u8, 5, 47, 33, 129], Nullability::NonNullable); let expected = Scalar::binary(buffer![0u8, 5], Nullability::NonNullable); - assert_eq!( - BinaryScalar::try_from(&binary) - .vortex_expect("binary scalar conversion should succeed") - .lower_bound(2), - BinaryScalar::try_from(&expected) - .vortex_expect("binary scalar conversion should succeed") - ); + assert_eq!(binary.as_binary().lower_bound(2), expected,); } #[test] fn upper_bound() { let binary = Scalar::binary(buffer![0u8, 5, 255, 234, 23], Nullability::NonNullable); let expected = Scalar::binary(buffer![0u8, 6, 0], Nullability::NonNullable); - assert_eq!( - BinaryScalar::try_from(&binary) - .vortex_expect("binary scalar conversion should succeed") - .upper_bound(3) - .vortex_expect("must have upper bound"), - BinaryScalar::try_from(&expected) - .vortex_expect("binary scalar conversion should succeed") - ); + assert_eq!(binary.as_binary().upper_bound(3).unwrap(), expected,); } #[test] fn upper_bound_overflow() { let binary = Scalar::binary(buffer![255u8, 255, 255], Nullability::NonNullable); - assert!( - BinaryScalar::try_from(&binary) - .vortex_expect("binary scalar conversion should succeed") - .upper_bound(2) - .is_none() - ); + assert!(binary.as_binary().upper_bound(2).is_none()); } #[rstest] @@ -335,8 +196,8 @@ mod tests { let binary1 = Scalar::binary(data1.to_vec(), Nullability::NonNullable); let binary2 = Scalar::binary(data2.to_vec(), Nullability::NonNullable); - let scalar1 = BinaryScalar::try_from(&binary1).unwrap(); - let scalar2 = BinaryScalar::try_from(&binary2).unwrap(); + let scalar1 = binary1.as_binary(); + let scalar2 = binary2.as_binary(); assert_eq!(scalar1 == scalar2, expected); } @@ -355,8 +216,8 @@ mod tests { let binary1 = Scalar::binary(data1.to_vec(), Nullability::NonNullable); let binary2 = Scalar::binary(data2.to_vec(), Nullability::NonNullable); - let scalar1 = BinaryScalar::try_from(&binary1).unwrap(); - let scalar2 = BinaryScalar::try_from(&binary2).unwrap(); + let scalar1 = binary1.as_binary(); + let scalar2 = binary2.as_binary(); assert_eq!(scalar1.partial_cmp(&scalar2), Some(expected)); } @@ -364,10 +225,10 @@ mod tests { #[test] fn test_binary_null_value() { let null_binary = Scalar::null(vortex_dtype::DType::Binary(Nullability::Nullable)); - let scalar = BinaryScalar::try_from(&null_binary).unwrap(); + let scalar = null_binary.as_binary(); assert!(scalar.value().is_none()); - assert!(scalar.value_ref().is_none()); + assert!(scalar.value().is_none()); assert!(scalar.len().is_none()); assert!(scalar.is_empty().is_none()); } @@ -379,11 +240,11 @@ mod tests { let empty = Scalar::binary(ByteBuffer::empty(), Nullability::NonNullable); let non_empty = Scalar::binary(buffer![1u8, 2, 3], Nullability::NonNullable); - let empty_scalar = BinaryScalar::try_from(&empty).unwrap(); + let empty_scalar = empty.as_binary(); assert_eq!(empty_scalar.len(), Some(0)); assert_eq!(empty_scalar.is_empty(), Some(true)); - let non_empty_scalar = BinaryScalar::try_from(&non_empty).unwrap(); + let non_empty_scalar = non_empty.as_binary(); assert_eq!(non_empty_scalar.len(), Some(3)); assert_eq!(non_empty_scalar.is_empty(), Some(false)); } @@ -394,13 +255,13 @@ mod tests { let data = vec![1u8, 2, 3, 4, 5]; let binary = Scalar::binary(ByteBuffer::from(data.clone()), Nullability::NonNullable); - let scalar = BinaryScalar::try_from(&binary).unwrap(); + let scalar = binary.as_binary(); // value_ref should not clone - let value_ref = scalar.value_ref().unwrap(); + let value_ref = scalar.value().unwrap(); assert_eq!(value_ref.as_slice(), &data); - // value should clone + // to_value should clone let value = scalar.value().unwrap(); assert_eq!(value.as_slice(), &data); } @@ -411,13 +272,13 @@ mod tests { use vortex_dtype::Nullability; let binary = Scalar::binary(buffer![1u8, 2, 3], Nullability::NonNullable); - let scalar = BinaryScalar::try_from(&binary).unwrap(); + let scalar = binary.as_binary(); // Cast to nullable binary let result = scalar.cast(&DType::Binary(Nullability::Nullable)).unwrap(); assert_eq!(result.dtype(), &DType::Binary(Nullability::Nullable)); - let casted = BinaryScalar::try_from(&result).unwrap(); + let casted = result.as_binary(); assert_eq!(casted.value().unwrap().as_slice(), &[1, 2, 3]); } @@ -428,7 +289,7 @@ mod tests { use vortex_dtype::PType; let binary = Scalar::binary(buffer![1u8, 2, 3], Nullability::NonNullable); - let scalar = BinaryScalar::try_from(&binary).unwrap(); + let scalar = binary.as_binary(); let result = scalar.cast(&DType::Primitive(PType::I32, Nullability::NonNullable)); assert!(result.is_err()); @@ -441,9 +302,9 @@ mod tests { use vortex_dtype::PType; let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let value = crate::ScalarValue(crate::InnerScalarValue::Primitive(crate::PValue::I32(42))); + let value = ScalarValue::Primitive(PValue::I32(42)); - let result = BinaryScalar::from_scalar_value(&dtype, value); + let result = BinaryScalar::try_new(&dtype, Some(&value)); assert!(result.is_err()); } @@ -452,47 +313,21 @@ mod tests { use vortex_dtype::Nullability; let scalar = Scalar::primitive(42i32, Nullability::NonNullable); - let result = BinaryScalar::try_from(&scalar); - assert!(result.is_err()); + assert!(scalar.as_binary_opt().is_none()); } #[test] fn test_upper_bound_null() { let null_binary = Scalar::null(vortex_dtype::DType::Binary(Nullability::Nullable)); - let scalar = BinaryScalar::try_from(&null_binary).unwrap(); - - let result = scalar.upper_bound(10); - assert!(result.is_some()); - assert!(result.unwrap().value().is_none()); + let scalar = null_binary.as_binary(); + assert!(scalar.upper_bound(10).is_none()); } #[test] fn test_lower_bound_null() { let null_binary = Scalar::null(vortex_dtype::DType::Binary(Nullability::Nullable)); - let scalar = BinaryScalar::try_from(&null_binary).unwrap(); - - let result = scalar.lower_bound(10); - assert!(result.value().is_none()); - } - - #[test] - fn test_upper_bound_exact_length() { - let binary = Scalar::binary(buffer![1u8, 2, 3], Nullability::NonNullable); - let scalar = BinaryScalar::try_from(&binary).unwrap(); - - let result = scalar.upper_bound(3); - assert!(result.is_some()); - let upper = result.unwrap(); - assert_eq!(upper.value().unwrap().as_slice(), &[1, 2, 3]); - } - - #[test] - fn test_lower_bound_exact_length() { - let binary = Scalar::binary(buffer![1u8, 2, 3], Nullability::NonNullable); - let scalar = BinaryScalar::try_from(&binary).unwrap(); - - let result = scalar.lower_bound(3); - assert_eq!(result.value().unwrap().as_slice(), &[1, 2, 3]); + let scalar = null_binary.as_binary(); + assert!(scalar.lower_bound(10).is_null()); } #[test] @@ -504,7 +339,7 @@ mod tests { scalar.dtype(), &vortex_dtype::DType::Binary(Nullability::NonNullable) ); - let binary = BinaryScalar::try_from(&scalar).unwrap(); + let binary = scalar.as_binary(); assert_eq!(binary.value().unwrap().as_slice(), data); } @@ -558,29 +393,30 @@ mod tests { #[test] fn test_from_arc_bytebuffer() { - use std::sync::Arc; - use vortex_buffer::ByteBuffer; let data = vec![10u8, 20, 30]; - let buffer = Arc::new(ByteBuffer::from(data.clone())); + let buffer = ByteBuffer::from(data.clone()); let scalar: Scalar = buffer.into(); assert_eq!( scalar.dtype(), &vortex_dtype::DType::Binary(Nullability::NonNullable) ); - let binary = BinaryScalar::try_from(&scalar).unwrap(); + let binary = scalar.as_binary(); assert_eq!(binary.value().unwrap().as_slice(), &data); } #[test] fn test_scalar_value_from_slice() { let data: &[u8] = &[100u8, 200]; - let value: crate::ScalarValue = data.into(); + let value: ScalarValue = data.into(); - let scalar = Scalar::new(vortex_dtype::DType::Binary(Nullability::NonNullable), value); - let binary = BinaryScalar::try_from(&scalar).unwrap(); + let scalar = Scalar::new( + vortex_dtype::DType::Binary(Nullability::NonNullable), + Some(value), + ); + let binary = scalar.as_binary(); assert_eq!(binary.value().unwrap().as_slice(), data); } @@ -590,10 +426,13 @@ mod tests { let data = vec![111u8, 222]; let buffer = ByteBuffer::from(data.clone()); - let value: crate::ScalarValue = buffer.into(); + let value: ScalarValue = buffer.into(); - let scalar = Scalar::new(vortex_dtype::DType::Binary(Nullability::NonNullable), value); - let binary = BinaryScalar::try_from(&scalar).unwrap(); + let scalar = Scalar::new( + vortex_dtype::DType::Binary(Nullability::NonNullable), + Some(value), + ); + let binary = scalar.as_binary(); assert_eq!(binary.value().unwrap().as_slice(), &data); } } diff --git a/vortex-scalar/src/bool.rs b/vortex-scalar/src/typed_view/bool.rs similarity index 67% rename from vortex-scalar/src/bool.rs rename to vortex-scalar/src/typed_view/bool.rs index 0ef1acf2e9f..94201ae21ec 100644 --- a/vortex-scalar/src/bool.rs +++ b/vortex-scalar/src/typed_view/bool.rs @@ -1,20 +1,17 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +//! [`BoolScalar`] typed view implementation. + use std::cmp::Ordering; use std::fmt::Display; use std::fmt::Formatter; use vortex_dtype::DType; -use vortex_dtype::Nullability; -use vortex_dtype::Nullability::NonNullable; -use vortex_error::VortexError; -use vortex_error::VortexExpect as _; +use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; -use vortex_error::vortex_err; -use crate::InnerScalarValue; use crate::Scalar; use crate::ScalarValue; @@ -24,7 +21,9 @@ use crate::ScalarValue; /// true, false, or null. #[derive(Debug, Clone, Hash, Eq)] pub struct BoolScalar<'a> { + /// The data type of this scalar. dtype: &'a DType, + /// The boolean value, or [`None`] if null. value: Option, } @@ -56,6 +55,21 @@ impl Ord for BoolScalar<'_> { } impl<'a> BoolScalar<'a> { + /// Attempts to create a new [`BoolScalar`] from a [`DType`] and optional [`ScalarValue`]. + /// + /// # Errors + /// + /// Returns an error if the data type is not a [`DType::Bool`]. + pub fn try_new(dtype: &'a DType, value: Option<&ScalarValue>) -> VortexResult { + if !matches!(dtype, DType::Bool(_)) { + vortex_bail!("Expected bool scalar, found {}", dtype) + } + Ok(Self { + dtype, + value: value.map(|v| v.as_bool()), + }) + } + /// Returns the data type of this boolean scalar. #[inline] pub fn dtype(&self) -> &'a DType { @@ -67,6 +81,7 @@ impl<'a> BoolScalar<'a> { self.value } + /// Casts this scalar to the given `dtype`. pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { if !matches!(dtype, DType::Bool(..)) { vortex_bail!( @@ -91,81 +106,8 @@ impl<'a> BoolScalar<'a> { /// Converts this boolean scalar into a general scalar. pub fn into_scalar(self) -> Scalar { - Scalar::new( - self.dtype.clone(), - self.value - .map(|x| ScalarValue(InnerScalarValue::Bool(x))) - .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)), - ) - } -} - -impl Scalar { - /// Creates a new boolean scalar with the given value and nullability. - pub fn bool(value: bool, nullability: Nullability) -> Self { - Self::new( - DType::Bool(nullability), - ScalarValue(InnerScalarValue::Bool(value)), - ) - } -} - -impl<'a> TryFrom<&'a Scalar> for BoolScalar<'a> { - type Error = VortexError; - - fn try_from(value: &'a Scalar) -> Result { - if !matches!(value.dtype(), DType::Bool(_)) { - vortex_bail!("Expected bool scalar, found {}", value.dtype()) - } - Ok(Self { - dtype: value.dtype(), - value: value.value().as_bool()?, - }) - } -} - -impl TryFrom<&Scalar> for bool { - type Error = VortexError; - - fn try_from(value: &Scalar) -> VortexResult { - >::try_from(value)? - .ok_or_else(|| vortex_err!("Can't extract present value from null scalar")) - } -} - -impl TryFrom<&Scalar> for Option { - type Error = VortexError; - - fn try_from(value: &Scalar) -> VortexResult { - Ok(BoolScalar::try_from(value)?.value()) - } -} - -impl TryFrom for bool { - type Error = VortexError; - - fn try_from(value: Scalar) -> VortexResult { - Self::try_from(&value) - } -} - -impl TryFrom for Option { - type Error = VortexError; - - fn try_from(value: Scalar) -> VortexResult { - Self::try_from(&value) - } -} - -impl From for Scalar { - fn from(value: bool) -> Self { - Self::new(DType::Bool(NonNullable), value.into()) - } -} - -impl From for ScalarValue { - fn from(value: bool) -> Self { - ScalarValue(InnerScalarValue::Bool(value)) + // SAFETY: `BoolScalar` is already a valid `Scalar`. + unsafe { Scalar::new_unchecked(self.dtype.clone(), self.value.map(ScalarValue::Bool)) } } } @@ -177,29 +119,19 @@ mod test { #[test] fn into_from() { - let scalar: Scalar = false.into(); + let scalar: Scalar = Some(false).into(); assert!(!bool::try_from(&scalar).unwrap()); } - #[test] - fn equality() { - assert_eq!(&Scalar::bool(true, Nullable), &Scalar::bool(true, Nullable)); - // Equality ignores nullability - assert_eq!( - &Scalar::bool(true, Nullable), - &Scalar::bool(true, NonNullable) - ); - } - #[test] fn test_bool_scalar_ordering() { let false_scalar = Scalar::bool(false, NonNullable); let true_scalar = Scalar::bool(true, NonNullable); let null_scalar = Scalar::null(DType::Bool(Nullable)); - let false_bool = BoolScalar::try_from(&false_scalar).unwrap(); - let true_bool = BoolScalar::try_from(&true_scalar).unwrap(); - let null_bool = BoolScalar::try_from(&null_scalar).unwrap(); + let false_bool = false_scalar.as_bool(); + let true_bool = true_scalar.as_bool(); + let null_bool = null_scalar.as_bool(); // false < true assert!(false_bool < true_bool); @@ -218,9 +150,9 @@ mod test { let false_scalar = Scalar::bool(false, NonNullable); let null_scalar = Scalar::null(DType::Bool(Nullable)); - let true_bool = BoolScalar::try_from(&true_scalar).unwrap(); - let false_bool = BoolScalar::try_from(&false_scalar).unwrap(); - let null_bool = BoolScalar::try_from(&null_scalar).unwrap(); + let true_bool = true_scalar.as_bool(); + let false_bool = false_scalar.as_bool(); + let null_bool = null_scalar.as_bool(); // Invert true -> false let inverted_true = true_bool.invert(); @@ -259,7 +191,7 @@ mod test { #[test] fn test_bool_cast_to_bool() { let bool_scalar = Scalar::bool(true, NonNullable); - let bool = BoolScalar::try_from(&bool_scalar).unwrap(); + let bool = bool_scalar.as_bool(); // Cast to nullable bool let result = bool.cast(&DType::Bool(Nullable)).unwrap(); @@ -277,7 +209,7 @@ mod test { use vortex_dtype::PType; let bool_scalar = Scalar::bool(true, NonNullable); - let bool = BoolScalar::try_from(&bool_scalar).unwrap(); + let bool = bool_scalar.as_bool(); let result = bool.cast(&DType::Primitive(PType::I32, NonNullable)); assert!(result.is_err()); @@ -286,8 +218,7 @@ mod test { #[test] fn test_try_from_non_bool_scalar() { let int_scalar = Scalar::primitive(42i32, NonNullable); - let result = BoolScalar::try_from(&int_scalar); - assert!(result.is_err()); + assert!(int_scalar.as_bool_opt().is_none()); } #[test] @@ -328,11 +259,11 @@ mod test { #[test] fn test_scalar_value_from_bool() { let value: ScalarValue = true.into(); - let scalar = Scalar::new(DType::Bool(NonNullable), value); + let scalar = Scalar::new(DType::Bool(NonNullable), Some(value)); assert!(bool::try_from(&scalar).unwrap()); let value: ScalarValue = false.into(); - let scalar = Scalar::new(DType::Bool(NonNullable), value); + let scalar = Scalar::new(DType::Bool(NonNullable), Some(value)); assert!(!bool::try_from(&scalar).unwrap()); } @@ -341,8 +272,8 @@ mod test { let true_scalar = Scalar::bool(true, NonNullable); let false_scalar = Scalar::bool(false, NonNullable); - let true_bool = BoolScalar::try_from(&true_scalar).unwrap(); - let false_bool = BoolScalar::try_from(&false_scalar).unwrap(); + let true_bool = true_scalar.as_bool(); + let false_bool = false_scalar.as_bool(); assert_ne!(true_bool, false_bool); } @@ -353,9 +284,9 @@ mod test { let null_scalar2 = Scalar::null(DType::Bool(Nullable)); let non_null_scalar = Scalar::bool(true, Nullable); - let null_bool1 = BoolScalar::try_from(&null_scalar1).unwrap(); - let null_bool2 = BoolScalar::try_from(&null_scalar2).unwrap(); - let non_null_bool = BoolScalar::try_from(&non_null_scalar).unwrap(); + let null_bool1 = null_scalar1.as_bool(); + let null_bool2 = null_scalar2.as_bool(); + let non_null_bool = non_null_scalar.as_bool(); // Two nulls are equal assert_eq!(null_bool1, null_bool2); @@ -370,9 +301,9 @@ mod test { let false_scalar = Scalar::bool(false, NonNullable); let null_scalar = Scalar::null(DType::Bool(Nullable)); - let true_bool = BoolScalar::try_from(&true_scalar).unwrap(); - let false_bool = BoolScalar::try_from(&false_scalar).unwrap(); - let null_bool = BoolScalar::try_from(&null_scalar).unwrap(); + let true_bool = true_scalar.as_bool(); + let false_bool = false_scalar.as_bool(); + let null_bool = null_scalar.as_bool(); assert_eq!(true_bool.value(), Some(true)); assert_eq!(false_bool.value(), Some(false)); @@ -384,8 +315,8 @@ mod test { let nullable_scalar = Scalar::bool(true, Nullable); let non_nullable_scalar = Scalar::bool(false, NonNullable); - let nullable_bool = BoolScalar::try_from(&nullable_scalar).unwrap(); - let non_nullable_bool = BoolScalar::try_from(&non_nullable_scalar).unwrap(); + let nullable_bool = nullable_scalar.as_bool(); + let non_nullable_bool = non_nullable_scalar.as_bool(); assert_eq!(nullable_bool.dtype(), &DType::Bool(Nullable)); assert_eq!(non_nullable_bool.dtype(), &DType::Bool(NonNullable)); @@ -396,8 +327,8 @@ mod test { let false_scalar = Scalar::bool(false, NonNullable); let true_scalar = Scalar::bool(true, NonNullable); - let false_bool = BoolScalar::try_from(&false_scalar).unwrap(); - let true_bool = BoolScalar::try_from(&true_scalar).unwrap(); + let false_bool = false_scalar.as_bool(); + let true_bool = true_scalar.as_bool(); assert_eq!(false_bool.partial_cmp(&false_bool), Some(Ordering::Equal)); assert_eq!(false_bool.partial_cmp(&true_bool), Some(Ordering::Less)); diff --git a/vortex-scalar/src/typed_view/decimal/dvalue.rs b/vortex-scalar/src/typed_view/decimal/dvalue.rs new file mode 100644 index 00000000000..14e60516e5f --- /dev/null +++ b/vortex-scalar/src/typed_view/decimal/dvalue.rs @@ -0,0 +1,229 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! [`DecimalValue`] type representing a typed decimal value. + +use std::cmp::Ordering; +use std::fmt; +use std::hash::Hash; + +use num_traits::CheckedAdd; +use num_traits::CheckedDiv; +use num_traits::CheckedMul; +use num_traits::CheckedSub; +use vortex_dtype::DecimalDType; +use vortex_dtype::DecimalType; +use vortex_dtype::NativeDecimalType; +use vortex_dtype::ToI256; +use vortex_dtype::i256; +use vortex_dtype::match_each_decimal_value; +use vortex_error::VortexExpect; + +/// A decimal value that can be stored in various integer widths. +/// +/// This enum represents decimal values with different storage sizes, +/// from 8-bit to 256-bit integers. +#[derive(Debug, Clone, Copy)] +pub enum DecimalValue { + /// 8-bit signed decimal value. + I8(i8), + /// 16-bit signed decimal value. + I16(i16), + /// 32-bit signed decimal value. + I32(i32), + /// 64-bit signed decimal value. + I64(i64), + /// 128-bit signed decimal value. + I128(i128), + /// 256-bit signed decimal value. + I256(i256), +} + +impl DecimalValue { + /// Cast `self` to T using the respective `ToPrimitive` method. + /// If the value cannot be represented by `T`, `None` is returned. + pub fn cast(&self) -> Option { + match_each_decimal_value!(self, |value| { T::from(*value) }) + } + + /// Returns a reasonable precision and scale as a [`DecimalDType`] for the given + /// [`DecimalValue`]. + /// + /// Note that this is **not** the same as [`DecimalValue::decimal_type`]!!! + pub fn decimal_dtype(&self) -> DecimalDType { + // Default to a reasonable precision and scale. + match self { + DecimalValue::I8(_) => DecimalDType::new(3, 0), + DecimalValue::I16(_) => DecimalDType::new(5, 0), + DecimalValue::I32(_) => DecimalDType::new(10, 0), + DecimalValue::I64(_) => DecimalDType::new(19, 0), + DecimalValue::I128(_) => DecimalDType::new(38, 0), + DecimalValue::I256(_) => DecimalDType::new(76, 0), + } + } + + /// Returns the [`DecimalType`] for the given [`DecimalValue`]. + /// + /// Note that this is **not** the same as [`DecimalValue::decimal_dtype`]!!! + pub fn decimal_type(&self) -> DecimalType { + match self { + DecimalValue::I8(_) => DecimalType::I8, + DecimalValue::I16(_) => DecimalType::I16, + DecimalValue::I32(_) => DecimalType::I32, + DecimalValue::I64(_) => DecimalType::I64, + DecimalValue::I128(_) => DecimalType::I128, + DecimalValue::I256(_) => DecimalType::I256, + } + } + + /// Returns true if this decimal value is zero. + pub fn is_zero(&self) -> bool { + match self { + DecimalValue::I8(v) => *v == 0, + DecimalValue::I16(v) => *v == 0, + DecimalValue::I32(v) => *v == 0, + DecimalValue::I64(v) => *v == 0, + DecimalValue::I128(v) => *v == 0, + DecimalValue::I256(v) => *v == i256::ZERO, + } + } + + /// Returns the 0 value given the [`DecimalType`]. + pub fn zero(decimal_type: &DecimalDType) -> Self { + let smallest_type = DecimalType::smallest_decimal_value_type(decimal_type); + + match smallest_type { + DecimalType::I8 => DecimalValue::I8(0), + DecimalType::I16 => DecimalValue::I16(0), + DecimalType::I32 => DecimalValue::I32(0), + DecimalType::I64 => DecimalValue::I64(0), + DecimalType::I128 => DecimalValue::I128(0), + DecimalType::I256 => DecimalValue::I256(i256::ZERO), + } + } + + /// Check if this decimal value fits within the precision constraints of the given decimal type. + /// + /// The precision defines the total number of significant digits that can be represented. + /// The stored value (regardless of scale) must fit within the range defined by precision. + /// For precision P, the maximum absolute stored value is 10^P - 1. + pub fn fits_in_precision(&self, decimal_type: DecimalDType) -> bool { + // Convert to i256 for comparison + let value_i256 = match_each_decimal_value!(self, |v| { + v.to_i256() + .vortex_expect("upcast to i256 must always succeed") + }); + + // Calculate the maximum stored value that can be represented with this precision + // For precision P, the max stored value is 10^P - 1 + // This is independent of scale - scale only affects how we interpret the value + let ten = i256::from_i128(10); + let max_value = ten + .checked_pow(decimal_type.precision() as _) + .vortex_expect("precision must exist in i256"); + let min_value = -max_value; + + value_i256 > min_value && value_i256 < max_value + } + + /// Helper function to perform a checked binary operation on two decimal values. + /// + /// Both values are upcast to i256 before the operation, and the result is returned as I256. + fn checked_binary_op(&self, other: &Self, op: F) -> Option + where + F: FnOnce(i256, i256) -> Option, + { + let self_upcast = match_each_decimal_value!(self, |v| { + v.to_i256() + .vortex_expect("upcast to i256 must always succeed") + }); + let other_upcast = match_each_decimal_value!(other, |v| { + v.to_i256() + .vortex_expect("upcast to i256 must always succeed") + }); + + op(self_upcast, other_upcast).map(DecimalValue::I256) + } + + /// Checked addition. Returns `None` on overflow. + pub fn checked_add(&self, other: &Self) -> Option { + self.checked_binary_op(other, |a, b| a.checked_add(&b)) + } + + /// Checked subtraction. Returns `None` on overflow. + pub fn checked_sub(&self, other: &Self) -> Option { + self.checked_binary_op(other, |a, b| a.checked_sub(&b)) + } + + /// Checked multiplication. Returns `None` on overflow. + pub fn checked_mul(&self, other: &Self) -> Option { + self.checked_binary_op(other, |a, b| a.checked_mul(&b)) + } + + /// Checked division. Returns `None` on overflow or division by zero. + pub fn checked_div(&self, other: &Self) -> Option { + self.checked_binary_op(other, |a, b| a.checked_div(&b)) + } +} + +// Additional trait implementations for decimal types to ensure consistency. + +// Comparisons between DecimalValue types should upcast to i256 and operate in the upcast space. +// Decimal values can take on any signed scalar type, but so long as their values are the same +// they are considered the same. +// DecimalScalar handles ensuring that both values being compared have the same precision/scale. +impl PartialEq for DecimalValue { + fn eq(&self, other: &Self) -> bool { + let self_upcast = match_each_decimal_value!(self, |v| { + v.to_i256() + .vortex_expect("upcast to i256 must always succeed") + }); + let other_upcast = match_each_decimal_value!(other, |v| { + v.to_i256() + .vortex_expect("upcast to i256 must always succeed") + }); + + self_upcast == other_upcast + } +} + +impl Eq for DecimalValue {} + +impl PartialOrd for DecimalValue { + fn partial_cmp(&self, other: &Self) -> Option { + let self_upcast = match_each_decimal_value!(self, |v| { + v.to_i256() + .vortex_expect("upcast to i256 must always succeed") + }); + let other_upcast = match_each_decimal_value!(other, |v| { + v.to_i256() + .vortex_expect("upcast to i256 must always succeed") + }); + + self_upcast.partial_cmp(&other_upcast) + } +} + +// Hashing works in the upcast space similar to the other comparison and equality operators. +impl Hash for DecimalValue { + fn hash(&self, state: &mut H) { + let self_upcast = match_each_decimal_value!(self, |v| { + v.to_i256() + .vortex_expect("upcast to i256 must always succeed") + }); + self_upcast.hash(state); + } +} + +impl fmt::Display for DecimalValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + DecimalValue::I8(v8) => write!(f, "decimal8({v8})"), + DecimalValue::I16(v16) => write!(f, "decimal16({v16})"), + DecimalValue::I32(v32) => write!(f, "decimal32({v32})"), + DecimalValue::I64(v64) => write!(f, "decimal64({v64})"), + DecimalValue::I128(v128) => write!(f, "decimal128({v128})"), + DecimalValue::I256(v256) => write!(f, "decimal256({v256})"), + } + } +} diff --git a/vortex-scalar/src/typed_view/decimal/mod.rs b/vortex-scalar/src/typed_view/decimal/mod.rs new file mode 100644 index 00000000000..b4e7005ba3e --- /dev/null +++ b/vortex-scalar/src/typed_view/decimal/mod.rs @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Definition and implementation of [`DecimalScalar`] and [`DecimalValue`]. + +mod dvalue; +mod scalar; + +pub use dvalue::DecimalValue; +pub use scalar::DecimalScalar; + +#[cfg(test)] +mod tests; diff --git a/vortex-scalar/src/decimal/scalar.rs b/vortex-scalar/src/typed_view/decimal/scalar.rs similarity index 88% rename from vortex-scalar/src/decimal/scalar.rs rename to vortex-scalar/src/typed_view/decimal/scalar.rs index 07a825ed263..5c96510586d 100644 --- a/vortex-scalar/src/decimal/scalar.rs +++ b/vortex-scalar/src/typed_view/decimal/scalar.rs @@ -1,6 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +//! [`DecimalScalar`] typed view implementation. + use std::cmp::Ordering; use std::fmt; @@ -9,14 +11,12 @@ use vortex_dtype::DType; use vortex_dtype::DecimalDType; use vortex_dtype::PType; use vortex_dtype::match_each_decimal_value; -use vortex_error::VortexError; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_err; use vortex_error::vortex_panic; use crate::DecimalValue; -use crate::InnerScalarValue; use crate::NumericOperator; use crate::Scalar; use crate::ScalarValue; @@ -24,9 +24,12 @@ use crate::ScalarValue; /// A scalar value representing a decimal number with fixed precision and scale. #[derive(Debug, Clone, Copy, Hash)] pub struct DecimalScalar<'a> { - pub(super) dtype: &'a DType, - pub(super) decimal_type: DecimalDType, - pub(super) value: Option, + /// The data type of this scalar. + dtype: &'a DType, + /// The decimal type (precision and scale). + decimal_type: DecimalDType, + /// The decimal value, or [`None`] if null. + decimal_value: Option, } impl<'a> DecimalScalar<'a> { @@ -35,14 +38,14 @@ impl<'a> DecimalScalar<'a> { /// # Errors /// /// Returns an error if the data type is not a decimal type. - pub fn try_new(dtype: &'a DType, value: &ScalarValue) -> VortexResult { + pub fn try_new(dtype: &'a DType, value: Option<&ScalarValue>) -> VortexResult { let decimal_type = DecimalDType::try_from(dtype)?; - let value = value.as_decimal()?; + let value = value.map(|v| *v.as_decimal()); Ok(Self { dtype, decimal_type, - value, + decimal_value: value, }) } @@ -54,28 +57,31 @@ impl<'a> DecimalScalar<'a> { /// Returns the decimal value, or None if null. pub fn decimal_value(&self) -> Option { - self.value + self.decimal_value + } + + /// Returns whether this decimal value is zero, or `None` if null. + pub fn is_zero(&self) -> Option { + self.decimal_value.map(|v| v.is_zero()) } - /// Cast decimal scalar to another data type. + /// Casts this scalar to the given `dtype`. pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { match dtype { DType::Decimal(target_dtype, target_nullability) => { // Cast between decimal types if self.decimal_type == *target_dtype { // Same decimal type, just change nullability if needed - return Ok(Scalar::new( + return Scalar::try_new( dtype.clone(), - ScalarValue(InnerScalarValue::Decimal( - self.value.unwrap_or(DecimalValue::I128(0)), - )), - )); + self.decimal_value.map(ScalarValue::Decimal), + ); } // Different precision/scale - need to implement scaling logic // For now, we'll do a simple value preservation without scaling // TODO: Implement proper decimal scaling logic - if let Some(value) = &self.value { + if let Some(value) = &self.decimal_value { Ok(Scalar::decimal(*value, *target_dtype, *target_nullability)) } else { Ok(Scalar::null(dtype.clone())) @@ -83,7 +89,7 @@ impl<'a> DecimalScalar<'a> { } DType::Primitive(ptype, nullability) => { // Cast decimal to primitive type - if let Some(decimal_value) = &self.value { + if let Some(decimal_value) = &self.decimal_value { // Convert decimal value to primitive, accounting for scale let scale_factor = 10_i128.pow(self.decimal_type.scale() as u32); @@ -94,6 +100,9 @@ impl<'a> DecimalScalar<'a> { }) })?; + // TODO(connor): A lot of questionable stuff happening here, it would be good to + // either formally prove this is all correct or use more checked methods. + // Apply scale to get the actual value. let actual_value = scaled_value as f64 / scale_factor as f64; @@ -211,7 +220,7 @@ impl<'a> DecimalScalar<'a> { }; // Handle null cases using SQL semantics - let result_value = match (self.value, other.value) { + let result_value = match (self.decimal_value, other.decimal_value) { (None, _) | (_, None) => None, (Some(lhs), Some(rhs)) => { // Perform the operation @@ -225,7 +234,7 @@ impl<'a> DecimalScalar<'a> { }?; // Check if the result fits within the precision constraints - if operation_result.fits_in_precision(self.decimal_type)? { + if operation_result.fits_in_precision(self.decimal_type) { Some(operation_result) } else { // Result exceeds precision, return None (overflow) @@ -237,22 +246,14 @@ impl<'a> DecimalScalar<'a> { Some(DecimalScalar { dtype: result_dtype, decimal_type: self.decimal_type, - value: result_value, + decimal_value: result_value, }) } } -impl<'a> TryFrom<&'a Scalar> for DecimalScalar<'a> { - type Error = VortexError; - - fn try_from(scalar: &'a Scalar) -> Result { - DecimalScalar::try_new(scalar.dtype(), scalar.value()) - } -} - impl PartialEq for DecimalScalar<'_> { fn eq(&self, other: &Self) -> bool { - self.dtype.eq_ignore_nullability(other.dtype) && self.value == other.value + self.dtype.eq_ignore_nullability(other.dtype) && self.decimal_value == other.decimal_value } } @@ -264,13 +265,13 @@ impl PartialOrd for DecimalScalar<'_> { if !self.dtype.eq_ignore_nullability(other.dtype) { return None; } - self.value.partial_cmp(&other.value) + self.decimal_value.partial_cmp(&other.decimal_value) } } impl fmt::Display for DecimalScalar<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let Some(&decimal_value) = self.value.as_ref() else { + let Some(&decimal_value) = self.decimal_value.as_ref() else { return write!(f, "null"); }; diff --git a/vortex-scalar/src/decimal/tests.rs b/vortex-scalar/src/typed_view/decimal/tests.rs similarity index 75% rename from vortex-scalar/src/decimal/tests.rs rename to vortex-scalar/src/typed_view/decimal/tests.rs index 5388b236fa6..d37d341e50f 100644 --- a/vortex-scalar/src/decimal/tests.rs +++ b/vortex-scalar/src/typed_view/decimal/tests.rs @@ -17,7 +17,6 @@ use vortex_utils::aliases::hash_set::HashSet; use crate::DecimalValue; use crate::Scalar; -use crate::decimal::DecimalScalar; #[rstest] #[case(DecimalValue::I8(100), DecimalValue::I8(100))] @@ -57,14 +56,14 @@ fn test_decimal_cast_to_primitive() { ); // Cast to f64 should give us 123.45 - let float_result = decimal_scalar + let float_result = &decimal_scalar .cast(&DType::Primitive(PType::F64, Nullability::NonNullable)) .unwrap(); let float_value: f64 = float_result.try_into().unwrap(); assert!((float_value - 123.45).abs() < 0.001); // Cast to i32 should give us 123 (truncated) - let int_result = decimal_scalar + let int_result = &decimal_scalar .cast(&DType::Primitive(PType::I32, Nullability::NonNullable)) .unwrap(); let int_value: i32 = int_result.try_into().unwrap(); @@ -141,7 +140,7 @@ fn test_decimal_cast_negative_values() { ); // Cast to f64 should give us -56.78 - let float_result = decimal_scalar + let float_result = &decimal_scalar .cast(&DType::Primitive(PType::F64, Nullability::NonNullable)) .unwrap(); let float_value: f64 = float_result.try_into().unwrap(); @@ -180,7 +179,7 @@ fn test_decimal_cast_with_scale(#[case] value: i32, #[case] scale: i8, #[case] e Nullability::NonNullable, ); - let float_result = decimal_scalar + let float_result = &decimal_scalar .cast(&DType::Primitive(PType::F64, Nullability::NonNullable)) .unwrap(); let float_value: f64 = float_result.try_into().unwrap(); @@ -571,14 +570,14 @@ fn test_decimal_partial_ord() { DecimalDType::new(10, 2), Nullability::NonNullable, ); - let scalar1 = DecimalScalar::try_from(&decimal1).unwrap(); + let scalar1 = decimal1.as_decimal(); let decimal2 = Scalar::decimal( DecimalValue::I32(200), DecimalDType::new(10, 2), Nullability::NonNullable, ); - let scalar2 = DecimalScalar::try_from(&decimal2).unwrap(); + let scalar2 = decimal2.as_decimal(); // Same type comparison should work assert!(scalar1 < scalar2); @@ -594,7 +593,7 @@ fn test_decimal_partial_ord() { DecimalDType::new(20, 4), // Different precision/scale Nullability::NonNullable, ); - let scalar3 = DecimalScalar::try_from(&decimal3).unwrap(); + let scalar3 = decimal3.as_decimal(); assert_eq!(scalar1.partial_cmp(&scalar3), None); } @@ -605,14 +604,14 @@ fn test_decimal_eq() { DecimalDType::new(10, 2), Nullability::NonNullable, ); - let scalar1 = DecimalScalar::try_from(&decimal1).unwrap(); + let scalar1 = decimal1.as_decimal(); let decimal2 = Scalar::decimal( DecimalValue::I32(100), DecimalDType::new(10, 2), Nullability::NonNullable, ); - let scalar2 = DecimalScalar::try_from(&decimal2).unwrap(); + let scalar2 = decimal2.as_decimal(); assert_eq!(scalar1, scalar2); @@ -622,7 +621,7 @@ fn test_decimal_eq() { DecimalDType::new(10, 2), Nullability::NonNullable, ); - let scalar3 = DecimalScalar::try_from(&decimal3).unwrap(); + let scalar3 = decimal3.as_decimal(); assert_ne!(scalar1, scalar3); } @@ -650,7 +649,7 @@ fn test_decimal_scalar_try_from_errors() { DecimalDType::new(5, 2), Nullability::NonNullable, ); - let scalar = DecimalScalar::try_from(&decimal).unwrap(); + let scalar = decimal.as_decimal(); // Try to extract as wrong type let result: Result = scalar.try_into(); @@ -661,7 +660,7 @@ fn test_decimal_scalar_try_from_errors() { DecimalDType::new(10, 2), Nullability::Nullable, )); - let null_scalar = DecimalScalar::try_from(&null_decimal).unwrap(); + let null_scalar = null_decimal.as_decimal(); let result: Result = null_scalar.try_into(); assert!(result.is_err()); @@ -683,7 +682,7 @@ fn test_decimal_cast_large_scale() { // Cast to f64 let result = decimal.cast(&DType::Primitive(PType::F64, Nullability::NonNullable)); assert!(result.is_ok()); - let f64_value: f64 = result.unwrap().try_into().unwrap(); + let f64_value: f64 = (&result.unwrap()).try_into().unwrap(); assert!((f64_value - 1234.56789012345).abs() < 0.0000000001); } @@ -699,7 +698,7 @@ fn test_decimal_cast_zero_scale() { // Cast to i32 should give exact value let result = decimal.cast(&DType::Primitive(PType::I32, Nullability::NonNullable)); assert!(result.is_ok()); - let i32_value: i32 = result.unwrap().try_into().unwrap(); + let i32_value: i32 = (&result.unwrap()).try_into().unwrap(); assert_eq!(i32_value, 123456); } @@ -719,7 +718,7 @@ fn test_decimal_cast_u64_boundary() { // Test U64 boundary case let decimal = Scalar::decimal( DecimalValue::I128(18446744073709551615_i128), // U64::MAX - DecimalDType::new(20, 0), + DecimalDType::new(21, 0), Nullability::NonNullable, ); @@ -734,7 +733,7 @@ fn test_decimal_cast_u64_boundary() { // Note: The cast logic checks the float value against U64::MAX let decimal = Scalar::decimal( DecimalValue::I128(i128::MAX), // Much larger than U64::MAX - DecimalDType::new(38, 0), + DecimalDType::new(39, 0), Nullability::NonNullable, ); @@ -767,14 +766,14 @@ fn test_decimal_scalar_checked_add() { DecimalDType::new(10, 2), Nullability::NonNullable, ); - let scalar1 = DecimalScalar::try_from(&decimal1).unwrap(); + let scalar1 = decimal1.as_decimal(); let decimal2 = Scalar::decimal( DecimalValue::I64(200), DecimalDType::new(10, 2), Nullability::NonNullable, ); - let scalar2 = DecimalScalar::try_from(&decimal2).unwrap(); + let scalar2 = decimal2.as_decimal(); let result = scalar1 .checked_binary_numeric(&scalar2, NumericOperator::Add) @@ -794,14 +793,14 @@ fn test_decimal_scalar_checked_sub() { DecimalDType::new(10, 2), Nullability::NonNullable, ); - let scalar1 = DecimalScalar::try_from(&decimal1).unwrap(); + let scalar1 = decimal1.as_decimal(); let decimal2 = Scalar::decimal( DecimalValue::I64(200), DecimalDType::new(10, 2), Nullability::NonNullable, ); - let scalar2 = DecimalScalar::try_from(&decimal2).unwrap(); + let scalar2 = decimal2.as_decimal(); let result = scalar1 .checked_binary_numeric(&scalar2, NumericOperator::Sub) @@ -821,14 +820,14 @@ fn test_decimal_scalar_checked_mul() { DecimalDType::new(10, 2), Nullability::NonNullable, ); - let scalar1 = DecimalScalar::try_from(&decimal1).unwrap(); + let scalar1 = decimal1.as_decimal(); let decimal2 = Scalar::decimal( DecimalValue::I32(10), DecimalDType::new(10, 2), Nullability::NonNullable, ); - let scalar2 = DecimalScalar::try_from(&decimal2).unwrap(); + let scalar2 = decimal2.as_decimal(); let result = scalar1 .checked_binary_numeric(&scalar2, NumericOperator::Mul) @@ -848,14 +847,14 @@ fn test_decimal_scalar_checked_div() { DecimalDType::new(10, 2), Nullability::NonNullable, ); - let scalar1 = DecimalScalar::try_from(&decimal1).unwrap(); + let scalar1 = decimal1.as_decimal(); let decimal2 = Scalar::decimal( DecimalValue::I64(10), DecimalDType::new(10, 2), Nullability::NonNullable, ); - let scalar2 = DecimalScalar::try_from(&decimal2).unwrap(); + let scalar2 = decimal2.as_decimal(); let result = scalar1 .checked_binary_numeric(&scalar2, NumericOperator::Div) @@ -875,14 +874,14 @@ fn test_decimal_scalar_checked_div_by_zero() { DecimalDType::new(10, 2), Nullability::NonNullable, ); - let scalar1 = DecimalScalar::try_from(&decimal1).unwrap(); + let scalar1 = decimal1.as_decimal(); let decimal2 = Scalar::decimal( DecimalValue::I64(0), DecimalDType::new(10, 2), Nullability::NonNullable, ); - let scalar2 = DecimalScalar::try_from(&decimal2).unwrap(); + let scalar2 = decimal2.as_decimal(); let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::Div); assert_eq!(result, None); @@ -896,14 +895,14 @@ fn test_decimal_scalar_null_handling() { DecimalDType::new(10, 2), Nullability::Nullable, )); - let scalar1 = DecimalScalar::try_from(&decimal1).unwrap(); + let scalar1 = decimal1.as_decimal(); let decimal2 = Scalar::decimal( DecimalValue::I64(200), DecimalDType::new(10, 2), Nullability::NonNullable, ); - let scalar2 = DecimalScalar::try_from(&decimal2).unwrap(); + let scalar2 = decimal2.as_decimal(); let result = scalar1 .checked_binary_numeric(&scalar2, NumericOperator::Add) @@ -921,14 +920,14 @@ fn test_decimal_scalar_precision_overflow() { DecimalDType::new(3, 0), Nullability::NonNullable, ); - let scalar1 = DecimalScalar::try_from(&decimal1).unwrap(); + let scalar1 = decimal1.as_decimal(); let decimal2 = Scalar::decimal( DecimalValue::I16(2), DecimalDType::new(3, 0), Nullability::NonNullable, ); - let scalar2 = DecimalScalar::try_from(&decimal2).unwrap(); + let scalar2 = decimal2.as_decimal(); // 999 + 2 = 1001 which exceeds precision 3 let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::Add); @@ -944,14 +943,14 @@ fn test_decimal_scalar_rsub_and_rdiv() { DecimalDType::new(10, 2), Nullability::NonNullable, ); - let scalar1 = DecimalScalar::try_from(&decimal1).unwrap(); + let scalar1 = decimal1.as_decimal(); let decimal2 = Scalar::decimal( DecimalValue::I64(300), DecimalDType::new(10, 2), Nullability::NonNullable, ); - let scalar2 = DecimalScalar::try_from(&decimal2).unwrap(); + let scalar2 = decimal2.as_decimal(); // RSub: 300 - 100 = 200 let result = scalar1 @@ -971,3 +970,259 @@ fn test_decimal_scalar_rsub_and_rdiv() { Some(DecimalValue::I256(i256::from_i128(3))) ); } + +#[test] +fn test_decimal_value_from_scalar() { + let value = DecimalValue::I32(12345); + let scalar = Scalar::from(value); + + // Test extraction + let extracted: DecimalValue = DecimalValue::try_from(&scalar).unwrap(); + assert_eq!(extracted, value); + + // Test owned extraction + let extracted_owned: DecimalValue = DecimalValue::try_from(scalar).unwrap(); + assert_eq!(extracted_owned, value); +} + +#[test] +fn test_decimal_value_option_from_scalar() { + // Non-null case + let value = DecimalValue::I64(999999); + let scalar = Scalar::from(value); + + let extracted: Option = Option::try_from(&scalar).unwrap(); + assert_eq!(extracted, Some(value)); + + // Null case + let null_scalar = Scalar::null(DType::Decimal( + DecimalDType::new(10, 2), + Nullability::Nullable, + )); + + let extracted_null: Option = Option::try_from(&null_scalar).unwrap(); + assert_eq!(extracted_null, None); +} + +#[test] +fn test_decimal_value_from_conversion() { + // Test that From creates reasonable defaults + let values = vec![ + DecimalValue::I8(127), + DecimalValue::I16(32767), + DecimalValue::I32(1000000), + DecimalValue::I64(1000000000000), + DecimalValue::I128(123456789012345678901234567890), + DecimalValue::I256(i256::from_i128(987654321)), + ]; + + for value in values { + let scalar = Scalar::from(value); + assert!(!scalar.is_null()); + + // Verify we can extract it back + let extracted: DecimalValue = DecimalValue::try_from(&scalar).unwrap(); + assert_eq!(extracted, value); + } +} + +#[test] +fn test_decimal_value_checked_add() { + let a = DecimalValue::I64(100); + let b = DecimalValue::I64(200); + let result = a.checked_add(&b).unwrap(); + assert_eq!(result, DecimalValue::I256(i256::from_i128(300))); +} + +#[test] +fn test_decimal_value_checked_sub() { + let a = DecimalValue::I64(500); + let b = DecimalValue::I64(200); + let result = a.checked_sub(&b).unwrap(); + assert_eq!(result, DecimalValue::I256(i256::from_i128(300))); +} + +#[test] +fn test_decimal_value_checked_mul() { + let a = DecimalValue::I32(50); + let b = DecimalValue::I32(10); + let result = a.checked_mul(&b).unwrap(); + assert_eq!(result, DecimalValue::I256(i256::from_i128(500))); +} + +#[test] +fn test_decimal_value_checked_div() { + let a = DecimalValue::I64(1000); + let b = DecimalValue::I64(10); + let result = a.checked_div(&b).unwrap(); + assert_eq!(result, DecimalValue::I256(i256::from_i128(100))); +} + +#[test] +fn test_decimal_value_checked_div_by_zero() { + let a = DecimalValue::I64(1000); + let b = DecimalValue::I64(0); + let result = a.checked_div(&b); + assert_eq!(result, None); +} + +#[test] +fn test_decimal_value_mixed_types() { + // Test operations with different underlying types + let a = DecimalValue::I8(10); + let b = DecimalValue::I128(20); + let result = a.checked_add(&b).unwrap(); + assert_eq!(result, DecimalValue::I256(i256::from_i128(30))); +} + +#[test] +fn test_fits_in_precision_exact_boundary() { + use vortex_dtype::DecimalDType; + + // Precision 3 means max value is 10^3 - 1 = 999 + let dtype = DecimalDType::new(3, 0); + + // Test exact upper boundary: 999 should fit + let value = DecimalValue::I16(999); + assert!(value.fits_in_precision(dtype)); + + // Test just beyond upper boundary: 1000 should NOT fit + let value = DecimalValue::I16(1000); + assert!(!value.fits_in_precision(dtype)); + + // Test exact lower boundary: -999 should fit + let value = DecimalValue::I16(-999); + assert!(value.fits_in_precision(dtype)); + + // Test just beyond lower boundary: -1000 should NOT fit + let value = DecimalValue::I16(-1000); + assert!(!value.fits_in_precision(dtype)); +} + +#[test] +fn test_fits_in_precision_zero() { + use vortex_dtype::DecimalDType; + + let dtype = DecimalDType::new(5, 2); + + // Zero should always fit + let value = DecimalValue::I8(0); + assert!(value.fits_in_precision(dtype)); +} + +#[test] +fn test_fits_in_precision_small_precision() { + use vortex_dtype::DecimalDType; + + // Precision 1 means max value is 10^1 - 1 = 9 + let dtype = DecimalDType::new(1, 0); + + // Test values within range + for i in -9..=9 { + let value = DecimalValue::I8(i); + assert!( + value.fits_in_precision(dtype), + "value {} should fit in precision 1", + i + ); + } + + // Test values outside range + let value = DecimalValue::I8(10); + assert!(!value.fits_in_precision(dtype)); + let value = DecimalValue::I8(-10); + assert!(!value.fits_in_precision(dtype)); +} + +#[test] +fn test_fits_in_precision_large_precision() { + use vortex_dtype::DecimalDType; + + // Precision 38 means max value is 10^38 - 1 + let dtype = DecimalDType::new(38, 0); + + // Test i128::MAX which is approximately 1.7e38 + // This should NOT fit because 10^38 - 1 < i128::MAX + let value = DecimalValue::I128(i128::MAX); + assert!(!value.fits_in_precision(dtype)); + + // Test a large value that should fit: 10^37 + let value = DecimalValue::I128(10_i128.pow(37)); + assert!(value.fits_in_precision(dtype)); + + // Test 10^38 - 1 (the exact maximum) + let max_val = i256::from_i128(10).wrapping_pow(38) - i256::from_i128(1); + let value = DecimalValue::I256(max_val); + assert!(value.fits_in_precision(dtype)); + + // Test 10^38 (just over the maximum) + let over_max = i256::from_i128(10).wrapping_pow(38); + let value = DecimalValue::I256(over_max); + assert!(!value.fits_in_precision(dtype)); +} + +#[test] +fn test_fits_in_precision_max_precision() { + use vortex_dtype::DecimalDType; + + // Maximum precision is 76 + let dtype = DecimalDType::new(76, 0); + + // Test that reasonable i256 values fit + let value = DecimalValue::I256(i256::from_i128(i128::MAX)); + assert!(value.fits_in_precision(dtype)); + + // Test negative + let value = DecimalValue::I256(i256::from_i128(i128::MIN)); + assert!(value.fits_in_precision(dtype)); +} + +#[test] +fn test_fits_in_precision_different_scales() { + use vortex_dtype::DecimalDType; + + // Scale doesn't affect the precision check - it's only about the stored value + let value = DecimalValue::I32(12345); + + // Precision 5 with different scales + assert!(value.fits_in_precision(DecimalDType::new(5, 0))); + assert!(value.fits_in_precision(DecimalDType::new(5, 2))); + assert!(value.fits_in_precision(DecimalDType::new(5, -2))); + + // Precision 4 should fail (max value 9999, we have 12345) + assert!(!value.fits_in_precision(DecimalDType::new(4, 0))); + assert!(!value.fits_in_precision(DecimalDType::new(4, 2))); +} + +#[test] +fn test_fits_in_precision_negative_values() { + use vortex_dtype::DecimalDType; + + let dtype = DecimalDType::new(4, 2); + + // Test negative values at boundaries + // Precision 4 means max magnitude is 9999 + let value = DecimalValue::I16(-9999); + assert!(value.fits_in_precision(dtype)); + + let value = DecimalValue::I16(-10000); + assert!(!value.fits_in_precision(dtype)); + + let value = DecimalValue::I16(-1); + assert!(value.fits_in_precision(dtype)); +} + +#[test] +fn test_fits_in_precision_mixed_decimal_value_types() { + use vortex_dtype::DecimalDType; + + let dtype = DecimalDType::new(5, 0); + + // Test that different DecimalValue types work correctly + assert!(DecimalValue::I8(99).fits_in_precision(dtype)); + assert!(DecimalValue::I16(9999).fits_in_precision(dtype)); + assert!(DecimalValue::I32(99999).fits_in_precision(dtype)); + assert!(!DecimalValue::I64(100000).fits_in_precision(dtype)); + assert!(DecimalValue::I128(99999).fits_in_precision(dtype)); + assert!(!DecimalValue::I256(i256::from_i128(100000)).fits_in_precision(dtype)); +} diff --git a/vortex-scalar/src/extension.rs b/vortex-scalar/src/typed_view/extension.rs similarity index 79% rename from vortex-scalar/src/extension.rs rename to vortex-scalar/src/typed_view/extension.rs index 5983b9a13ab..72b5cca018c 100644 --- a/vortex-scalar/src/extension.rs +++ b/vortex-scalar/src/typed_view/extension.rs @@ -1,16 +1,15 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +//! [`ExtScalar`] typed view implementation. + use std::fmt::Display; use std::fmt::Formatter; use std::hash::Hash; use vortex_dtype::DType; -use vortex_dtype::ExtDType; use vortex_dtype::datetime::AnyTemporal; use vortex_dtype::extension::ExtDTypeRef; -use vortex_dtype::extension::ExtDTypeVTable; -use vortex_error::VortexError; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; @@ -23,8 +22,10 @@ use crate::ScalarValue; /// Extension types allow wrapping a storage type with custom semantics. #[derive(Debug, Clone)] pub struct ExtScalar<'a> { + /// The extension data type reference. ext_dtype: &'a ExtDTypeRef, - value: &'a ScalarValue, + /// The underlying scalar value, or [`None`] if null. + value: Option<&'a ScalarValue>, } impl Display for ExtScalar<'_> { @@ -75,12 +76,13 @@ impl Hash for ExtScalar<'_> { } impl<'a> ExtScalar<'a> { + // TODO(connor): This should really be validating the data on construction!!! /// Creates a new extension scalar from a data type and scalar value. /// /// # Errors /// /// Returns an error if the data type is not an extension type. - pub fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult { + pub fn try_new(dtype: &'a DType, value: Option<&'a ScalarValue>) -> VortexResult { let DType::Extension(ext_dtype) = dtype else { vortex_bail!("Expected extension scalar, found {}", dtype) }; @@ -90,7 +92,8 @@ impl<'a> ExtScalar<'a> { /// Returns the storage scalar of the extension scalar. pub fn storage(&self) -> Scalar { - Scalar::new(self.ext_dtype.storage_dtype().clone(), self.value.clone()) + Scalar::try_new(self.ext_dtype.storage_dtype().clone(), self.value.cloned()) + .vortex_expect("ExtScalar is invalid") } /// Returns the extension data type. @@ -98,8 +101,9 @@ impl<'a> ExtScalar<'a> { self.ext_dtype } + /// Casts this scalar to the given `dtype`. pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { - if self.value.is_null() && !dtype.is_nullable() { + if self.value.is_none() && !dtype.is_nullable() { vortex_bail!( "cannot cast extension dtype with id {} and storage type {} to {}", self.ext_dtype.id(), @@ -110,13 +114,13 @@ impl<'a> ExtScalar<'a> { if self.ext_dtype.storage_dtype().eq_ignore_nullability(dtype) { // Casting from an extension type to the underlying storage type is OK. - return Ok(Scalar::new(dtype.clone(), self.value.clone())); + return Scalar::try_new(dtype.clone(), self.value.cloned()); } if let DType::Extension(ext_dtype) = dtype && self.ext_dtype.eq_ignore_nullability(ext_dtype) { - return Ok(Scalar::new(dtype.clone(), self.value.clone())); + return Scalar::try_new(dtype.clone(), self.value.cloned()); } vortex_bail!( @@ -128,29 +132,6 @@ impl<'a> ExtScalar<'a> { } } -impl<'a> TryFrom<&'a Scalar> for ExtScalar<'a> { - type Error = VortexError; - - fn try_from(scalar: &'a Scalar) -> Result { - ExtScalar::try_new(scalar.dtype(), scalar.value()) - } -} - -impl Scalar { - /// Creates a new extension scalar wrapping the given storage value. - pub fn extension(options: V::Metadata, value: Scalar) -> Self { - let ext_dtype = ExtDType::::try_new(options, value.dtype().clone()) - .vortex_expect("Failed to create extension dtype"); - Self::new(DType::Extension(ext_dtype.erased()), value.value().clone()) - } - - /// Creates a new extension scalar wrapping the given storage value. - pub fn extension_ref(ext_dtype: ExtDTypeRef, value: Scalar) -> Self { - assert_eq!(ext_dtype.storage_dtype(), value.dtype()); - Self::new(DType::Extension(ext_dtype), value.value().clone()) - } -} - #[cfg(test)] mod tests { use vortex_dtype::DType; @@ -163,7 +144,7 @@ mod tests { use vortex_error::VortexResult; use crate::ExtScalar; - use crate::InnerScalarValue; + use crate::PValue; use crate::Scalar; use crate::ScalarValue; @@ -210,9 +191,9 @@ mod tests { Scalar::primitive(43i32, Nullability::NonNullable), ); - let ext1 = ExtScalar::try_from(&scalar1).unwrap(); - let ext2 = ExtScalar::try_from(&scalar2).unwrap(); - let ext3 = ExtScalar::try_from(&scalar3).unwrap(); + let ext1 = scalar1.as_extension(); + let ext2 = scalar2.as_extension(); + let ext3 = scalar3.as_extension(); assert_eq!(ext1, ext2); assert_ne!(ext1, ext3); @@ -229,8 +210,8 @@ mod tests { Scalar::primitive(20i32, Nullability::NonNullable), ); - let ext1 = ExtScalar::try_from(&scalar1).unwrap(); - let ext2 = ExtScalar::try_from(&scalar2).unwrap(); + let ext1 = scalar1.as_extension(); + let ext2 = scalar2.as_extension(); assert!(ext1 < ext2); assert!(ext2 > ext1); @@ -265,8 +246,8 @@ mod tests { Scalar::primitive(20i32, Nullability::NonNullable), ); - let ext1 = ExtScalar::try_from(&scalar1).unwrap(); - let ext2 = ExtScalar::try_from(&scalar2).unwrap(); + let ext1 = scalar1.as_extension(); + let ext2 = scalar2.as_extension(); // Different extension types should not be comparable assert_eq!(ext1.partial_cmp(&ext2), None); @@ -306,7 +287,7 @@ mod tests { let storage_scalar = Scalar::primitive(42i32, Nullability::NonNullable); let ext_scalar = Scalar::extension::(EmptyMetadata, storage_scalar.clone()); - let ext = ExtScalar::try_from(&ext_scalar).unwrap(); + let ext = ext_scalar.as_extension(); assert_eq!(ext.storage(), storage_scalar); } @@ -318,7 +299,7 @@ mod tests { Scalar::primitive(42i32, Nullability::NonNullable), ); - let ext = ExtScalar::try_from(&scalar).unwrap(); + let ext = scalar.as_extension(); assert_eq!(ext.ext_dtype().id(), ext_dtype.id()); assert_eq!(ext.ext_dtype(), &ext_dtype.erased()); } @@ -330,7 +311,7 @@ mod tests { Scalar::primitive(42i32, Nullability::NonNullable), ); - let ext = ExtScalar::try_from(&scalar).unwrap(); + let ext = scalar.as_extension(); // Cast to storage type let casted = ext @@ -365,7 +346,7 @@ mod tests { Scalar::primitive(42i32, Nullability::NonNullable), ); - let ext = ExtScalar::try_from(&scalar).unwrap(); + let ext = scalar.as_extension(); let ext_dtype = ext_dtype.erased(); // Cast to same extension type @@ -385,7 +366,7 @@ mod tests { Scalar::primitive(42i32, Nullability::NonNullable), ); - let ext = ExtScalar::try_from(&scalar).unwrap(); + let ext = scalar.as_extension(); // Cast to incompatible type should fail let result = ext.cast(&DType::Utf8(Nullability::NonNullable)); @@ -399,7 +380,7 @@ mod tests { Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)), ); - let ext = ExtScalar::try_from(&scalar).unwrap(); + let ext = scalar.as_extension(); // Cast null to non-nullable should fail let result = ext.cast(&DType::Primitive(PType::I32, Nullability::NonNullable)); @@ -409,9 +390,9 @@ mod tests { #[test] fn test_ext_scalar_try_new_non_extension() { let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let value = ScalarValue(InnerScalarValue::Primitive(crate::PValue::I32(42))); + let value = ScalarValue::Primitive(PValue::I32(42)); - let result = ExtScalar::try_new(&dtype, &value); + let result = ExtScalar::try_new(&dtype, Some(&value)); assert!(result.is_err()); } @@ -440,25 +421,7 @@ mod tests { Scalar::primitive(42i32, Nullability::NonNullable), ); - let ext = ExtScalar::try_from(&scalar).unwrap(); + let ext = scalar.as_extension(); assert_eq!(ext.ext_dtype().metadata::(), &1234); } - - #[test] - fn test_ext_scalar_equality_ignores_nullability() { - let scalar1 = Scalar::extension::( - EmptyMetadata, - Scalar::primitive(42i32, Nullability::NonNullable), - ); - let scalar2 = Scalar::extension::( - EmptyMetadata, - Scalar::primitive(42i32, Nullability::Nullable), - ); - - let ext1 = ExtScalar::try_from(&scalar1).unwrap(); - let ext2 = ExtScalar::try_from(&scalar2).unwrap(); - - // Equality should ignore nullability differences - assert_eq!(ext1, ext2); - } } diff --git a/vortex-scalar/src/list.rs b/vortex-scalar/src/typed_view/list.rs similarity index 74% rename from vortex-scalar/src/list.rs rename to vortex-scalar/src/typed_view/list.rs index 34b829c3013..f70a490b15d 100644 --- a/vortex-scalar/src/list.rs +++ b/vortex-scalar/src/typed_view/list.rs @@ -1,22 +1,21 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +//! [`ListScalar`] typed view implementation. + use std::fmt::Display; use std::fmt::Formatter; use std::hash::Hash; use std::sync::Arc; -use itertools::Itertools as _; +use itertools::Itertools; use vortex_dtype::DType; -use vortex_dtype::Nullability; -use vortex_error::VortexError; -use vortex_error::VortexExpect as _; +use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_err; use vortex_error::vortex_panic; -use crate::InnerScalarValue; use crate::Scalar; use crate::ScalarValue; @@ -32,9 +31,15 @@ use crate::ScalarValue; /// [`FixedSizeList`]: DType::FixedSizeList #[derive(Debug, Clone)] pub struct ListScalar<'a> { + /// The data type of this scalar. dtype: &'a DType, + /// A convenience field so that we do not have to unwrap and check the top-level `dtype` field + /// every time we want to access this. element_dtype: &'a Arc, - elements: Option>, + /// The elements of the list. `None` if the entire list is null. + /// Each element is `Option` where `None` represents a null element within the + /// list. + elements: Option<&'a [Option]>, } impl Display for ListScalar<'_> { @@ -42,19 +47,19 @@ impl Display for ListScalar<'_> { match &self.elements { None => write!(f, "null"), Some(elems) => { - let fixed_size_list_str: &dyn Display = - if let DType::FixedSizeList(_, size, _) = self.dtype { - &format!("fixed_size<{size}>") - } else { - &"" - }; + let type_str: &dyn Display = if let DType::FixedSizeList(_, size, _) = self.dtype { + &format!("fixed_size<{size}>") + } else { + &"" + }; write!( f, - "{fixed_size_list_str}[{}]", + "{type_str}[{}]", elems .iter() - .map(|e| Scalar::new(self.element_dtype().clone(), e.clone())) + .map(|e| Scalar::try_new(self.element_dtype().clone(), e.clone()) + .vortex_expect("`ListScalar` is already a valid `Scalar`")) .format(", ") ) } @@ -91,6 +96,23 @@ impl Hash for ListScalar<'_> { } impl<'a> ListScalar<'a> { + /// Attempts to create a new [`ListScalar`] from a [`DType`] and optional [`ScalarValue`]. + /// + /// # Errors + /// + /// Returns an error if the data type is not a [`DType::List`] or [`DType::FixedSizeList`]. + pub fn try_new(dtype: &'a DType, value: Option<&'a ScalarValue>) -> VortexResult { + let element_dtype = dtype + .as_any_size_list_element_opt() + .ok_or_else(|| vortex_err!("Expected list scalar, found {}", dtype))?; + + Ok(Self { + dtype, + element_dtype, + elements: value.map(|v| v.as_list()), + }) + } + /// Returns the data type of this list scalar. #[inline] pub fn dtype(&self) -> &'a DType { @@ -132,20 +154,23 @@ impl<'a> ListScalar<'a> { /// /// Returns None if the list is null or the index is out of bounds. pub fn element(&self, idx: usize) -> Option { - self.elements - .as_ref() - .and_then(|l| l.get(idx)) - .map(|value| Scalar::new(self.element_dtype().clone(), value.clone())) + self.elements.and_then(|l| l.get(idx)).map(|value| { + // SAFETY: `ListScalar` is already a valid `Scalar`. + unsafe { Scalar::new_unchecked(self.element_dtype().clone(), value.clone()) } + }) } /// Returns all elements in the list as a vector of scalars. /// /// Returns None if the list is null. pub fn elements(&self) -> Option> { - self.elements.as_ref().map(|elems| { + self.elements.map(|elems| { elems .iter() - .map(|e| Scalar::new(self.element_dtype().clone(), e.clone())) + .map(|e| { + // SAFETY: `ListScalar` is already a valid `Scalar`. + unsafe { Scalar::new_unchecked(self.element_dtype().clone(), e.clone()) } + }) .collect_vec() }) } @@ -154,7 +179,7 @@ impl<'a> ListScalar<'a> { /// /// # Panics /// - /// Panics if the target [`DType`] is not a [`List`]: or [`FixedSizeList`], or if trying to cast + /// Panics if the target [`DType`] is not a [`List`] or [`FixedSizeList`], or if trying to cast /// to a [`FixedSizeList`] with the incorrect number of elements. /// /// [`List`]: DType::List @@ -180,115 +205,21 @@ impl<'a> ListScalar<'a> { ) } - Ok(Scalar::new( + Scalar::try_new( dtype.clone(), - ScalarValue(InnerScalarValue::List( + Some(ScalarValue::List( self.elements - .as_ref() - .vortex_expect("nullness handled in Scalar::cast") + .ok_or_else(|| vortex_err!("nullness should be handled in Scalar::cast"))? .iter() .map(|element| { // Recursively cast the elements of the list. - Scalar::new(DType::clone(self.element_dtype), element.clone()) + Scalar::try_new(DType::clone(self.element_dtype), element.clone())? .cast(target_element_dtype) .map(|x| x.into_value()) }) - .collect::>>()?, + .collect::>>>()?, )), - )) - } -} - -/// A helper enum for creating a [`ListScalar`]. -enum ListKind { - Variable, - FixedSize, -} - -/// Helper functions to create a [`ListScalar`] as a [`Scalar`]. -impl Scalar { - fn create_list( - element_dtype: impl Into>, - children: Vec, - nullability: Nullability, - list_kind: ListKind, - ) -> Self { - let element_dtype = element_dtype.into(); - - let children: Arc<[ScalarValue]> = children - .into_iter() - .map(|child| { - if child.dtype() != &*element_dtype { - vortex_panic!( - "tried to create list of {} with values of type {}", - element_dtype, - child.dtype() - ); - } - child.into_value() - }) - .collect(); - let size: u32 = children - .len() - .try_into() - .vortex_expect("tried to create a list that was too large"); - - let dtype = match list_kind { - ListKind::Variable => DType::List(element_dtype, nullability), - ListKind::FixedSize => DType::FixedSizeList(element_dtype, size, nullability), - }; - - Self::new(dtype, ScalarValue(InnerScalarValue::List(children))) - } - - /// Creates a new list scalar with the given element type and children. - /// - /// # Panics - /// - /// Panics if any child scalar has a different type than the element type, or if there are too - /// many children. - pub fn list( - element_dtype: impl Into>, - children: Vec, - nullability: Nullability, - ) -> Self { - Self::create_list(element_dtype, children, nullability, ListKind::Variable) - } - - /// Creates a new empty list scalar with the given element type. - pub fn list_empty(element_dtype: Arc, nullability: Nullability) -> Self { - Self::create_list(element_dtype, vec![], nullability, ListKind::Variable) - } - - /// Creates a new fixed-size list scalar with the given element type and children. - /// - /// # Panics - /// - /// Panics if any child scalar has a different type than the element type, or if there are too - /// many children. - pub fn fixed_size_list( - element_dtype: impl Into>, - children: Vec, - nullability: Nullability, - ) -> Self { - Self::create_list(element_dtype, children, nullability, ListKind::FixedSize) - } -} - -impl<'a> TryFrom<&'a Scalar> for ListScalar<'a> { - type Error = VortexError; - - fn try_from(value: &'a Scalar) -> Result { - let element_dtype = value - .dtype() - .as_any_size_list_element_opt() - .ok_or_else(|| vortex_err!("Expected list scalar, found {}", value.dtype()))?; - - Ok(Self { - dtype: value.dtype(), - element_dtype, - elements: value.value().as_list()?.cloned(), - }) + ) } } @@ -312,7 +243,7 @@ mod tests { ]; let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable); - let list = ListScalar::try_from(&list_scalar).unwrap(); + let list = list_scalar.as_list(); assert_eq!(list.len(), 3); assert!(!list.is_empty()); assert!(!list.is_null()); @@ -323,7 +254,7 @@ mod tests { let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)); let list_scalar = Scalar::list_empty(element_dtype, Nullability::NonNullable); - let list = ListScalar::try_from(&list_scalar).unwrap(); + let list = list_scalar.as_list(); assert_eq!(list.len(), 0); assert!(list.is_empty()); assert!(!list.is_null()); @@ -334,7 +265,7 @@ mod tests { let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::Nullable)); let null = Scalar::null(DType::List(element_dtype, Nullability::Nullable)); - let list = ListScalar::try_from(&null).unwrap(); + let list = null.as_list(); assert!(list.is_empty()); assert!(list.is_null()); } @@ -349,7 +280,7 @@ mod tests { ]; let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable); - let list = ListScalar::try_from(&list_scalar).unwrap(); + let list = list_scalar.as_list(); // Test element access let elem0 = list.element(0).unwrap(); @@ -374,7 +305,7 @@ mod tests { ]; let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable); - let list = ListScalar::try_from(&list_scalar).unwrap(); + let list = list_scalar.as_list(); let elements = list.elements().unwrap(); assert_eq!(elements.len(), 2); @@ -397,7 +328,7 @@ mod tests { ]; let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable); - let list = ListScalar::try_from(&list_scalar).unwrap(); + let list = list_scalar.as_list(); let display = format!("{list}"); assert!(display.contains("1")); assert!(display.contains("2")); @@ -418,8 +349,8 @@ mod tests { ]; let list_scalar2 = Scalar::list(element_dtype, children2, Nullability::NonNullable); - let list1 = ListScalar::try_from(&list_scalar1).unwrap(); - let list2 = ListScalar::try_from(&list_scalar2).unwrap(); + let list1 = list_scalar1.as_list(); + let list2 = list_scalar2.as_list(); assert_eq!(list1, list2); } @@ -439,8 +370,8 @@ mod tests { ]; let list_scalar2 = Scalar::list(element_dtype, children2, Nullability::NonNullable); - let list1 = ListScalar::try_from(&list_scalar1).unwrap(); - let list2 = ListScalar::try_from(&list_scalar2).unwrap(); + let list1 = list_scalar1.as_list(); + let list2 = list_scalar2.as_list(); assert_ne!(list1, list2); } @@ -455,8 +386,8 @@ mod tests { let children2 = vec![Scalar::primitive(2i32, Nullability::NonNullable)]; let list_scalar2 = Scalar::list(element_dtype, children2, Nullability::NonNullable); - let list1 = ListScalar::try_from(&list_scalar1).unwrap(); - let list2 = ListScalar::try_from(&list_scalar2).unwrap(); + let list1 = list_scalar1.as_list(); + let list2 = list_scalar2.as_list(); assert!(list1 < list2); } @@ -472,8 +403,8 @@ mod tests { let children2 = vec![Scalar::primitive(1i64, Nullability::NonNullable)]; let list_scalar2 = Scalar::list(element_dtype2, children2, Nullability::NonNullable); - let list1 = ListScalar::try_from(&list_scalar1).unwrap(); - let list2 = ListScalar::try_from(&list_scalar2).unwrap(); + let list1 = list_scalar1.as_list(); + let list2 = list_scalar2.as_list(); assert!(list1.partial_cmp(&list2).is_none()); } @@ -491,7 +422,7 @@ mod tests { ]; let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable); - let list = ListScalar::try_from(&list_scalar).unwrap(); + let list = list_scalar.as_list(); let mut hasher1 = DefaultHasher::new(); list.hash(&mut hasher1); @@ -536,7 +467,7 @@ mod tests { ]; let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable); - let list = ListScalar::try_from(&list_scalar).unwrap(); + let list = list_scalar.as_list(); // Cast to list with i64 elements let target_dtype = DType::List( @@ -545,7 +476,7 @@ mod tests { ); let casted = list.cast(&target_dtype).unwrap(); - let casted_list = ListScalar::try_from(&casted).unwrap(); + let casted_list = casted.as_list(); assert_eq!(casted_list.len(), 2); let elem0 = casted_list.element(0).unwrap(); @@ -565,8 +496,7 @@ mod tests { #[test] fn test_try_from_wrong_dtype() { let scalar = Scalar::primitive(42i32, Nullability::NonNullable); - let result = ListScalar::try_from(&scalar); - assert!(result.is_err()); + assert!(scalar.as_list_opt().is_none()); } #[test] @@ -578,7 +508,7 @@ mod tests { ]; let list_scalar = Scalar::list(element_dtype, children, Nullability::NonNullable); - let list = ListScalar::try_from(&list_scalar).unwrap(); + let list = list_scalar.as_list(); assert_eq!(list.len(), 2); let elem0 = list.element(0).unwrap(); @@ -620,11 +550,11 @@ mod tests { Nullability::NonNullable, ); - let list = ListScalar::try_from(&outer_list).unwrap(); + let list = outer_list.as_list(); assert_eq!(list.len(), 2); let nested_list = list.element(0).unwrap(); - let nested = ListScalar::try_from(&nested_list).unwrap(); + let nested = nested_list.as_list(); assert_eq!(nested.len(), 2); } } diff --git a/vortex-scalar/src/typed_view/mod.rs b/vortex-scalar/src/typed_view/mod.rs new file mode 100644 index 00000000000..067e3a4d4a5 --- /dev/null +++ b/vortex-scalar/src/typed_view/mod.rs @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Definitions of typed views into the [`Scalar`] type. +//! +//! Since the [`Scalar`] type is dynamically typed, it is useful to have a typed version of it when +//! we know we are working with a specific kind of [`Scalar`]. +//! +//! All the types defined in this module are either typed views into [`Scalar`] or +//! easier-to-work-with value types ([`PValue`] and [`DecimalValue`]). +//! +//! Note that we do **not** have a typed scalar for `FixedSizeList`, as a singular list value has no +//! notion of a "fixed size" in isolation. We use the same [`ListScalar`] for both `FixedSizeList` +//! and `List` `DType`s. +//! +//! [`Scalar`]: crate::Scalar + +mod binary; +mod bool; +mod decimal; +mod extension; +mod list; +mod primitive; +mod struct_; +mod utf8; + +pub use binary::*; +pub use bool::*; +pub use decimal::*; +pub use extension::*; +pub use list::*; +pub use primitive::*; +pub use struct_::*; +pub use utf8::*; diff --git a/vortex-scalar/src/typed_view/primitive/mod.rs b/vortex-scalar/src/typed_view/primitive/mod.rs new file mode 100644 index 00000000000..2305934b864 --- /dev/null +++ b/vortex-scalar/src/typed_view/primitive/mod.rs @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Definition and implementation of [`PrimitiveScalar`] and [`PValue`]. + +mod numeric_operator; +mod pvalue; +mod scalar; + +pub use numeric_operator::NumericOperator; +pub use pvalue::PValue; +pub use scalar::PrimitiveScalar; + +#[cfg(test)] +mod tests; diff --git a/vortex-scalar/src/typed_view/primitive/numeric_operator.rs b/vortex-scalar/src/typed_view/primitive/numeric_operator.rs new file mode 100644 index 00000000000..03df9da8a59 --- /dev/null +++ b/vortex-scalar/src/typed_view/primitive/numeric_operator.rs @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! [`NumericOperator`] enum for arithmetic operations on primitive scalars. + +use std::fmt; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +/// Binary element-wise operations on two arrays or two scalars. +pub enum NumericOperator { + /// Binary element-wise addition of two arrays or of two scalars. + /// + /// Errs at runtime if the sum would overflow or underflow. + Add, + /// Binary element-wise subtraction of two arrays or of two scalars. + Sub, + /// Same as [NumericOperator::Sub] but with the parameters flipped: `right - left`. + RSub, + /// Binary element-wise multiplication of two arrays or of two scalars. + Mul, + /// Binary element-wise division of two arrays or of two scalars. + Div, + /// Same as [NumericOperator::Div] but with the parameters flipped: `right / left`. + RDiv, + // Missing from arrow-rs: + // Min, + // Max, + // Pow, +} + +impl fmt::Display for NumericOperator { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +impl NumericOperator { + /// Returns the operator with swapped operands (e.g., Sub becomes RSub). + pub fn swap(self) -> Self { + match self { + NumericOperator::Add => NumericOperator::Add, + NumericOperator::Sub => NumericOperator::RSub, + NumericOperator::RSub => NumericOperator::Sub, + NumericOperator::Mul => NumericOperator::Mul, + NumericOperator::Div => NumericOperator::RDiv, + NumericOperator::RDiv => NumericOperator::Div, + } + } +} diff --git a/vortex-scalar/src/pvalue.rs b/vortex-scalar/src/typed_view/primitive/pvalue.rs similarity index 57% rename from vortex-scalar/src/pvalue.rs rename to vortex-scalar/src/typed_view/primitive/pvalue.rs index 021979010b5..3f09f9addcd 100644 --- a/vortex-scalar/src/pvalue.rs +++ b/vortex-scalar/src/typed_view/primitive/pvalue.rs @@ -1,6 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +//! [`PValue`] enum representing a typed primitive value. + use core::fmt::Display; use std::cmp::Ordering; use std::hash::Hash; @@ -8,6 +10,7 @@ use std::hash::Hasher; use num_traits::NumCast; use num_traits::ToPrimitive; +use num_traits::Zero; use paste::paste; use vortex_dtype::NativePType; use vortex_dtype::PType; @@ -124,6 +127,7 @@ impl ToBytes for PValue { } } +/// Generates an `as_` accessor method on [`PValue`]. macro_rules! as_primitive { ($T:ty, $PT:tt) => { paste! { @@ -136,8 +140,25 @@ macro_rules! as_primitive { } impl PValue { + /// Returns true if this decimal value is zero. + pub fn is_zero(&self) -> bool { + matches!( + self, + PValue::U8(0) + | PValue::U16(0) + | PValue::U32(0) + | PValue::U64(0) + | PValue::I8(0) + | PValue::I16(0) + | PValue::I32(0) + | PValue::I64(0) + ) || matches!(self, PValue::F16(f) if f.to_f32().is_some_and(|f| f.is_zero())) + || matches!(self, PValue::F32(f) if f.is_zero()) + || matches!(self, PValue::F64(f) if f.is_zero()) + } + /// Creates a zero value for the given primitive type. - pub fn zero(ptype: PType) -> PValue { + pub fn zero(ptype: &PType) -> PValue { match ptype { PType::U8 => PValue::U8(0), PType::U16 => PValue::U16(0), @@ -321,6 +342,7 @@ macro_rules! int_pvalue { }; } +/// Implements [`TryFrom`] for a floating-point type. macro_rules! float_pvalue { ($T:ty, $PT:tt) => { impl TryFrom for $T { @@ -359,6 +381,7 @@ float_pvalue!(f16, F16); float_pvalue!(f32, F32); float_pvalue!(f64, F64); +/// Implements [`From`] for [`PValue`]. macro_rules! impl_pvalue { ($T:ty, $PT:tt) => { impl From<$T> for PValue { @@ -406,6 +429,7 @@ impl Display for PValue { } } +/// Coercion trait for widening or reinterpreting a [`PValue`] into a concrete type. pub(super) trait CoercePValue: Sized { /// Coerce value from a compatible bit representation using into given type. /// @@ -414,6 +438,7 @@ pub(super) trait CoercePValue: Sized { fn coerce(value: PValue) -> VortexResult; } +/// Implements [`CoercePValue`] for an integer type. macro_rules! int_coerce { ($T:ty) => { impl CoercePValue for $T { @@ -534,400 +559,3 @@ impl CoercePValue for f64 { } } } - -#[cfg(test)] -mod test { - use std::cmp::Ordering; - - use vortex_dtype::FromPrimitiveOrF16; - use vortex_dtype::PType; - use vortex_dtype::ToBytes; - use vortex_dtype::half::f16; - use vortex_utils::aliases::hash_set::HashSet; - - use crate::PValue; - use crate::pvalue::CoercePValue; - - #[test] - pub fn test_is_instance_of() { - assert!(PValue::U8(10).is_instance_of(&PType::U8)); - assert!(!PValue::U8(10).is_instance_of(&PType::U16)); - assert!(!PValue::U8(10).is_instance_of(&PType::I8)); - assert!(!PValue::U8(10).is_instance_of(&PType::F16)); - - assert!(PValue::I8(10).is_instance_of(&PType::I8)); - assert!(!PValue::I8(10).is_instance_of(&PType::I16)); - assert!(!PValue::I8(10).is_instance_of(&PType::U8)); - assert!(!PValue::I8(10).is_instance_of(&PType::F16)); - - assert!(PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::F16)); - assert!(!PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::F32)); - assert!(!PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::U16)); - assert!(!PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::I16)); - } - - #[test] - fn test_compare_different_types() { - assert_eq!( - PValue::I8(4).partial_cmp(&PValue::I8(5)), - Some(Ordering::Less) - ); - assert_eq!( - PValue::I8(4).partial_cmp(&PValue::I64(5)), - Some(Ordering::Less) - ); - } - - #[test] - fn test_hash() { - let set = HashSet::from([ - PValue::U8(1), - PValue::U16(1), - PValue::U32(1), - PValue::U64(1), - PValue::I8(1), - PValue::I16(1), - PValue::I32(1), - PValue::I64(1), - PValue::I8(-1), - PValue::I16(-1), - PValue::I32(-1), - PValue::I64(-1), - ]); - assert_eq!(set.len(), 2); - } - - #[test] - fn test_zero_values() { - assert_eq!(PValue::zero(PType::U8), PValue::U8(0)); - assert_eq!(PValue::zero(PType::U16), PValue::U16(0)); - assert_eq!(PValue::zero(PType::U32), PValue::U32(0)); - assert_eq!(PValue::zero(PType::U64), PValue::U64(0)); - assert_eq!(PValue::zero(PType::I8), PValue::I8(0)); - assert_eq!(PValue::zero(PType::I16), PValue::I16(0)); - assert_eq!(PValue::zero(PType::I32), PValue::I32(0)); - assert_eq!(PValue::zero(PType::I64), PValue::I64(0)); - assert_eq!(PValue::zero(PType::F16), PValue::F16(f16::from_f32(0.0))); - assert_eq!(PValue::zero(PType::F32), PValue::F32(0.0)); - assert_eq!(PValue::zero(PType::F64), PValue::F64(0.0)); - } - - #[test] - fn test_ptype() { - assert_eq!(PValue::U8(10).ptype(), PType::U8); - assert_eq!(PValue::U16(10).ptype(), PType::U16); - assert_eq!(PValue::U32(10).ptype(), PType::U32); - assert_eq!(PValue::U64(10).ptype(), PType::U64); - assert_eq!(PValue::I8(10).ptype(), PType::I8); - assert_eq!(PValue::I16(10).ptype(), PType::I16); - assert_eq!(PValue::I32(10).ptype(), PType::I32); - assert_eq!(PValue::I64(10).ptype(), PType::I64); - assert_eq!(PValue::F16(f16::from_f32(10.0)).ptype(), PType::F16); - assert_eq!(PValue::F32(10.0).ptype(), PType::F32); - assert_eq!(PValue::F64(10.0).ptype(), PType::F64); - } - - #[test] - fn test_reinterpret_cast_same_type() { - let value = PValue::U32(42); - assert_eq!(value.reinterpret_cast(PType::U32), value); - } - - #[test] - fn test_reinterpret_cast_u8_i8() { - let value = PValue::U8(255); - let casted = value.reinterpret_cast(PType::I8); - assert_eq!(casted, PValue::I8(-1)); - } - - #[test] - fn test_reinterpret_cast_u16_types() { - let value = PValue::U16(12345); - - // U16 -> I16 - let as_i16 = value.reinterpret_cast(PType::I16); - assert_eq!(as_i16, PValue::I16(12345)); - - // U16 -> F16 - let as_f16 = value.reinterpret_cast(PType::F16); - assert_eq!(as_f16, PValue::F16(f16::from_bits(12345))); - } - - #[test] - fn test_reinterpret_cast_u32_types() { - let value = PValue::U32(0x3f800000); // 1.0 in float bits - - // U32 -> F32 - let as_f32 = value.reinterpret_cast(PType::F32); - assert_eq!(as_f32, PValue::F32(1.0)); - - // U32 -> I32 - let value2 = PValue::U32(0x80000000); - let as_i32 = value2.reinterpret_cast(PType::I32); - assert_eq!(as_i32, PValue::I32(i32::MIN)); - } - - #[test] - fn test_reinterpret_cast_f32_to_u32() { - let value = PValue::F32(1.0); - let as_u32 = value.reinterpret_cast(PType::U32); - assert_eq!(as_u32, PValue::U32(0x3f800000)); - } - - #[test] - fn test_reinterpret_cast_f64_to_i64() { - let value = PValue::F64(1.0); - let as_i64 = value.reinterpret_cast(PType::I64); - assert_eq!(as_i64, PValue::I64(0x3ff0000000000000_i64)); - } - - #[test] - #[should_panic(expected = "Cannot reinterpret cast between types of different widths")] - fn test_reinterpret_cast_different_widths() { - let value = PValue::U8(42); - let _ = value.reinterpret_cast(PType::U16); - } - - #[test] - fn test_as_primitive_conversions() { - // Test as_u8 - assert_eq!(PValue::U8(42).as_u8(), Some(42)); - assert_eq!(PValue::I8(42).as_u8(), Some(42)); - assert_eq!(PValue::U16(255).as_u8(), Some(255)); - assert_eq!(PValue::U16(256).as_u8(), None); // Overflow - - // Test as_i32 - assert_eq!(PValue::I32(42).as_i32(), Some(42)); - assert_eq!(PValue::U32(42).as_i32(), Some(42)); - assert_eq!(PValue::I64(42).as_i32(), Some(42)); - assert_eq!(PValue::U64(u64::MAX).as_i32(), None); // Overflow - - // Test as_f64 - assert_eq!(PValue::F64(42.5).as_f64(), Some(42.5)); - assert_eq!(PValue::F32(42.5).as_f64(), Some(42.5f64)); - assert_eq!(PValue::I32(42).as_f64(), Some(42.0)); - } - - #[test] - fn test_try_from_pvalue_integers() { - // Test u8 conversion - assert_eq!(u8::try_from(PValue::U8(42)).unwrap(), 42); - assert_eq!(u8::try_from(PValue::I8(42)).unwrap(), 42); - assert!(u8::try_from(PValue::I8(-1)).is_err()); - assert!(u8::try_from(PValue::U16(256)).is_err()); - - // Test i32 conversion - assert_eq!(i32::try_from(PValue::I32(42)).unwrap(), 42); - assert_eq!(i32::try_from(PValue::I16(-100)).unwrap(), -100); - assert!(i32::try_from(PValue::U64(u64::MAX)).is_err()); - - // Float to int should fail - assert!(i32::try_from(PValue::F32(42.5)).is_err()); - } - - #[test] - fn test_try_from_pvalue_floats() { - // Test f32 conversion - assert_eq!(f32::try_from(PValue::F32(42.5)).unwrap(), 42.5); - assert_eq!(f32::try_from(PValue::I32(42)).unwrap(), 42.0); - assert_eq!(f32::try_from(PValue::U8(255)).unwrap(), 255.0); - - // Test f64 conversion - assert_eq!(f64::try_from(PValue::F64(42.5)).unwrap(), 42.5); - assert_eq!(f64::try_from(PValue::F32(42.5)).unwrap(), 42.5f64); - assert_eq!(f64::try_from(PValue::I64(-100)).unwrap(), -100.0); - } - - #[test] - fn test_from_usize() { - let value: PValue = 42usize.into(); - assert_eq!(value, PValue::U64(42)); - - let max_value: PValue = usize::MAX.into(); - assert_eq!(max_value, PValue::U64(usize::MAX as u64)); - } - - #[test] - fn test_equality_cross_types() { - // Same numeric value, different types - assert_eq!(PValue::U8(42), PValue::U16(42)); - assert_eq!(PValue::U8(42), PValue::U32(42)); - assert_eq!(PValue::U8(42), PValue::U64(42)); - assert_eq!(PValue::I8(42), PValue::I16(42)); - assert_eq!(PValue::I8(42), PValue::I32(42)); - assert_eq!(PValue::I8(42), PValue::I64(42)); - - // Unsigned vs signed with same value (they compare equal even though different categories) - assert_eq!(PValue::U8(42), PValue::I8(42)); - assert_eq!(PValue::U32(42), PValue::I32(42)); - - // Float equality - assert_eq!(PValue::F32(42.0), PValue::F32(42.0)); - assert_eq!(PValue::F64(42.0), PValue::F64(42.0)); - assert_ne!(PValue::F32(42.0), PValue::F64(42.0)); // Different types - - // Float vs int should not be equal - assert_ne!(PValue::F32(42.0), PValue::I32(42)); - } - - #[test] - fn test_partial_ord_cross_types() { - // Unsigned comparisons - assert_eq!( - PValue::U8(10).partial_cmp(&PValue::U16(20)), - Some(Ordering::Less) - ); - assert_eq!( - PValue::U32(30).partial_cmp(&PValue::U8(20)), - Some(Ordering::Greater) - ); - - // Signed comparisons - assert_eq!( - PValue::I8(-10).partial_cmp(&PValue::I64(0)), - Some(Ordering::Less) - ); - assert_eq!( - PValue::I32(10).partial_cmp(&PValue::I16(10)), - Some(Ordering::Equal) - ); - - // Float comparisons (same type only) - assert_eq!( - PValue::F32(1.0).partial_cmp(&PValue::F32(2.0)), - Some(Ordering::Less) - ); - assert_eq!( - PValue::F64(2.0).partial_cmp(&PValue::F64(1.0)), - Some(Ordering::Greater) - ); - - // Cross-category comparisons - unsigned vs signed work, float vs int don't - assert_eq!( - PValue::U32(42).partial_cmp(&PValue::I32(42)), - Some(Ordering::Equal) - ); // Actually works - assert_eq!(PValue::F32(42.0).partial_cmp(&PValue::I32(42)), None); - assert_eq!(PValue::F32(42.0).partial_cmp(&PValue::F64(42.0)), None); - } - - #[test] - fn test_to_le_bytes() { - assert_eq!(PValue::U8(0x12).to_le_bytes(), &[0x12]); - assert_eq!(PValue::U16(0x1234).to_le_bytes(), &[0x34, 0x12]); - assert_eq!( - PValue::U32(0x12345678).to_le_bytes(), - &[0x78, 0x56, 0x34, 0x12] - ); - - assert_eq!(PValue::I8(-1).to_le_bytes(), &[0xFF]); - assert_eq!(PValue::I16(-1).to_le_bytes(), &[0xFF, 0xFF]); - - let f32_bytes = PValue::F32(1.0).to_le_bytes(); - assert_eq!(f32_bytes.len(), 4); - - let f64_bytes = PValue::F64(1.0).to_le_bytes(); - assert_eq!(f64_bytes.len(), 8); - } - - #[test] - fn test_f16_special_values() { - // Test F16 NaN handling - let nan = f16::NAN; - let nan_value = PValue::F16(nan); - assert!(nan_value.as_f16().unwrap().is_nan()); - - // Test F16 infinity - let inf = f16::INFINITY; - let inf_value = PValue::F16(inf); - assert!(inf_value.as_f16().unwrap().is_infinite()); - - // Test F16 comparison with NaN - assert_eq!( - PValue::F16(nan).partial_cmp(&PValue::F16(nan)), - Some(Ordering::Equal) - ); - } - - #[test] - fn test_coerce_pvalue() { - // Test integer coercion - assert_eq!(u32::coerce(PValue::U16(42)).unwrap(), 42u32); - assert_eq!(i64::coerce(PValue::I32(-42)).unwrap(), -42i64); - - // Test float coercion from bits - assert_eq!(f32::coerce(PValue::U32(0x3f800000)).unwrap(), 1.0f32); - assert_eq!( - f64::coerce(PValue::U64(0x3ff0000000000000)).unwrap(), - 1.0f64 - ); - } - - #[test] - fn test_coerce_f16_beyond_u16_max() { - // Test U32 to f16 coercion within valid range - assert!(f16::coerce(PValue::U32(u16::MAX as u32)).is_ok()); - assert_eq!( - f16::coerce(PValue::U32(0x3C00)).unwrap(), - f16::from_bits(0x3C00) // 1.0 in f16 - ); - - // Test U32 to f16 coercion beyond u16::MAX - should fail - assert!(f16::coerce(PValue::U32((u16::MAX as u32) + 1)).is_err()); - assert!(f16::coerce(PValue::U32(u32::MAX)).is_err()); - - // Test U64 to f16 coercion within valid range - assert!(f16::coerce(PValue::U64(u16::MAX as u64)).is_ok()); - assert_eq!( - f16::coerce(PValue::U64(0x3C00)).unwrap(), - f16::from_bits(0x3C00) // 1.0 in f16 - ); - - // Test U64 to f16 coercion beyond u16::MAX - should fail - assert!(f16::coerce(PValue::U64((u16::MAX as u64) + 1)).is_err()); - assert!(f16::coerce(PValue::U64(u32::MAX as u64)).is_err()); - assert!(f16::coerce(PValue::U64(u64::MAX)).is_err()); - } - - #[test] - fn test_coerce_f32_beyond_u32_max() { - // Test U64 to f32 coercion within valid range - assert!(f32::coerce(PValue::U64(u32::MAX as u64)).is_ok()); - assert_eq!( - f32::coerce(PValue::U64(0x3f800000)).unwrap(), - 1.0f32 // 0x3f800000 is 1.0 in f32 - ); - - // Test U64 to f32 coercion beyond u32::MAX - should fail - assert!(f32::coerce(PValue::U64((u32::MAX as u64) + 1)).is_err()); - assert!(f32::coerce(PValue::U64(u64::MAX)).is_err()); - - // Test smaller types still work - assert!(f32::coerce(PValue::U8(255)).is_ok()); - assert!(f32::coerce(PValue::U16(u16::MAX)).is_ok()); - assert!(f32::coerce(PValue::U32(u32::MAX)).is_ok()); - } - - #[test] - fn test_coerce_f64_all_unsigned() { - // Test f64 can accept all unsigned integer values as bit patterns - assert!(f64::coerce(PValue::U8(u8::MAX)).is_ok()); - assert!(f64::coerce(PValue::U16(u16::MAX)).is_ok()); - assert!(f64::coerce(PValue::U32(u32::MAX)).is_ok()); - assert!(f64::coerce(PValue::U64(u64::MAX)).is_ok()); - - // Verify specific bit patterns - assert_eq!( - f64::coerce(PValue::U64(0x3ff0000000000000)).unwrap(), - 1.0f64 // 0x3ff0000000000000 is 1.0 in f64 - ); - } - - #[test] - fn test_f16_nans_equal() { - let nan1 = f16::from_le_bytes([154, 253]); - assert!(nan1.is_nan()); - let nan3 = f16::from_f16(nan1).unwrap(); - assert_eq!(nan1.to_bits(), nan3.to_bits(),); - } -} diff --git a/vortex-scalar/src/typed_view/primitive/scalar.rs b/vortex-scalar/src/typed_view/primitive/scalar.rs new file mode 100644 index 00000000000..da2fa8ccd12 --- /dev/null +++ b/vortex-scalar/src/typed_view/primitive/scalar.rs @@ -0,0 +1,393 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! [`PrimitiveScalar`] typed view implementation. + +use std::any::type_name; +use std::cmp::Ordering; +use std::fmt::Debug; +use std::fmt::Display; +use std::fmt::Formatter; +use std::ops::Add; +use std::ops::Sub; + +use num_traits::CheckedAdd; +use num_traits::CheckedDiv; +use num_traits::CheckedMul; +use num_traits::CheckedSub; +use vortex_dtype::DType; +use vortex_dtype::FromPrimitiveOrF16; +use vortex_dtype::NativePType; +use vortex_dtype::PType; +use vortex_dtype::match_each_native_ptype; +use vortex_error::VortexError; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_err; +use vortex_error::vortex_panic; + +use super::pvalue::CoercePValue; +use crate::NumericOperator; +use crate::PValue; +use crate::Scalar; +use crate::ScalarValue; + +/// A scalar value representing a primitive type. +/// +/// This type provides a view into a primitive scalar value of any primitive type +/// (integers, floats) with various bit widths. +#[derive(Debug, Clone, Copy, Hash)] +pub struct PrimitiveScalar<'a> { + /// The data type of this scalar. + dtype: &'a DType, + /// The primitive type. + ptype: PType, + /// The primitive value, or [`None`] if null. + pvalue: Option, +} + +impl Display for PrimitiveScalar<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self.pvalue { + None => write!(f, "null"), + Some(pv) => write!(f, "{pv}"), + } + } +} + +impl PartialEq for PrimitiveScalar<'_> { + fn eq(&self, other: &Self) -> bool { + self.dtype.eq_ignore_nullability(other.dtype) && self.pvalue == other.pvalue + } +} + +impl Eq for PrimitiveScalar<'_> {} + +/// Ord is not implemented since it's undefined for different PTypes +impl PartialOrd for PrimitiveScalar<'_> { + fn partial_cmp(&self, other: &Self) -> Option { + if !self.dtype.eq_ignore_nullability(other.dtype) { + return None; + } + self.pvalue.partial_cmp(&other.pvalue) + } +} + +impl<'a> PrimitiveScalar<'a> { + /// Creates a new primitive scalar from a data type and scalar value. + /// + /// # Errors + /// + /// Returns an error if the data type is not a primitive type or if the value + /// cannot be converted to the expected primitive type. + pub fn try_new(dtype: &'a DType, value: Option<&ScalarValue>) -> VortexResult { + let ptype = PType::try_from(dtype)?; + + // Read the serialized value into the correct PValue. + // The serialized form may come back over the wire as e.g. any integer type. + let pvalue = match value { + None => None, + Some(v) => { + let pv = v.as_primitive(); + match_each_native_ptype!(ptype, |T| { Some(PValue::from(::coerce(*pv)?)) }) + } + }; + + Ok(Self { + dtype, + ptype, + pvalue, + }) + } + + /// Returns the data type of this primitive scalar. + #[inline] + pub fn dtype(&self) -> &'a DType { + self.dtype + } + + /// Returns the primitive type of this scalar. + #[inline] + pub fn ptype(&self) -> PType { + self.ptype + } + + /// Returns the primitive value, or None if null. + #[inline] + pub fn pvalue(&self) -> Option { + self.pvalue + } + + // TODO(connor): This should probably be deprecated for `try_typed_value`. + /// Returns the value as a specific native primitive type. + /// + /// Returns `None` if the scalar is null, otherwise returns `Some(value)` where + /// value is the underlying primitive value cast to the requested type `T`. + /// + /// # Panics + /// + /// Panics if the primitive type of this scalar does not match the requested type. + pub fn typed_value(&self) -> Option { + assert_eq!( + self.ptype, + T::PTYPE, + "Attempting to read {} scalar as {}", + self.ptype, + T::PTYPE + ); + + // TODO(connor): This should really use `cast_opt`... + self.pvalue.map(|pv| pv.cast::()) + } + + /// Returns the value as a specific native primitive type. + /// + /// Returns `Ok(None)` if the scalar is null, otherwise returns `Ok(Some(value))` where + /// value is the underlying primitive value cast to the requested type `T`. + /// + /// # Errors + /// + /// Returns an error if the primitive type of this scalar does not match the requested type. + pub fn try_typed_value(&self) -> VortexResult> { + vortex_ensure!( + self.ptype == T::PTYPE, + "Attempting to read {} scalar as {}", + self.ptype, + T::PTYPE + ); + + // TODO(connor): This should really use `cast_opt`... + Ok(self.pvalue.map(|pv| pv.cast::())) + } + + /// Casts this scalar to the given `dtype`. + pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { + let ptype = PType::try_from(dtype)?; + let pvalue = self + .pvalue + .vortex_expect("nullness handled in Scalar::cast"); + Ok(match_each_native_ptype!(ptype, |Q| { + Scalar::primitive( + pvalue + .cast_opt::() + .ok_or_else(|| vortex_err!("Cannot cast {} to {}", self.ptype, dtype))?, + dtype.nullability(), + ) + })) + } + + /// Returns true if the scalar is nan. + pub fn is_nan(&self) -> bool { + self.pvalue.as_ref().is_some_and(|p| p.is_nan()) + } + + /// Returns whether this decimal value is zero, or `None` if null. + pub fn is_zero(&self) -> Option { + self.pvalue.map(|v| v.is_zero()) + } + + /// Attempts to extract the primitive value as the given type. + /// + /// # Errors + /// + /// Panics if the cast fails due to overflow or type incompatibility. See also + /// `as_opt` for the checked version that does not panic. + /// + /// # Examples + /// + /// ```should_panic + /// # use vortex_dtype::{DType, PType}; + /// # use vortex_scalar::Scalar; + /// let wide = Scalar::primitive(1000i32, false.into()); + /// + /// // This succeeds + /// let narrow = wide.as_primitive().as_::(); + /// assert_eq!(narrow, Some(1000i16)); + /// + /// // This also succeeds + /// let null = Scalar::null(DType::Primitive(PType::I16, true.into())); + /// assert_eq!(null.as_primitive().as_::(), None); + /// + /// // This will panic, because 1000 does not fit in i8 + /// wide.as_primitive().as_::(); + /// ``` + pub fn as_(&self) -> Option { + self.as_opt::().unwrap_or_else(|| { + vortex_panic!( + "cast {} to {}: value out of range", + self.ptype, + type_name::() + ) + }) + } + + /// Returns the inner value cast to the desired type. + /// + /// If the cast fails, `None` is returned. If the scalar represents a null, `Some(None)` + /// is returned. Otherwise, `Some(Some(T))` is returned for a successful non-null conversion. + /// + /// + /// # Examples + /// + /// ``` + /// # use vortex_dtype::{DType, PType}; + /// # use vortex_scalar::Scalar; + /// + /// // Non-null values + /// let scalar = Scalar::primitive(100i32, false.into()); + /// let primitive = scalar.as_primitive(); + /// assert_eq!(primitive.as_opt::(), Some(Some(100i8))); + /// + /// // Null value + /// let scalar = Scalar::null(DType::Primitive(PType::I32, true.into())); + /// let primitive = scalar.as_primitive(); + /// assert_eq!(primitive.as_opt::(), Some(None)); + /// + /// // Failed conversion: 1000 cannot fit in an i8 + /// let scalar = Scalar::primitive(1000i32, false.into()); + /// let primitive = scalar.as_primitive(); + /// assert_eq!(primitive.as_opt::(), None); + /// ``` + pub fn as_opt(&self) -> Option> { + if let Some(pv) = self.pvalue { + match pv { + PValue::U8(v) => T::from_u8(v), + PValue::U16(v) => T::from_u16(v), + PValue::U32(v) => T::from_u32(v), + PValue::U64(v) => T::from_u64(v), + PValue::I8(v) => T::from_i8(v), + PValue::I16(v) => T::from_i16(v), + PValue::I32(v) => T::from_i32(v), + PValue::I64(v) => T::from_i64(v), + PValue::F16(v) => T::from_f16(v), + PValue::F32(v) => T::from_f32(v), + PValue::F64(v) => T::from_f64(v), + } + .map(Some) + } else { + Some(None) + } + } +} + +impl Sub for PrimitiveScalar<'_> { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + self.checked_sub(&rhs) + .vortex_expect("PrimitiveScalar subtract: overflow or underflow") + } +} + +impl CheckedSub for PrimitiveScalar<'_> { + fn checked_sub(&self, rhs: &Self) -> Option { + self.checked_binary_numeric(rhs, NumericOperator::Sub) + } +} + +impl Add for PrimitiveScalar<'_> { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + self.checked_add(&rhs) + .vortex_expect("PrimitiveScalar add: overflow or underflow") + } +} + +impl CheckedAdd for PrimitiveScalar<'_> { + fn checked_add(&self, rhs: &Self) -> Option { + self.checked_binary_numeric(rhs, NumericOperator::Add) + } +} + +impl<'a> PrimitiveScalar<'a> { + /// Apply the (checked) operator to self and other using SQL-style null semantics. + /// + /// If the operation overflows, `None` is returned (for integral types only). + /// + /// Note: Floating-point operations cannot overflow in the traditional sense. + /// Instead, they may return `Some(Inf)` or `Some(NaN)` for operations that + /// would overflow or are undefined (e.g., `0.0 / 0.0`). + /// + /// If the types are incompatible (ignoring nullability), an error is returned. + /// + /// If either value is null, the result is null. + pub fn checked_binary_numeric( + &self, + other: &PrimitiveScalar<'a>, + op: NumericOperator, + ) -> Option> { + if !self.dtype().eq_ignore_nullability(other.dtype()) { + vortex_panic!("types must match: {} {}", self.dtype(), other.dtype()); + } + let result_dtype = if self.dtype().is_nullable() { + self.dtype() + } else { + other.dtype() + }; + let ptype = self.ptype(); + + match_each_native_ptype!( + self.ptype(), + integral: |P| { + self.checked_integral_numeric_operator::

(other, result_dtype, ptype, op) + }, + floating: |P| { + let lhs = self.typed_value::

(); + let rhs = other.typed_value::

(); + let value_or_null = match (lhs, rhs) { + (_, None) | (None, _) => None, + (Some(lhs), Some(rhs)) => match op { + NumericOperator::Add => Some(lhs + rhs), + NumericOperator::Sub => Some(lhs - rhs), + NumericOperator::RSub => Some(rhs - lhs), + NumericOperator::Mul => Some(lhs * rhs), + NumericOperator::Div => Some(lhs / rhs), + NumericOperator::RDiv => Some(rhs / lhs), + } + }; + Some(Self { dtype: result_dtype, ptype, pvalue: value_or_null.map(PValue::from) }) + } + ) + } + + /// Applies a checked arithmetic operation between two integral primitive scalars. + fn checked_integral_numeric_operator< + P: NativePType + + TryFrom + + CheckedSub + + CheckedAdd + + CheckedMul + + CheckedDiv, + >( + &self, + other: &PrimitiveScalar<'a>, + result_dtype: &'a DType, + ptype: PType, + op: NumericOperator, + ) -> Option> + where + PValue: From

, + { + let lhs = self.typed_value::

(); + let rhs = other.typed_value::

(); + let value_or_null_or_overflow = match (lhs, rhs) { + (_, None) | (None, _) => Some(None), + (Some(lhs), Some(rhs)) => match op { + NumericOperator::Add => lhs.checked_add(&rhs).map(Some), + NumericOperator::Sub => lhs.checked_sub(&rhs).map(Some), + NumericOperator::RSub => rhs.checked_sub(&lhs).map(Some), + NumericOperator::Mul => lhs.checked_mul(&rhs).map(Some), + NumericOperator::Div => lhs.checked_div(&rhs).map(Some), + NumericOperator::RDiv => rhs.checked_div(&lhs).map(Some), + }, + }; + + value_or_null_or_overflow.map(|value_or_null| Self { + dtype: result_dtype, + ptype, + pvalue: value_or_null.map(PValue::from), + }) + } +} diff --git a/vortex-scalar/src/typed_view/primitive/tests.rs b/vortex-scalar/src/typed_view/primitive/tests.rs new file mode 100644 index 00000000000..69b71587d71 --- /dev/null +++ b/vortex-scalar/src/typed_view/primitive/tests.rs @@ -0,0 +1,761 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::cmp::Ordering; + +use num_traits::CheckedSub; +use rstest::rstest; +use vortex_dtype::DType; +use vortex_dtype::FromPrimitiveOrF16; +use vortex_dtype::Nullability; +use vortex_dtype::PType; +use vortex_dtype::ToBytes; +use vortex_dtype::half::f16; +use vortex_error::VortexExpect; +use vortex_utils::aliases::hash_set::HashSet; + +use super::pvalue::CoercePValue; +use super::*; +use crate::PValue; +use crate::PrimitiveScalar; +use crate::ScalarValue; + +#[test] +fn test_integer_subtract() { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let value1 = ScalarValue::Primitive(PValue::I32(5)); + let value2 = ScalarValue::Primitive(PValue::I32(4)); + let p_scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); + let p_scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); + let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2); + let value_or_null_or_type_error = pscalar_or_overflow.unwrap().as_::(); + assert_eq!(value_or_null_or_type_error.unwrap(), 1); + + assert_eq!((p_scalar1 - p_scalar2).as_::().unwrap(), 1); +} + +#[test] +#[should_panic(expected = "PrimitiveScalar subtract: overflow or underflow")] +fn test_integer_subtract_overflow() { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let value1 = ScalarValue::Primitive(PValue::I32(i32::MIN)); + let value2 = ScalarValue::Primitive(PValue::I32(i32::MAX)); + let p_scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); + let p_scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); + let _ = p_scalar1 - p_scalar2; +} + +#[test] +fn test_float_subtract() { + let dtype = DType::Primitive(PType::F32, Nullability::NonNullable); + let value1 = ScalarValue::Primitive(PValue::F32(1.99f32)); + let value2 = ScalarValue::Primitive(PValue::F32(1.0f32)); + let p_scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); + let p_scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); + let pscalar_or_overflow = p_scalar1.checked_sub(&p_scalar2).unwrap(); + let value_or_null_or_type_error = pscalar_or_overflow.as_::(); + assert_eq!(value_or_null_or_type_error.unwrap(), 0.99f32); + + assert_eq!((p_scalar1 - p_scalar2).as_::().unwrap(), 0.99f32); +} + +#[test] +fn test_primitive_scalar_equality() { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let value1 = ScalarValue::Primitive(PValue::I32(42)); + let value2 = ScalarValue::Primitive(PValue::I32(42)); + let value3 = ScalarValue::Primitive(PValue::I32(43)); + let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); + let scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); + let scalar3 = PrimitiveScalar::try_new(&dtype, Some(&value3)).unwrap(); + + assert_eq!(scalar1, scalar2); + assert_ne!(scalar1, scalar3); +} + +#[test] +fn test_primitive_scalar_partial_ord() { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let value1 = ScalarValue::Primitive(PValue::I32(10)); + let value2 = ScalarValue::Primitive(PValue::I32(20)); + let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); + let scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); + + assert!(scalar1 < scalar2); + assert!(scalar2 > scalar1); + assert_eq!(scalar1.partial_cmp(&scalar1), Some(Ordering::Equal)); +} + +#[test] +fn test_primitive_scalar_null_handling() { + let dtype = DType::Primitive(PType::I32, Nullability::Nullable); + let null_scalar = PrimitiveScalar::try_new(&dtype, None).unwrap(); + + assert_eq!(null_scalar.pvalue(), None); + assert_eq!(null_scalar.typed_value::(), None); +} + +#[test] +fn test_typed_value_correct_type() { + let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); + let value = ScalarValue::Primitive(PValue::F64(3.5)); + let scalar = PrimitiveScalar::try_new(&dtype, Some(&value)).unwrap(); + + assert_eq!(scalar.typed_value::(), Some(3.5)); +} + +#[test] +#[should_panic(expected = "Attempting to read")] +fn test_typed_value_wrong_type() { + let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); + let value = ScalarValue::Primitive(PValue::F64(3.5)); + let scalar = PrimitiveScalar::try_new(&dtype, Some(&value)).unwrap(); + + let _ = scalar.typed_value::(); +} + +#[rstest] +#[case(PType::I8, 127i32, PType::I16, true)] +#[case(PType::I8, 127i32, PType::I32, true)] +#[case(PType::I8, 127i32, PType::I64, true)] +#[case(PType::U8, 255i32, PType::U16, true)] +#[case(PType::U8, 255i32, PType::U32, true)] +#[case(PType::I32, 42i32, PType::F32, true)] +#[case(PType::I32, 42i32, PType::F64, true)] +// Overflow cases +#[case(PType::I32, 300i32, PType::U8, false)] +#[case(PType::I32, -1i32, PType::U32, false)] +#[case(PType::I32, 256i32, PType::I8, false)] +#[case(PType::U16, 65535i32, PType::I8, false)] +fn test_primitive_cast( + #[case] source_type: PType, + #[case] source_value: i32, + #[case] target_type: PType, + #[case] should_succeed: bool, +) { + let source_pvalue = match source_type { + PType::I8 => PValue::I8(i8::try_from(source_value).vortex_expect("cannot cast")), + PType::U8 => PValue::U8(u8::try_from(source_value).vortex_expect("cannot cast")), + PType::U16 => PValue::U16(u16::try_from(source_value).vortex_expect("cannot cast")), + PType::I32 => PValue::I32(source_value), + _ => unreachable!("Test case uses unexpected source type"), + }; + + let dtype = DType::Primitive(source_type, Nullability::NonNullable); + let value = ScalarValue::Primitive(source_pvalue); + let scalar = PrimitiveScalar::try_new(&dtype, Some(&value)).unwrap(); + + let target_dtype = DType::Primitive(target_type, Nullability::NonNullable); + let result = scalar.cast(&target_dtype); + + if should_succeed { + assert!( + result.is_ok(), + "Cast from {:?} to {:?} should succeed", + source_type, + target_type + ); + } else { + assert!( + result.is_err(), + "Cast from {:?} to {:?} should fail due to overflow", + source_type, + target_type + ); + } +} + +#[test] +fn test_as_conversion_success() { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let value = ScalarValue::Primitive(PValue::I32(42)); + let scalar = PrimitiveScalar::try_new(&dtype, Some(&value)).unwrap(); + + assert_eq!(scalar.as_::(), Some(42i64)); + assert_eq!(scalar.as_::(), Some(42.0)); +} + +#[test] +fn test_as_conversion_overflow() { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let value = ScalarValue::Primitive(PValue::I32(-1)); + let scalar = PrimitiveScalar::try_new(&dtype, Some(&value)).unwrap(); + + // Converting -1 to u32 should fail + let result = scalar.as_opt::(); + assert!(result.is_none()); +} + +#[test] +fn test_as_conversion_null() { + let dtype = DType::Primitive(PType::I32, Nullability::Nullable); + let scalar = PrimitiveScalar::try_new(&dtype, None).unwrap(); + + assert_eq!(scalar.as_::(), None); + assert_eq!(scalar.as_::(), None); +} + +#[test] +fn test_numeric_operator_swap() { + assert_eq!(NumericOperator::Add.swap(), NumericOperator::Add); + assert_eq!(NumericOperator::Sub.swap(), NumericOperator::RSub); + assert_eq!(NumericOperator::RSub.swap(), NumericOperator::Sub); + assert_eq!(NumericOperator::Mul.swap(), NumericOperator::Mul); + assert_eq!(NumericOperator::Div.swap(), NumericOperator::RDiv); + assert_eq!(NumericOperator::RDiv.swap(), NumericOperator::Div); +} + +#[test] +fn test_checked_binary_numeric_add() { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let value1 = ScalarValue::Primitive(PValue::I32(10)); + let value2 = ScalarValue::Primitive(PValue::I32(20)); + let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); + let scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); + + let result = scalar1 + .checked_binary_numeric(&scalar2, NumericOperator::Add) + .unwrap(); + assert_eq!(result.typed_value::(), Some(30)); +} + +#[test] +fn test_checked_binary_numeric_overflow() { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let value1 = ScalarValue::Primitive(PValue::I32(i32::MAX)); + let value2 = ScalarValue::Primitive(PValue::I32(1)); + let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); + let scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); + + // Add should overflow and return None + let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::Add); + assert!(result.is_none()); +} + +#[test] +fn test_checked_binary_numeric_with_null() { + let dtype = DType::Primitive(PType::I32, Nullability::Nullable); + let value = ScalarValue::Primitive(PValue::I32(10)); + let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value)).unwrap(); + let null_scalar = PrimitiveScalar::try_new(&dtype, None).unwrap(); + + // Operation with null should return null + let result = scalar1 + .checked_binary_numeric(&null_scalar, NumericOperator::Add) + .unwrap(); + assert_eq!(result.pvalue(), None); +} + +#[test] +fn test_checked_binary_numeric_mul() { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let value1 = ScalarValue::Primitive(PValue::I32(5)); + let value2 = ScalarValue::Primitive(PValue::I32(6)); + let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); + let scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); + + let result = scalar1 + .checked_binary_numeric(&scalar2, NumericOperator::Mul) + .unwrap(); + assert_eq!(result.typed_value::(), Some(30)); +} + +#[test] +fn test_checked_binary_numeric_div() { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let value1 = ScalarValue::Primitive(PValue::I32(20)); + let value2 = ScalarValue::Primitive(PValue::I32(4)); + let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); + let scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); + + let result = scalar1 + .checked_binary_numeric(&scalar2, NumericOperator::Div) + .unwrap(); + assert_eq!(result.typed_value::(), Some(5)); +} + +#[test] +fn test_checked_binary_numeric_rdiv() { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let value1 = ScalarValue::Primitive(PValue::I32(4)); + let value2 = ScalarValue::Primitive(PValue::I32(20)); + let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); + let scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); + + // RDiv means right / left, so 20 / 4 = 5 + let result = scalar1 + .checked_binary_numeric(&scalar2, NumericOperator::RDiv) + .unwrap(); + assert_eq!(result.typed_value::(), Some(5)); +} + +#[test] +fn test_checked_binary_numeric_div_by_zero() { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let value1 = ScalarValue::Primitive(PValue::I32(10)); + let value2 = ScalarValue::Primitive(PValue::I32(0)); + let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); + let scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); + + // Division by zero should return None for integers + let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::Div); + assert!(result.is_none()); +} + +#[test] +fn test_checked_binary_numeric_float_ops() { + let dtype = DType::Primitive(PType::F32, Nullability::NonNullable); + let value1 = ScalarValue::Primitive(PValue::F32(10.0)); + let value2 = ScalarValue::Primitive(PValue::F32(2.5)); + let scalar1 = PrimitiveScalar::try_new(&dtype, Some(&value1)).unwrap(); + let scalar2 = PrimitiveScalar::try_new(&dtype, Some(&value2)).unwrap(); + + // Test all operations with floats + let add_result = scalar1 + .checked_binary_numeric(&scalar2, NumericOperator::Add) + .unwrap(); + assert_eq!(add_result.typed_value::(), Some(12.5)); + + let sub_result = scalar1 + .checked_binary_numeric(&scalar2, NumericOperator::Sub) + .unwrap(); + assert_eq!(sub_result.typed_value::(), Some(7.5)); + + let mul_result = scalar1 + .checked_binary_numeric(&scalar2, NumericOperator::Mul) + .unwrap(); + assert_eq!(mul_result.typed_value::(), Some(25.0)); + + let div_result = scalar1 + .checked_binary_numeric(&scalar2, NumericOperator::Div) + .unwrap(); + assert_eq!(div_result.typed_value::(), Some(4.0)); +} + +#[test] +fn test_from_primitive_or_f16() { + // Test f16 to f32 conversion + let f16_val = f16::from_f32(3.5); + assert!(f32::from_f16(f16_val).is_some()); + + // Test f16 to f64 conversion + assert!(f64::from_f16(f16_val).is_some()); + + // Test PValue::F16(f16) to integer conversion (should fail) + assert!(i32::try_from(PValue::from(f16_val)).is_err()); + assert!(u32::try_from(PValue::from(f16_val)).is_err()); +} + +#[test] +fn test_partial_ord_different_types() { + let dtype1 = DType::Primitive(PType::I32, Nullability::NonNullable); + let dtype2 = DType::Primitive(PType::F32, Nullability::NonNullable); + + let value1 = ScalarValue::Primitive(PValue::I32(10)); + let value2 = ScalarValue::Primitive(PValue::F32(10.0)); + let scalar1 = PrimitiveScalar::try_new(&dtype1, Some(&value1)).unwrap(); + let scalar2 = PrimitiveScalar::try_new(&dtype2, Some(&value2)).unwrap(); + + // Different types should not be comparable + assert_eq!(scalar1.partial_cmp(&scalar2), None); +} + +#[test] +fn test_scalar_value_from_usize() { + let value: ScalarValue = 42usize.into(); + assert!(matches!(value, ScalarValue::Primitive(PValue::U64(42)))); +} + +#[test] +fn test_getters() { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let value = ScalarValue::Primitive(PValue::I32(42)); + let scalar = PrimitiveScalar::try_new(&dtype, Some(&value)).unwrap(); + + assert_eq!(scalar.dtype(), &dtype); + assert_eq!(scalar.ptype(), PType::I32); + assert_eq!(scalar.pvalue(), Some(PValue::I32(42))); +} + +#[test] +pub fn test_is_instance_of() { + assert!(PValue::U8(10).is_instance_of(&PType::U8)); + assert!(!PValue::U8(10).is_instance_of(&PType::U16)); + assert!(!PValue::U8(10).is_instance_of(&PType::I8)); + assert!(!PValue::U8(10).is_instance_of(&PType::F16)); + + assert!(PValue::I8(10).is_instance_of(&PType::I8)); + assert!(!PValue::I8(10).is_instance_of(&PType::I16)); + assert!(!PValue::I8(10).is_instance_of(&PType::U8)); + assert!(!PValue::I8(10).is_instance_of(&PType::F16)); + + assert!(PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::F16)); + assert!(!PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::F32)); + assert!(!PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::U16)); + assert!(!PValue::F16(f16::from_f32(10.0)).is_instance_of(&PType::I16)); +} + +#[test] +fn test_compare_different_types() { + assert_eq!( + PValue::I8(4).partial_cmp(&PValue::I8(5)), + Some(Ordering::Less) + ); + assert_eq!( + PValue::I8(4).partial_cmp(&PValue::I64(5)), + Some(Ordering::Less) + ); +} + +#[test] +fn test_hash() { + let set = HashSet::from([ + PValue::U8(1), + PValue::U16(1), + PValue::U32(1), + PValue::U64(1), + PValue::I8(1), + PValue::I16(1), + PValue::I32(1), + PValue::I64(1), + PValue::I8(-1), + PValue::I16(-1), + PValue::I32(-1), + PValue::I64(-1), + ]); + assert_eq!(set.len(), 2); +} + +#[test] +fn test_zero_values() { + assert_eq!(PValue::zero(&PType::U8), PValue::U8(0)); + assert_eq!(PValue::zero(&PType::U16), PValue::U16(0)); + assert_eq!(PValue::zero(&PType::U32), PValue::U32(0)); + assert_eq!(PValue::zero(&PType::U64), PValue::U64(0)); + assert_eq!(PValue::zero(&PType::I8), PValue::I8(0)); + assert_eq!(PValue::zero(&PType::I16), PValue::I16(0)); + assert_eq!(PValue::zero(&PType::I32), PValue::I32(0)); + assert_eq!(PValue::zero(&PType::I64), PValue::I64(0)); + assert_eq!(PValue::zero(&PType::F16), PValue::F16(f16::from_f32(0.0))); + assert_eq!(PValue::zero(&PType::F32), PValue::F32(0.0)); + assert_eq!(PValue::zero(&PType::F64), PValue::F64(0.0)); +} + +#[test] +fn test_ptype() { + assert_eq!(PValue::U8(10).ptype(), PType::U8); + assert_eq!(PValue::U16(10).ptype(), PType::U16); + assert_eq!(PValue::U32(10).ptype(), PType::U32); + assert_eq!(PValue::U64(10).ptype(), PType::U64); + assert_eq!(PValue::I8(10).ptype(), PType::I8); + assert_eq!(PValue::I16(10).ptype(), PType::I16); + assert_eq!(PValue::I32(10).ptype(), PType::I32); + assert_eq!(PValue::I64(10).ptype(), PType::I64); + assert_eq!(PValue::F16(f16::from_f32(10.0)).ptype(), PType::F16); + assert_eq!(PValue::F32(10.0).ptype(), PType::F32); + assert_eq!(PValue::F64(10.0).ptype(), PType::F64); +} + +#[test] +fn test_reinterpret_cast_same_type() { + let value = PValue::U32(42); + assert_eq!(value.reinterpret_cast(PType::U32), value); +} + +#[test] +fn test_reinterpret_cast_u8_i8() { + let value = PValue::U8(255); + let casted = value.reinterpret_cast(PType::I8); + assert_eq!(casted, PValue::I8(-1)); +} + +#[test] +fn test_reinterpret_cast_u16_types() { + let value = PValue::U16(12345); + + // U16 -> I16 + let as_i16 = value.reinterpret_cast(PType::I16); + assert_eq!(as_i16, PValue::I16(12345)); + + // U16 -> F16 + let as_f16 = value.reinterpret_cast(PType::F16); + assert_eq!(as_f16, PValue::F16(f16::from_bits(12345))); +} + +#[test] +fn test_reinterpret_cast_u32_types() { + let value = PValue::U32(0x3f800000); // 1.0 in float bits + + // U32 -> F32 + let as_f32 = value.reinterpret_cast(PType::F32); + assert_eq!(as_f32, PValue::F32(1.0)); + + // U32 -> I32 + let value2 = PValue::U32(0x80000000); + let as_i32 = value2.reinterpret_cast(PType::I32); + assert_eq!(as_i32, PValue::I32(i32::MIN)); +} + +#[test] +fn test_reinterpret_cast_f32_to_u32() { + let value = PValue::F32(1.0); + let as_u32 = value.reinterpret_cast(PType::U32); + assert_eq!(as_u32, PValue::U32(0x3f800000)); +} + +#[test] +fn test_reinterpret_cast_f64_to_i64() { + let value = PValue::F64(1.0); + let as_i64 = value.reinterpret_cast(PType::I64); + assert_eq!(as_i64, PValue::I64(0x3ff0000000000000_i64)); +} + +#[test] +#[should_panic(expected = "Cannot reinterpret cast between types of different widths")] +fn test_reinterpret_cast_different_widths() { + let value = PValue::U8(42); + let _ = value.reinterpret_cast(PType::U16); +} + +#[test] +fn test_as_primitive_conversions() { + // Test as_u8 + assert_eq!(PValue::U8(42).as_u8(), Some(42)); + assert_eq!(PValue::I8(42).as_u8(), Some(42)); + assert_eq!(PValue::U16(255).as_u8(), Some(255)); + assert_eq!(PValue::U16(256).as_u8(), None); // Overflow + + // Test as_i32 + assert_eq!(PValue::I32(42).as_i32(), Some(42)); + assert_eq!(PValue::U32(42).as_i32(), Some(42)); + assert_eq!(PValue::I64(42).as_i32(), Some(42)); + assert_eq!(PValue::U64(u64::MAX).as_i32(), None); // Overflow + + // Test as_f64 + assert_eq!(PValue::F64(42.5).as_f64(), Some(42.5)); + assert_eq!(PValue::F32(42.5).as_f64(), Some(42.5f64)); + assert_eq!(PValue::I32(42).as_f64(), Some(42.0)); +} + +#[test] +fn test_try_from_pvalue_integers() { + // Test u8 conversion + assert_eq!(u8::try_from(PValue::U8(42)).unwrap(), 42); + assert_eq!(u8::try_from(PValue::I8(42)).unwrap(), 42); + assert!(u8::try_from(PValue::I8(-1)).is_err()); + assert!(u8::try_from(PValue::U16(256)).is_err()); + + // Test i32 conversion + assert_eq!(i32::try_from(PValue::I32(42)).unwrap(), 42); + assert_eq!(i32::try_from(PValue::I16(-100)).unwrap(), -100); + assert!(i32::try_from(PValue::U64(u64::MAX)).is_err()); + + // Float to int should fail + assert!(i32::try_from(PValue::F32(42.5)).is_err()); +} + +#[test] +fn test_try_from_pvalue_floats() { + // Test f32 conversion + assert_eq!(f32::try_from(PValue::F32(42.5)).unwrap(), 42.5); + assert_eq!(f32::try_from(PValue::I32(42)).unwrap(), 42.0); + assert_eq!(f32::try_from(PValue::U8(255)).unwrap(), 255.0); + + // Test f64 conversion + assert_eq!(f64::try_from(PValue::F64(42.5)).unwrap(), 42.5); + assert_eq!(f64::try_from(PValue::F32(42.5)).unwrap(), 42.5f64); + assert_eq!(f64::try_from(PValue::I64(-100)).unwrap(), -100.0); +} + +#[test] +fn test_from_usize() { + let value: PValue = 42usize.into(); + assert_eq!(value, PValue::U64(42)); + + let max_value: PValue = usize::MAX.into(); + assert_eq!(max_value, PValue::U64(usize::MAX as u64)); +} + +#[test] +fn test_equality_cross_types() { + // Same numeric value, different types + assert_eq!(PValue::U8(42), PValue::U16(42)); + assert_eq!(PValue::U8(42), PValue::U32(42)); + assert_eq!(PValue::U8(42), PValue::U64(42)); + assert_eq!(PValue::I8(42), PValue::I16(42)); + assert_eq!(PValue::I8(42), PValue::I32(42)); + assert_eq!(PValue::I8(42), PValue::I64(42)); + + // Unsigned vs signed with same value (they compare equal even though different categories) + assert_eq!(PValue::U8(42), PValue::I8(42)); + assert_eq!(PValue::U32(42), PValue::I32(42)); + + // Float equality + assert_eq!(PValue::F32(42.0), PValue::F32(42.0)); + assert_eq!(PValue::F64(42.0), PValue::F64(42.0)); + assert_ne!(PValue::F32(42.0), PValue::F64(42.0)); // Different types + + // Float vs int should not be equal + assert_ne!(PValue::F32(42.0), PValue::I32(42)); +} + +#[test] +fn test_partial_ord_cross_types() { + // Unsigned comparisons + assert_eq!( + PValue::U8(10).partial_cmp(&PValue::U16(20)), + Some(Ordering::Less) + ); + assert_eq!( + PValue::U32(30).partial_cmp(&PValue::U8(20)), + Some(Ordering::Greater) + ); + + // Signed comparisons + assert_eq!( + PValue::I8(-10).partial_cmp(&PValue::I64(0)), + Some(Ordering::Less) + ); + assert_eq!( + PValue::I32(10).partial_cmp(&PValue::I16(10)), + Some(Ordering::Equal) + ); + + // Float comparisons (same type only) + assert_eq!( + PValue::F32(1.0).partial_cmp(&PValue::F32(2.0)), + Some(Ordering::Less) + ); + assert_eq!( + PValue::F64(2.0).partial_cmp(&PValue::F64(1.0)), + Some(Ordering::Greater) + ); + + // Cross-category comparisons - unsigned vs signed work, float vs int don't + assert_eq!( + PValue::U32(42).partial_cmp(&PValue::I32(42)), + Some(Ordering::Equal) + ); // Actually works + assert_eq!(PValue::F32(42.0).partial_cmp(&PValue::I32(42)), None); + assert_eq!(PValue::F32(42.0).partial_cmp(&PValue::F64(42.0)), None); +} + +#[test] +fn test_to_le_bytes() { + assert_eq!(PValue::U8(0x12).to_le_bytes(), &[0x12]); + assert_eq!(PValue::U16(0x1234).to_le_bytes(), &[0x34, 0x12]); + assert_eq!( + PValue::U32(0x12345678).to_le_bytes(), + &[0x78, 0x56, 0x34, 0x12] + ); + + assert_eq!(PValue::I8(-1).to_le_bytes(), &[0xFF]); + assert_eq!(PValue::I16(-1).to_le_bytes(), &[0xFF, 0xFF]); + + let f32_bytes = PValue::F32(1.0).to_le_bytes(); + assert_eq!(f32_bytes.len(), 4); + + let f64_bytes = PValue::F64(1.0).to_le_bytes(); + assert_eq!(f64_bytes.len(), 8); +} + +#[test] +fn test_f16_special_values() { + // Test F16 NaN handling + let nan = f16::NAN; + let nan_value = PValue::F16(nan); + assert!(nan_value.as_f16().unwrap().is_nan()); + + // Test F16 infinity + let inf = f16::INFINITY; + let inf_value = PValue::F16(inf); + assert!(inf_value.as_f16().unwrap().is_infinite()); + + // Test F16 comparison with NaN + assert_eq!( + PValue::F16(nan).partial_cmp(&PValue::F16(nan)), + Some(Ordering::Equal) + ); +} + +#[test] +fn test_coerce_pvalue() { + // Test integer coercion + assert_eq!(u32::coerce(PValue::U16(42)).unwrap(), 42u32); + assert_eq!(i64::coerce(PValue::I32(-42)).unwrap(), -42i64); + + // Test float coercion from bits + assert_eq!(f32::coerce(PValue::U32(0x3f800000)).unwrap(), 1.0f32); + assert_eq!( + f64::coerce(PValue::U64(0x3ff0000000000000)).unwrap(), + 1.0f64 + ); +} + +#[test] +fn test_coerce_f16_beyond_u16_max() { + // Test U32 to f16 coercion within valid range + assert!(f16::coerce(PValue::U32(u16::MAX as u32)).is_ok()); + assert_eq!( + f16::coerce(PValue::U32(0x3C00)).unwrap(), + f16::from_bits(0x3C00) // 1.0 in f16 + ); + + // Test U32 to f16 coercion beyond u16::MAX - should fail + assert!(f16::coerce(PValue::U32((u16::MAX as u32) + 1)).is_err()); + assert!(f16::coerce(PValue::U32(u32::MAX)).is_err()); + + // Test U64 to f16 coercion within valid range + assert!(f16::coerce(PValue::U64(u16::MAX as u64)).is_ok()); + assert_eq!( + f16::coerce(PValue::U64(0x3C00)).unwrap(), + f16::from_bits(0x3C00) // 1.0 in f16 + ); + + // Test U64 to f16 coercion beyond u16::MAX - should fail + assert!(f16::coerce(PValue::U64((u16::MAX as u64) + 1)).is_err()); + assert!(f16::coerce(PValue::U64(u32::MAX as u64)).is_err()); + assert!(f16::coerce(PValue::U64(u64::MAX)).is_err()); +} + +#[test] +fn test_coerce_f32_beyond_u32_max() { + // Test U64 to f32 coercion within valid range + assert!(f32::coerce(PValue::U64(u32::MAX as u64)).is_ok()); + assert_eq!( + f32::coerce(PValue::U64(0x3f800000)).unwrap(), + 1.0f32 // 0x3f800000 is 1.0 in f32 + ); + + // Test U64 to f32 coercion beyond u32::MAX - should fail + assert!(f32::coerce(PValue::U64((u32::MAX as u64) + 1)).is_err()); + assert!(f32::coerce(PValue::U64(u64::MAX)).is_err()); + + // Test smaller types still work + assert!(f32::coerce(PValue::U8(255)).is_ok()); + assert!(f32::coerce(PValue::U16(u16::MAX)).is_ok()); + assert!(f32::coerce(PValue::U32(u32::MAX)).is_ok()); +} + +#[test] +fn test_coerce_f64_all_unsigned() { + // Test f64 can accept all unsigned integer values as bit patterns + assert!(f64::coerce(PValue::U8(u8::MAX)).is_ok()); + assert!(f64::coerce(PValue::U16(u16::MAX)).is_ok()); + assert!(f64::coerce(PValue::U32(u32::MAX)).is_ok()); + assert!(f64::coerce(PValue::U64(u64::MAX)).is_ok()); + + // Verify specific bit patterns + assert_eq!( + f64::coerce(PValue::U64(0x3ff0000000000000)).unwrap(), + 1.0f64 // 0x3ff0000000000000 is 1.0 in f64 + ); +} + +#[test] +fn test_f16_nans_equal() { + let nan1 = f16::from_le_bytes([154, 253]); + assert!(nan1.is_nan()); + let nan3 = f16::from_f16(nan1).unwrap(); + assert_eq!(nan1.to_bits(), nan3.to_bits(),); +} diff --git a/vortex-scalar/src/struct_.rs b/vortex-scalar/src/typed_view/struct_.rs similarity index 84% rename from vortex-scalar/src/struct_.rs rename to vortex-scalar/src/typed_view/struct_.rs index d8ef1af7dd9..254dd1b4223 100644 --- a/vortex-scalar/src/struct_.rs +++ b/vortex-scalar/src/typed_view/struct_.rs @@ -1,27 +1,25 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +//! [`StructScalar`] typed view implementation. + use std::cmp::Ordering; use std::fmt::Display; use std::fmt::Formatter; use std::hash::Hash; use std::hash::Hasher; -use std::ops::Deref; -use std::sync::Arc; use itertools::Itertools; use vortex_dtype::DType; use vortex_dtype::FieldName; use vortex_dtype::FieldNames; use vortex_dtype::StructFields; -use vortex_error::VortexError; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_err; use vortex_error::vortex_panic; -use crate::InnerScalarValue; use crate::Scalar; use crate::ScalarValue; @@ -31,8 +29,10 @@ use crate::ScalarValue; /// named fields with different types, or be null. #[derive(Debug, Clone)] pub struct StructScalar<'a> { + /// The data type of this scalar. dtype: &'a DType, - fields: Option<&'a Arc<[ScalarValue]>>, + /// The field values, or [`None`] if the entire struct is null. + fields: Option<&'a [Option]>, } impl Display for StructScalar<'_> { @@ -47,7 +47,8 @@ impl Display for StructScalar<'_> { .zip_eq(self.struct_fields().fields()) .zip_eq(fields.iter()) .map(|((name, dtype), value)| { - let val = Scalar::new(dtype, value.clone()); + let val = Scalar::try_new(dtype, value.clone()) + .vortex_expect("unable to construct a struct `Scalar`"); format!("{name}: {val}") }) .format(", "); @@ -64,7 +65,7 @@ impl PartialEq for StructScalar<'_> { return false; } - match (self.fields(), other.fields()) { + match (self.fields_iter(), other.fields_iter()) { (Some(lhs), Some(rhs)) => lhs.zip(rhs).all(|(l_s, r_s)| l_s == r_s), (None, None) => true, (Some(_), None) | (None, Some(_)) => false, @@ -81,7 +82,7 @@ impl PartialOrd for StructScalar<'_> { return None; } - match (self.fields(), other.fields()) { + match (self.fields_iter(), other.fields_iter()) { (Some(lhs), Some(rhs)) => { for (l_s, r_s) in lhs.zip(rhs) { match l_s.partial_cmp(&r_s)? { @@ -103,7 +104,7 @@ impl PartialOrd for StructScalar<'_> { impl Hash for StructScalar<'_> { fn hash(&self, state: &mut H) { self.dtype.hash(state); - if let Some(fields) = self.fields() { + if let Some(fields) = self.fields_iter() { for f in fields { f.hash(state); } @@ -112,15 +113,16 @@ impl Hash for StructScalar<'_> { } impl<'a> StructScalar<'a> { + /// Creates a new [`StructScalar`] from a [`DType`] and optional [`ScalarValue`]. #[inline] - pub(crate) fn try_new(dtype: &'a DType, value: &'a ScalarValue) -> VortexResult { + pub(crate) fn try_new(dtype: &'a DType, value: Option<&'a ScalarValue>) -> VortexResult { if !matches!(dtype, DType::Struct(..)) { vortex_bail!("Expected struct scalar, found {}", dtype) } Ok(Self { dtype, - fields: value.as_list()?, + fields: value.map(|value| value.as_list()), }) } @@ -156,6 +158,8 @@ impl<'a> StructScalar<'a> { self.field_by_idx(idx) } + // TODO(connor): This should have the opposite behavior: It should panic if the field index is + // out of bounds, and it should return None if it is null. /// Returns the field at the given index as a scalar. /// /// Returns None if the index is out of bounds. @@ -167,27 +171,32 @@ impl<'a> StructScalar<'a> { let fields = self .fields .vortex_expect("Can't take field out of null struct scalar"); - Some(Scalar::new( - self.struct_fields().field_by_index(idx)?, - fields[idx].clone(), - )) + Some( + // SAFETY: We assume that the struct `DType` correctly describes the struct values. + unsafe { + Scalar::new_unchecked( + self.struct_fields().field_by_index(idx)?, + fields[idx].clone(), + ) + }, + ) } /// Returns the fields of the struct scalar, or None if the scalar is null. - pub fn fields(&self) -> Option> { + pub fn fields_iter(&self) -> Option> { let fields = self.fields?; Some( fields .iter() .zip(self.struct_fields().fields()) - .map(|(v, dtype)| Scalar::new(dtype, v.clone())), + .map(|(v, dtype)| { + // SAFETY: We assume that the struct `DType` correctly describes the struct + // values. + unsafe { Scalar::new_unchecked(dtype, v.clone()) } + }), ) } - pub(crate) fn field_values(&self) -> Option<&[ScalarValue]> { - self.fields.map(Arc::deref) - } - /// Casts this struct scalar to another struct type. /// /// # Errors @@ -210,17 +219,17 @@ impl<'a> StructScalar<'a> { ); } - if let Some(fs) = self.field_values() { + if let Some(fs) = self.fields { let fields = fs .iter() .enumerate() .map(|(i, f)| { - Scalar::new( + Scalar::try_new( own_st .field_by_index(i) .vortex_expect("Iterating over scalar fields"), f.clone(), - ) + )? .cast( &st.field_by_index(i) .vortex_expect("Iterating over scalar fields"), @@ -228,10 +237,7 @@ impl<'a> StructScalar<'a> { .map(|s| s.into_value()) }) .collect::>>()?; - Ok(Scalar::new( - dtype.clone(), - ScalarValue(InnerScalarValue::List(fields.into())), - )) + Scalar::try_new(dtype.clone(), Some(ScalarValue::List(fields))) } else { Ok(Scalar::null(dtype.clone())) } @@ -247,26 +253,28 @@ impl<'a> StructScalar<'a> { .dtype .as_struct_fields_opt() .ok_or_else(|| vortex_err!("Not a struct dtype"))?; - let projected_dtype = struct_dtype.project(projection)?; - let new_fields = if let Some(fs) = self.field_values() { - ScalarValue(InnerScalarValue::List( - projection - .iter() - .map(|name| { - struct_dtype - .find(name) - .vortex_expect("DType has been successfully projected already") - }) - .map(|i| fs[i].clone()) - .collect(), - )) - } else { - ScalarValue(InnerScalarValue::Null) + let projected_dtype = DType::Struct( + struct_dtype.project(projection)?, + self.dtype().nullability(), + ); + + let Some(fs) = self.fields else { + return Ok(Scalar::null(projected_dtype)); }; - Ok(Scalar::new( - DType::Struct(projected_dtype, self.dtype().nullability()), - new_fields, - )) + + let new_fields = ScalarValue::List( + projection + .iter() + .map(|name| { + struct_dtype + .find(name) + .vortex_expect("DType has been successfully projected already") + }) + .map(|i| fs[i].clone()) + .collect(), + ); + + Scalar::try_new(projected_dtype, Some(new_fields)) } } @@ -300,18 +308,8 @@ impl Scalar { let mut value_children = Vec::with_capacity(children.len()); value_children.extend(children.into_iter().map(|x| x.into_value())); - Self::new( - dtype, - ScalarValue(InnerScalarValue::List(value_children.into())), - ) - } -} - -impl<'a> TryFrom<&'a Scalar> for StructScalar<'a> { - type Error = VortexError; - - fn try_from(value: &'a Scalar) -> Result { - Self::try_new(value.dtype(), value.value()) + Self::try_new(dtype, Some(ScalarValue::List(value_children))) + .vortex_expect("unable to construct a struct `Scalar`") } } @@ -323,6 +321,7 @@ mod tests { use vortex_dtype::StructFields; use super::*; + use crate::PValue; fn setup_types() -> (DType, DType, DType) { let f0_dt = DType::Primitive(I32, Nullability::NonNullable); @@ -361,13 +360,13 @@ mod tests { let scalar_f0 = scalar.as_struct().field_by_idx(0); assert!(scalar_f0.is_some()); let scalar_f0 = scalar_f0.unwrap(); - assert_eq!(scalar_f0, f0_val_null); + assert_eq!(scalar_f0.value(), f0_val_null.value()); assert_eq!(scalar_f0.dtype(), &f0_dt); let scalar_f1 = scalar.as_struct().field_by_idx(1); assert!(scalar_f1.is_some()); let scalar_f1 = scalar_f1.unwrap(); - assert_eq!(scalar_f1, f1_val_null); + assert_eq!(scalar_f1.value(), f1_val_null.value()); assert_eq!(scalar_f1.dtype(), &f1_dt); } @@ -421,7 +420,10 @@ mod tests { let field_b = scalar.as_struct().field("b"); assert!(field_b.is_some()); - assert_eq!(field_b.unwrap().as_utf8().value().unwrap(), "world".into()); + assert_eq!( + field_b.unwrap().as_utf8().value().cloned().unwrap(), + "world".into() + ); // Non-existent field let field_c = scalar.as_struct().field("c"); @@ -436,10 +438,14 @@ mod tests { let scalar = Scalar::struct_(dtype, vec![f0_val, f1_val]); - let fields = scalar.as_struct().fields().unwrap().collect::>(); + let fields = scalar + .as_struct() + .fields_iter() + .unwrap() + .collect::>(); assert_eq!(fields.len(), 2); assert_eq!(fields[0].as_primitive().typed_value::().unwrap(), 100); - assert_eq!(fields[1].as_utf8().value().unwrap(), "test".into()); + assert_eq!(fields[1].as_utf8().value().cloned().unwrap(), "test".into()); } #[test] @@ -448,8 +454,8 @@ mod tests { let null_scalar = Scalar::null(dtype); assert!(null_scalar.as_struct().is_null()); - assert!(null_scalar.as_struct().fields().is_none()); - assert!(null_scalar.as_struct().field_values().is_none()); + assert!(null_scalar.as_struct().fields_iter().is_none()); + assert!(null_scalar.as_struct().fields.is_none()); } #[test] @@ -482,7 +488,11 @@ mod tests { let result = source_scalar.as_struct().cast(&target_dtype).unwrap(); assert_eq!(result.dtype(), &target_dtype); - let fields = result.as_struct().fields().unwrap().collect::>(); + let fields = result + .as_struct() + .fields_iter() + .unwrap() + .collect::>(); assert_eq!(fields[0].as_primitive().typed_value::().unwrap(), 42); assert_eq!(fields[1].as_primitive().typed_value::().unwrap(), 123); } @@ -545,7 +555,7 @@ mod tests { assert_eq!(projected_struct.names().len(), 1); assert_eq!(projected_struct.names()[0].as_ref(), "b"); - let fields = projected_struct.fields().unwrap().collect::>(); + let fields = projected_struct.fields_iter().unwrap().collect::>(); assert_eq!(fields.len(), 1); assert_eq!(fields[0].as_utf8().value().unwrap().as_str(), "hello"); } @@ -668,9 +678,9 @@ mod tests { #[test] fn test_struct_try_new_non_struct_dtype() { let dtype = DType::Primitive(I32, Nullability::NonNullable); - let value = ScalarValue(InnerScalarValue::Primitive(crate::PValue::I32(42))); + let value = ScalarValue::Primitive(PValue::I32(42)); - let result = StructScalar::try_new(&dtype, &value); + let result = StructScalar::try_new(&dtype, Some(&value)); assert!(result.is_err()); } diff --git a/vortex-scalar/src/utf8.rs b/vortex-scalar/src/typed_view/utf8.rs similarity index 55% rename from vortex-scalar/src/utf8.rs rename to vortex-scalar/src/typed_view/utf8.rs index c274ac6b735..165b2a45bd6 100644 --- a/vortex-scalar/src/utf8.rs +++ b/vortex-scalar/src/typed_view/utf8.rs @@ -1,101 +1,33 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +//! [`Utf8Scalar`] typed view implementation. + use std::cmp; use std::fmt; use std::fmt::Display; use std::fmt::Formatter; -use std::sync::Arc; use vortex_buffer::BufferString; use vortex_dtype::DType; -use vortex_dtype::Nullability; -use vortex_dtype::Nullability::NonNullable; -use vortex_error::VortexError; -use vortex_error::VortexExpect as _; +use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; -use vortex_error::vortex_err; use vortex_utils::aliases::StringEscape; -use crate::InnerScalarValue; use crate::Scalar; use crate::ScalarValue; -/// Types that can hold a valid UTF-8 string. -pub trait StringLike: private::Sealed + Sized { - /// Replace the last codepoint in the string with the next codepoint. - /// - /// This operation will attempt to reuse the original memory. - /// - /// If incrementing the last char fails, or if the string is empty, - /// we return an Err with the original unmodified string. - fn increment(self) -> Result; -} - -mod private { - use vortex_buffer::BufferString; - - use crate::StringLike; - - pub trait Sealed {} - - impl Sealed for String {} - - impl StringLike for String { - fn increment(mut self) -> Result { - let Some(last_char) = self.pop() else { - return Ok(self); - }; - - if let Some(next_char) = char::from_u32(last_char as u32 + 1) { - self.push(next_char); - Ok(self) - } else { - // Return the original string - self.push(last_char); - Err(self) - } - } - } - - impl Sealed for BufferString {} - - impl StringLike for BufferString { - #[allow(clippy::unwrap_in_result, clippy::expect_used)] - fn increment(self) -> Result { - if self.is_empty() { - return Err(self); - } - - // Chop off the last char and return it here. - let (last_idx, last_char) = self.char_indices().last().expect("non-empty"); - if let Some(next_char) = char::from_u32(last_char as u32 + 1) - && next_char.len_utf8() == last_char.len_utf8() - { - // Because the next char has the same byte width as the last char, we can overwrite - // the memory directly. - let mut bytes = self.into_inner().into_mut(); - next_char.encode_utf8(&mut bytes.as_mut()[last_idx..]); - - // SAFETY: we overwrite the last valid char with new valid char, so - // the buffer continues to hold valid UTF-8 data. - unsafe { Ok(BufferString::new_unchecked(bytes.freeze())) } - } else { - Err(self) - } - } - } -} - /// A scalar value representing a UTF-8 encoded string. /// /// This type provides a view into a UTF-8 string scalar value, which can be either /// a valid UTF-8 string or null. #[derive(Debug, Clone, Hash, Eq)] pub struct Utf8Scalar<'a> { + /// The data type of this scalar. dtype: &'a DType, - value: Option>, + /// The string value, or [`None`] if null. + value: Option<&'a BufferString>, } impl Display for Utf8Scalar<'_> { @@ -131,13 +63,14 @@ impl<'a> Utf8Scalar<'a> { /// # Errors /// /// Returns an error if the data type is not a UTF-8 type. - pub fn from_scalar_value(dtype: &'a DType, value: ScalarValue) -> VortexResult { + pub fn try_new(dtype: &'a DType, value: Option<&'a ScalarValue>) -> VortexResult { if !matches!(dtype, DType::Utf8(..)) { vortex_bail!("Can only construct utf8 scalar from utf8 dtype, found {dtype}") } + Ok(Self { dtype, - value: value.as_buffer_string()?, + value: value.map(|value| value.as_utf8()), }) } @@ -147,90 +80,64 @@ impl<'a> Utf8Scalar<'a> { self.dtype } - /// Returns the string value, or None if null. - pub fn value(&self) -> Option { - self.value.as_ref().map(|v| v.as_ref().clone()) - } - /// Returns a reference to the string value, or None if null. /// This avoids cloning the underlying BufferString. - pub fn value_ref(&self) -> Option<&BufferString> { - self.value.as_ref().map(|v| v.as_ref()) + pub fn value(&self) -> Option<&'a BufferString> { + self.value } - /// Constructs the next scalar at most `max_length` bytes that's lexicographically greater than - /// this. + /// Constructs the next [`Scalar`] at most `max_length` bytes that's lexicographically greater + /// than this. /// - /// Returns None if constructing a greater value would overflow. - pub fn upper_bound(self, max_length: usize) -> Option { - if let Some(value) = self.value { - if value.len() > max_length { - let utf8_split_pos = (max_length.saturating_sub(3)..=max_length) - .rfind(|p| value.is_char_boundary(*p)) - .vortex_expect("Failed to find utf8 character boundary"); + /// Returns `None` if the value is null or if constructing a greater value would overflow. + pub fn upper_bound(&self, max_length: usize) -> Option { + let value = self.value()?; + let utf8_split_pos = (max_length.saturating_sub(3)..=max_length) + .rfind(|p| value.is_char_boundary(*p)) + .vortex_expect("Failed to find utf8 character boundary"); - let sliced = value.inner().slice(..utf8_split_pos); - drop(value); - - // SAFETY: we slice to a char boundary so the sliced range contains valid UTF-8. - let sliced_buf = unsafe { BufferString::new_unchecked(sliced) }; - let incremented = sliced_buf.increment().ok()?; - Some(Self { - dtype: self.dtype, - value: Some(Arc::new(incremented)), - }) - } else { - Some(Self { - dtype: self.dtype, - value: Some(value), - }) - } - } else { - Some(self) - } + // SAFETY: we slice to a char boundary so the sliced range contains valid UTF-8. + let sliced = unsafe { BufferString::new_unchecked(value.inner().slice(..utf8_split_pos)) }; + let incremented = sliced.increment().ok()?; + Some(Scalar::utf8(incremented, self.dtype().nullability())) } - /// Construct a value at most `max_length` in size that's less than ourselves. - pub fn lower_bound(self, max_length: usize) -> Self { - if let Some(value) = self.value { - if value.len() > max_length { - // UTF8 characters are at most 4 bytes, since we know that BufferString is UTF8 we must have a valid character boundary + /// Construct a [`Scalar`] at most `max_length` in size that's less than or equal to + /// ourselves. + /// + /// Returns a null [`Scalar`] if the value is null. + pub fn lower_bound(&self, max_length: usize) -> Scalar { + match self.value() { + Some(value) => { + // UTF-8 characters are at most 4 bytes. Since we know that `BufferString` is + // valid UTF-8, we must have a valid character boundary. let utf8_split_pos = (max_length.saturating_sub(3)..=max_length) .rfind(|p| value.is_char_boundary(*p)) .vortex_expect("Failed to find utf8 character boundary"); - Self { - dtype: self.dtype, - value: Some(Arc::new(unsafe { - BufferString::new_unchecked(value.inner().slice(0..utf8_split_pos)) - })), - } - } else { - Self { - dtype: self.dtype, - value: Some(value), - } + let sliced = + unsafe { BufferString::new_unchecked(value.inner().slice(0..utf8_split_pos)) }; + Scalar::utf8(sliced, self.dtype().nullability()) } - } else { - self + None => Scalar::null(self.dtype().clone()), } } + /// Casts this scalar to the given `dtype`. pub(crate) fn cast(&self, dtype: &DType) -> VortexResult { if !matches!(dtype, DType::Utf8(..)) { vortex_bail!( "Cannot cast utf8 to {dtype}: UTF-8 scalars can only be cast to UTF-8 types with different nullability" ) } - Ok(Scalar::new( + Scalar::try_new( dtype.clone(), - ScalarValue(InnerScalarValue::BufferString( - self.value - .as_ref() - .vortex_expect("nullness handled in Scalar::cast") - .clone(), + Some(ScalarValue::Utf8( + self.value() + .cloned() + .vortex_expect("nullness handled in Scalar::cast"), )), - )) + ) } /// Length of the scalar value or None if value is null @@ -244,156 +151,74 @@ impl<'a> Utf8Scalar<'a> { } } -impl Scalar { - /// Creates a new UTF-8 scalar from a string-like value. - /// - /// # Panics +/// Types that can hold a valid UTF-8 string. +pub trait StringLike: private::Sealed + Sized { + /// Replace the last codepoint in the string with the next codepoint. /// - /// Panics if the input cannot be converted to a valid UTF-8 string. - pub fn utf8(str: B, nullability: Nullability) -> Self - where - B: Into, - { - Self::try_utf8(str, nullability).unwrap() - } - - /// Tries to create a new UTF-8 scalar from a string-like value. + /// This operation will attempt to reuse the original memory. /// + /// If incrementing the last char fails, or if the string is empty, + /// we return an Err with the original unmodified string. /// # Errors /// - /// Returns an error if the input cannot be converted to a valid UTF-8 string. - pub fn try_utf8( - str: B, - nullability: Nullability, - ) -> Result>::Error> - where - B: TryInto, - { - Ok(Self::new( - DType::Utf8(nullability), - ScalarValue(InnerScalarValue::BufferString(Arc::new(str.try_into()?))), - )) - } -} - -impl<'a> TryFrom<&'a Scalar> for Utf8Scalar<'a> { - type Error = VortexError; - - fn try_from(value: &'a Scalar) -> Result { - if !matches!(value.dtype(), DType::Utf8(_)) { - vortex_bail!("Expected utf8 scalar, found {}", value.dtype()) - } - Ok(Self { - dtype: value.dtype(), - value: value.value().as_buffer_string()?, - }) - } -} - -impl<'a> TryFrom<&'a Scalar> for String { - type Error = VortexError; - - fn try_from(value: &'a Scalar) -> Result { - Ok(BufferString::try_from(value)?.to_string()) - } -} - -impl TryFrom for String { - type Error = VortexError; - - fn try_from(value: Scalar) -> Result { - Ok(BufferString::try_from(value)?.to_string()) - } -} - -impl From<&str> for Scalar { - fn from(value: &str) -> Self { - Self::new( - DType::Utf8(NonNullable), - ScalarValue(InnerScalarValue::BufferString(Arc::new( - value.to_string().into(), - ))), - ) - } -} - -impl From for Scalar { - fn from(value: String) -> Self { - Self::new( - DType::Utf8(NonNullable), - ScalarValue(InnerScalarValue::BufferString(Arc::new(value.into()))), - ) - } -} - -impl From for Scalar { - fn from(value: BufferString) -> Self { - Self::new( - DType::Utf8(NonNullable), - ScalarValue(InnerScalarValue::BufferString(Arc::new(value))), - ) - } -} - -impl From> for Scalar { - fn from(value: Arc) -> Self { - Self::new( - DType::Utf8(NonNullable), - ScalarValue(InnerScalarValue::BufferString(value)), - ) - } + /// Returns `Err(self)` if the string is empty or if incrementing the last char overflows. + fn increment(self) -> Result; } -impl<'a> TryFrom<&'a Scalar> for BufferString { - type Error = VortexError; +/// Sealed trait implementation module for [`StringLike`]. +mod private { + use vortex_buffer::BufferString; - fn try_from(scalar: &'a Scalar) -> VortexResult { - >::try_from(scalar)? - .ok_or_else(|| vortex_err!("Can't extract present value from null scalar")) - } -} + use crate::StringLike; -impl TryFrom for BufferString { - type Error = VortexError; + /// Prevents external implementations of [`StringLike`]. + pub trait Sealed {} - fn try_from(scalar: Scalar) -> Result { - Self::try_from(&scalar) - } -} + impl Sealed for String {} -impl<'a> TryFrom<&'a Scalar> for Option { - type Error = VortexError; + impl StringLike for String { + fn increment(mut self) -> Result { + let Some(last_char) = self.pop() else { + return Ok(self); + }; - fn try_from(scalar: &'a Scalar) -> Result { - Ok(Utf8Scalar::try_from(scalar)?.value()) + if let Some(next_char) = char::from_u32(last_char as u32 + 1) { + self.push(next_char); + Ok(self) + } else { + // Return the original string + self.push(last_char); + Err(self) + } + } } -} - -impl TryFrom for Option { - type Error = VortexError; - fn try_from(scalar: Scalar) -> Result { - Self::try_from(&scalar) - } -} + impl Sealed for BufferString {} -impl From<&str> for ScalarValue { - fn from(value: &str) -> Self { - ScalarValue(InnerScalarValue::BufferString(Arc::new( - value.to_string().into(), - ))) - } -} + impl StringLike for BufferString { + #[allow(clippy::unwrap_in_result, clippy::expect_used)] + fn increment(self) -> Result { + if self.is_empty() { + return Err(self); + } -impl From for ScalarValue { - fn from(value: String) -> Self { - ScalarValue(InnerScalarValue::BufferString(Arc::new(value.into()))) - } -} + // Chop off the last char and return it here. + let (last_idx, last_char) = self.char_indices().last().expect("non-empty"); + if let Some(next_char) = char::from_u32(last_char as u32 + 1) + && next_char.len_utf8() == last_char.len_utf8() + { + // Because the next char has the same byte width as the last char, we can overwrite + // the memory directly. + let mut bytes = self.into_inner().into_mut(); + next_char.encode_utf8(&mut bytes.as_mut()[last_idx..]); -impl From for ScalarValue { - fn from(value: BufferString) -> Self { - ScalarValue(InnerScalarValue::BufferString(Arc::new(value))) + // SAFETY: we overwrite the last valid char with new valid char, so + // the buffer continues to hold valid UTF-8 data. + unsafe { Ok(BufferString::new_unchecked(bytes.freeze())) } + } else { + Err(self) + } + } } } @@ -403,7 +228,6 @@ mod tests { use rstest::rstest; use vortex_dtype::Nullability; - use vortex_error::VortexExpect; use crate::Scalar; use crate::Utf8Scalar; @@ -412,37 +236,20 @@ mod tests { fn lower_bound() { let utf8 = Scalar::utf8("snowman⛄️snowman", Nullability::NonNullable); let expected = Scalar::utf8("snowman", Nullability::NonNullable); - assert_eq!( - Utf8Scalar::try_from(&utf8) - .vortex_expect("utf8 scalar conversion should succeed") - .lower_bound(9), - Utf8Scalar::try_from(&expected).vortex_expect("utf8 scalar conversion should succeed") - ); + assert_eq!(utf8.as_utf8().lower_bound(9), expected,); } #[test] fn upper_bound() { let utf8 = Scalar::utf8("char🪩", Nullability::NonNullable); let expected = Scalar::utf8("chas", Nullability::NonNullable); - assert_eq!( - Utf8Scalar::try_from(&utf8) - .vortex_expect("utf8 scalar conversion should succeed") - .upper_bound(5) - .vortex_expect("must have upper bound"), - Utf8Scalar::try_from(&expected).vortex_expect("utf8 scalar conversion should succeed") - ); + assert_eq!(utf8.as_utf8().upper_bound(5).unwrap(), expected,); } #[test] fn upper_bound_overflow() { let utf8 = Scalar::utf8("🂑🂒🂓", Nullability::NonNullable); - - assert!( - Utf8Scalar::try_from(&utf8) - .vortex_expect("utf8 scalar conversion should succeed") - .upper_bound(2) - .is_none() - ); + assert!(utf8.as_utf8().upper_bound(2).is_none()); } #[rstest] @@ -454,8 +261,8 @@ mod tests { let scalar1 = Scalar::utf8(str1, Nullability::NonNullable); let scalar2 = Scalar::utf8(str2, Nullability::NonNullable); - let utf8_scalar1 = Utf8Scalar::try_from(&scalar1).unwrap(); - let utf8_scalar2 = Utf8Scalar::try_from(&scalar2).unwrap(); + let utf8_scalar1 = scalar1.as_utf8(); + let utf8_scalar2 = scalar2.as_utf8(); assert_eq!(utf8_scalar1 == utf8_scalar2, expected); } @@ -474,8 +281,8 @@ mod tests { let scalar1 = Scalar::utf8(str1, Nullability::NonNullable); let scalar2 = Scalar::utf8(str2, Nullability::NonNullable); - let utf8_scalar1 = Utf8Scalar::try_from(&scalar1).unwrap(); - let utf8_scalar2 = Utf8Scalar::try_from(&scalar2).unwrap(); + let utf8_scalar1 = scalar1.as_utf8(); + let utf8_scalar2 = scalar2.as_utf8(); assert_eq!(utf8_scalar1.partial_cmp(&utf8_scalar2), Some(expected)); } @@ -483,10 +290,10 @@ mod tests { #[test] fn test_utf8_null_value() { let null_utf8 = Scalar::null(vortex_dtype::DType::Utf8(Nullability::Nullable)); - let scalar = Utf8Scalar::try_from(&null_utf8).unwrap(); + let scalar = null_utf8.as_utf8(); assert!(scalar.value().is_none()); - assert!(scalar.value_ref().is_none()); + assert!(scalar.value().is_none()); assert!(scalar.len().is_none()); assert!(scalar.is_empty().is_none()); } @@ -496,11 +303,11 @@ mod tests { let empty = Scalar::utf8("", Nullability::NonNullable); let non_empty = Scalar::utf8("hello", Nullability::NonNullable); - let empty_scalar = Utf8Scalar::try_from(&empty).unwrap(); + let empty_scalar = empty.as_utf8(); assert_eq!(empty_scalar.len(), Some(0)); assert_eq!(empty_scalar.is_empty(), Some(true)); - let non_empty_scalar = Utf8Scalar::try_from(&non_empty).unwrap(); + let non_empty_scalar = non_empty.as_utf8(); assert_eq!(non_empty_scalar.len(), Some(5)); assert_eq!(non_empty_scalar.is_empty(), Some(false)); } @@ -509,10 +316,10 @@ mod tests { fn test_utf8_value_ref() { let data = "test string"; let utf8 = Scalar::utf8(data, Nullability::NonNullable); - let scalar = Utf8Scalar::try_from(&utf8).unwrap(); + let scalar = utf8.as_utf8(); // value_ref should not clone - let value_ref = scalar.value_ref().unwrap(); + let value_ref = scalar.value().unwrap(); assert_eq!(value_ref.as_str(), data); // value should clone @@ -526,13 +333,13 @@ mod tests { use vortex_dtype::Nullability; let utf8 = Scalar::utf8("test", Nullability::NonNullable); - let scalar = Utf8Scalar::try_from(&utf8).unwrap(); + let scalar = utf8.as_utf8(); // Cast to nullable utf8 let result = scalar.cast(&DType::Utf8(Nullability::Nullable)).unwrap(); assert_eq!(result.dtype(), &DType::Utf8(Nullability::Nullable)); - let casted = Utf8Scalar::try_from(&result).unwrap(); + let casted = result.as_utf8(); assert_eq!(casted.value().unwrap().as_str(), "test"); } @@ -543,22 +350,22 @@ mod tests { use vortex_dtype::PType; let utf8 = Scalar::utf8("test", Nullability::NonNullable); - let scalar = Utf8Scalar::try_from(&utf8).unwrap(); + let scalar = utf8.as_utf8(); let result = scalar.cast(&DType::Primitive(PType::I32, Nullability::NonNullable)); assert!(result.is_err()); } #[test] - fn test_from_scalar_value_non_utf8_dtype() { + fn test_try_new_non_utf8_dtype() { use vortex_dtype::DType; use vortex_dtype::Nullability; use vortex_dtype::PType; let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let value = crate::ScalarValue(crate::InnerScalarValue::Primitive(crate::PValue::I32(42))); + let value = crate::ScalarValue::Primitive(crate::PValue::I32(42)); - let result = Utf8Scalar::from_scalar_value(&dtype, value); + let result = Utf8Scalar::try_new(&dtype, Some(&value)); assert!(result.is_err()); } @@ -567,47 +374,21 @@ mod tests { use vortex_dtype::Nullability; let scalar = Scalar::primitive(42i32, Nullability::NonNullable); - let result = Utf8Scalar::try_from(&scalar); - assert!(result.is_err()); + assert!(scalar.as_utf8_opt().is_none()); } #[test] fn test_upper_bound_null() { let null_utf8 = Scalar::null(vortex_dtype::DType::Utf8(Nullability::Nullable)); - let scalar = Utf8Scalar::try_from(&null_utf8).unwrap(); - - let result = scalar.upper_bound(10); - assert!(result.is_some()); - assert!(result.unwrap().value().is_none()); + let scalar = null_utf8.as_utf8(); + assert!(scalar.upper_bound(10).is_none()); } #[test] fn test_lower_bound_null() { let null_utf8 = Scalar::null(vortex_dtype::DType::Utf8(Nullability::Nullable)); - let scalar = Utf8Scalar::try_from(&null_utf8).unwrap(); - - let result = scalar.lower_bound(10); - assert!(result.value().is_none()); - } - - #[test] - fn test_upper_bound_exact_length() { - let utf8 = Scalar::utf8("abc", Nullability::NonNullable); - let scalar = Utf8Scalar::try_from(&utf8).unwrap(); - - let result = scalar.upper_bound(3); - assert!(result.is_some()); - let upper = result.unwrap(); - assert_eq!(upper.value().unwrap().as_str(), "abc"); - } - - #[test] - fn test_lower_bound_exact_length() { - let utf8 = Scalar::utf8("abc", Nullability::NonNullable); - let scalar = Utf8Scalar::try_from(&utf8).unwrap(); - - let result = scalar.lower_bound(3); - assert_eq!(result.value().unwrap().as_str(), "abc"); + let scalar = null_utf8.as_utf8(); + assert!(scalar.lower_bound(10).is_null()); } #[test] @@ -619,7 +400,7 @@ mod tests { scalar.dtype(), &vortex_dtype::DType::Utf8(Nullability::NonNullable) ); - let utf8 = Utf8Scalar::try_from(&scalar).unwrap(); + let utf8 = scalar.as_utf8(); assert_eq!(utf8.value().unwrap().as_str(), data); } @@ -632,7 +413,7 @@ mod tests { scalar.dtype(), &vortex_dtype::DType::Utf8(Nullability::NonNullable) ); - let utf8 = Utf8Scalar::try_from(&scalar).unwrap(); + let utf8 = scalar.as_utf8(); assert_eq!(utf8.value().unwrap().as_str(), "hello world"); } @@ -647,24 +428,7 @@ mod tests { scalar.dtype(), &vortex_dtype::DType::Utf8(Nullability::NonNullable) ); - let utf8 = Utf8Scalar::try_from(&scalar).unwrap(); - assert_eq!(utf8.value().unwrap().as_str(), "test"); - } - - #[test] - fn test_from_arc_buffer_string() { - use std::sync::Arc; - - use vortex_buffer::BufferString; - - let data = Arc::new(BufferString::from("test")); - let scalar: Scalar = data.into(); - - assert_eq!( - scalar.dtype(), - &vortex_dtype::DType::Utf8(Nullability::NonNullable) - ); - let utf8 = Utf8Scalar::try_from(&scalar).unwrap(); + let utf8 = scalar.as_utf8(); assert_eq!(utf8.value().unwrap().as_str(), "test"); } @@ -730,8 +494,11 @@ mod tests { let data = "test"; let value: crate::ScalarValue = data.into(); - let scalar = Scalar::new(vortex_dtype::DType::Utf8(Nullability::NonNullable), value); - let utf8 = Utf8Scalar::try_from(&scalar).unwrap(); + let scalar = Scalar::new( + vortex_dtype::DType::Utf8(Nullability::NonNullable), + Some(value), + ); + let utf8 = scalar.as_utf8(); assert_eq!(utf8.value().unwrap().as_str(), data); } @@ -740,8 +507,11 @@ mod tests { let data = String::from("test"); let value: crate::ScalarValue = data.clone().into(); - let scalar = Scalar::new(vortex_dtype::DType::Utf8(Nullability::NonNullable), value); - let utf8 = Utf8Scalar::try_from(&scalar).unwrap(); + let scalar = Scalar::new( + vortex_dtype::DType::Utf8(Nullability::NonNullable), + Some(value), + ); + let utf8 = scalar.as_utf8(); assert_eq!(utf8.value().unwrap().as_str(), &data); } @@ -752,8 +522,11 @@ mod tests { let data = BufferString::from("test"); let value: crate::ScalarValue = data.into(); - let scalar = Scalar::new(vortex_dtype::DType::Utf8(Nullability::NonNullable), value); - let utf8 = Utf8Scalar::try_from(&scalar).unwrap(); + let scalar = Scalar::new( + vortex_dtype::DType::Utf8(Nullability::NonNullable), + Some(value), + ); + let utf8 = scalar.as_utf8(); assert_eq!(utf8.value().unwrap().as_str(), "test"); } @@ -761,7 +534,7 @@ mod tests { fn test_utf8_with_emoji() { let emoji_str = "Hello 👋 World 🌍!"; let scalar = Scalar::utf8(emoji_str, Nullability::NonNullable); - let utf8_scalar = Utf8Scalar::try_from(&scalar).unwrap(); + let utf8_scalar = scalar.as_utf8(); assert_eq!(utf8_scalar.value().unwrap().as_str(), emoji_str); assert!(utf8_scalar.len().unwrap() > emoji_str.chars().count()); // Byte length > char count @@ -772,8 +545,8 @@ mod tests { let null_scalar = Scalar::null(vortex_dtype::DType::Utf8(Nullability::Nullable)); let non_null_scalar = Scalar::utf8("test", Nullability::Nullable); - let null = Utf8Scalar::try_from(&null_scalar).unwrap(); - let non_null = Utf8Scalar::try_from(&non_null_scalar).unwrap(); + let null = null_scalar.as_utf8(); + let non_null = non_null_scalar.as_utf8(); // Null < Some("test") assert!(null < non_null);