diff --git a/vortex-dtype/src/dtype.rs b/vortex-dtype/src/dtype.rs index 03864b73bb..6bde6e1b47 100644 --- a/vortex-dtype/src/dtype.rs +++ b/vortex-dtype/src/dtype.rs @@ -89,7 +89,30 @@ impl DType { /// Check if `self` and `other` are equal, ignoring nullability pub fn eq_ignore_nullability(&self, other: &Self) -> bool { - self.as_nullable().eq(&other.as_nullable()) + match (self, other) { + (Null, Null) => true, + (Null, _) => false, + (Bool(_), Bool(_)) => true, + (Bool(_), _) => false, + (Primitive(lhs_ptype, _), Primitive(rhs_ptype, _)) => lhs_ptype == rhs_ptype, + (Primitive(..), _) => false, + (Utf8(_), Utf8(_)) => true, + (Utf8(_), _) => false, + (Binary(_), Binary(_)) => true, + (Binary(_), _) => false, + (List(lhs_dtype, _), List(rhs_dtype, _)) => lhs_dtype.eq_ignore_nullability(rhs_dtype), + (List(..), _) => false, + (Struct(lhs_dtype, _), Struct(rhs_dtype, _)) => { + (lhs_dtype.names() == rhs_dtype.names()) + && (lhs_dtype + .dtypes() + .zip_eq(rhs_dtype.dtypes()) + .all(|(l, r)| l.eq_ignore_nullability(&r))) + } + (Struct(..), _) => false, + (Extension(lhs_extdtype), Extension(rhs_extdtype)) => lhs_extdtype == rhs_extdtype, + (Extension(_), _) => false, + } } /// Check if `self` is a `StructDType` diff --git a/vortex-dtype/src/extension.rs b/vortex-dtype/src/extension.rs index 066c3df0af..305b3dff4d 100644 --- a/vortex-dtype/src/extension.rs +++ b/vortex-dtype/src/extension.rs @@ -58,7 +58,7 @@ impl From<&[u8]> for ExtMetadata { } /// A type descriptor for an extension type -#[derive(Debug, Clone, PartialOrd, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialOrd, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct ExtDType { id: ExtID, @@ -66,6 +66,18 @@ pub struct ExtDType { metadata: Option, } +impl PartialEq for ExtDType { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + } +} + +impl std::hash::Hash for ExtDType { + fn hash(&self, state: &mut H) { + self.id.hash(state); + } +} + impl ExtDType { /// Creates a new `ExtDType`. ///