diff --git a/rust/flatbuffers/src/lib.rs b/rust/flatbuffers/src/lib.rs index 9ed308024fb..600d6d9c194 100644 --- a/rust/flatbuffers/src/lib.rs +++ b/rust/flatbuffers/src/lib.rs @@ -56,8 +56,8 @@ pub use crate::push::{Push, PushAlignment}; pub use crate::table::{buffer_has_identifier, Table}; pub use crate::vector::{follow_cast_ref, Vector, VectorIter}; pub use crate::verifier::{ - ErrorTraceDetail, InvalidFlatbuffer, SimpleToVerifyInSlice, TableVerifier, Verifiable, Verifier, - VerifierOptions, + ErrorTraceDetail, InvalidFlatbuffer, SimpleToVerifyInSlice, TableVerifier, Verifiable, + Verifier, VerifierOptions, }; pub use crate::vtable::field_index_to_field_offset; pub use bitflags; diff --git a/rust/flatbuffers/src/vector.rs b/rust/flatbuffers/src/vector.rs index e1569510c85..22b68db19cd 100644 --- a/rust/flatbuffers/src/vector.rs +++ b/rust/flatbuffers/src/vector.rs @@ -90,6 +90,14 @@ impl<'a, T: 'a> Vector<'a, T> { let len = self.len(); &self.0[self.1 + SIZE_UOFFSET..self.1 + SIZE_UOFFSET + sz * len] } + + /// # Safety + /// + /// The underlying bytes must be interpretable as a vector of the *same number* of `U`'s. + #[inline(always)] + pub unsafe fn cast(&self) -> Vector<'a, U> { + Vector::new(self.0, self.1) + } } impl<'a, T: Follow<'a> + 'a> Vector<'a, T> { @@ -123,11 +131,11 @@ impl<'a, T: Follow<'a> + 'a> Vector<'a, T> { Ordering::Equal => return Some(value), Ordering::Less => left = mid + 1, Ordering::Greater => { - if mid == 0 { - return None; - } - right = mid - 1; - }, + if mid == 0 { + return None; + } + right = mid - 1; + } } } diff --git a/rust/flatbuffers/src/verifier.rs b/rust/flatbuffers/src/verifier.rs index c4c55f587d0..31425f11061 100644 --- a/rust/flatbuffers/src/verifier.rs +++ b/rust/flatbuffers/src/verifier.rs @@ -560,7 +560,10 @@ impl<'ver, 'opts, 'buf> TableVerifier<'ver, 'opts, 'buf> { )?; Ok(self) } - _ => InvalidFlatbuffer::new_inconsistent_union(key_field_name.into(), val_field_name.into()), + _ => InvalidFlatbuffer::new_inconsistent_union( + key_field_name.into(), + val_field_name.into(), + ), } } pub fn finish(self) -> &'ver mut Verifier<'opts, 'buf> { diff --git a/rust/reflection/src/lib.rs b/rust/reflection/src/lib.rs index f2ef413ae71..e1fbd860cdc 100644 --- a/rust/reflection/src/lib.rs +++ b/rust/reflection/src/lib.rs @@ -16,10 +16,13 @@ mod reflection_generated; mod reflection_verifier; -mod safe_buffer; +pub mod safe_buffer; +mod vector_of_any; +pub use vector_of_any::VectorOfAny; mod r#struct; pub use crate::r#struct::Struct; pub use crate::reflection_generated::reflection; +pub use crate::reflection_verifier::verify_with_options; pub use crate::safe_buffer::SafeBuffer; use flatbuffers::{ @@ -50,8 +53,12 @@ pub enum FlatbufferError { SetStringPolluted, #[error("Invalid schema: Polluted buffer or the schema doesn't match the buffer.")] InvalidSchema, - #[error("Type not supported: {0}")] - TypeNotSupported(String), + #[error("Unsupported table field type: {0:?}")] + UnsupportedTableFieldType(BaseType), + #[error("Unsupported vector element type: {0:?}")] + UnsupportedVectorElementType(BaseType), + #[error("Unsupported union element type: {0:?}")] + UnsupportedUnionElementType(BaseType), #[error("No type or invalid type found in union enum")] InvalidUnionEnum, #[error("Table or Struct doesn't belong to the buffer")] @@ -80,21 +87,17 @@ pub unsafe fn get_any_root(data: &[u8]) -> Table { pub unsafe fn get_field_integer Follow<'a, Inner = T> + PrimInt + FromPrimitive>( table: &Table, field: &Field, -) -> FlatbufferResult> { - if size_of::() != get_type_size(field.type_().base_type()) { - return Err(FlatbufferError::FieldTypeMismatch( - std::any::type_name::().to_string(), - field - .type_() - .base_type() - .variant_name() - .unwrap_or_default() - .to_string(), - )); - } +) -> Option { + debug_assert_eq!( + size_of::(), + get_type_size(field.type_().base_type()), + "Type size mismatch: {} vs {}", + std::any::type_name::(), + field.type_().base_type().variant_name().unwrap_or_default() + ); let default = T::from_i64(field.default_integer()); - Ok(table.get::(field.offset(), default)) + table.get::(field.offset(), default) } /// Gets a floating point table field given its exact type. Returns default float value if the field is not set. Returns [None] if no default value is found. Returns error if the type doesn't match. @@ -105,21 +108,17 @@ pub unsafe fn get_field_integer Follow<'a, Inner = T> + PrimInt + Fro pub unsafe fn get_field_float Follow<'a, Inner = T> + Float>( table: &Table, field: &Field, -) -> FlatbufferResult> { - if size_of::() != get_type_size(field.type_().base_type()) { - return Err(FlatbufferError::FieldTypeMismatch( - std::any::type_name::().to_string(), - field - .type_() - .base_type() - .variant_name() - .unwrap_or_default() - .to_string(), - )); - } +) -> Option { + debug_assert_eq!( + size_of::(), + get_type_size(field.type_().base_type()), + "Type size mismatch: {} vs {}", + std::any::type_name::(), + field.type_().base_type().variant_name().unwrap_or_default() + ); let default = T::from(field.default_real()); - Ok(table.get::(field.offset(), default)) + table.get::(field.offset(), default) } /// Gets a String table field given its exact type. Returns empty string if the field is not set. Returns [None] if no default value is found. Returns error if the type size doesn't match. @@ -127,99 +126,80 @@ pub unsafe fn get_field_float Follow<'a, Inner = T> + Float>( /// # Safety /// /// The value of the corresponding slot must have type String -pub unsafe fn get_field_string<'a>( - table: &Table<'a>, - field: &Field, -) -> FlatbufferResult> { - if field.type_().base_type() != BaseType::String { - return Err(FlatbufferError::FieldTypeMismatch( - String::from("String"), - field - .type_() - .base_type() - .variant_name() - .unwrap_or_default() - .to_string(), - )); - } - - Ok(table.get::>(field.offset(), Some(""))) +pub unsafe fn get_field_string<'a>(table: &Table<'a>, field: &Field) -> &'a str { + debug_assert_eq!(field.type_().base_type(), BaseType::String); + table + .get::>(field.offset(), Some("")) + .unwrap() } -/// Gets a [Struct] table field given its exact type. Returns [None] if the field is not set. Returns error if the type doesn't match. +/// Gets a [Struct] table field given its exact type. Returns [None] if the field is not set. /// /// # Safety /// /// The value of the corresponding slot must have type Struct -pub unsafe fn get_field_struct<'a>( - table: &Table<'a>, - field: &Field, -) -> FlatbufferResult>> { +pub unsafe fn get_field_struct<'a>(table: &Table<'a>, field: &Field) -> Option> { // TODO inherited from C++: This does NOT check if the field is a table or struct, but we'd need // access to the schema to check the is_struct flag. - if field.type_().base_type() != BaseType::Obj { - return Err(FlatbufferError::FieldTypeMismatch( - String::from("Obj"), - field - .type_() - .base_type() - .variant_name() - .unwrap_or_default() - .to_string(), - )); - } + debug_assert_eq!(field.type_().base_type(), BaseType::Obj); - Ok(table.get::(field.offset(), None)) + table.get::(field.offset(), None) } -/// Gets a Vector table field given its exact type. Returns empty vector if the field is not set. Returns error if the type doesn't match. +/// Get a vector table field, whose elements have type `T`. +/// +/// Returns an empty vector if the field is not set. +/// +/// Non-scalar values such as tables or strings are not stored inline. In such cases, you should use +/// `ForwardsUOffset`. So for example, use `T = ForwardsUOffset>` for a vector of tables, +/// or `T = ForwardsUOffset<&'a str>` for a vector of strings. /// /// # Safety /// -/// The value of the corresponding slot must have type Vector -pub unsafe fn get_field_vector<'a, T: Follow<'a, Inner = T>>( +/// The value of the corresponding slot must be a vector of elements of type `T` which are stored +/// inline. +pub unsafe fn get_field_vector<'a, T: Follow<'a>>( table: &Table<'a>, field: &Field, -) -> FlatbufferResult>> { - if field.type_().base_type() != BaseType::Vector - || core::mem::size_of::() != get_type_size(field.type_().element()) - { - return Err(FlatbufferError::FieldTypeMismatch( - std::any::type_name::().to_string(), - field - .type_() - .base_type() - .variant_name() - .unwrap_or_default() - .to_string(), - )); +) -> Vector<'a, T> { + debug_assert_eq!(field.type_().base_type(), BaseType::Vector); + if field.type_().element() != BaseType::Obj { + // Skip this check in the case that it is a vector of structs, because the struct element's + // size cannot be checked on older schema versions without access to the schema. + debug_assert_eq!( + core::mem::size_of::(), + get_type_size(field.type_().element()) + ); } - Ok(table.get::>>(field.offset(), Some(Vector::::default()))) + table + .get::>>(field.offset(), Some(Vector::::default())) + .unwrap() } -/// Gets a Table table field given its exact type. Returns [None] if the field is not set. Returns error if the type doesn't match. +/// Get a vector table field, whose elements have unknown type. +/// +/// Returns an empty vector if the field is not set. +/// +/// # Safety +/// +/// The value of the corresponding slot must be a vector of elements of type `T`. +pub unsafe fn get_field_vector_of_any<'a>(table: &Table<'a>, field: &Field) -> VectorOfAny<'a> { + debug_assert_eq!(field.type_().base_type(), BaseType::Vector); + table + .get::>>(field.offset(), Some(VectorOfAny::default())) + .unwrap() +} + +/// Gets a Table table field given its exact type. Returns [None] if the field is not set. /// /// # Safety /// /// The value of the corresponding slot must have type Table -pub unsafe fn get_field_table<'a>( - table: &Table<'a>, - field: &Field, -) -> FlatbufferResult>> { - if field.type_().base_type() != BaseType::Obj { - return Err(FlatbufferError::FieldTypeMismatch( - String::from("Obj"), - field - .type_() - .base_type() - .variant_name() - .unwrap_or_default() - .to_string(), - )); - } +pub unsafe fn get_field_table<'a>(table: &Table<'a>, field: &Field) -> Option> { + debug_assert_eq!(field.type_().base_type(), BaseType::Obj); - Ok(table.get::>>(field.offset(), None)) + table.get::>>(field.offset(), None) } /// Returns the value of any table field as a 64-bit int, regardless of what type it is. Returns default integer if the field is not set or error if the value cannot be parsed as integer. @@ -273,25 +253,12 @@ pub unsafe fn get_any_field_string(table: &Table, field: &Field, schema: &Schema /// # Safety /// /// The value of the corresponding slot must have type Struct. -pub unsafe fn get_field_struct_in_struct<'a>( - st: &Struct<'a>, - field: &Field, -) -> FlatbufferResult> { +pub unsafe fn get_field_struct_in_struct<'a>(st: &Struct<'a>, field: &Field) -> Struct<'a> { // TODO inherited from C++: This does NOT check if the field is a table or struct, but we'd need // access to the schema to check the is_struct flag. - if field.type_().base_type() != BaseType::Obj { - return Err(FlatbufferError::FieldTypeMismatch( - String::from("Obj"), - field - .type_() - .base_type() - .variant_name() - .unwrap_or_default() - .to_string(), - )); - } + debug_assert_eq!(field.type_().base_type(), BaseType::Obj); - Ok(st.get::(field.offset() as usize)) + st.get::(field.offset() as usize) } /// Returns the value of any struct field as a 64-bit int, regardless of what type it is. Returns error if the value cannot be parsed as integer. @@ -576,7 +543,7 @@ pub unsafe fn set_string( } /// Returns the size of a scalar type in the `BaseType` enum. In the case of structs, returns the size of their offset (`UOffsetT`) in the buffer. -fn get_type_size(base_type: BaseType) -> usize { +pub fn get_type_size(base_type: BaseType) -> usize { match base_type { BaseType::UType | BaseType::Bool | BaseType::Byte | BaseType::UByte => 1, BaseType::Short | BaseType::UShort => 2, diff --git a/rust/reflection/src/reflection_verifier.rs b/rust/reflection/src/reflection_verifier.rs index 43739147783..575ba5713ef 100644 --- a/rust/reflection/src/reflection_verifier.rs +++ b/rust/reflection/src/reflection_verifier.rs @@ -28,10 +28,17 @@ pub fn verify_with_options( schema: &Schema, opts: &VerifierOptions, buf_loc_to_obj_idx: &mut HashMap, + root_table_name: Option<&str>, ) -> FlatbufferResult<()> { let mut verifier = Verifier::new(opts, buffer); - if let Some(table_object) = schema.root_table() { - if let core::result::Result::Ok(table_pos) = verifier.get_uoffset(0) { + let root_table = match root_table_name { + Some(name) => schema + .objects() + .lookup_by_key(name, |o, k| o.key_compare_with_value(k)), + None => schema.root_table(), + }; + if let Some(table_object) = root_table { + if let Ok(table_pos) = verifier.get_uoffset(0) { // Inserts -1 as object index for root table buf_loc_to_obj_idx.insert(table_pos.try_into()?, -1); let mut verified = vec![false; buffer.len()]; @@ -62,7 +69,8 @@ fn verify_table( let mut table_verifier = verifier.visit_table(table_pos)?; - for field in &table_object.fields() { + let fields = table_object.fields(); + for field in &fields { let field_name = field.name().to_owned(); table_verifier = match field.type_().base_type() { BaseType::UType | BaseType::UByte => { @@ -154,15 +162,8 @@ fn verify_table( table_verifier } } - _ => { - return Err(FlatbufferError::TypeNotSupported( - field - .type_() - .base_type() - .variant_name() - .unwrap_or_default() - .to_string(), - )); + other => { + return Err(FlatbufferError::UnsupportedTableFieldType(other)); } }; } @@ -202,91 +203,30 @@ fn verify_vector<'a, 'b, 'c>( buf_loc_to_obj_idx: &mut HashMap, ) -> FlatbufferResult> { let field_name = field.name().to_owned(); + macro_rules! visit_vec { + ($t:ty) => { + table_verifier + .visit_field::>>( + field_name, + field.offset(), + field.required(), + ) + .map_err(FlatbufferError::VerificationError) + }; + } match field.type_().element() { - BaseType::UType | BaseType::UByte => table_verifier - .visit_field::>>( - field_name, - field.offset(), - field.required(), - ) - .map_err(FlatbufferError::VerificationError), - BaseType::Bool => table_verifier - .visit_field::>>( - field_name, - field.offset(), - field.required(), - ) - .map_err(FlatbufferError::VerificationError), - BaseType::Byte => table_verifier - .visit_field::>>( - field_name, - field.offset(), - field.required(), - ) - .map_err(FlatbufferError::VerificationError), - BaseType::Short => table_verifier - .visit_field::>>( - field_name, - field.offset(), - field.required(), - ) - .map_err(FlatbufferError::VerificationError), - BaseType::UShort => table_verifier - .visit_field::>>( - field_name, - field.offset(), - field.required(), - ) - .map_err(FlatbufferError::VerificationError), - BaseType::Int => table_verifier - .visit_field::>>( - field_name, - field.offset(), - field.required(), - ) - .map_err(FlatbufferError::VerificationError), - BaseType::UInt => table_verifier - .visit_field::>>( - field_name, - field.offset(), - field.required(), - ) - .map_err(FlatbufferError::VerificationError), - BaseType::Long => table_verifier - .visit_field::>>( - field_name, - field.offset(), - field.required(), - ) - .map_err(FlatbufferError::VerificationError), - BaseType::ULong => table_verifier - .visit_field::>>( - field_name, - field.offset(), - field.required(), - ) - .map_err(FlatbufferError::VerificationError), - BaseType::Float => table_verifier - .visit_field::>>( - field_name, - field.offset(), - field.required(), - ) - .map_err(FlatbufferError::VerificationError), - BaseType::Double => table_verifier - .visit_field::>>( - field_name, - field.offset(), - field.required(), - ) - .map_err(FlatbufferError::VerificationError), - BaseType::String => table_verifier - .visit_field::>>>( - field_name, - field.offset(), - field.required(), - ) - .map_err(FlatbufferError::VerificationError), + BaseType::UType | BaseType::UByte => visit_vec!(u8), + BaseType::Bool => visit_vec!(bool), + BaseType::Byte => visit_vec!(i8), + BaseType::Short => visit_vec!(i16), + BaseType::UShort => visit_vec!(u16), + BaseType::Int => visit_vec!(i32), + BaseType::UInt => visit_vec!(u32), + BaseType::Long => visit_vec!(i64), + BaseType::ULong => visit_vec!(u64), + BaseType::Float => visit_vec!(f32), + BaseType::Double => visit_vec!(f64), + BaseType::String => visit_vec!(ForwardsUOffset<&str>), BaseType::Obj => { if let Some(field_pos) = table_verifier.deref(field.offset())? { let verifier = table_verifier.verifier(); @@ -340,17 +280,154 @@ fn verify_vector<'a, 'b, 'c>( } Ok(table_verifier) } + BaseType::Union => { + verify_vector_of_unions(table_verifier, field, schema, verified, buf_loc_to_obj_idx) + } + other => { + return Err(FlatbufferError::UnsupportedVectorElementType(other)); + } + } +} + +fn verify_vector_of_unions<'a, 'b, 'c>( + mut table_verifier: TableVerifier<'a, 'b, 'c>, + field: &Field, + schema: &Schema, + verified: &mut [bool], + buf_loc_to_obj_idx: &mut HashMap, +) -> FlatbufferResult> { + let field_type = field.type_(); + // If the schema is valid, none of these asserts can fail. + debug_assert_eq!(field_type.base_type(), BaseType::Vector); + debug_assert_eq!(field_type.element(), BaseType::Union); + let child_enum_idx = field_type.index(); + let child_enum = schema.enums().get(child_enum_idx.try_into()?); + debug_assert!(!child_enum.values().is_empty()); + + // Assuming the schema is valid, the previous field must be the enum vector, which consists of + // of 1-byte enums. + let enum_field_offset = field + .offset() + .checked_sub(u16::try_from(SIZE_VOFFSET).unwrap()) + .ok_or(FlatbufferError::InvalidUnionEnum)?; + + // Either both vectors must be present, or both must be absent. + let (value_field_pos, enum_field_pos) = match ( + table_verifier.deref(field.offset())?, + table_verifier.deref(enum_field_offset)?, + ) { + (Some(value_field_pos), Some(enum_field_pos)) => (value_field_pos, enum_field_pos), + (None, None) => { + if field.required() { + return InvalidFlatbuffer::new_missing_required(field.name().to_owned())?; + } else { + return Ok(table_verifier); + } + } _ => { - return Err(FlatbufferError::TypeNotSupported( - field - .type_() - .base_type() - .variant_name() - .unwrap_or_default() - .to_string(), - )) + return InvalidFlatbuffer::new_inconsistent_union( + format!("{}_type", field.name()), + field.name().to_owned(), + )?; } + }; + + let verifier = table_verifier.verifier(); + let enum_vector_offset = verifier.get_uoffset(enum_field_pos)?; + let enum_vector_pos = enum_field_pos.saturating_add(enum_vector_offset.try_into()?); + let enum_vector_len = verifier.get_uoffset(enum_vector_pos)?; + let enum_vector_start = enum_vector_pos.saturating_add(SIZE_UOFFSET); + + let value_vector_offset = verifier.get_uoffset(value_field_pos)?; + let value_vector_pos = value_field_pos.saturating_add(value_vector_offset.try_into()?); + let value_vector_len = verifier.get_uoffset(value_vector_pos)?; + let value_vector_start = value_vector_pos.saturating_add(SIZE_UOFFSET); + + // Both vectors should have the same length. + // The C++ verifier instead assumes that the length of the value vector is the length of the enum vector: + // https://github.com/google/flatbuffers/blob/bd1b2d0bafb8be6059a29487db9e5ace5c32914d/src/reflection.cpp#L147-L162 + // This has been reported at https://github.com/google/flatbuffers/issues/8567 + if enum_vector_len != value_vector_len { + return InvalidFlatbuffer::new_inconsistent_union( + format!("{}_type", field.name()), + field.name().to_owned(), + )?; } + + // Regardless of its contents, the value vector in a vector of unions must be a vector of + // offsets. Source: https://github.com/dvidelabs/flatcc/blob/master/doc/binary-format.md#unions + verifier.is_aligned::(value_vector_start)?; + let value_vector_size = value_vector_len.saturating_mul(SIZE_UOFFSET.try_into()?); + verifier.range_in_buffer(value_vector_start, value_vector_size.try_into()?)?; + let value_vector_range = core::ops::Range { + start: value_vector_start, + end: value_vector_start.saturating_add(value_vector_size.try_into()?), + }; + + // The enums must have a size of 1 byte, so we just use the length of the vector. + let enum_vector_size = enum_vector_len; + verifier.range_in_buffer(enum_vector_start, enum_vector_size.try_into()?)?; + let enum_vector_range = core::ops::Range { + start: enum_vector_start, + end: enum_vector_start.saturating_add(enum_vector_size.try_into()?), + }; + + let enum_values = child_enum.values(); + + for (enum_pos, union_offset_pos) in + enum_vector_range.zip(value_vector_range.step_by(SIZE_UOFFSET)) + { + let enum_value = verifier.get_u8(enum_pos)?; + if enum_value == 0 { + // Discriminator is NONE. This should never happen: the C++ implementation forbids it. + // For example, the C++ JSON parser forbids the type entry to be NONE for a vector of + // unions: in + // https://github.com/google/flatbuffers/blob/bd1b2d0bafb8be6059a29487db9e5ace5c32914d/src/idl_parser.cpp#L1383 + // the second argument is 'true', meaning that the NONE entry is skipped for the reverse + // lookup. + // + // This should possibly be an error, but we can be forgiving and just ignore the + // corresponding union entry entirely. + continue; + } + + let enum_value: usize = enum_value.into(); + let enum_type = if enum_value < enum_values.len() { + enum_values.get(enum_value).union_type().expect( + "Schema verification should have checked that every union enum value has a type", + ) + } else { + return Err(FlatbufferError::InvalidUnionEnum); + }; + let union_pos = + union_offset_pos.saturating_add(verifier.get_uoffset(union_offset_pos)?.try_into()?); + verifier.in_buffer::(union_pos)?; + + match enum_type.base_type() { + BaseType::String => <&str>::run_verifier(verifier, union_pos)?, + BaseType::Obj => { + let child_obj = schema.objects().get(enum_type.index().try_into()?); + buf_loc_to_obj_idx.insert(union_pos, enum_type.index()); + if child_obj.is_struct() { + verify_struct(verifier, &child_obj, union_pos, schema, buf_loc_to_obj_idx)? + } else { + verify_table( + verifier, + &child_obj, + union_pos, + schema, + verified, + buf_loc_to_obj_idx, + )?; + } + } + other => { + return Err(FlatbufferError::UnsupportedUnionElementType(other)); + } + } + verified[union_pos] = true; + } + Ok(table_verifier) } fn verify_union<'a, 'b, 'c>( @@ -399,14 +476,8 @@ fn verify_union<'a, 'b, 'c>( )?; } } - _ => { - return Err(FlatbufferError::TypeNotSupported( - enum_type - .base_type() - .variant_name() - .unwrap_or_default() - .to_string(), - )) + other => { + return Err(FlatbufferError::UnsupportedUnionElementType(other)); } } } else { diff --git a/rust/reflection/src/safe_buffer.rs b/rust/reflection/src/safe_buffer.rs index ba61e87301f..66abe0c43fd 100644 --- a/rust/reflection/src/safe_buffer.rs +++ b/rust/reflection/src/safe_buffer.rs @@ -15,13 +15,14 @@ */ use crate::r#struct::Struct; +use crate::reflection::BaseType; use crate::reflection_generated::reflection::{Field, Schema}; use crate::reflection_verifier::verify_with_options; use crate::{ get_any_field_float, get_any_field_float_in_struct, get_any_field_integer, get_any_field_integer_in_struct, get_any_field_string, get_any_field_string_in_struct, get_any_root, get_field_float, get_field_integer, get_field_string, get_field_struct, - get_field_struct_in_struct, get_field_table, get_field_vector, FlatbufferError, + get_field_struct_in_struct, get_field_table, get_field_vector, get_type_size, FlatbufferError, FlatbufferResult, ForwardsUOffset, }; use flatbuffers::{Follow, Table, Vector, VerifierOptions}; @@ -48,7 +49,7 @@ impl<'a> SafeBuffer<'a> { opts: &VerifierOptions, ) -> FlatbufferResult { let mut buf_loc_to_obj_idx = HashMap::new(); - verify_with_options(&buf, schema, opts, &mut buf_loc_to_obj_idx)?; + verify_with_options(&buf, schema, opts, &mut buf_loc_to_obj_idx, None)?; Ok(SafeBuffer { buf, schema, @@ -102,15 +103,18 @@ pub struct SafeTable<'a> { impl<'a> SafeTable<'a> { /// Gets an integer table field given its exact type. Returns default integer value if the field is not set. Returns [None] if no default value is found. Returns error if /// the table doesn't match the buffer or - /// the [field_name] doesn't match the table or - /// the field type doesn't match. + /// the [field_name] doesn't match the table. pub fn get_field_integer Follow<'b, Inner = T> + PrimInt + FromPrimitive>( &self, field_name: &str, ) -> FlatbufferResult> { if let Some(field) = self.safe_buf.find_field_by_name(self.loc, field_name)? { // SAFETY: the buffer was verified during construction. - unsafe { get_field_integer::(&Table::new(&self.safe_buf.buf, self.loc), &field) } + Ok( + unsafe { + get_field_integer::(&Table::new(&self.safe_buf.buf, self.loc), &field) + }, + ) } else { Err(FlatbufferError::FieldNotFound) } @@ -118,15 +122,14 @@ impl<'a> SafeTable<'a> { /// Gets a floating point table field given its exact type. Returns default float value if the field is not set. Returns [None] if no default value is found. Returns error if /// the table doesn't match the buffer or - /// the [field_name] doesn't match the table or - /// the field type doesn't match. + /// the [field_name] doesn't match the table. pub fn get_field_float Follow<'b, Inner = T> + Float>( &self, field_name: &str, ) -> FlatbufferResult> { if let Some(field) = self.safe_buf.find_field_by_name(self.loc, field_name)? { // SAFETY: the buffer was verified during construction. - unsafe { get_field_float::(&Table::new(&self.safe_buf.buf, self.loc), &field) } + Ok(unsafe { get_field_float::(&Table::new(&self.safe_buf.buf, self.loc), &field) }) } else { Err(FlatbufferError::FieldNotFound) } @@ -134,12 +137,13 @@ impl<'a> SafeTable<'a> { /// Gets a String table field given its exact type. Returns empty string if the field is not set. Returns [None] if no default value is found. Returns error if /// the table doesn't match the buffer or - /// the [field_name] doesn't match the table or - /// the field type doesn't match. + /// the [field_name] doesn't match the table. pub fn get_field_string(&self, field_name: &str) -> FlatbufferResult> { if let Some(field) = self.safe_buf.find_field_by_name(self.loc, field_name)? { // SAFETY: the buffer was verified during construction. - unsafe { get_field_string(&Table::new(&self.safe_buf.buf, self.loc), &field) } + Ok(Some(unsafe { + get_field_string(&Table::new(&self.safe_buf.buf, self.loc), &field) + })) } else { Err(FlatbufferError::FieldNotFound) } @@ -147,13 +151,12 @@ impl<'a> SafeTable<'a> { /// Gets a [SafeStruct] table field given its exact type. Returns [None] if the field is not set. Returns error if /// the table doesn't match the buffer or - /// the [field_name] doesn't match the table or - /// the field type doesn't match. + /// the [field_name] doesn't match the table. pub fn get_field_struct(&self, field_name: &str) -> FlatbufferResult>> { if let Some(field) = self.safe_buf.find_field_by_name(self.loc, field_name)? { // SAFETY: the buffer was verified during construction. let optional_st = - unsafe { get_field_struct(&Table::new(&self.safe_buf.buf, self.loc), &field)? }; + unsafe { get_field_struct(&Table::new(&self.safe_buf.buf, self.loc), &field) }; Ok(optional_st.map(|st| SafeStruct { safe_buf: self.safe_buf, loc: st.loc(), @@ -167,13 +170,26 @@ impl<'a> SafeTable<'a> { /// the table doesn't match the buffer or /// the [field_name] doesn't match the table or /// the field type doesn't match. - pub fn get_field_vector>( + pub fn get_field_vector>( &self, field_name: &str, - ) -> FlatbufferResult>> { + ) -> FlatbufferResult> { if let Some(field) = self.safe_buf.find_field_by_name(self.loc, field_name)? { + if field.type_().base_type() != BaseType::Vector + || core::mem::size_of::() != get_type_size(field.type_().element()) + { + return Err(FlatbufferError::FieldTypeMismatch( + "Vector".to_string(), + field + .type_() + .base_type() + .variant_name() + .unwrap_or_default() + .to_string(), + )); + } // SAFETY: the buffer was verified during construction. - unsafe { get_field_vector(&Table::new(&self.safe_buf.buf, self.loc), &field) } + Ok(unsafe { get_field_vector(&Table::new(&self.safe_buf.buf, self.loc), &field) }) } else { Err(FlatbufferError::FieldNotFound) } @@ -181,13 +197,31 @@ impl<'a> SafeTable<'a> { /// Gets a [SafeTable] table field given its exact type. Returns [None] if the field is not set. Returns error if /// the table doesn't match the buffer or - /// the [field_name] doesn't match the table or - /// the field type doesn't match. + /// the [field_name] doesn't match the table. pub fn get_field_table(&self, field_name: &str) -> FlatbufferResult>> { if let Some(field) = self.safe_buf.find_field_by_name(self.loc, field_name)? { + if field.type_().base_type() != BaseType::Obj + || self + .safe_buf + .schema + .objects() + .get(field.type_().index() as usize) + .is_struct() + { + return Err(FlatbufferError::FieldTypeMismatch( + "Table".to_string(), + field + .type_() + .base_type() + .variant_name() + .unwrap_or_default() + .to_string(), + )); + } + // SAFETY: the buffer was verified during construction. let optional_table = - unsafe { get_field_table(&Table::new(&self.safe_buf.buf, self.loc), &field)? }; + unsafe { get_field_table(&Table::new(&self.safe_buf.buf, self.loc), &field) }; Ok(optional_table.map(|t| SafeTable { safe_buf: self.safe_buf, loc: t.loc(), @@ -252,13 +286,12 @@ pub struct SafeStruct<'a> { impl<'a> SafeStruct<'a> { /// Gets a [SafeStruct] struct field given its exact type. Returns error if /// the struct doesn't match the buffer or - /// the [field_name] doesn't match the struct or - /// the field type doesn't match. + /// the [field_name] doesn't match the struct. pub fn get_field_struct(&self, field_name: &str) -> FlatbufferResult> { if let Some(field) = self.safe_buf.find_field_by_name(self.loc, field_name)? { // SAFETY: the buffer was verified during construction. let st = unsafe { - get_field_struct_in_struct(&Struct::new(&self.safe_buf.buf, self.loc), &field)? + get_field_struct_in_struct(&Struct::new(&self.safe_buf.buf, self.loc), &field) }; Ok(SafeStruct { safe_buf: self.safe_buf, diff --git a/rust/reflection/src/struct.rs b/rust/reflection/src/struct.rs index 0952fafd4fa..3c6aab12b91 100644 --- a/rust/reflection/src/struct.rs +++ b/rust/reflection/src/struct.rs @@ -33,6 +33,11 @@ impl<'a> Struct<'a> { self.loc } + #[inline] + pub fn bytes(&self) -> &'a [u8] { + &self.buf[self.loc..] + } + /// # Safety /// /// [buf] must contain a valid struct at [loc] diff --git a/rust/reflection/src/vector_of_any.rs b/rust/reflection/src/vector_of_any.rs new file mode 100644 index 00000000000..b8e8eacf19d --- /dev/null +++ b/rust/reflection/src/vector_of_any.rs @@ -0,0 +1,76 @@ +use flatbuffers::{read_scalar_at, Follow, UOffsetT, Vector, SIZE_UOFFSET}; + +use crate::reflection::Object; + +pub struct VectorOfAny<'a> { + buf: &'a [u8], + loc: usize, +} + +impl<'a> VectorOfAny<'a> { + /// # Safety + /// + /// [buf] must contain a valid vector at [loc] + #[inline] + pub unsafe fn new(buf: &'a [u8], loc: usize) -> Self { + Self { buf, loc } + } + + #[inline(always)] + pub fn len(&self) -> usize { + // Safety: + // Valid vector at time of construction starting with UOffsetT element count + unsafe { read_scalar_at::(self.buf, self.loc) as usize } + } + + /// Get a slice of all the bytes from the start of the vector to the *end of the buffer*. + /// + /// We don't know the size of the elements in the vector, so we can't return just its contents. + #[inline(always)] + pub fn bytes(&self) -> &'a [u8] { + &self.buf[self.loc + SIZE_UOFFSET..] + } + + /// Get a slice of bytes corresponding to the struct at [index], assuming this is a vector of + /// structs of type [obj]. + /// + /// # Panics + /// + /// Panics if [index] is out of bounds or if [obj] is not a struct. + #[inline(always)] + pub fn get_struct(&self, index: usize, obj: Object<'_>) -> &'a [u8] { + assert!(obj.is_struct()); + assert!(index < self.len()); + let bytesize = obj.bytesize() as usize; + let start = self.loc + SIZE_UOFFSET + index * bytesize; + let end = start + bytesize; + assert!(end <= self.buf.len()); + &self.buf[start..end] + } + + /// # Safety + /// + /// [buf] must contain a valid vector at [loc] + pub unsafe fn as_vector>(&self) -> Vector { + Vector::new(self.buf, self.loc) + } +} + +impl<'a> Follow<'a> for VectorOfAny<'a> { + type Inner = VectorOfAny<'a>; + + #[inline(always)] + unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { + Self { buf, loc } + } +} + +impl Default for VectorOfAny<'_> { + fn default() -> Self { + // Need to include length prefix, even if it is just zero. + Self { + buf: &[0; core::mem::size_of::()], + loc: 0, + } + } +}