Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions conformance/failing_tests.txt
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
# TODO(tokio-rs/prost#2): prost doesn't preserve unknown fields.
Required.Proto2.ProtobufInput.UnknownVarint.ProtobufOutput
Required.Proto3.ProtobufInput.UnknownVarint.ProtobufOutput
13 changes: 13 additions & 0 deletions prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,11 @@ impl<'b> CodeGenerator<'_, 'b> {
}
self.path.pop();
}
if let Some(unknown_fields) = &self.config().include_unknown_fields {
if let Some(field_name) = unknown_fields.get_first(&fq_message_name).cloned() {
self.append_unknown_field_set(&fq_message_name, &field_name);
}
}
self.path.pop();

self.path.push(8);
Expand Down Expand Up @@ -581,6 +586,14 @@ impl<'b> CodeGenerator<'_, 'b> {
));
}

fn append_unknown_field_set(&mut self, fq_message_name: &str, field_name: &str) {
self.buf.push_str("#[prost(unknown_fields)]\n");
self.append_field_attributes(fq_message_name, field_name);
self.push_indent();
self.buf
.push_str(&format!("pub {}: ::prost::UnknownFieldList,\n", field_name,));
}

fn append_oneof_field(
&mut self,
message_name: &str,
Expand Down
33 changes: 33 additions & 0 deletions prost-build/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub struct Config {
pub(crate) message_attributes: PathMap<String>,
pub(crate) enum_attributes: PathMap<String>,
pub(crate) field_attributes: PathMap<String>,
pub(crate) include_unknown_fields: Option<PathMap<String>>,
pub(crate) boxed: PathMap<()>,
pub(crate) prost_types: bool,
pub(crate) strip_enum_prefix: bool,
Expand Down Expand Up @@ -266,6 +267,36 @@ impl Config {
self
}

/// Preserve unknown fields for the message type.
///
/// # Arguments
///
/// **`paths`** - paths to specific messages, or packages which should preserve unknown
/// fields during deserialization.
///
/// **`field_name`** - the name of the field to place unknown fields in. A field with this
/// name and type `prost::UnknownFieldList` will be added to the generated struct
///
/// # Examples
///
/// ```rust
/// # let mut config = prost_build::Config::new();
/// config.include_unknown_fields(".my_messages.MyMessageType", "unknown_fields");
/// ```
pub fn include_unknown_fields<P, A>(&mut self, path: P, field_name: A) -> &mut Self
where
P: AsRef<str>,
A: AsRef<str>,
{
if self.include_unknown_fields.is_none() {
self.include_unknown_fields = Some(PathMap::default());
}
if let Some(unknown_fields) = &mut self.include_unknown_fields {
unknown_fields.insert(path.as_ref().to_string(), field_name.as_ref().to_string());
}
self
}

/// Add additional attribute to matched messages.
///
/// # Arguments
Expand Down Expand Up @@ -1202,6 +1233,7 @@ impl default::Default for Config {
message_attributes: PathMap::default(),
enum_attributes: PathMap::default(),
field_attributes: PathMap::default(),
include_unknown_fields: None,
boxed: PathMap::default(),
prost_types: true,
strip_enum_prefix: true,
Expand Down Expand Up @@ -1234,6 +1266,7 @@ impl fmt::Debug for Config {
.field("bytes_type", &self.bytes_type)
.field("type_attributes", &self.type_attributes)
.field("field_attributes", &self.field_attributes)
.field("include_unknown_fields", &self.include_unknown_fields)
.field("prost_types", &self.prost_types)
.field("strip_enum_prefix", &self.strip_enum_prefix)
.field("out_dir", &self.out_dir)
Expand Down
7 changes: 7 additions & 0 deletions prost-build/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,13 @@ impl<'a> Context<'a> {
/// Returns `true` if this message can automatically derive Copy trait.
pub fn can_message_derive_copy(&self, fq_message_name: &str) -> bool {
assert_eq!(".", &fq_message_name[..1]);
// Unknown fields can potentially include an unbounded Bytes object, which
// cannot implement Copy
if let Some(unknown_fields) = &self.config().include_unknown_fields {
if unknown_fields.get_first(fq_message_name).is_some() {
return false;
}
};
self.message_graph
.get_message(fq_message_name)
.unwrap()
Expand Down
14 changes: 14 additions & 0 deletions prost-derive/src/field/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod map;
mod message;
mod oneof;
mod scalar;
mod unknown;

use std::fmt;
use std::slice;
Expand All @@ -26,6 +27,8 @@ pub enum Field {
Oneof(oneof::Field),
/// A group field.
Group(group::Field),
/// A set of unknown message fields.
Unknown(unknown::Field),
}

impl Field {
Expand All @@ -48,6 +51,8 @@ impl Field {
Field::Oneof(field)
} else if let Some(field) = group::Field::new(&attrs, inferred_tag)? {
Field::Group(field)
} else if let Some(field) = unknown::Field::new(&attrs)? {
Field::Unknown(field)
} else {
bail!("no type attribute");
};
Expand Down Expand Up @@ -86,6 +91,7 @@ impl Field {
Field::Map(ref map) => vec![map.tag],
Field::Oneof(ref oneof) => oneof.tags.clone(),
Field::Group(ref group) => vec![group.tag],
Field::Unknown(_) => vec![],
}
}

Expand All @@ -97,6 +103,7 @@ impl Field {
Field::Map(ref map) => map.encode(prost_path, ident),
Field::Oneof(ref oneof) => oneof.encode(ident),
Field::Group(ref group) => group.encode(prost_path, ident),
Field::Unknown(ref unknown) => unknown.encode(ident),
}
}

Expand All @@ -109,6 +116,7 @@ impl Field {
Field::Map(ref map) => map.merge(prost_path, ident),
Field::Oneof(ref oneof) => oneof.merge(ident),
Field::Group(ref group) => group.merge(prost_path, ident),
Field::Unknown(ref unknown) => unknown.merge(ident),
}
}

Expand All @@ -120,6 +128,7 @@ impl Field {
Field::Message(ref msg) => msg.encoded_len(prost_path, ident),
Field::Oneof(ref oneof) => oneof.encoded_len(ident),
Field::Group(ref group) => group.encoded_len(prost_path, ident),
Field::Unknown(ref unknown) => unknown.encoded_len(ident),
}
}

Expand All @@ -131,6 +140,7 @@ impl Field {
Field::Map(ref map) => map.clear(ident),
Field::Oneof(ref oneof) => oneof.clear(ident),
Field::Group(ref group) => group.clear(ident),
Field::Unknown(ref unknown) => unknown.clear(ident),
}
}

Expand Down Expand Up @@ -173,6 +183,10 @@ impl Field {
_ => None,
}
}

pub fn is_unknown(&self) -> bool {
matches!(self, Field::Unknown(_))
}
}

