|  | 
|  | 1 | +use std::cell::Cell; | 
|  | 2 | +use std::ptr::{self, NonNull}; | 
|  | 3 | + | 
| 1 | 4 | use rustc_ast as ast; | 
| 2 | 5 | use rustc_ast::ptr::P; | 
| 3 | 6 | use rustc_ast::tokenstream::TokenStream; | 
|  | 7 | +use rustc_data_structures::svh::Svh; | 
| 4 | 8 | use rustc_errors::ErrorGuaranteed; | 
| 5 |  | -use rustc_middle::ty; | 
|  | 9 | +use rustc_middle::ty::{self, TyCtxt}; | 
| 6 | 10 | use rustc_parse::parser::{ForceCollect, Parser}; | 
| 7 | 11 | use rustc_session::Session; | 
| 8 | 12 | use rustc_session::config::ProcMacroExecutionStrategy; | 
| 9 |  | -use rustc_span::Span; | 
| 10 | 13 | use rustc_span::profiling::SpannedEventArgRecorder; | 
|  | 14 | +use rustc_span::{LocalExpnId, Span}; | 
| 11 | 15 | 
 | 
| 12 | 16 | use crate::base::{self, *}; | 
| 13 | 17 | use crate::{errors, proc_macro_server}; | 
| @@ -154,11 +158,9 @@ impl MultiItemModifier for DeriveProcMacro { | 
| 154 | 158 |             // well)? | 
| 155 | 159 |             if tcx.sess.opts.incremental.is_some() && tcx.sess.opts.unstable_opts.cache_proc_macros | 
| 156 | 160 |             { | 
| 157 |  | -                crate::derive_macro_expansion::enter_context((ecx, self.client), move || { | 
| 158 |  | -                    tcx.derive_macro_expansion(key).cloned() | 
| 159 |  | -                }) | 
|  | 161 | +                enter_context((ecx, self.client), move || tcx.derive_macro_expansion(key).cloned()) | 
| 160 | 162 |             } else { | 
| 161 |  | -                crate::derive_macro_expansion::provide_derive_macro_expansion(tcx, key).cloned() | 
|  | 163 | +                provide_derive_macro_expansion(tcx, key).cloned() | 
| 162 | 164 |             } | 
| 163 | 165 |         }); | 
| 164 | 166 |         let Ok(output) = res else { | 
| @@ -195,3 +197,94 @@ impl MultiItemModifier for DeriveProcMacro { | 
| 195 | 197 |         ExpandResult::Ready(items) | 
| 196 | 198 |     } | 
| 197 | 199 | } | 
|  | 200 | + | 
|  | 201 | +pub(super) fn provide_derive_macro_expansion<'tcx>( | 
|  | 202 | +    tcx: TyCtxt<'tcx>, | 
|  | 203 | +    key: (LocalExpnId, Svh, &'tcx TokenStream), | 
|  | 204 | +) -> Result<&'tcx TokenStream, ()> { | 
|  | 205 | +    let (invoc_id, _macro_crate_hash, input) = key; | 
|  | 206 | + | 
|  | 207 | +    with_context(|(ecx, client)| { | 
|  | 208 | +        let invoc_expn_data = invoc_id.expn_data(); | 
|  | 209 | +        let span = invoc_expn_data.call_site; | 
|  | 210 | +        let event_arg = invoc_expn_data.kind.descr(); | 
|  | 211 | +        let _timer = tcx.sess.prof.generic_activity_with_arg_recorder( | 
|  | 212 | +            "expand_derive_proc_macro_inner", | 
|  | 213 | +            |recorder| { | 
|  | 214 | +                recorder.record_arg_with_span(tcx.sess.source_map(), event_arg.clone(), span); | 
|  | 215 | +            }, | 
|  | 216 | +        ); | 
|  | 217 | + | 
|  | 218 | +        let proc_macro_backtrace = ecx.ecfg.proc_macro_backtrace; | 
|  | 219 | +        let strategy = crate::proc_macro::exec_strategy(tcx.sess); | 
|  | 220 | +        let server = crate::proc_macro_server::Rustc::new(ecx); | 
|  | 221 | + | 
|  | 222 | +        match client.run(&strategy, server, input.clone(), proc_macro_backtrace) { | 
|  | 223 | +            Ok(stream) => Ok(tcx.arena.alloc(stream) as &TokenStream), | 
|  | 224 | +            Err(e) => { | 
|  | 225 | +                tcx.dcx().emit_err({ | 
|  | 226 | +                    errors::ProcMacroDerivePanicked { | 
|  | 227 | +                        span, | 
|  | 228 | +                        message: e.as_str().map(|message| errors::ProcMacroDerivePanickedHelp { | 
|  | 229 | +                            message: message.into(), | 
|  | 230 | +                        }), | 
|  | 231 | +                    } | 
|  | 232 | +                }); | 
|  | 233 | +                Err(()) | 
|  | 234 | +            } | 
|  | 235 | +        } | 
|  | 236 | +    }) | 
|  | 237 | +} | 
|  | 238 | + | 
|  | 239 | +type CLIENT = pm::bridge::client::Client<pm::TokenStream, pm::TokenStream>; | 
|  | 240 | + | 
|  | 241 | +// based on rust/compiler/rustc_middle/src/ty/context/tls.rs | 
|  | 242 | +thread_local! { | 
|  | 243 | +    /// A thread local variable that stores a pointer to the current `CONTEXT`. | 
|  | 244 | +    static TLV: Cell<(*mut (), Option<CLIENT>)> = const { Cell::new((ptr::null_mut(), None)) }; | 
|  | 245 | +} | 
|  | 246 | + | 
|  | 247 | +/// Sets `context` as the new current `CONTEXT` for the duration of the function `f`. | 
|  | 248 | +#[inline] | 
|  | 249 | +pub(crate) fn enter_context<'a, F, R>(context: (&mut ExtCtxt<'a>, CLIENT), f: F) -> R | 
|  | 250 | +where | 
|  | 251 | +    F: FnOnce() -> R, | 
|  | 252 | +{ | 
|  | 253 | +    let (ectx, client) = context; | 
|  | 254 | +    let erased = (ectx as *mut _ as *mut (), Some(client)); | 
|  | 255 | +    TLV.with(|tlv| { | 
|  | 256 | +        let old = tlv.replace(erased); | 
|  | 257 | +        let _reset = rustc_data_structures::defer(move || tlv.set(old)); | 
|  | 258 | +        f() | 
|  | 259 | +    }) | 
|  | 260 | +} | 
|  | 261 | + | 
|  | 262 | +/// Allows access to the current `CONTEXT`. | 
|  | 263 | +/// Panics if there is no `CONTEXT` available. | 
|  | 264 | +#[inline] | 
|  | 265 | +#[track_caller] | 
|  | 266 | +fn with_context<F, R>(f: F) -> R | 
|  | 267 | +where | 
|  | 268 | +    F: for<'a, 'b> FnOnce(&'b mut (&mut ExtCtxt<'a>, CLIENT)) -> R, | 
|  | 269 | +{ | 
|  | 270 | +    let (ectx, client_opt) = TLV.get(); | 
|  | 271 | +    let ectx = NonNull::new(ectx).expect("no CONTEXT stored in tls"); | 
|  | 272 | + | 
|  | 273 | +    // We could get an `CONTEXT` pointer from another thread. | 
|  | 274 | +    // Ensure that `CONTEXT` is `DynSync`. | 
|  | 275 | +    // FIXME(pr-time): we should not be able to? | 
|  | 276 | +    // sync::assert_dyn_sync::<CONTEXT<'_>>(); | 
|  | 277 | + | 
|  | 278 | +    // prevent double entering, as that would allow creating two `&mut ExtCtxt`s | 
|  | 279 | +    // FIXME(pr-time): probably use a RefCell instead (which checks this properly)? | 
|  | 280 | +    TLV.with(|tlv| { | 
|  | 281 | +        let old = tlv.replace((ptr::null_mut(), None)); | 
|  | 282 | +        let _reset = rustc_data_structures::defer(move || tlv.set(old)); | 
|  | 283 | +        let ectx = { | 
|  | 284 | +            let mut casted = ectx.cast::<ExtCtxt<'_>>(); | 
|  | 285 | +            unsafe { casted.as_mut() } | 
|  | 286 | +        }; | 
|  | 287 | + | 
|  | 288 | +        f(&mut (ectx, client_opt.unwrap())) | 
|  | 289 | +    }) | 
|  | 290 | +} | 
0 commit comments