Skip to content

Commit 09111e6

Browse files
committed
add #[default_error(<type>)] attribute to initializer macros
The `#[default_error(<type>)]` attribute macro can be used to supply a default type as the error used for the `[pin_]init!` macros. This way one can easily define custom `try_[pin_]init!` variants that default to your project specific error type. Just write the following declarative macro: macro_rules! try_init { ($($args:tt)*) => { ::pin_init::init!( #[default_error(YourCustomErrorType)] $($args)* ) } } Signed-off-by: Benno Lossin <[email protected]>
1 parent e0ae072 commit 09111e6

File tree

2 files changed

+65
-4
lines changed

2 files changed

+65
-4
lines changed

internal/src/init.rs

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ use syn::{
66
parse_quote,
77
punctuated::Punctuated,
88
spanned::Spanned,
9-
token, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
9+
token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
1010
};
1111

1212
pub struct Initializer {
13+
attrs: Vec<InitializerAttribute>,
1314
this: Option<This>,
1415
path: Path,
1516
brace_token: token::Brace,
@@ -50,23 +51,44 @@ impl InitializerField {
5051
}
5152
}
5253

54+
enum InitializerAttribute {
55+
DefaultError(DefaultErrorAttribute),
56+
}
57+
58+
struct DefaultErrorAttribute {
59+
ty: Type,
60+
}
61+
5362
pub(crate) fn expand(
5463
Initializer {
64+
attrs,
5565
this,
5666
path,
5767
brace_token,
5868
fields,
5969
rest,
60-
mut error,
70+
error,
6171
}: Initializer,
6272
default_error: Option<&'static str>,
6373
pinned: bool,
6474
) -> TokenStream {
6575
let mut errors = TokenStream::new();
76+
let mut error = error.map(|(_, err)| err);
77+
if let Some(default_error) = attrs.iter().fold(None, |acc, attr| {
78+
#[expect(irrefutable_let_patterns)]
79+
if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr {
80+
Some(ty.clone())
81+
} else {
82+
acc
83+
}
84+
}) {
85+
error.get_or_insert(default_error);
86+
}
6687
if let Some(default_error) = default_error {
67-
error.get_or_insert((Default::default(), syn::parse_str(default_error).unwrap()));
88+
error.get_or_insert(syn::parse_str(default_error).unwrap());
6889
}
69-
let error = error.map(|(_, err)| err).unwrap_or_else(|| {
90+
91+
let error = error.unwrap_or_else(|| {
7092
errors.extend(quote_spanned!(brace_token.span.close()=>
7193
::core::compile_error!("expected `? <type>` after `}`");
7294
));
@@ -346,6 +368,7 @@ fn make_field_check(
346368

347369
impl Parse for Initializer {
348370
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
371+
let attrs = input.call(Attribute::parse_outer)?;
349372
let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
350373
let path = input.parse()?;
351374
let content;
@@ -377,7 +400,19 @@ impl Parse for Initializer {
377400
.peek(Token![?])
378401
.then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
379402
.transpose()?;
403+
let attrs = attrs
404+
.into_iter()
405+
.map(|a| {
406+
if a.path().is_ident("default_error") {
407+
a.parse_args::<DefaultErrorAttribute>()
408+
.map(InitializerAttribute::DefaultError)
409+
} else {
410+
Err(syn::Error::new_spanned(a, "unknown initializer attribute"))
411+
}
412+
})
413+
.collect::<Result<Vec<_>, _>>()?;
380414
Ok(Self {
415+
attrs,
381416
this,
382417
path,
383418
brace_token,
@@ -388,6 +423,16 @@ impl Parse for Initializer {
388423
}
389424
}
390425

426+
impl Parse for DefaultErrorAttribute {
427+
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
428+
let ty = input.parse()?;
429+
if !input.peek(End) {
430+
return Err(input.error("expected end of input"));
431+
}
432+
Ok(Self { ty })
433+
}
434+
}
435+
391436
impl Parse for This {
392437
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
393438
Ok(Self {

tests/default_error.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#![allow(dead_code)]
2+
3+
use pin_init::{init, Init};
4+
5+
struct Foo {}
6+
7+
struct Error;
8+
9+
impl Foo {
10+
fn new() -> impl Init<Foo, Error> {
11+
init!(
12+
#[default_error(Error)]
13+
Foo {}
14+
)
15+
}
16+
}

0 commit comments

Comments
 (0)