Skip to content

Commit 09a77f1

Browse files
committed
works!
1 parent 37c1144 commit 09a77f1

File tree

2 files changed

+63
-4
lines changed

2 files changed

+63
-4
lines changed

quic-rpc-derive/src/lib.rs

+61-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use syn::{
77
parse::{Parse, ParseStream},
88
parse_macro_input,
99
spanned::Spanned,
10-
Data, DeriveInput, Fields, Ident, Token, Type,
10+
Data, DeriveInput, Fields, Ident, Macro, Token, Type,
1111
};
1212

1313
const RPC: &str = "rpc";
@@ -52,10 +52,32 @@ fn generate_channels_impl(
5252
Ok(res)
5353
}
5454

55+
// Parse arguments in the format (MessageEnumName, ServiceType)
56+
struct MacroArgs {
57+
message_enum_name: Ident,
58+
service_name: Ident,
59+
}
60+
61+
impl Parse for MacroArgs {
62+
fn parse(input: ParseStream) -> syn::Result<Self> {
63+
let service_name: Ident = input.parse()?;
64+
let _: Token![,] = input.parse()?;
65+
let message_enum_name: Ident = input.parse()?;
66+
67+
Ok(MacroArgs {
68+
service_name,
69+
message_enum_name,
70+
})
71+
}
72+
}
73+
5574
#[proc_macro_attribute]
5675
pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream {
5776
let mut input = parse_macro_input!(item as DeriveInput);
58-
let service_name = parse_macro_input!(attr as Ident);
77+
let MacroArgs {
78+
service_name,
79+
message_enum_name,
80+
} = parse_macro_input!(attr as MacroArgs);
5981

6082
let input_span = input.span();
6183
let data_enum = match &mut input.data {
@@ -122,9 +144,46 @@ pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream {
122144
}
123145
}
124146

147+
let message_variants = data_enum
148+
.variants
149+
.iter()
150+
.map(|variant| {
151+
let variant_name = &variant.ident;
152+
153+
// Extract the inner type - we know it's already valid
154+
let inner_type = match &variant.fields {
155+
Fields::Unnamed(fields) => {
156+
if let Type::Path(type_path) = &fields.unnamed.first().unwrap().ty {
157+
if let Some(last_segment) = type_path.path.segments.last() {
158+
&last_segment.ident
159+
} else {
160+
&type_path.path.segments.first().unwrap().ident
161+
}
162+
} else {
163+
panic!("Unexpected type"); // Should never happen due to prior validation
164+
}
165+
}
166+
_ => panic!("Unexpected field type"), // Should never happen due to prior validation
167+
};
168+
169+
quote! {
170+
#variant_name(::quic_rpc::WithChannels<#inner_type, #service_name>)
171+
}
172+
})
173+
.collect::<Vec<_>>();
174+
175+
let message_enum = quote! {
176+
#[derive(derive_more::From)]
177+
enum #message_enum_name {
178+
#(#message_variants),*
179+
}
180+
};
181+
125182
let output = quote! {
126183
#input
127184

185+
#message_enum
186+
128187
#(#additional_items)*
129188
};
130189

quic-rpc-derive/tests/smoke.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ fn simple() {
3838
#[derive(Debug, Serialize, Deserialize)]
3939
struct Response4;
4040

41-
#[rpc_requests(Service)]
41+
#[rpc_requests(Service, RequestWithChannels)]
4242
#[derive(Debug, Serialize, Deserialize, derive_more::From, derive_more::TryInto)]
4343
enum Request {
44-
#[rpc(tx=NoSender)]
44+
#[rpc(tx=oneshot::Sender<()>)]
4545
Rpc(RpcRequest),
4646
#[rpc(tx=NoSender)]
4747
ServerStreaming(ServerStreamingRequest),

0 commit comments

Comments
 (0)