Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

der_derive: add DecodeValue, EncodeValue macros #1722

Merged
merged 2 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion der/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ pub use crate::{
pub use crate::{asn1::Any, document::Document};

#[cfg(feature = "derive")]
pub use der_derive::{BitString, Choice, Enumerated, Sequence, ValueOrd};
pub use der_derive::{BitString, Choice, DecodeValue, EncodeValue, Enumerated, Sequence, ValueOrd};

#[cfg(feature = "flagset")]
pub use flagset;
Expand Down
50 changes: 50 additions & 0 deletions der/tests/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,56 @@ mod sequence {
}
}

/// Custom derive test cases for the `EncodeValue` macro.
mod encode_value {
use der::{Encode, EncodeValue, FixedTag, Tag};
use hex_literal::hex;

#[derive(EncodeValue, Default, Eq, PartialEq, Debug)]
#[asn1(tag_mode = "IMPLICIT")]
pub struct EncodeOnlyCheck<'a> {
#[asn1(type = "OCTET STRING", context_specific = "5")]
pub field: &'a [u8],
}
impl FixedTag for EncodeOnlyCheck<'_> {
const TAG: Tag = Tag::Sequence;
}

#[test]
fn sequence_encode_only_to_der() {
let obj = EncodeOnlyCheck {
field: &[0x33, 0x44],
};

let der_encoded = obj.to_der().unwrap();

assert_eq!(der_encoded, hex!("30 04 85 02 33 44"));
}
}

/// Custom derive test cases for the `DecodeValue` macro.
mod decode_value {
use der::{Decode, DecodeValue, FixedTag, Tag};
use hex_literal::hex;

#[derive(DecodeValue, Default, Eq, PartialEq, Debug)]
#[asn1(tag_mode = "IMPLICIT")]
pub struct DecodeOnlyCheck<'a> {
#[asn1(type = "OCTET STRING", context_specific = "5")]
pub field: &'a [u8],
}
impl FixedTag for DecodeOnlyCheck<'_> {
const TAG: Tag = Tag::Sequence;
}

#[test]
fn sequence_decode_only_from_der() {
let obj = DecodeOnlyCheck::from_der(&hex!("30 04 85 02 33 44")).unwrap();

assert_eq!(obj.field, &[0x33, 0x44]);
}
}

/// Custom derive test cases for the `BitString` macro.
#[cfg(feature = "std")]
mod bitstring {
Expand Down
36 changes: 31 additions & 5 deletions der_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ pub fn derive_enumerated(input: TokenStream) -> TokenStream {
}
}

/// Derive the [`Sequence`][1] trait on a `struct`.
/// Derive the [`DecodeValue`][1], [`EncodeValue`][2], [`Sequence`][3] traits on a `struct`.
///
/// This custom derive macro can be used to automatically impl the
/// `Sequence` trait for any struct which can be decoded/encoded as an
Expand Down Expand Up @@ -289,16 +289,42 @@ pub fn derive_enumerated(input: TokenStream) -> TokenStream {
///
/// # `#[asn1(type = "...")]` attribute
///
/// See [toplevel documentation for the `der_derive` crate][2] for more
/// See [toplevel documentation for the `der_derive` crate][4] for more
/// information about the `#[asn1]` attribute.
///
/// [1]: https://docs.rs/der/latest/der/trait.Sequence.html
/// [2]: https://docs.rs/der_derive/
/// [1]: https://docs.rs/der/latest/der/trait.DecodeValue.html
/// [2]: https://docs.rs/der/latest/der/trait.EncodeValue.html
/// [3]: https://docs.rs/der/latest/der/trait.Sequence.html
/// [4]: https://docs.rs/der_derive/
#[proc_macro_derive(Sequence, attributes(asn1))]
pub fn derive_sequence(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
match DeriveSequence::new(input) {
Ok(t) => t.to_tokens().into(),
Ok(t) => t.to_tokens_all().into(),
Err(e) => e.to_compile_error().into(),
}
}

/// Derive the [`EncodeValue`][1] trait on a `struct`.
///
/// [1]: https://docs.rs/der/latest/der/trait.EncodeValue.html
#[proc_macro_derive(EncodeValue, attributes(asn1))]
pub fn derive_sequence_encode(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
match DeriveSequence::new(input) {
Ok(t) => t.to_tokens_encode().into(),
Err(e) => e.to_compile_error().into(),
}
}

/// Derive the [`DecodeValue`][1] trait on a `struct`.
///
/// [1]: https://docs.rs/der/latest/der/trait.DecodeValue.html
#[proc_macro_derive(DecodeValue, attributes(asn1))]
pub fn derive_sequence_decode(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
match DeriveSequence::new(input) {
Ok(t) => t.to_tokens_decode().into(),
Err(e) => e.to_compile_error().into(),
}
}
Expand Down
79 changes: 63 additions & 16 deletions der_derive/src/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{ErrorType, TypeAttrs, default_lifetime};
use field::SequenceField;
use proc_macro2::TokenStream;
use quote::{ToTokens, quote};
use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam};
use syn::{DeriveInput, GenericParam, Generics, Ident, Lifetime, LifetimeParam};

