diff --git a/sqlx-macros-core/src/query/mod.rs b/sqlx-macros-core/src/query/mod.rs index 060a24b847..d94498deea 100644 --- a/sqlx-macros-core/src/query/mod.rs +++ b/sqlx-macros-core/src/query/mod.rs @@ -1,4 +1,6 @@ +use std::cell::RefCell; use std::collections::{hash_map, HashMap}; +use std::env::VarError; use std::path::{Path, PathBuf}; use std::sync::{Arc, LazyLock, Mutex}; use std::{fs, io}; @@ -109,61 +111,64 @@ impl Metadata { } } -static METADATA: LazyLock>> = LazyLock::new(Default::default); +static METADATA: LazyLock>>> = LazyLock::new(Default::default); +static CRATE_ENV_FILE_VARS: LazyLock>>> = + LazyLock::new(Default::default); + +thread_local! { + static CURRENT_CRATE_MANIFEST_DIR: RefCell = RefCell::new(PathBuf::new()); +} // If we are in a workspace, lookup `workspace_root` since `CARGO_MANIFEST_DIR` won't // reflect the workspace dir: https://github.com/rust-lang/cargo/issues/3946 -fn init_metadata(manifest_dir: &String) -> crate::Result { - let manifest_dir: PathBuf = manifest_dir.into(); +fn init_metadata(manifest_dir: &Path) -> crate::Result> { + let config = Config::try_from_crate_or_default()?; - let (database_url, offline, offline_dir) = load_dot_env(&manifest_dir); + load_env(manifest_dir, &config); let offline = env("SQLX_OFFLINE") - .ok() - .or(offline) .map(|s| s.eq_ignore_ascii_case("true") || s == "1") .unwrap_or(false); - let offline_dir = env("SQLX_OFFLINE_DIR").ok().or(offline_dir); - - let config = Config::try_from_crate_or_default()?; + let offline_dir = env("SQLX_OFFLINE_DIR").ok(); - let database_url = env(config.common.database_url_var()).ok().or(database_url); + let database_url = env(config.common.database_url_var()).ok(); - Ok(Metadata { - manifest_dir, + Ok(Arc::new(Metadata { + manifest_dir: manifest_dir.to_path_buf(), offline, database_url, offline_dir, config, workspace_root: Arc::new(Mutex::new(None)), - }) + })) } pub fn expand_input<'a>( input: QueryMacroInput, drivers: impl IntoIterator, ) -> crate::Result { - let manifest_dir = env("CARGO_MANIFEST_DIR").expect("`CARGO_MANIFEST_DIR` must be set"); + // `CARGO_MANIFEST_DIR` can only be loaded from a real environment variable due to the filtering done + // by `load_env`, so the value of `CURRENT_CRATE_MANIFEST_DIR` does not matter here. + let manifest_dir = + PathBuf::from(env("CARGO_MANIFEST_DIR").expect("`CARGO_MANIFEST_DIR` must be set")); + CURRENT_CRATE_MANIFEST_DIR.set(manifest_dir.clone()); - let mut metadata_lock = METADATA - .lock() - // Just reset the metadata on error - .unwrap_or_else(|poison_err| { - let mut guard = poison_err.into_inner(); - *guard = Default::default(); - guard - }); + let mut metadata_lock = METADATA.lock().unwrap(); let metadata = match metadata_lock.entry(manifest_dir) { - hash_map::Entry::Occupied(occupied) => occupied.into_mut(), + hash_map::Entry::Occupied(occupied) => Arc::clone(occupied.get()), hash_map::Entry::Vacant(vacant) => { let metadata = init_metadata(vacant.key())?; - vacant.insert(metadata) + vacant.insert(Arc::clone(&metadata)); + metadata } }; - let data_source = match &metadata { + // Release the lock now so other expansions in other threads of this process can proceed concurrently. + drop(metadata_lock); + + let data_source = match &*metadata { Metadata { offline: false, database_url: Some(db_url), @@ -181,7 +186,7 @@ pub fn expand_input<'a>( ]; let Some(data_file_path) = dirs .iter() - .filter_map(|path| path(metadata)) + .filter_map(|path| path(&metadata)) .map(|path| path.join(&filename)) .find(|path| path.exists()) else { @@ -415,64 +420,107 @@ where Ok(ret_tokens) } -/// Get the value of an environment variable, telling the compiler about it if applicable. +/// Get the value of an environment variable for the current crate, telling the compiler about it if applicable. +/// +/// The current crate is determined by the `CURRENT_CRATE_MANIFEST_DIR` thread-local variable, which is assumed +/// to be set to match the crate whose macro is being expanded before this function is called. It is also assumed +/// that the expansion of this macro happens on a single thread. fn env(name: &str) -> Result { #[cfg(procmacro2_semver_exempt)] - { - proc_macro::tracked_env::var(name) - } - + let tracked_value = Some(proc_macro::tracked_env::var(name)); #[cfg(not(procmacro2_semver_exempt))] - { - std::env::var(name) + let tracked_value = None; + + match tracked_value.map_or_else(|| std::env::var(name), |var| var) { + Ok(v) => Ok(v), + Err(VarError::NotPresent) => CURRENT_CRATE_MANIFEST_DIR + .with_borrow(|manifest_dir| { + CRATE_ENV_FILE_VARS + .lock() + .unwrap() + .get(manifest_dir) + .cloned() + }) + .and_then(|env_file_vars| env_file_vars.get(name).cloned()) + .ok_or(VarError::NotPresent), + Err(e) => Err(e), } } -/// Get `DATABASE_URL`, `SQLX_OFFLINE` and `SQLX_OFFLINE_DIR` from the `.env`. -fn load_dot_env(manifest_dir: &Path) -> (Option, Option, Option) { - let mut env_path = manifest_dir.join(".env"); - - // If a .env file exists at CARGO_MANIFEST_DIR, load environment variables from this, - // otherwise fallback to default dotenv file. - #[cfg_attr(not(procmacro2_semver_exempt), allow(unused_variables))] - let env_file = if env_path.exists() { - let res = dotenvy::from_path_iter(&env_path); - match res { - Ok(iter) => Some(iter), - Err(e) => panic!("failed to load environment from {env_path:?}, {e}"), - } +/// Load configuration environment variables from a `.env` file. If applicable, the compiler is +/// about the `.env` files they may come from. +fn load_env(manifest_dir: &Path, config: &Config) { + // A whitelist of environment variables to load from a `.env` file avoids + // such files overriding internal variables they should not (e.g., `CARGO`, + // `CARGO_MANIFEST_DIR`) when using the `env` function above. + let database_url_var = config.common.database_url_var(); + let loadable_vars = if database_url_var == "DATABASE_URL" { + &["DATABASE_URL", "SQLX_OFFLINE", "SQLX_OFFLINE_DIR"][..] } else { - #[allow(unused_assignments)] - { - env_path = PathBuf::from(".env"); - } - dotenvy::dotenv_iter().ok() + &[ + "DATABASE_URL", + "SQLX_OFFLINE", + "SQLX_OFFLINE_DIR", + database_url_var, + ] }; - let mut offline = None; - let mut database_url = None; - let mut offline_dir = None; + let (found_dotenv, candidate_dotenv_paths) = find_dotenv(manifest_dir); - if let Some(env_file) = env_file { - // tell the compiler to watch the `.env` for changes. - #[cfg(procmacro2_semver_exempt)] - if let Some(env_path) = env_path.to_str() { - proc_macro::tracked_path::path(env_path); + // Tell the compiler to watch the candidate `.env` paths for changes. It's important to + // watch them all, because there are several possible locations where a `.env` file + // might be read, and we want to react to changes in any of them. + #[cfg(procmacro2_semver_exempt)] + for path in &candidate_dotenv_paths { + if let Some(path) = path.to_str() { + proc_macro::tracked_path::path(path); } + } - for item in env_file { - let Ok((key, value)) = item else { - continue; - }; + let loaded_vars = found_dotenv + .then_some(candidate_dotenv_paths) + .iter() + .flatten() + .last() + .map(|dotenv_path| { + dotenvy::from_path_iter(dotenv_path) + .ok() + .into_iter() + .flatten() + .filter_map(|dotenv_var_result| match dotenv_var_result { + Ok((key, value)) + if loadable_vars.contains(&&*key) && std::env::var(&key).is_err() => + { + Some((key, value)) + } + _ => None, + }) + }) + .into_iter() + .flatten() + .collect(); - match key.as_str() { - "DATABASE_URL" => database_url = Some(value), - "SQLX_OFFLINE" => offline = Some(value), - "SQLX_OFFLINE_DIR" => offline_dir = Some(value), - _ => {} - }; + CRATE_ENV_FILE_VARS + .lock() + .unwrap() + .insert(manifest_dir.to_path_buf(), loaded_vars); +} + +fn find_dotenv(mut dir: &Path) -> (bool, Vec) { + let mut candidate_files = vec![]; + + loop { + candidate_files.push(dir.join(".env")); + let candidate_file = candidate_files.last().unwrap(); + + if candidate_file.is_file() { + return (true, candidate_files); } - } - (database_url, offline, offline_dir) + if let Some(parent) = dir.parent() { + dir = parent; + } else { + return (false, candidate_files); + } + } }