diff --git a/Cargo.lock b/Cargo.lock index 7c0e7149ba..6e056205ee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4618,6 +4618,7 @@ dependencies = [ "datafusion-common", "flatbuffers", "flexbuffers", + "half", "itertools 0.13.0", "jiff", "num-traits", diff --git a/vortex-scalar/Cargo.toml b/vortex-scalar/Cargo.toml index f840e81873..669bb2745b 100644 --- a/vortex-scalar/Cargo.toml +++ b/vortex-scalar/Cargo.toml @@ -19,6 +19,7 @@ arrow-array = { workspace = true } datafusion-common = { workspace = true, optional = true } flatbuffers = { workspace = true, optional = true } flexbuffers = { workspace = true, optional = true } +half = { workspace = true, optional = true } itertools = { workspace = true } jiff = { workspace = true } num-traits = { workspace = true } @@ -54,6 +55,7 @@ flatbuffers = [ proto = [ "dep:prost", "dep:prost-types", + "dep:half", "vortex-dtype/proto", "vortex-proto/scalar", ] diff --git a/vortex-scalar/src/serde/proto.rs b/vortex-scalar/src/serde/proto.rs index 0e4872edda..52f973eee4 100644 --- a/vortex-scalar/src/serde/proto.rs +++ b/vortex-scalar/src/serde/proto.rs @@ -1,6 +1,7 @@ +use half::f16; use vortex_buffer::{Buffer, BufferString}; -use vortex_dtype::DType; -use vortex_error::{vortex_err, VortexError}; +use vortex_dtype::{DType, PType}; +use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult}; use vortex_proto::scalar as pb; use vortex_proto::scalar::scalar_value::Kind; use vortex_proto::scalar::ListValue; @@ -97,7 +98,8 @@ impl TryFrom<&pb::Scalar> for Scalar { .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing dtype"))?, )?; - let value = ScalarValue::try_from( + let value = deserialize_scalar_value( + &dtype, value .value .as_ref() @@ -108,34 +110,46 @@ impl TryFrom<&pb::Scalar> for Scalar { } } -impl TryFrom<&pb::ScalarValue> for ScalarValue { - type Error = VortexError; - - fn try_from(value: &pb::ScalarValue) -> Result { - let kind = value - .kind - .as_ref() - .ok_or_else(|| vortex_err!(InvalidSerde: "ScalarValue missing kind"))?; - - Ok(match kind { - Kind::NullValue(_) => ScalarValue::Null, - Kind::BoolValue(v) => ScalarValue::Bool(*v), - Kind::Int32Value(v) => ScalarValue::Primitive(PValue::I32(*v)), - Kind::Int64Value(v) => ScalarValue::Primitive(PValue::I64(*v)), - Kind::Uint32Value(v) => ScalarValue::Primitive(PValue::U32(*v)), - Kind::Uint64Value(v) => ScalarValue::Primitive(PValue::U64(*v)), - Kind::FloatValue(v) => ScalarValue::Primitive(PValue::F32(*v)), - Kind::DoubleValue(v) => ScalarValue::Primitive(PValue::F64(*v)), - Kind::StringValue(v) => ScalarValue::BufferString(BufferString::from(v.clone())), - Kind::BytesValue(v) => ScalarValue::Buffer(Buffer::from(v.as_slice())), - Kind::ListValue(v) => { - let mut values = Vec::with_capacity(v.values.len()); - for elem in v.values.iter() { - values.push(ScalarValue::try_from(elem)?); +fn deserialize_scalar_value(dtype: &DType, value: &pb::ScalarValue) -> VortexResult { + let kind = value + .kind + .as_ref() + .ok_or_else(|| vortex_err!(InvalidSerde: "ScalarValue missing kind"))?; + + match kind { + Kind::NullValue(_) => Ok(ScalarValue::Null), + Kind::BoolValue(v) => Ok(ScalarValue::Bool(*v)), + Kind::Int32Value(v) => Ok(ScalarValue::Primitive(PValue::I32(*v))), + Kind::Int64Value(v) => Ok(ScalarValue::Primitive(PValue::I64(*v))), + Kind::Uint32Value(v) => Ok(ScalarValue::Primitive(PValue::U32(*v))), + Kind::Uint64Value(v) => Ok(ScalarValue::Primitive(PValue::U64(*v))), + Kind::FloatValue(v) => match dtype { + DType::Primitive(PType::F16, _) => { + Ok(ScalarValue::Primitive(PValue::F16(f16::from_f32(*v)))) + } + DType::Primitive(PType::F32, _) => Ok(ScalarValue::Primitive(PValue::F32(*v))), + _ => vortex_bail!("invalid dtype for f32 value {}", dtype), + }, + Kind::DoubleValue(v) => Ok(ScalarValue::Primitive(PValue::F64(*v))), + Kind::StringValue(v) => Ok(ScalarValue::BufferString(BufferString::from(v.clone()))), + Kind::BytesValue(v) => Ok(ScalarValue::Buffer(Buffer::from(v.as_slice()))), + Kind::ListValue(v) => { + let mut values = Vec::with_capacity(v.values.len()); + match dtype { + DType::Struct(structdt, _) => { + for (elem, dtype) in v.values.iter().zip(structdt.dtypes().iter()) { + values.push(deserialize_scalar_value(dtype, elem)?); + } } - ScalarValue::List(values.into()) + DType::List(elementdt, _) => { + for elem in v.values.iter() { + values.push(deserialize_scalar_value(elementdt, elem)?); + } + } + _ => vortex_bail!("invalid dtype for list value {}", dtype), } - }) + Ok(ScalarValue::List(values.into())) + } } } @@ -143,15 +157,19 @@ impl TryFrom<&pb::ScalarValue> for ScalarValue { mod test { use std::sync::Arc; + use half::f16; use vortex_buffer::BufferString; - use vortex_dtype::PType::I32; + use vortex_dtype::PType::{self, I32}; use vortex_dtype::{DType, Nullability}; use vortex_proto::scalar as pb; - use crate::{Scalar, ScalarValue}; + use crate::{PValue, Scalar, ScalarValue}; fn round_trip(scalar: Scalar) { - Scalar::try_from(&pb::Scalar::from(&scalar)).unwrap(); + assert_eq!( + Scalar::try_from(&pb::Scalar::from(&scalar)).unwrap(), + scalar + ); } #[test] @@ -207,4 +225,12 @@ mod test { ), )); } + + #[test] + fn test_f16() { + round_trip(Scalar::new( + DType::Primitive(PType::F16, Nullability::Nullable), + ScalarValue::Primitive(PValue::F16(f16::from_f32(0.42))), + )); + } }