Skip to content

Commit 37616f3

Browse files
committed
add #[default_error(<type>)] attribute to initializer macros
The `#[default_error(<type>)]` attribute macro can be used to supply a default type as the error used for the `[pin_]init!` macros. This way one can easily define custom `try_[pin_]init!` variants that default to your project specific error type. Just write the following declarative macro: macro_rules! try_init { ($($args:tt)*) => { ::pin_init::init!( #[default_error(YourCustomErrorType)] $($args)* ) } } Signed-off-by: Benno Lossin <[email protected]>
1 parent 49cea2d commit 37616f3

File tree

2 files changed

+68
-7
lines changed

2 files changed

+68
-7
lines changed

internal/src/init.rs

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ use syn::{
66
parse_quote,
77
punctuated::Punctuated,
88
spanned::Spanned,
9-
token, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
9+
token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
1010
};
1111

1212
pub struct Initializer {
13+
attrs: Vec<InitializerAttribute>,
1314
this: Option<This>,
1415
path: Path,
1516
brace_token: token::Brace,
@@ -50,8 +51,17 @@ impl InitializerField {
5051
}
5152
}
5253

54+
enum InitializerAttribute {
55+
DefaultError(DefaultErrorAttribute),
56+
}
57+
58+
struct DefaultErrorAttribute {
59+
ty: Type,
60+
}
61+
5362
pub(crate) fn expand(
5463
Initializer {
64+
attrs,
5565
this,
5666
path,
5767
brace_token,
@@ -66,12 +76,24 @@ pub(crate) fn expand(
6676
if let Some(default_error) = default_error {
6777
error.get_or_insert((Default::default(), syn::parse_str(default_error).unwrap()));
6878
}
69-
let error = error.map(|(_, err)| err).unwrap_or_else(|| {
70-
errors.extend(quote_spanned!(brace_token.span.close()=>
71-
::core::compile_error!("expected `? <type>` after `}`");
72-
));
73-
parse_quote!(::core::convert::Infallible)
74-
});
79+
80+
let error = attrs
81+
.iter()
82+
.fold(None, |acc, attr| {
83+
#[expect(irrefutable_let_patterns)]
84+
if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr {
85+
Some(ty.clone())
86+
} else {
87+
acc
88+
}
89+
})
90+
.or(error.map(|(_, err)| err))
91+
.unwrap_or_else(|| {
92+
errors.extend(quote_spanned!(brace_token.span.close()=>
93+
::core::compile_error!("expected `? <type>` after `}`");
94+
));
95+
parse_quote!(::core::convert::Infallible)
96+
});
7597
let slot = format_ident!("slot");
7698
let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned {
7799
(
@@ -346,6 +368,7 @@ fn make_field_check(
346368

347369
impl Parse for Initializer {
348370
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
371+
let attrs = input.call(Attribute::parse_outer)?;
349372
let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
350373
let path = input.parse()?;
351374
let content;
@@ -377,7 +400,19 @@ impl Parse for Initializer {
377400
.peek(Token![?])
378401
.then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
379402
.transpose()?;
403+
let attrs = attrs
404+
.into_iter()
405+
.map(|a| {
406+
if a.path().is_ident("default_error") {
407+
a.parse_args::<DefaultErrorAttribute>()
408+
.map(InitializerAttribute::DefaultError)
409+
} else {
410+
Err(syn::Error::new_spanned(a, "unknown initializer attribute"))
411+
}
412+
})
413+
.collect::<Result<Vec<_>, _>>()?;
380414
Ok(Self {
415+
attrs,
381416
this,
382417
path,
383418
brace_token,
@@ -388,6 +423,16 @@ impl Parse for Initializer {
388423
}
389424
}
390425

426+
impl Parse for DefaultErrorAttribute {
427+
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
428+
let ty = input.parse()?;
429+
if !input.peek(End) {
430+
return Err(input.error("expected end of input"));
431+
}
432+
Ok(Self { ty })
433+
}
434+
}
435+
391436
impl Parse for This {
392437
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
393438
Ok(Self {

tests/default_error.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#![allow(dead_code)]
2+
3+
use pin_init::{init, Init};
4+
5+
struct Foo {}
6+
7+
struct Error;
8+
9+
impl Foo {
10+
fn new() -> impl Init<Foo, Error> {
11+
init!(
12+
#[default_error(Error)]
13+
Foo {}
14+
)
15+
}
16+
}

0 commit comments

Comments
 (0)