Skip to content

[Rust] Allow verification of buffers from specified table #8575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions rust/flatbuffers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ pub use crate::push::{Push, PushAlignment};
pub use crate::table::{buffer_has_identifier, Table};
pub use crate::vector::{follow_cast_ref, Vector, VectorIter};
pub use crate::verifier::{
ErrorTraceDetail, InvalidFlatbuffer, SimpleToVerifyInSlice, TableVerifier, Verifiable, Verifier,
VerifierOptions,
ErrorTraceDetail, InvalidFlatbuffer, SimpleToVerifyInSlice, TableVerifier, Verifiable,
Verifier, VerifierOptions,
};
pub use crate::vtable::field_index_to_field_offset;
pub use bitflags;
Expand Down
18 changes: 13 additions & 5 deletions rust/flatbuffers/src/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ impl<'a, T: 'a> Vector<'a, T> {
let len = self.len();
&self.0[self.1 + SIZE_UOFFSET..self.1 + SIZE_UOFFSET + sz * len]
}

/// # Safety
///
/// The underlying bytes must be interpretable as a vector of the *same number* of `U`'s.
#[inline(always)]
pub unsafe fn cast<U: 'a>(&self) -> Vector<'a, U> {
Vector::new(self.0, self.1)
}
}

impl<'a, T: Follow<'a> + 'a> Vector<'a, T> {
Expand Down Expand Up @@ -123,11 +131,11 @@ impl<'a, T: Follow<'a> + 'a> Vector<'a, T> {
Ordering::Equal => return Some(value),
Ordering::Less => left = mid + 1,
Ordering::Greater => {
if mid == 0 {
return None;
}
right = mid - 1;
},
if mid == 0 {
return None;
}
right = mid - 1;
}
}
}

Expand Down
5 changes: 4 additions & 1 deletion rust/flatbuffers/src/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,10 @@ impl<'ver, 'opts, 'buf> TableVerifier<'ver, 'opts, 'buf> {
)?;
Ok(self)
}
_ => InvalidFlatbuffer::new_inconsistent_union(key_field_name.into(), val_field_name.into()),
_ => InvalidFlatbuffer::new_inconsistent_union(
key_field_name.into(),
val_field_name.into(),
),
}
}
pub fn finish(self) -> &'ver mut Verifier<'opts, 'buf> {
Expand Down
195 changes: 81 additions & 114 deletions rust/reflection/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

mod reflection_generated;
mod reflection_verifier;
mod safe_buffer;
pub mod safe_buffer;
mod vector_of_any;
pub use vector_of_any::VectorOfAny;
mod r#struct;
pub use crate::r#struct::Struct;
pub use crate::reflection_generated::reflection;
pub use crate::reflection_verifier::verify_with_options;
pub use crate::safe_buffer::SafeBuffer;

