Skip to content

Commit

Permalink
fix: teach protobuf how to deserialize f16 (#991)
Browse files Browse the repository at this point in the history
  • Loading branch information
danking authored Oct 7, 2024
1 parent 7ca0928 commit 8e3d227
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 32 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions vortex-scalar/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ arrow-array = { workspace = true }
datafusion-common = { workspace = true, optional = true }
flatbuffers = { workspace = true, optional = true }
flexbuffers = { workspace = true, optional = true }
half = { workspace = true, optional = true }
itertools = { workspace = true }
jiff = { workspace = true }
num-traits = { workspace = true }
Expand Down Expand Up @@ -54,6 +55,7 @@ flatbuffers = [
proto = [
"dep:prost",
"dep:prost-types",
"dep:half",
"vortex-dtype/proto",
"vortex-proto/scalar",
]
Expand Down
90 changes: 58 additions & 32 deletions vortex-scalar/src/serde/proto.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use half::f16;
use vortex_buffer::{Buffer, BufferString};
use vortex_dtype::DType;
use vortex_error::{vortex_err, VortexError};
use vortex_dtype::{DType, PType};
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult};
use vortex_proto::scalar as pb;
use vortex_proto::scalar::scalar_value::Kind;
use vortex_proto::scalar::ListValue;
Expand Down Expand Up @@ -97,7 +98,8 @@ impl TryFrom<&pb::Scalar> for Scalar {
.ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing dtype"))?,
)?;

let value = ScalarValue::try_from(
let value = deserialize_scalar_value(
&dtype,
value
.value
.as_ref()
Expand All @@ -108,50 +110,66 @@ impl TryFrom<&pb::Scalar> for Scalar {
}
}

impl TryFrom<&pb::ScalarValue> for ScalarValue {
type Error = VortexError;

fn try_from(value: &pb::ScalarValue) -> Result<Self, Self::Error> {
let kind = value
.kind
.as_ref()
.ok_or_else(|| vortex_err!(InvalidSerde: "ScalarValue missing kind"))?;

Ok(match kind {
Kind::NullValue(_) => ScalarValue::Null,
Kind::BoolValue(v) => ScalarValue::Bool(*v),
Kind::Int32Value(v) => ScalarValue::Primitive(PValue::I32(*v)),
Kind::Int64Value(v) => ScalarValue::Primitive(PValue::I64(*v)),
Kind::Uint32Value(v) => ScalarValue::Primitive(PValue::U32(*v)),
Kind::Uint64Value(v) => ScalarValue::Primitive(PValue::U64(*v)),
Kind::FloatValue(v) => ScalarValue::Primitive(PValue::F32(*v)),
Kind::DoubleValue(v) => ScalarValue::Primitive(PValue::F64(*v)),
Kind::StringValue(v) => ScalarValue::BufferString(BufferString::from(v.clone())),
Kind::BytesValue(v) => ScalarValue::Buffer(Buffer::from(v.as_slice())),
Kind::ListValue(v) => {
let mut values = Vec::with_capacity(v.values.len());
for elem in v.values.iter() {
values.push(ScalarValue::try_from(elem)?);
fn deserialize_scalar_value(dtype: &DType, value: &pb::ScalarValue) -> VortexResult<ScalarValue> {
let kind = value
.kind
.as_ref()
.ok_or_else(|| vortex_err!(InvalidSerde: "ScalarValue missing kind"))?;

match kind {
Kind::NullValue(_) => Ok(ScalarValue::Null),
Kind::BoolValue(v) => Ok(ScalarValue::Bool(*v)),
Kind::Int32Value(v) => Ok(ScalarValue::Primitive(PValue::I32(*v))),
Kind::Int64Value(v) => Ok(ScalarValue::Primitive(PValue::I64(*v))),
Kind::Uint32Value(v) => Ok(ScalarValue::Primitive(PValue::U32(*v))),
Kind::Uint64Value(v) => Ok(ScalarValue::Primitive(PValue::U64(*v))),
Kind::FloatValue(v) => match dtype {
DType::Primitive(PType::F16, _) => {
Ok(ScalarValue::Primitive(PValue::F16(f16::from_f32(*v))))
}
DType::Primitive(PType::F32, _) => Ok(ScalarValue::Primitive(PValue::F32(*v))),
_ => vortex_bail!("invalid dtype for f32 value {}", dtype),
},
Kind::DoubleValue(v) => Ok(ScalarValue::Primitive(PValue::F64(*v))),
Kind::StringValue(v) => Ok(ScalarValue::BufferString(BufferString::from(v.clone()))),
Kind::BytesValue(v) => Ok(ScalarValue::Buffer(Buffer::from(v.as_slice()))),
Kind::ListValue(v) => {
let mut values = Vec::with_capacity(v.values.len());
match dtype {
DType::Struct(structdt, _) => {
for (elem, dtype) in v.values.iter().zip(structdt.dtypes().iter()) {
values.push(deserialize_scalar_value(dtype, elem)?);
}
}
ScalarValue::List(values.into())
DType::List(elementdt, _) => {
for elem in v.values.iter() {
values.push(deserialize_scalar_value(elementdt, elem)?);
}
}
_ => vortex_bail!("invalid dtype for list value {}", dtype),
}
})
Ok(ScalarValue::List(values.into()))
}
}
}

#[cfg(test)]
mod test {
use std::sync::Arc;

use half::f16;
use vortex_buffer::BufferString;
use vortex_dtype::PType::I32;
use vortex_dtype::PType::{self, I32};
use vortex_dtype::{DType, Nullability};
use vortex_proto::scalar as pb;

use crate::{Scalar, ScalarValue};
use crate::{PValue, Scalar, ScalarValue};

fn round_trip(scalar: Scalar) {
Scalar::try_from(&pb::Scalar::from(&scalar)).unwrap();
assert_eq!(
Scalar::try_from(&pb::Scalar::from(&scalar)).unwrap(),
scalar
);
}

#[test]
Expand Down Expand Up @@ -207,4 +225,12 @@ mod test {
),
));
}

#[test]
fn test_f16() {
round_trip(Scalar::new(
DType::Primitive(PType::F16, Nullability::Nullable),
ScalarValue::Primitive(PValue::F16(f16::from_f32(0.42))),
));
}
}

0 comments on commit 8e3d227

Please sign in to comment.