Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions encodings/alp/src/alp/compute/between.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ impl BetweenKernel for ALPVTable {
match_each_alp_float_ptype!(array.ptype(), |F| {
between_impl::<F>(
array,
F::try_from(lower)?,
F::try_from(upper)?,
F::try_from(&lower)?,
F::try_from(&upper)?,
nullability,
options,
)
Expand Down
10 changes: 8 additions & 2 deletions encodings/alp/src/alp/compute/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use vortex_array::register_kernel;
use vortex_dtype::NativePType;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_scalar::PrimitiveScalar;
use vortex_error::vortex_err;
use vortex_scalar::Scalar;

use crate::ALPArray;
Expand All @@ -42,7 +42,13 @@ impl CompareKernel for ALPVTable {
}

if let Some(const_scalar) = rhs.as_constant() {
let pscalar = PrimitiveScalar::try_from(&const_scalar)?;
let pscalar = const_scalar.as_primitive_opt().ok_or_else(|| {
vortex_err!(
"ALP Compare RHS had the wrong type {}, expected {}",
const_scalar,
const_scalar.dtype()
)
})?;

match_each_alp_float_ptype!(pscalar.ptype(), |T| {
match pscalar.typed_value::<T>() {
Expand Down
6 changes: 2 additions & 4 deletions encodings/alp/src/alp/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ impl OperationsVTable<ALPVTable> for ALPVTable {
let encoded_val = array.encoded().scalar_at(index)?;

Ok(match_each_alp_float_ptype!(array.ptype(), |T| {
let encoded_val: <T as ALPFloat>::ALPInt = encoded_val
.as_ref()
.try_into()
.vortex_expect("invalid ALPInt");
let encoded_val: <T as ALPFloat>::ALPInt =
(&encoded_val).try_into().vortex_expect("invalid ALPInt");
Scalar::primitive(
<T as ALPFloat>::decode_single(encoded_val, array.exponents()),
array.dtype().nullability(),
Expand Down
3 changes: 1 addition & 2 deletions encodings/alp/src/alp_rd/compute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use vortex_array::arrays::TakeExecute;
use vortex_array::compute::fill_null;
use vortex_error::VortexResult;
use vortex_scalar::Scalar;
use vortex_scalar::ScalarValue;

use crate::ALPRDArray;
use crate::ALPRDVTable;
Expand All @@ -36,7 +35,7 @@ impl TakeExecute for ALPRDVTable {
.transpose()?;
let right_parts = fill_null(
&array.right_parts().take(indices.to_array())?,
&Scalar::new(array.right_parts().dtype().clone(), ScalarValue::from(0)),
&Scalar::zero_value(array.right_parts().dtype()),
)?;

Ok(Some(
Expand Down
2 changes: 1 addition & 1 deletion encodings/datetime-parts/src/compute/rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ fn try_extract_days_constant(array: &ArrayRef) -> Option<i64> {
fn is_constant_zero(array: &ArrayRef) -> bool {
array
.as_opt::<ConstantVTable>()
.is_some_and(|c| c.scalar().is_zero())
.is_some_and(|c| c.scalar().is_zero() == Some(true))
}

#[cfg(test)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ impl CompareKernel for DecimalBytePartsVTable {
.vortex_expect("checked for null in entry func");

match decimal_value_wrapper_to_primitive(rhs_decimal, lhs.msp.as_primitive_typed().ptype())
.map(|value| Scalar::new(scalar_type.clone(), value))
{
Ok(encoded_scalar) => {
Ok(value) => {
let encoded_scalar = Scalar::try_new(scalar_type, Some(value))?;
let encoded_const = ConstantArray::new(encoded_scalar, rhs.len());
compare(&lhs.msp, &encoded_const.to_array(), operator).map(Some)
}
Expand Down Expand Up @@ -165,7 +165,10 @@ mod tests {
)
.unwrap()
.to_array();
let rhs = ConstantArray::new(Scalar::new(dtype, DecimalValue::I64(400).into()), lhs.len());
let rhs = ConstantArray::new(
Scalar::try_new(dtype, Some(DecimalValue::I64(400).into())).unwrap(),
lhs.len(),
);

let res = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();

Expand Down Expand Up @@ -215,10 +218,11 @@ mod tests {
.to_array();
// This cannot be converted to a i32.
let rhs = ConstantArray::new(
Scalar::new(
Scalar::try_new(
dtype.clone(),
DecimalValue::I128(-9999999999999965304).into(),
),
Some(DecimalValue::I128(-9999999999999965304).into()),
)
.unwrap(),
lhs.len(),
);

Expand All @@ -236,7 +240,7 @@ mod tests {

// This cannot be converted to a i32.
let rhs = ConstantArray::new(
Scalar::new(dtype, DecimalValue::I128(9999999999999965304).into()),
Scalar::try_new(dtype, Some(DecimalValue::I128(9999999999999965304).into())).unwrap(),
lhs.len(),
);

Expand Down
16 changes: 11 additions & 5 deletions encodings/decimal-byte-parts/src/decimal_byte_parts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ use vortex_error::vortex_bail;
use vortex_error::vortex_ensure;
use vortex_scalar::DecimalValue;
use vortex_scalar::Scalar;
use vortex_scalar::ScalarValue;
use vortex_session::VortexSession;

use crate::decimal_byte_parts::compute::kernel::PARENT_KERNELS;
Expand Down Expand Up @@ -285,10 +286,10 @@ impl OperationsVTable<DecimalBytePartsVTable> for DecimalBytePartsVTable {
let primitive_scalar = scalar.as_primitive();
// TODO(joe): extend this to support multiple parts.
let value = primitive_scalar.as_::<i64>().vortex_expect("non-null");
Ok(Scalar::new(
Scalar::try_new(
array.dtype.clone(),
DecimalValue::I64(value).into(),
))
Some(ScalarValue::Decimal(DecimalValue::I64(value))),
)
}
}

Expand Down Expand Up @@ -319,6 +320,7 @@ mod tests {
use vortex_dtype::Nullability;
use vortex_scalar::DecimalValue;
use vortex_scalar::Scalar;
use vortex_scalar::ScalarValue;

use crate::DecimalBytePartsArray;

Expand All @@ -339,11 +341,15 @@ mod tests {

assert_eq!(Scalar::null(dtype.clone()), array.scalar_at(0).unwrap());
assert_eq!(
Scalar::new(dtype.clone(), DecimalValue::I64(200).into()),
Scalar::try_new(
dtype.clone(),
Some(ScalarValue::Decimal(DecimalValue::I64(200)))
)
.unwrap(),
array.scalar_at(1).unwrap()
);
assert_eq!(
Scalar::new(dtype, DecimalValue::I64(400).into()),
Scalar::try_new(dtype, Some(ScalarValue::Decimal(DecimalValue::I64(400)))).unwrap(),
array.scalar_at(2).unwrap()
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ mod tests {
.iter()
.enumerate()
.for_each(|(i, v)| {
let scalar: u16 = unpack_single(&compressed, i).try_into().unwrap();
let scalar: u16 = (&unpack_single(&compressed, i)).try_into().unwrap();
assert_eq!(scalar, *v);
});
}
Expand Down
5 changes: 1 addition & 4 deletions encodings/fastlanes/src/for/array/for_compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,7 @@ mod test {
.iter()
.enumerate()
.for_each(|(i, v)| {
assert_eq!(
*v,
i8::try_from(compressed.scalar_at(i).unwrap().as_ref()).unwrap()
);
assert_eq!(*v, i8::try_from(&compressed.scalar_at(i).unwrap()).unwrap());
});
assert_arrays_eq!(decompressed, array);
Ok(())
Expand Down
3 changes: 1 addition & 2 deletions encodings/fastlanes/src/for/compute/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use vortex_error::VortexError;
use vortex_error::VortexExpect as _;
use vortex_error::VortexResult;
use vortex_scalar::PValue;
use vortex_scalar::PrimitiveScalar;
use vortex_scalar::Scalar;

use crate::FoRArray;
Expand All @@ -33,7 +32,7 @@ impl CompareKernel for FoRVTable {
operator: Operator,
) -> VortexResult<Option<ArrayRef>> {
if let Some(constant) = rhs.as_constant()
&& let Ok(constant) = PrimitiveScalar::try_from(&constant)
&& let Some(constant) = constant.as_primitive_opt()
{
match_each_integer_ptype!(constant.ptype(), |T| {
return compare_constant(
Expand Down
43 changes: 7 additions & 36 deletions encodings/fastlanes/src/for/vtable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use std::fmt::Debug;
use std::fmt::Formatter;

use vortex_array::ArrayRef;
use vortex_array::DeserializeMetadata;
use vortex_array::ExecutionCtx;
use vortex_array::IntoArray;
use vortex_array::SerializeMetadata;
use vortex_array::buffer::BufferHandle;
use vortex_array::serde::ArrayChildren;
use vortex_array::vtable;
Expand Down Expand Up @@ -41,7 +38,7 @@ vtable!(FoR);
impl VTable for FoRVTable {
type Array = FoRArray;

type Metadata = ScalarValueMetadata;
type Metadata = Scalar;

type ArrayVTable = Self;
type OperationsVTable = Self;
Expand All @@ -68,22 +65,21 @@ impl VTable for FoRVTable {
}

fn metadata(array: &FoRArray) -> VortexResult<Self::Metadata> {
Ok(ScalarValueMetadata(
array.reference_scalar().value().clone(),
))
Ok(array.reference_scalar().clone())
}

fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
Ok(Some(metadata.serialize()))
Ok(Some(ScalarValue::to_proto_bytes(metadata.value())))
}

fn deserialize(
bytes: &[u8],
_dtype: &DType,
dtype: &DType,
_len: usize,
_session: &VortexSession,
) -> VortexResult<Self::Metadata> {
ScalarValueMetadata::deserialize(bytes)
let scalar_value = ScalarValue::from_proto_bytes(bytes, dtype)?;
Scalar::try_new(dtype.clone(), scalar_value)
}

fn build(
Expand All @@ -101,9 +97,8 @@ impl VTable for FoRVTable {
}

let encoded = children.get(0, dtype, len)?;
let reference = Scalar::new(dtype.clone(), metadata.0.clone());

FoRArray::try_new(encoded, reference)
FoRArray::try_new(encoded, metadata.clone())
}

fn reduce_parent(
Expand Down Expand Up @@ -134,27 +129,3 @@ pub struct FoRVTable;
impl FoRVTable {
pub const ID: ArrayId = ArrayId::new_ref("fastlanes.for");
}

#[derive(Clone)]
pub struct ScalarValueMetadata(pub ScalarValue);

impl SerializeMetadata for ScalarValueMetadata {
fn serialize(self) -> Vec<u8> {
self.0.to_protobytes()
}
}

impl DeserializeMetadata for ScalarValueMetadata {
type Output = ScalarValueMetadata;

fn deserialize(metadata: &[u8]) -> VortexResult<Self::Output> {
let scalar_value = ScalarValue::from_protobytes(metadata)?;
Ok(ScalarValueMetadata(scalar_value))
}
}

impl Debug for ScalarValueMetadata {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", &self.0)
}
}
2 changes: 1 addition & 1 deletion encodings/fastlanes/src/rle/vtable/operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ impl OperationsVTable<RLEVTable> for RLEVTable {
.values()
.scalar_at(value_idx_offset + chunk_relative_idx)?;

Ok(Scalar::new(array.dtype().clone(), scalar.into_value()))
Scalar::try_new(array.dtype().clone(), scalar.into_value())
}
}

Expand Down
6 changes: 3 additions & 3 deletions encodings/fsst/src/compute/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ fn compare_fsst_constant(
_ => unreachable!("FSSTArray can only have string or binary data type"),
};

let encoded_scalar = Scalar::new(
DType::Binary(left.dtype().nullability() | right.dtype().nullability()),
encoded_buffer.into(),
let encoded_scalar = Scalar::binary(
encoded_buffer,
left.dtype().nullability() | right.dtype().nullability(),
);

let rhs = ConstantArray::new(encoded_scalar, left.len());
Expand Down
6 changes: 1 addition & 5 deletions encodings/fsst/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_err;
use vortex_scalar::Scalar;
use vortex_scalar::ScalarValue;

use crate::FSSTArray;
use crate::FSSTVTable;
Expand All @@ -41,10 +40,7 @@ impl TakeExecute for FSSTVTable {
.map_err(|_| vortex_err!("take for codes must return varbin array"))?,
fill_null(
&array.uncompressed_lengths().take(indices.to_array())?,
&Scalar::new(
array.uncompressed_lengths_dtype().clone(),
ScalarValue::from(0),
),
&Scalar::zero_value(&array.uncompressed_lengths_dtype().clone()),
)?,
)?
.into_array(),
Expand Down
3 changes: 1 addition & 2 deletions encodings/fsst/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ impl OperationsVTable<FSSTVTable> for FSSTVTable {
let compressed = array.codes().scalar_at(index)?;
let binary_datum = compressed.as_binary().value().vortex_expect("non-null");

let decoded_buffer =
ByteBuffer::from(array.decompressor().decompress(binary_datum.as_slice()));
let decoded_buffer = ByteBuffer::from(array.decompressor().decompress(binary_datum));
Ok(varbin_scalar(decoded_buffer, array.dtype()))
}
}
6 changes: 3 additions & 3 deletions encodings/runend/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,13 @@ impl RunEndArray {

// Validate the offset and length are valid for the given ends and values
if offset != 0 && length != 0 {
let first_run_end: usize = ends.scalar_at(0)?.as_ref().try_into()?;
let first_run_end = usize::try_from(&ends.scalar_at(0)?)?;
if first_run_end <= offset {
vortex_bail!("First run end {first_run_end} must be bigger than offset {offset}");
}
}

let last_run_end: usize = ends.scalar_at(ends.len() - 1)?.as_ref().try_into()?;
let last_run_end = usize::try_from(&ends.scalar_at(ends.len() - 1)?)?;
let min_required_end = offset + length;
if last_run_end < min_required_end {
vortex_bail!("Last run end {last_run_end} must be >= offset+length {min_required_end}");
Expand Down Expand Up @@ -302,7 +302,7 @@ impl RunEndArray {
let length: usize = if ends.is_empty() {
0
} else {
ends.scalar_at(ends.len() - 1)?.as_ref().try_into()?
usize::try_from(&ends.scalar_at(ends.len() - 1)?)?
};

Self::try_new_offset_length(ends, values, 0, length)
Expand Down
Loading
Loading