use flatbuffers::{
Expand Down Expand Up @@ -50,8 +53,12 @@ pub enum FlatbufferError {
SetStringPolluted,
#[error("Invalid schema: Polluted buffer or the schema doesn't match the buffer.")]
InvalidSchema,
#[error("Type not supported: {0}")]
TypeNotSupported(String),
#[error("Unsupported table field type: {0:?}")]
UnsupportedTableFieldType(BaseType),
#[error("Unsupported vector element type: {0:?}")]
UnsupportedVectorElementType(BaseType),
#[error("Unsupported union element type: {0:?}")]
UnsupportedUnionElementType(BaseType),
#[error("No type or invalid type found in union enum")]
InvalidUnionEnum,
#[error("Table or Struct doesn't belong to the buffer")]
Expand Down Expand Up @@ -80,21 +87,17 @@ pub unsafe fn get_any_root(data: &[u8]) -> Table {
pub unsafe fn get_field_integer<T: for<'a> Follow<'a, Inner = T> + PrimInt + FromPrimitive>(
table: &Table,
field: &Field,
) -> FlatbufferResult<Option<T>> {
if size_of::<T>() != get_type_size(field.type_().base_type()) {
return Err(FlatbufferError::FieldTypeMismatch(
std::any::type_name::<T>().to_string(),
field
.type_()
.base_type()
.variant_name()
.unwrap_or_default()
.to_string(),
));
}
) -> Option<T> {
debug_assert_eq!(
size_of::<T>(),
get_type_size(field.type_().base_type()),
"Type size mismatch: {} vs {}",
std::any::type_name::<T>(),
field.type_().base_type().variant_name().unwrap_or_default()
);

let default = T::from_i64(field.default_integer());
Ok(table.get::<T>(field.offset(), default))
table.get::<T>(field.offset(), default)
}

/// Gets a floating point table field given its exact type. Returns default float value if the field is not set. Returns [None] if no default value is found. Returns error if the type doesn't match.
Expand All @@ -105,121 +108,98 @@ pub unsafe fn get_field_integer<T: for<'a> Follow<'a, Inner = T> + PrimInt + Fro
pub unsafe fn get_field_float<T: for<'a> Follow<'a, Inner = T> + Float>(
table: &Table,
field: &Field,
) -> FlatbufferResult<Option<T>> {
if size_of::<T>() != get_type_size(field.type_().base_type()) {
return Err(FlatbufferError::FieldTypeMismatch(
std::any::type_name::<T>().to_string(),
field
.type_()
.base_type()
.variant_name()
.unwrap_or_default()
.to_string(),
));
}
) -> Option<T> {
debug_assert_eq!(
size_of::<T>(),
get_type_size(field.type_().base_type()),
"Type size mismatch: {} vs {}",
std::any::type_name::<T>(),
field.type_().base_type().variant_name().unwrap_or_default()
);

let default = T::from(field.default_real());
Ok(table.get::<T>(field.offset(), default))
table.get::<T>(field.offset(), default)
}

/// Gets a String table field given its exact type. Returns empty string if the field is not set. Returns [None] if no default value is found. Returns error if the type size doesn't match.
///
/// # Safety
///
/// The value of the corresponding slot must have type String
pub unsafe fn get_field_string<'a>(
table: &Table<'a>,
field: &Field,
) -> FlatbufferResult<Option<&'a str>> {
if field.type_().base_type() != BaseType::String {
return Err(FlatbufferError::FieldTypeMismatch(
String::from("String"),
field
.type_()
.base_type()
.variant_name()
.unwrap_or_default()
.to_string(),
));
}

Ok(table.get::<ForwardsUOffset<&'a str>>(field.offset(), Some("")))
pub unsafe fn get_field_string<'a>(table: &Table<'a>, field: &Field) -> &'a str {
debug_assert_eq!(field.type_().base_type(), BaseType::String);
table
.get::<ForwardsUOffset<&'a str>>(field.offset(), Some(""))
.unwrap()
}

/// Gets a [Struct] table field given its exact type. Returns [None] if the field is not set. Returns error if the type doesn't match.
/// Gets a [Struct] table field given its exact type. Returns [None] if the field is not set.
///
/// # Safety
///
/// The value of the corresponding slot must have type Struct
pub unsafe fn get_field_struct<'a>(
table: &Table<'a>,
field: &Field,
) -> FlatbufferResult<Option<Struct<'a>>> {
pub unsafe fn get_field_struct<'a>(table: &Table<'a>, field: &Field) -> Option<Struct<'a>> {
// TODO inherited from C++: This does NOT check if the field is a table or struct, but we'd need
// access to the schema to check the is_struct flag.
if field.type_().base_type() != BaseType::Obj {
return Err(FlatbufferError::FieldTypeMismatch(
String::from("Obj"),
field
.type_()
.base_type()
.variant_name()
.unwrap_or_default()
.to_string(),
));
}
debug_assert_eq!(field.type_().base_type(), BaseType::Obj);

Ok(table.get::<Struct>(field.offset(), None))
table.get::<Struct>(field.offset(), None)
}

