diff --git a/math/src/elliptic_curve/short_weierstrass/point.rs b/math/src/elliptic_curve/short_weierstrass/point.rs index 13acb3ca7..388190994 100644 --- a/math/src/elliptic_curve/short_weierstrass/point.rs +++ b/math/src/elliptic_curve/short_weierstrass/point.rs @@ -212,8 +212,7 @@ impl IsGroup for ShortWeierstrassProjectivePoint { #[derive(PartialEq)] pub enum PointFormat { Projective, - // TO DO: - // Uncompressed, + Uncompressed, // Compressed, } @@ -233,7 +232,7 @@ where { /// Serialize the points in the given format #[cfg(feature = "std")] - pub fn serialize(&self, _point_format: PointFormat, endianness: Endianness) -> Vec { + pub fn serialize(&self, point_format: PointFormat, endianness: Endianness) -> Vec { // TODO: Add more compact serialization formats // Uncompressed affine / Compressed @@ -242,59 +241,101 @@ where let y_bytes: Vec; let z_bytes: Vec; - let [x, y, z] = self.coordinates(); - if endianness == Endianness::BigEndian { - x_bytes = x.to_bytes_be(); - y_bytes = y.to_bytes_be(); - z_bytes = z.to_bytes_be(); - } else { - x_bytes = x.to_bytes_le(); - y_bytes = y.to_bytes_le(); - z_bytes = z.to_bytes_le(); + match point_format { + PointFormat::Projective => { + let [x, y, z] = self.coordinates(); + if endianness == Endianness::BigEndian { + x_bytes = x.to_bytes_be(); + y_bytes = y.to_bytes_be(); + z_bytes = z.to_bytes_be(); + } else { + x_bytes = x.to_bytes_le(); + y_bytes = y.to_bytes_le(); + z_bytes = z.to_bytes_le(); + } + bytes.extend(&x_bytes); + bytes.extend(&y_bytes); + bytes.extend(&z_bytes); + } + PointFormat::Uncompressed => { + let affine_representation = self.to_affine(); + let [x, y, _z] = affine_representation.coordinates(); + if endianness == Endianness::BigEndian { + x_bytes = x.to_bytes_be(); + y_bytes = y.to_bytes_be(); + } else { + x_bytes = x.to_bytes_le(); + y_bytes = y.to_bytes_le(); + } + bytes.extend(&x_bytes); + bytes.extend(&y_bytes); + } } - - bytes.extend(&x_bytes); - bytes.extend(&y_bytes); - bytes.extend(&z_bytes); - bytes } pub fn deserialize( bytes: &[u8], - _point_format: PointFormat, + point_format: PointFormat, endianness: Endianness, ) -> Result { - if bytes.len() % 3 != 0 { - return Err(DeserializationError::InvalidAmountOfBytes); - } + match point_format { + PointFormat::Projective => { + if bytes.len() % 3 != 0 { + return Err(DeserializationError::InvalidAmountOfBytes); + } - let len = bytes.len() / 3; - let x: FieldElement; - let y: FieldElement; - let z: FieldElement; + let len = bytes.len() / 3; + let x: FieldElement; + let y: FieldElement; + let z: FieldElement; - if endianness == Endianness::BigEndian { - x = ByteConversion::from_bytes_be(&bytes[..len])?; - y = ByteConversion::from_bytes_be(&bytes[len..len * 2])?; - z = ByteConversion::from_bytes_be(&bytes[len * 2..])?; - } else { - x = ByteConversion::from_bytes_le(&bytes[..len])?; - y = ByteConversion::from_bytes_le(&bytes[len..len * 2])?; - z = ByteConversion::from_bytes_le(&bytes[len * 2..])?; - } + if endianness == Endianness::BigEndian { + x = ByteConversion::from_bytes_be(&bytes[..len])?; + y = ByteConversion::from_bytes_be(&bytes[len..len * 2])?; + z = ByteConversion::from_bytes_be(&bytes[len * 2..])?; + } else { + x = ByteConversion::from_bytes_le(&bytes[..len])?; + y = ByteConversion::from_bytes_le(&bytes[len..len * 2])?; + z = ByteConversion::from_bytes_le(&bytes[len * 2..])?; + } - if z == FieldElement::zero() { - let point = Self::new([x, y, z]); - if point.is_neutral_element() { - Ok(point) - } else { - Err(DeserializationError::FieldFromBytesError) + if z == FieldElement::zero() { + let point = Self::new([x, y, z]); + if point.is_neutral_element() { + Ok(point) + } else { + Err(DeserializationError::FieldFromBytesError) + } + } else if E::defining_equation(&(&x / &z), &(&y / &z)) == FieldElement::zero() { + Ok(Self::new([x, y, z])) + } else { + Err(DeserializationError::FieldFromBytesError) + } + } + PointFormat::Uncompressed => { + if bytes.len() % 2 != 0 { + return Err(DeserializationError::InvalidAmountOfBytes); + } + + let len = bytes.len() / 2; + let x: FieldElement; + let y: FieldElement; + + if endianness == Endianness::BigEndian { + x = ByteConversion::from_bytes_be(&bytes[..len])?; + y = ByteConversion::from_bytes_be(&bytes[len..])?; + } else { + x = ByteConversion::from_bytes_le(&bytes[..len])?; + y = ByteConversion::from_bytes_le(&bytes[len..])?; + } + + if E::defining_equation(&x, &y) == FieldElement::zero() { + Ok(Self::new([x, y, FieldElement::one()])) + } else { + Err(DeserializationError::FieldFromBytesError) + } } - } else if E::defining_equation(&(&x / &z), &(&y / &z)) == FieldElement::zero() { - Ok(Self::new([x, y, z])) - } else { - Err(DeserializationError::FieldFromBytesError) } } } @@ -347,7 +388,7 @@ mod tests { #[cfg(feature = "std")] #[test] - fn byte_conversion_from_and_to_be() { + fn byte_conversion_from_and_to_be_projective() { let expected_point = point(); let bytes_be = expected_point.serialize(PointFormat::Projective, Endianness::BigEndian); @@ -361,7 +402,20 @@ mod tests { #[cfg(feature = "std")] #[test] - fn byte_conversion_from_and_to_le() { + fn byte_conversion_from_and_to_be_uncompressed() { + let expected_point = point(); + let bytes_be = expected_point.serialize(PointFormat::Uncompressed, Endianness::BigEndian); + let result = ShortWeierstrassProjectivePoint::deserialize( + &bytes_be, + PointFormat::Uncompressed, + Endianness::BigEndian, + ); + assert_eq!(expected_point, result.unwrap()); + } + + #[cfg(feature = "std")] + #[test] + fn byte_conversion_from_and_to_le_projective() { let expected_point = point(); let bytes_be = expected_point.serialize(PointFormat::Projective, Endianness::LittleEndian); @@ -375,7 +429,22 @@ mod tests { #[cfg(feature = "std")] #[test] - fn byte_conversion_from_and_to_with_mixed_le_and_be_does_not_work() { + fn byte_conversion_from_and_to_le_uncompressed() { + let expected_point = point(); + let bytes_be = + expected_point.serialize(PointFormat::Uncompressed, Endianness::LittleEndian); + + let result = ShortWeierstrassProjectivePoint::deserialize( + &bytes_be, + PointFormat::Uncompressed, + Endianness::LittleEndian, + ); + assert_eq!(expected_point, result.unwrap()); + } + + #[cfg(feature = "std")] + #[test] + fn byte_conversion_from_and_to_with_mixed_le_and_be_does_not_work_projective() { let bytes = point().serialize(PointFormat::Projective, Endianness::LittleEndian); let result = ShortWeierstrassProjectivePoint::::deserialize( @@ -392,7 +461,24 @@ mod tests { #[cfg(feature = "std")] #[test] - fn byte_conversion_from_and_to_with_mixed_be_and_le_does_not_work() { + fn byte_conversion_from_and_to_with_mixed_le_and_be_does_not_work_uncompressed() { + let bytes = point().serialize(PointFormat::Uncompressed, Endianness::LittleEndian); + + let result = ShortWeierstrassProjectivePoint::::deserialize( + &bytes, + PointFormat::Uncompressed, + Endianness::BigEndian, + ); + + assert_eq!( + result.unwrap_err(), + DeserializationError::FieldFromBytesError + ); + } + + #[cfg(feature = "std")] + #[test] + fn byte_conversion_from_and_to_with_mixed_be_and_le_does_not_work_projective() { let bytes = point().serialize(PointFormat::Projective, Endianness::BigEndian); let result = ShortWeierstrassProjectivePoint::::deserialize( @@ -407,8 +493,25 @@ mod tests { ); } + #[cfg(feature = "std")] + #[test] + fn byte_conversion_from_and_to_with_mixed_be_and_le_does_not_work_uncompressed() { + let bytes = point().serialize(PointFormat::Uncompressed, Endianness::BigEndian); + + let result = ShortWeierstrassProjectivePoint::::deserialize( + &bytes, + PointFormat::Uncompressed, + Endianness::LittleEndian, + ); + + assert_eq!( + result.unwrap_err(), + DeserializationError::FieldFromBytesError + ); + } + #[test] - fn cannot_create_point_from_wrong_number_of_bytes_le() { + fn cannot_create_point_from_wrong_number_of_bytes_le_projective() { let bytes = &[0_u8; 13]; let result = ShortWeierstrassProjectivePoint::::deserialize( @@ -424,7 +527,23 @@ mod tests { } #[test] - fn cannot_create_point_from_wrong_number_of_bytes_be() { + fn cannot_create_point_from_wrong_number_of_bytes_le_uncompressed() { + let bytes = &[0_u8; 13]; + + let result = ShortWeierstrassProjectivePoint::::deserialize( + bytes, + PointFormat::Uncompressed, + Endianness::LittleEndian, + ); + + assert_eq!( + result.unwrap_err(), + DeserializationError::InvalidAmountOfBytes + ); + } + + #[test] + fn cannot_create_point_from_wrong_number_of_bytes_be_projective() { let bytes = &[0_u8; 13]; let result = ShortWeierstrassProjectivePoint::::deserialize( @@ -438,4 +557,20 @@ mod tests { DeserializationError::InvalidAmountOfBytes ); } + + #[test] + fn cannot_create_point_from_wrong_number_of_bytes_be_uncompressed() { + let bytes = &[0_u8; 13]; + + let result = ShortWeierstrassProjectivePoint::::deserialize( + bytes, + PointFormat::Uncompressed, + Endianness::BigEndian, + ); + + assert_eq!( + result.unwrap_err(), + DeserializationError::InvalidAmountOfBytes + ); + } }