@@ -7,7 +7,7 @@ use syn::{
7
7
parse:: { Parse , ParseStream } ,
8
8
parse_macro_input,
9
9
spanned:: Spanned ,
10
- Data , DeriveInput , Fields , Ident , Token , Type ,
10
+ Data , DeriveInput , Fields , Ident , Macro , Token , Type ,
11
11
} ;
12
12
13
13
const RPC : & str = "rpc" ;
@@ -52,10 +52,32 @@ fn generate_channels_impl(
52
52
Ok ( res)
53
53
}
54
54
55
+ // Parse arguments in the format (MessageEnumName, ServiceType)
56
+ struct MacroArgs {
57
+ message_enum_name : Ident ,
58
+ service_name : Ident ,
59
+ }
60
+
61
+ impl Parse for MacroArgs {
62
+ fn parse ( input : ParseStream ) -> syn:: Result < Self > {
63
+ let service_name: Ident = input. parse ( ) ?;
64
+ let _: Token ! [ , ] = input. parse ( ) ?;
65
+ let message_enum_name: Ident = input. parse ( ) ?;
66
+
67
+ Ok ( MacroArgs {
68
+ service_name,
69
+ message_enum_name,
70
+ } )
71
+ }
72
+ }
73
+
55
74
#[ proc_macro_attribute]
56
75
pub fn rpc_requests ( attr : TokenStream , item : TokenStream ) -> TokenStream {
57
76
let mut input = parse_macro_input ! ( item as DeriveInput ) ;
58
- let service_name = parse_macro_input ! ( attr as Ident ) ;
77
+ let MacroArgs {
78
+ service_name,
79
+ message_enum_name,
80
+ } = parse_macro_input ! ( attr as MacroArgs ) ;
59
81
60
82
let input_span = input. span ( ) ;
61
83
let data_enum = match & mut input. data {
@@ -122,9 +144,46 @@ pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream {
122
144
}
123
145
}
124
146
147
+ let message_variants = data_enum
148
+ . variants
149
+ . iter ( )
150
+ . map ( |variant| {
151
+ let variant_name = & variant. ident ;
152
+
153
+ // Extract the inner type - we know it's already valid
154
+ let inner_type = match & variant. fields {
155
+ Fields :: Unnamed ( fields) => {
156
+ if let Type :: Path ( type_path) = & fields. unnamed . first ( ) . unwrap ( ) . ty {
157
+ if let Some ( last_segment) = type_path. path . segments . last ( ) {
158
+ & last_segment. ident
159
+ } else {
160
+ & type_path. path . segments . first ( ) . unwrap ( ) . ident
161
+ }
162
+ } else {
163
+ panic ! ( "Unexpected type" ) ; // Should never happen due to prior validation
164
+ }
165
+ }
166
+ _ => panic ! ( "Unexpected field type" ) , // Should never happen due to prior validation
167
+ } ;
168
+
169
+ quote ! {
170
+ #variant_name( :: quic_rpc:: WithChannels <#inner_type, #service_name>)
171
+ }
172
+ } )
173
+ . collect :: < Vec < _ > > ( ) ;
174
+
175
+ let message_enum = quote ! {
176
+ #[ derive( derive_more:: From ) ]
177
+ enum #message_enum_name {
178
+ #( #message_variants) , *
179
+ }
180
+ } ;
181
+
125
182
let output = quote ! {
126
183
#input
127
184
185
+ #message_enum
186
+
128
187
#( #additional_items) *
129
188
} ;
130
189
0 commit comments