/// Gets a Vector table field given its exact type. Returns empty vector if the field is not set. Returns error if the type doesn't match.
/// Get a vector table field, whose elements have type `T`.
///
/// Returns an empty vector if the field is not set.
///
/// Non-scalar values such as tables or strings are not stored inline. In such cases, you should use
/// `ForwardsUOffset`. So for example, use `T = ForwardsUOffset<Table<'a>>` for a vector of tables,
/// or `T = ForwardsUOffset<&'a str>` for a vector of strings.
///
/// # Safety
///
/// The value of the corresponding slot must have type Vector
pub unsafe fn get_field_vector<'a, T: Follow<'a, Inner = T>>(
/// The value of the corresponding slot must be a vector of elements of type `T` which are stored
/// inline.
pub unsafe fn get_field_vector<'a, T: Follow<'a>>(
table: &Table<'a>,
field: &Field,
) -> FlatbufferResult<Option<Vector<'a, T>>> {
if field.type_().base_type() != BaseType::Vector
|| core::mem::size_of::<T>() != get_type_size(field.type_().element())
{
return Err(FlatbufferError::FieldTypeMismatch(
std::any::type_name::<T>().to_string(),
field
.type_()
.base_type()
.variant_name()
.unwrap_or_default()
.to_string(),
));
) -> Vector<'a, T> {
debug_assert_eq!(field.type_().base_type(), BaseType::Vector);
if field.type_().element() != BaseType::Obj {
// Skip this check in the case that it is a vector of structs, because the struct element's
// size cannot be checked on older schema versions without access to the schema.
debug_assert_eq!(
core::mem::size_of::<T>(),
get_type_size(field.type_().element())
);
}

Ok(table.get::<ForwardsUOffset<Vector<'a, T>>>(field.offset(), Some(Vector::<T>::default())))
table
.get::<ForwardsUOffset<Vector<'a, T>>>(field.offset(), Some(Vector::<T>::default()))
.unwrap()
}

/// Gets a Table table field given its exact type. Returns [None] if the field is not set. Returns error if the type doesn't match.
/// Get a vector table field, whose elements have unknown type.
///
/// Returns an empty vector if the field is not set.
///
/// # Safety
///
/// The value of the corresponding slot must be a vector of elements of type `T`.
pub unsafe fn get_field_vector_of_any<'a>(table: &Table<'a>, field: &Field) -> VectorOfAny<'a> {
debug_assert_eq!(field.type_().base_type(), BaseType::Vector);
table
.get::<ForwardsUOffset<VectorOfAny<'a>>>(field.offset(), Some(VectorOfAny::default()))
.unwrap()
}

/// Gets a Table table field given its exact type. Returns [None] if the field is not set.
///
/// # Safety
///
/// The value of the corresponding slot must have type Table
pub unsafe fn get_field_table<'a>(
table: &Table<'a>,
field: &Field,
) -> FlatbufferResult<Option<Table<'a>>> {
if field.type_().base_type() != BaseType::Obj {
return Err(FlatbufferError::FieldTypeMismatch(
String::from("Obj"),
field
.type_()
.base_type()
.variant_name()
.unwrap_or_default()
.to_string(),
));
}
pub unsafe fn get_field_table<'a>(table: &Table<'a>, field: &Field) -> Option<Table<'a>> {
debug_assert_eq!(field.type_().base_type(), BaseType::Obj);

Ok(table.get::<ForwardsUOffset<Table<'a>>>(field.offset(), None))
table.get::<ForwardsUOffset<Table<'a>>>(field.offset(), None)
}

/// Returns the value of any table field as a 64-bit int, regardless of what type it is. Returns default integer if the field is not set or error if the value cannot be parsed as integer.
Expand Down Expand Up @@ -273,25 +253,12 @@ pub unsafe fn get_any_field_string(table: &Table, field: &Field, schema: &Schema
/// # Safety
///
/// The value of the corresponding slot must have type Struct.
pub unsafe fn get_field_struct_in_struct<'a>(
st: &Struct<'a>,
field: &Field,
) -> FlatbufferResult<Struct<'a>> {
pub unsafe fn get_field_struct_in_struct<'a>(st: &Struct<'a>, field: &Field) -> Struct<'a> {
// TODO inherited from C++: This does NOT check if the field is a table or struct, but we'd need
// access to the schema to check the is_struct flag.
if field.type_().base_type() != BaseType::Obj {
return Err(FlatbufferError::FieldTypeMismatch(
String::from("Obj"),
field
.type_()
.base_type()
.variant_name()
.unwrap_or_default()
.to_string(),
));
}
debug_assert_eq!(field.type_().base_type(), BaseType::Obj);

Ok(st.get::<Struct>(field.offset() as usize))
st.get::<Struct>(field.offset() as usize)
}

/// Returns the value of any struct field as a 64-bit int, regardless of what type it is. Returns error if the value cannot be parsed as integer.
Expand Down Expand Up @@ -576,7 +543,7 @@ pub unsafe fn set_string(
}

/// Returns the size of a scalar type in the `BaseType` enum. In the case of structs, returns the size of their offset (`UOffsetT`) in the buffer.
fn get_type_size(base_type: BaseType) -> usize {
pub fn get_type_size(base_type: BaseType) -> usize {
match base_type {
BaseType::UType | BaseType::Bool | BaseType::Byte | BaseType::UByte => 1,
BaseType::Short | BaseType::UShort => 2,
Expand Down
Loading