From 81dc683fe148047b3c1fa3113ae7c1e0c51ed972 Mon Sep 17 00:00:00 2001 From: Simon Warta Date: Fri, 6 Dec 2024 00:12:48 +0100 Subject: [PATCH] Add Bufany::repeated_message --- src/bufany.rs | 173 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 2 +- 2 files changed, 174 insertions(+), 1 deletion(-) diff --git a/src/bufany.rs b/src/bufany.rs index cb51e28..3c14411 100644 --- a/src/bufany.rs +++ b/src/bufany.rs @@ -61,6 +61,14 @@ pub enum RepeatedStringError { InvalidUtf8, } +#[derive(Debug, PartialEq)] +pub enum RepeatedMessageError { + /// Found a value of the wrong wire type + TypeMismatch, + /// Found a value that cannot be decoded + DecodingError(BufanyError), +} + impl core::fmt::Display for BufanyError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_str("Error decoding protobuf: ")?; @@ -92,12 +100,29 @@ impl core::fmt::Display for RepeatedStringError { } } +impl core::fmt::Display for RepeatedMessageError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str("Error decoding repeated message: ")?; + match self { + RepeatedMessageError::TypeMismatch => { + f.write_str("Found a value of the wrong wire type") + } + RepeatedMessageError::DecodingError(_) => { + f.write_str("Could not decode nested message") + } + } + } +} + #[cfg(feature = "std")] impl std::error::Error for BufanyError {} #[cfg(feature = "std")] impl std::error::Error for RepeatedStringError {} +#[cfg(feature = "std")] +impl std::error::Error for RepeatedMessageError {} + impl Bufany<'_> { /// Creates an empty instance with the given lifetime. /// @@ -747,6 +772,78 @@ impl<'a> Bufany<'a> { Ok(out) } + /// Gets repeated message from the given field number. + /// + /// Returns an error in case a wrong wire type was found + /// or the message cannot be decoded. + /// + /// ## Example + /// + /// ``` + /// use anybuf::{Anybuf, Bufany, RepeatedMessageError}; + /// + /// let serialized = Anybuf::new() + /// .append_message( + /// 1, + /// &Anybuf::new() + /// .append_bool(1, true) + /// .append_string(2, "foo") + /// .append_sint64(3, -37648762834), + /// ) + /// .append_message( + /// 1, + /// &Anybuf::new() + /// .append_bool(1, true) + /// .append_string(2, "bar") + /// .append_sint64(3, -37648762834), + /// ) + /// .append_message( + /// 1, + /// &Anybuf::new() + /// .append_bool(1, true) + /// .append_string(2, "baz") + /// .append_sint64(3, -37648762834), + /// ) + /// .append_sint32(2, 150) + /// .append_bytes(3, b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0") + /// .into_vec(); + /// let decoded = Bufany::deserialize(&serialized).unwrap(); + /// + /// let nested = decoded.repeated_message(1).unwrap(); + /// assert_eq!(nested.len(), 3); + /// assert_eq!(nested[0].bool(1), Some(true)); + /// assert_eq!(nested[0].string(2), Some("foo".to_string())); + /// assert_eq!(nested[0].sint64(3), Some(-37648762834)); + /// assert_eq!(nested[1].bool(1), Some(true)); + /// assert_eq!(nested[1].string(2), Some("bar".to_string())); + /// assert_eq!(nested[1].sint64(3), Some(-37648762834)); + /// assert_eq!(nested[2].bool(1), Some(true)); + /// assert_eq!(nested[2].string(2), Some("baz".to_string())); + /// assert_eq!(nested[2].sint64(3), Some(-37648762834)); + /// + /// assert!(matches!(decoded.repeated_message(2).unwrap_err(), RepeatedMessageError::TypeMismatch)); // wrong type + /// assert!(matches!(decoded.repeated_message(3).unwrap_err(), RepeatedMessageError::DecodingError(_))); // not a valid proto message + /// ``` + pub fn repeated_message( + &'a self, + field_number: u32, + ) -> Result>, RepeatedMessageError> { + let values = self.repeated_value_ref(field_number); + let mut out = Vec::with_capacity(values.len()); + + for value in values { + let data = match value { + Value::VariableLength(data) => *data, + _ => return Err(RepeatedMessageError::TypeMismatch), + }; + match Bufany::deserialize(data) { + Ok(m) => out.push(m), + Err(err) => return Err(RepeatedMessageError::DecodingError(err)), + } + } + Ok(out) + } + /// Gets the value of the given field number. This returns None if /// the field number does not exist pub fn value(&self, field_number: u32) -> Option { @@ -1309,4 +1406,80 @@ mod tests { // not serialized => default assert_eq!(decoded.repeated_string(85).unwrap(), Vec::::new()); } + + #[test] + fn repeated_message_works() { + let serialized = Anybuf::new() + .append_message( + 1, + &Anybuf::new() + .append_bool(1, true) + .append_string(2, "foo") + .append_sint64(3, -37648762834), + ) + .append_message( + 1, + &Anybuf::new() + .append_bool(1, true) + .append_string(2, "bar") + .append_sint64(3, -37648762834), + ) + .append_message( + 1, + &Anybuf::new() + .append_bool(1, true) + .append_string(2, "baz") + .append_sint64(3, -37648762834), + ) + .append_message(2, &Anybuf::new().append_uint32(1, 42)) + .append_message(3, &Anybuf::new()) + .append_sint32(10, 150) + .append_message( + 11, + &Anybuf::new() + .append_bool(1, true) + .append_string(2, "baz") + .append_sint64(3, -37648762834), + ) + .append_uint32(11, 22) + .append_bytes(12, b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0") + .into_vec(); + let decoded = Bufany::deserialize(&serialized).unwrap(); + + let nested = decoded.repeated_message(1).unwrap(); + assert_eq!(nested.len(), 3); + assert_eq!(nested[0].bool(1), Some(true)); + assert_eq!(nested[0].string(2), Some("foo".to_string())); + assert_eq!(nested[0].sint64(3), Some(-37648762834)); + assert_eq!(nested[1].bool(1), Some(true)); + assert_eq!(nested[1].string(2), Some("bar".to_string())); + assert_eq!(nested[1].sint64(3), Some(-37648762834)); + assert_eq!(nested[2].bool(1), Some(true)); + assert_eq!(nested[2].string(2), Some("baz".to_string())); + assert_eq!(nested[2].sint64(3), Some(-37648762834)); + + let nested = decoded.repeated_message(2).unwrap(); + assert_eq!(nested.len(), 1); + assert_eq!(nested[0].uint32(1), Some(42)); + + // An empty message is non existent + let nested = decoded.repeated_message(3).unwrap(); + assert_eq!(nested.len(), 0); + + // int + assert!(matches!( + decoded.repeated_message(10).unwrap_err(), + RepeatedMessageError::TypeMismatch + )); + // mixed type string and int + assert!(matches!( + decoded.repeated_message(11).unwrap_err(), + RepeatedMessageError::TypeMismatch + )); + // invalid data in variable length field + assert!(matches!( + decoded.repeated_message(12).unwrap_err(), + RepeatedMessageError::DecodingError(_) + )); + } } diff --git a/src/lib.rs b/src/lib.rs index fb24809..842b54d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -36,4 +36,4 @@ mod slice_reader; mod varint; pub use crate::anybuf::Anybuf; -pub use crate::bufany::{Bufany, BufanyError, RepeatedStringError}; +pub use crate::bufany::{Bufany, BufanyError, RepeatedMessageError, RepeatedStringError};