diff --git a/packages/mocks/src/encoding.rs b/packages/mocks/src/encoding.rs index ffb2d03..ae18730 100644 --- a/packages/mocks/src/encoding.rs +++ b/packages/mocks/src/encoding.rs @@ -9,9 +9,18 @@ use storey_encoding::{Cover, DecodableWithImpl, EncodableWithImpl, Encoding}; pub struct TestEncoding; +#[derive(Debug, PartialEq)] +pub struct MockError; + +impl std::fmt::Display for MockError { + fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Ok(()) + } +} + impl Encoding for TestEncoding { - type DecodeError = (); - type EncodeError = (); + type DecodeError = MockError; + type EncodeError = MockError; } // This is how we would implement `EncodableWith` and `DecodableWith` for @@ -39,16 +48,16 @@ where // Imagine `MyTestEncoding` is a third-party trait that we don't control. trait MyTestEncoding: Sized { - fn my_encode(&self) -> Result, ()>; - fn my_decode(data: &[u8]) -> Result; + fn my_encode(&self) -> Result, MockError>; + fn my_decode(data: &[u8]) -> Result; } impl MyTestEncoding for u64 { - fn my_encode(&self) -> Result, ()> { + fn my_encode(&self) -> Result, MockError> { Ok(self.to_le_bytes().to_vec()) } - fn my_decode(data: &[u8]) -> Result { + fn my_decode(data: &[u8]) -> Result { let mut bytes = [0u8; 8]; bytes.copy_from_slice(data); Ok(u64::from_le_bytes(bytes)) diff --git a/packages/storey-encoding/src/lib.rs b/packages/storey-encoding/src/lib.rs index 15317f5..8c65c9e 100644 --- a/packages/storey-encoding/src/lib.rs +++ b/packages/storey-encoding/src/lib.rs @@ -1,9 +1,9 @@ pub trait Encoding { /// The error type returned when encoding fails. - type EncodeError; + type EncodeError: std::fmt::Display; /// The error type returned when decoding fails. - type DecodeError; + type DecodeError: std::fmt::Display; } pub trait EncodableWith: sealed::SealedE { diff --git a/packages/storey/src/containers/column.rs b/packages/storey/src/containers/column.rs index 449819e..a9f3a77 100644 --- a/packages/storey/src/containers/column.rs +++ b/packages/storey/src/containers/column.rs @@ -365,7 +365,7 @@ where Ok(id) } - /// Update the value associated with the given ID. + /// Set the value associated with the given ID. /// /// # Example /// ``` @@ -380,13 +380,11 @@ where /// access.push(&1337).unwrap(); /// assert_eq!(access.get(1).unwrap(), Some(1337)); /// - /// access.update(1, &9001).unwrap(); + /// access.set(1, &9001).unwrap(); /// assert_eq!(access.get(1).unwrap(), Some(9001)); /// ``` - pub fn update(&mut self, id: u32, value: &T) -> Result<(), UpdateError> { - self.storage - .get(&encode_id(id)) - .ok_or(UpdateError::NotFound)?; + pub fn set(&mut self, id: u32, value: &T) -> Result<(), SetError> { + self.storage.get(&encode_id(id)).ok_or(SetError::NotFound)?; let bytes = value.encode()?; @@ -395,6 +393,44 @@ where Ok(()) } + /// Update the value associated with the given ID by applying a function to it. + /// + /// The provided function is called with the current value, if it exists, and should return the + /// new value. If the function returns `None`, the value is removed. + /// + /// # Example + /// ``` + /// # use mocks::encoding::TestEncoding; + /// # use mocks::backend::TestStorage; + /// use storey::containers::Column; + /// + /// let mut storage = TestStorage::new(); + /// let column = Column::::new(0); + /// let mut access = column.access(&mut storage); + /// + /// access.push(&1337).unwrap(); + /// assert_eq!(access.get(1).unwrap(), Some(1337)); + /// + /// access.update(1, |value| value.map(|v| v + 1)).unwrap(); + /// assert_eq!(access.get(1).unwrap(), Some(1338)); + /// ``` + pub fn update( + &mut self, + id: u32, + f: F, + ) -> Result<(), UpdateError> + where + F: FnOnce(Option) -> Option, + { + let new_value = f(self.get(id).map_err(UpdateError::Decode)?); + match new_value { + Some(value) => self.set(id, &value).map_err(UpdateError::Set), + None => self + .remove(id) + .map_err(|_| UpdateError::Set(SetError::NotFound)), + } + } + /// Remove the value associated with the given ID. /// /// This operation leaves behind an empty slot in the column. The ID is not reused. @@ -445,19 +481,27 @@ impl From for PushError { } #[derive(Debug, PartialEq, Eq, Clone, Copy, Error)] -pub enum UpdateError { +pub enum SetError { #[error("not found")] NotFound, #[error("{0}")] EncodingError(E), } -impl From for UpdateError { +impl From for SetError { fn from(e: E) -> Self { - UpdateError::EncodingError(e) + SetError::EncodingError(e) } } +#[derive(Debug, PartialEq, Eq, Clone, Copy, Error)] +pub enum UpdateError { + #[error("decode error: {0}")] + Decode(D), + #[error("set error: {0}")] + Set(SetError), +} + #[derive(Debug, PartialEq, Eq, Clone, Copy, Error)] pub enum RemoveError { #[error("inconsistent state")] @@ -497,8 +541,8 @@ mod tests { assert_eq!(access.len().unwrap(), 2); access.remove(1).unwrap(); - assert_eq!(access.update(1, &9001), Err(UpdateError::NotFound)); - access.update(2, &9001).unwrap(); + assert_eq!(access.set(1, &9001), Err(SetError::NotFound)); + access.set(2, &9001).unwrap(); assert_eq!(access.get(1).unwrap(), None); assert_eq!(access.get(2).unwrap(), Some(9001)); @@ -535,6 +579,28 @@ mod tests { assert_eq!(access.len().unwrap(), 1); } + #[test] + fn update() { + let mut storage = TestStorage::new(); + + let column = Column::::new(0); + let mut access = column.access(&mut storage); + + access.push(&1337).unwrap(); + access.push(&42).unwrap(); + access.push(&9001).unwrap(); + access.remove(2).unwrap(); + + access.update(1, |value| value.map(|v| v + 1)).unwrap(); + assert_eq!(access.get(1).unwrap(), Some(1338)); + + access.update(2, |value| value.map(|v| v + 1)).unwrap(); + assert_eq!(access.get(2).unwrap(), None); + + access.update(3, |value| value.map(|v| v + 1)).unwrap(); + assert_eq!(access.get(3).unwrap(), Some(9002)); + } + #[test] fn iteration() { let mut storage = TestStorage::new(); diff --git a/packages/storey/src/containers/item.rs b/packages/storey/src/containers/item.rs index 4180d06..89f1ba8 100644 --- a/packages/storey/src/containers/item.rs +++ b/packages/storey/src/containers/item.rs @@ -212,7 +212,7 @@ impl ItemAccess where E: Encoding, T: EncodableWith + DecodableWith, - S: StorageMut, + S: Storage + StorageMut, { /// Set the value of the item. /// @@ -234,6 +234,39 @@ where Ok(()) } + /// Update the value of the item. + /// + /// The function `f` is called with the current value of the item, if it exists. + /// If the function returns `Some`, the item is set to the new value. + /// If the function returns `None`, the item is removed. + /// + /// # Example + /// ``` + /// # use mocks::encoding::TestEncoding; + /// # use mocks::backend::TestStorage; + /// use storey::containers::Item; + /// + /// let mut storage = TestStorage::new(); + /// let item = Item::::new(0); + /// + /// item.access(&mut storage).set(&42).unwrap(); + /// item.access(&mut storage).update(|value| value.map(|v| v + 1)).unwrap(); + /// assert_eq!(item.access(&storage).get().unwrap(), Some(43)); + /// ``` + pub fn update(&mut self, f: F) -> Result<(), UpdateError> + where + F: FnOnce(Option) -> Option, + { + let new_value = f(self.get().map_err(UpdateError::Decode)?); + match new_value { + Some(value) => self.set(&value).map_err(UpdateError::Encode), + None => { + self.remove(); + Ok(()) + } + } + } + /// Remove the value of the item. /// /// # Example @@ -254,6 +287,14 @@ where } } +#[derive(Debug, PartialEq, Eq, Clone, Copy, thiserror::Error)] +pub enum UpdateError { + #[error("decode error: {0}")] + Decode(D), + #[error("encode error: {0}")] + Encode(E), +} + #[cfg(test)] mod tests { use super::*; @@ -276,4 +317,20 @@ mod tests { assert_eq!(access1.get().unwrap(), None); assert_eq!(storage.get(&[1]), None); } + + #[test] + fn update() { + let mut storage = TestStorage::new(); + + let item = Item::::new(0); + item.access(&mut storage).set(&42).unwrap(); + + item.access(&mut storage) + .update(|value| value.map(|v| v + 1)) + .unwrap(); + assert_eq!(item.access(&storage).get().unwrap(), Some(43)); + + item.access(&mut storage).update(|_| None).unwrap(); + assert_eq!(item.access(&storage).get().unwrap(), None); + } } diff --git a/packages/storey/src/encoding.rs b/packages/storey/src/encoding.rs index 0b311c2..85e1d8c 100644 --- a/packages/storey/src/encoding.rs +++ b/packages/storey/src/encoding.rs @@ -34,15 +34,15 @@ //! struct DisplayEncoding; //! //! impl Encoding for DisplayEncoding { -//! type DecodeError = (); -//! type EncodeError = (); +//! type DecodeError = String; +//! type EncodeError = String; //! } //! //! impl EncodableWithImpl for Cover<&T,> //! where //! T: std::fmt::Display, //! { -//! fn encode_impl(self) -> Result, ()> { +//! fn encode_impl(self) -> Result, String> { //! Ok(format!("{}", self.0).into_bytes()) //! } //! } @@ -67,17 +67,18 @@ //! struct DisplayEncoding; //! //! impl Encoding for DisplayEncoding { -//! type DecodeError = (); -//! type EncodeError = (); +//! type DecodeError = String; +//! type EncodeError = String; //! } //! //! impl DecodableWithImpl for Cover //! where //! T: std::str::FromStr, //! { -//! fn decode_impl(data: &[u8]) -> Result { -//! let string = String::from_utf8(data.to_vec()).map_err(|_| ())?; -//! let value = string.parse().map_err(|_| ())?; +//! fn decode_impl(data: &[u8]) -> Result { +//! let string = +//! String::from_utf8(data.to_vec()).map_err(|_| "string isn't UTF-8".to_string())?; +//! let value = string.parse().map_err(|_| "parsing failed".to_string())?; //! Ok(Cover(value)) //! } //! }