#[derive(Clone, Copy, PartialEq, Eq)]
Expand Down
66 changes: 66 additions & 0 deletions prost-derive/src/field/unknown.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use anyhow::{bail, Error};
use proc_macro2::TokenStream;
use quote::quote;
use syn::Meta;

use crate::field::{set_bool, word_attr};

#[derive(Clone)]
pub struct Field {}

impl Field {
pub fn new(attrs: &[Meta]) -> Result<Option<Field>, Error> {
let mut unknown = false;
let mut unknown_attrs = Vec::new();

for attr in attrs {
if word_attr("unknown_fields", attr) {
set_bool(&mut unknown, "duplicate message attribute")?;
} else {
unknown_attrs.push(attr);
}
}

if !unknown {
return Ok(None);
}

match unknown_attrs.len() {
0 => (),
1 => bail!(
"unknown attribute for unknown field set: {:?}",
unknown_attrs[0]
),
_ => bail!(
"unknown attributes for unknown field set: {:?}",
unknown_attrs
),
}

Ok(Some(Field {}))
}

pub fn encode(&self, ident: TokenStream) -> TokenStream {
quote! {
#ident.encode_raw(buf)
}
}

pub fn merge(&self, ident: TokenStream) -> TokenStream {
quote! {
#ident.merge_field(tag, wire_type, buf, ctx)
}
}

pub fn encoded_len(&self, ident: TokenStream) -> TokenStream {
quote! {
#ident.encoded_len()
}
}

pub fn clear(&self, ident: TokenStream) -> TokenStream {
quote! {
#ident.clear()
}
}
}
32 changes: 29 additions & 3 deletions prost-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,17 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
// We want Debug to be in declaration order
let unsorted_fields = fields.clone();

// Sort the fields by tag number so that fields will be encoded in tag order.
// Sort the fields by tag number so that fields will be encoded in tag order,
// and unknown fields are encoded last.
// TODO: This encodes oneof fields in the position of their lowest tag,
// regardless of the currently occupied variant, is that consequential?
// See: https://protobuf.dev/programming-guides/encoding/#order
fields.sort_by_key(|(_, field)| field.tags().into_iter().min().unwrap());
fields.sort_by_key(|(_, field)| {
(
field.is_unknown(),
field.tags().into_iter().min().unwrap_or(0),
)
});
let fields = fields;

if let Some(duplicate_tag) = fields
Expand All @@ -113,6 +119,9 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
.map(|(field_ident, field)| field.encode(&prost_path, quote!(self.#field_ident)));

let merge = fields.iter().map(|(field_ident, field)| {
if field.is_unknown() {
return quote!();
}
let merge = field.merge(&prost_path, quote!(value));
let tags = field.tags().into_iter().map(|tag| quote!(#tag));
let tags = Itertools::intersperse(tags, quote!(|));
Expand All @@ -127,6 +136,23 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
},
}
});
let merge_fallback = match fields.iter().find(|&(_, f)| f.is_unknown()) {
Some((field_ident, field)) => {
let merge = field.merge(&prost_path, quote!(value));
quote! {
_ => {
let mut value = &mut self.#field_ident;
#merge.map_err(|mut error| {
error.push(STRUCT_NAME, stringify!(#field_ident));
error
})
},
}
}
None => quote! {
_ => #prost_path::encoding::skip_field(wire_type, tag, buf, ctx),
},
};

let struct_name = if fields.is_empty() {
quote!()
Expand Down Expand Up @@ -192,7 +218,7 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
#struct_name
match tag {
#(#merge)*
_ => #prost_path::encoding::skip_field(wire_type, tag, buf, ctx),
#merge_fallback
}
}

Expand Down
2 changes: 2 additions & 0 deletions prost/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ mod error;
mod message;
mod name;
mod types;
mod unknown;

#[doc(hidden)]
pub mod encoding;
Expand All @@ -23,6 +24,7 @@ pub use crate::encoding::length_delimiter::{
pub use crate::error::{DecodeError, EncodeError, UnknownEnumValue};
pub use crate::message::Message;
pub use crate::name::Name;
pub use crate::unknown::{UnknownField, UnknownFieldIter, UnknownFieldList};

// See `encoding::DecodeContext` for more info.
// 100 is the default recursion limit in the C++ implementation.
Expand Down
Loading
Loading