/// Derive the `Sequence` trait for a struct
pub(crate) struct DeriveSequence {
Expand Down Expand Up @@ -51,13 +51,10 @@ impl DeriveSequence {
})
}

/// Lower the derived output into a [`TokenStream`].
pub fn to_tokens(&self) -> TokenStream {
let ident = &self.ident;
/// Use the first lifetime parameter as lifetime for Decode/Encode lifetime
/// if none found, add one.
fn calc_lifetime(&self) -> (Generics, Lifetime) {
let mut generics = self.generics.clone();

// Use the first lifetime parameter as lifetime for Decode/Encode lifetime
// if none found, add one.
let lifetime = generics
.lifetimes()
.next()
Expand All @@ -69,23 +66,39 @@ impl DeriveSequence {
.insert(0, GenericParam::Lifetime(LifetimeParam::new(lt.clone())));
lt
});

// We may or may not have inserted a lifetime.
(generics, lifetime)
}

/// Lower the derived output into a [`TokenStream`] for Sequence trait impl.
pub fn to_tokens_sequence_trait(&self) -> TokenStream {
let ident = &self.ident;

let (der_generics, lifetime) = self.calc_lifetime();

let (_, ty_generics, where_clause) = self.generics.split_for_impl();
let (impl_generics, _, _) = generics.split_for_impl();
let (impl_generics, _, _) = der_generics.split_for_impl();

quote! {
impl #impl_generics ::der::Sequence<#lifetime> for #ident #ty_generics #where_clause {}
}
}

/// Lower the derived output into a [`TokenStream`] for DecodeValue trait impl.
pub fn to_tokens_decode(&self) -> TokenStream {
let ident = &self.ident;

let (der_generics, lifetime) = self.calc_lifetime();

let (_, ty_generics, where_clause) = self.generics.split_for_impl();
let (impl_generics, _, _) = der_generics.split_for_impl();

let mut decode_body = Vec::new();
let mut decode_result = Vec::new();
let mut encoded_lengths = Vec::new();
let mut encode_fields = Vec::new();

for field in &self.fields {
decode_body.push(field.to_decode_tokens());
decode_result.push(&field.ident);

let field = field.to_encode_tokens();
encoded_lengths.push(quote!(#field.encoded_len()?));
encode_fields.push(quote!(#field.encode(writer)?;));
}

let error = self.error.to_token_stream();
Expand All @@ -109,6 +122,26 @@ impl DeriveSequence {
})
}
}
}
}

/// Lower the derived output into a [`TokenStream`] for EncodeValue trait impl.
pub fn to_tokens_encode(&self) -> TokenStream {
let ident = &self.ident;

let (_, ty_generics, where_clause) = self.generics.split_for_impl();
let (impl_generics, _, _) = self.generics.split_for_impl();

let mut encoded_lengths = Vec::new();
let mut encode_fields = Vec::new();

for field in &self.fields {
let field = field.to_encode_tokens();
encoded_lengths.push(quote!(#field.encoded_len()?));
encode_fields.push(quote!(#field.encode(writer)?;));
}

quote! {

impl #impl_generics ::der::EncodeValue for #ident #ty_generics #where_clause {
fn value_len(&self) -> ::der::Result<::der::Length> {
Expand All @@ -127,8 +160,22 @@ impl DeriveSequence {
Ok(())
}
}
}
}

impl #impl_generics ::der::Sequence<#lifetime> for #ident #ty_generics #where_clause {}
/// Lower the derived output into a [`TokenStream`] for trait impls:
/// - EncodeValue
/// - DecodeValue
/// - Sequence
pub fn to_tokens_all(&self) -> TokenStream {
let decode_tokens = self.to_tokens_decode();
let encode_tokens = self.to_tokens_encode();
let sequence_trait_tokens = self.to_tokens_sequence_trait();

quote! {
#decode_tokens
#encode_tokens
#sequence_trait_tokens
}
}
}
Expand Down