diff --git a/Cargo.toml b/Cargo.toml index 74b74d39..12736003 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,10 +53,10 @@ sqlx = { version = "0.8.1", features = ["runtime-tokio", "tls-rustls", "postgres chrono = "0.4" itertools = "0.13.0" +derive_builder = "0.20.0" [dev-dependencies] insta = "1.26" -derive_builder = "0.20.0" wiremock = "0.6.0" base64 = "0.22.1" tracing-test = "0.2.4" diff --git a/src/bin/bors.rs b/src/bin/bors.rs index 28637670..5531aaa4 100644 --- a/src/bin/bors.rs +++ b/src/bin/bors.rs @@ -6,8 +6,9 @@ use std::time::Duration; use anyhow::Context; use bors::{ - create_app, create_bors_process, create_github_client, load_repositories, BorsContext, - BorsGlobalEvent, CommandParser, PgDbClient, ServerState, TeamApiClient, WebhookSecret, + create_app, create_bors_process, create_github_client, create_github_client_from_access_token, + load_repositories, BorsContextBuilder, BorsGlobalEvent, CommandParser, PgDbClient, ServerState, + TeamApiClient, WebhookSecret, }; use clap::Parser; use sqlx::postgres::PgConnectOptions; @@ -18,7 +19,9 @@ use tracing_subscriber::filter::EnvFilter; /// How often should the bot check DB state, e.g. for handling timeouts. const PERIODIC_REFRESH: Duration = Duration::from_secs(120); -#[derive(clap::Parser)] +const GITHUB_API_URL: &str = "https://api.github.com"; + +#[derive(Parser)] struct Opts { /// Github App ID. #[arg(long, env = "APP_ID")] @@ -39,6 +42,10 @@ struct Opts { /// Prefix used for bot commands in PR comments. #[arg(long, env = "CMD_PREFIX", default_value = "@bors")] cmd_prefix: String, + + /// Prefix used for bot commands in PR comments. + #[arg(long, env = "CI_ACCESS_TOKEN")] + ci_access_token: Option, } /// Starts a server that receives GitHub webhooks and generates events into a queue @@ -81,15 +88,30 @@ fn try_main(opts: Opts) -> anyhow::Result<()> { let db = runtime .block_on(initialize_db(&opts.db)) .context("Cannot initialize database")?; - let team_api = TeamApiClient::default(); - let (client, loaded_repos) = runtime.block_on(async { - let client = create_github_client( + let team_api_client = TeamApiClient::default(); + let gh_app_client = runtime.block_on(async { + create_github_client( opts.app_id.into(), - "https://api.github.com".to_string(), + GITHUB_API_URL.to_string(), opts.private_key.into(), - )?; - let repos = load_repositories(&client, &team_api).await?; - Ok::<_, anyhow::Error>((client, repos)) + ) + })?; + let ci_client = match opts.ci_access_token { + Some(access_token) => { + let client = runtime.block_on(async { + tracing::warn!("creating client ci"); + create_github_client_from_access_token( + GITHUB_API_URL.to_string(), + access_token.into(), + ) + })?; + Some(client) + } + None => None, + }; + let loaded_repos = runtime.block_on(async { + let repos = load_repositories(&gh_app_client, &team_api_client).await?; + Ok::<_, anyhow::Error>(repos) })?; let mut repos = HashMap::default(); @@ -108,8 +130,16 @@ fn try_main(opts: Opts) -> anyhow::Result<()> { repos.insert(name, Arc::new(repo)); } - let ctx = BorsContext::new(CommandParser::new(opts.cmd_prefix), Arc::new(db), repos); - let (repository_tx, global_tx, bors_process) = create_bors_process(ctx, client, team_api); + let ctx = BorsContextBuilder::default() + .parser(CommandParser::new(opts.cmd_prefix)) + .db(Arc::new(db)) + .repositories(repos) + .gh_app_client(gh_app_client) + .ci_client(ci_client) + .team_api_client(team_api_client) + .build() + .unwrap(); + let (repository_tx, global_tx, bors_process) = create_bors_process(ctx); let refresh_tx = global_tx.clone(); let refresh_process = async move { diff --git a/src/bors/command/parser.rs b/src/bors/command/parser.rs index e922f46d..ea37ff4d 100644 --- a/src/bors/command/parser.rs +++ b/src/bors/command/parser.rs @@ -22,6 +22,7 @@ enum CommandPart<'a> { KeyValue { key: &'a str, value: &'a str }, } +#[derive(Clone)] pub struct CommandParser { prefix: String, } diff --git a/src/bors/context.rs b/src/bors/context.rs index b3a12382..7eb80d15 100644 --- a/src/bors/context.rs +++ b/src/bors/context.rs @@ -3,27 +3,24 @@ use std::{ sync::{Arc, RwLock}, }; -use crate::{bors::command::CommandParser, github::GithubRepoName, PgDbClient}; +use derive_builder::Builder; +use octocrab::Octocrab; + +use crate::{bors::command::CommandParser, github::GithubRepoName, PgDbClient, TeamApiClient}; use super::RepositoryState; +#[derive(Builder)] pub struct BorsContext { pub parser: CommandParser, pub db: Arc, + #[builder(field( + ty = "HashMap>", + build = "RwLock::new(self.repositories.clone())" + ))] pub repositories: RwLock>>, -} - -impl BorsContext { - pub fn new( - parser: CommandParser, - db: Arc, - repositories: HashMap>, - ) -> Self { - let repositories = RwLock::new(repositories); - Self { - parser, - db, - repositories, - } - } + pub gh_app_client: Octocrab, + #[builder(default)] + pub ci_client: Option, + pub team_api_client: TeamApiClient, } diff --git a/src/bors/handlers/mod.rs b/src/bors/handlers/mod.rs index 7b76da53..048d6c5f 100644 --- a/src/bors/handlers/mod.rs +++ b/src/bors/handlers/mod.rs @@ -1,7 +1,6 @@ use std::sync::Arc; use anyhow::Context; -use octocrab::Octocrab; use tracing::Instrument; use crate::bors::command::{BorsCommand, CommandParseError}; @@ -17,7 +16,7 @@ use crate::bors::handlers::workflow::{ handle_check_suite_completed, handle_workflow_completed, handle_workflow_started, }; use crate::bors::{BorsContext, Comment, RepositoryState}; -use crate::{load_repositories, PgDbClient, TeamApiClient}; +use crate::{load_repositories, PgDbClient}; #[cfg(test)] use crate::tests::util::TestSyncMarker; @@ -142,16 +141,12 @@ pub static WAIT_FOR_REFRESH: TestSyncMarker = TestSyncMarker::new(); pub async fn handle_bors_global_event( event: BorsGlobalEvent, ctx: Arc, - gh_client: &Octocrab, - team_api_client: &TeamApiClient, ) -> anyhow::Result<()> { let db = Arc::clone(&ctx.db); match event { BorsGlobalEvent::InstallationsChanged => { let span = tracing::info_span!("Installations changed"); - reload_repos(ctx, gh_client, team_api_client) - .instrument(span) - .await?; + reload_repos(ctx).instrument(span).await?; } BorsGlobalEvent::Refresh => { let span = tracing::info_span!("Refresh"); @@ -161,7 +156,7 @@ pub async fn handle_bors_global_event( let repo = Arc::clone(&repo); async { let subspan = tracing::info_span!("Repo", repo = repo.repository().to_string()); - refresh_repository(repo, Arc::clone(&db), team_api_client) + refresh_repository(repo, Arc::clone(&db), &ctx.team_api_client) .instrument(subspan) .await } @@ -274,12 +269,8 @@ async fn handle_comment( Ok(()) } -async fn reload_repos( - ctx: Arc, - gh_client: &Octocrab, - team_api_client: &TeamApiClient, -) -> anyhow::Result<()> { - let reloaded_repos = load_repositories(gh_client, team_api_client).await?; +async fn reload_repos(ctx: Arc) -> anyhow::Result<()> { + let reloaded_repos = load_repositories(&ctx.gh_app_client, &ctx.team_api_client).await?; let mut repositories = ctx.repositories.write().unwrap(); for repo in repositories.values() { if !reloaded_repos.contains_key(repo.repository()) { diff --git a/src/bors/mod.rs b/src/bors/mod.rs index 719a3bd1..f8050198 100644 --- a/src/bors/mod.rs +++ b/src/bors/mod.rs @@ -3,6 +3,7 @@ use arc_swap::ArcSwap; pub use command::CommandParser; pub use comment::Comment; pub use context::BorsContext; +pub use context::BorsContextBuilder; #[cfg(test)] pub use handlers::WAIT_FOR_REFRESH; pub use handlers::{handle_bors_global_event, handle_bors_repository_event}; diff --git a/src/github/api/mod.rs b/src/github/api/mod.rs index 7ef31039..c46ed8cf 100644 --- a/src/github/api/mod.rs +++ b/src/github/api/mod.rs @@ -36,6 +36,17 @@ pub fn create_github_client( .context("Could not create octocrab builder") } +pub fn create_github_client_from_access_token( + github_url: String, + access_token: SecretString, +) -> anyhow::Result { + Octocrab::builder() + .base_uri(github_url)? + .user_access_token(access_token) + .build() + .context("Could not create octocrab builder") +} + /// Loads repositories that are connected to the given GitHub App client. /// The anyhow::Result is intended, because we wanted to have /// a hard error when the repos fail to load when the bot starts, but only log diff --git a/src/github/server.rs b/src/github/server.rs index ebb157c8..13dab2d8 100644 --- a/src/github/server.rs +++ b/src/github/server.rs @@ -2,7 +2,7 @@ use crate::bors::event::BorsEvent; use crate::bors::{handle_bors_global_event, handle_bors_repository_event, BorsContext}; use crate::github::webhook::GitHubWebhook; use crate::github::webhook::WebhookSecret; -use crate::{BorsGlobalEvent, BorsRepositoryEvent, TeamApiClient}; +use crate::{BorsGlobalEvent, BorsRepositoryEvent}; use anyhow::Error; use axum::extract::State; @@ -10,7 +10,6 @@ use axum::http::StatusCode; use axum::response::IntoResponse; use axum::routing::{get, post}; use axum::Router; -use octocrab::Octocrab; use std::future::Future; use std::sync::Arc; use tokio::sync::mpsc; @@ -83,8 +82,6 @@ pub async fn github_webhook_handler( /// them. pub fn create_bors_process( ctx: BorsContext, - gh_client: Octocrab, - team_api: TeamApiClient, ) -> ( mpsc::Sender, mpsc::Sender, @@ -104,7 +101,7 @@ pub fn create_bors_process( { tokio::join!( consume_repository_events(ctx.clone(), repository_rx), - consume_global_events(ctx.clone(), global_rx, gh_client, team_api) + consume_global_events(ctx.clone(), global_rx) ); } // In real execution, the bot runs forever. If there is something that finishes @@ -116,7 +113,7 @@ pub fn create_bors_process( _ = consume_repository_events(ctx.clone(), repository_rx) => { tracing::error!("Repository event handling process has ended"); } - _ = consume_global_events(ctx.clone(), global_rx, gh_client, team_api) => { + _ = consume_global_events(ctx.clone(), global_rx) => { tracing::error!("Global event handling process has ended"); } } @@ -146,15 +143,13 @@ async fn consume_repository_events( async fn consume_global_events( ctx: Arc, mut global_rx: mpsc::Receiver, - gh_client: Octocrab, - team_api: TeamApiClient, ) { while let Some(event) = global_rx.recv().await { let ctx = ctx.clone(); let span = tracing::info_span!("GlobalEvent"); tracing::debug!("Received global event: {event:#?}"); - if let Err(error) = handle_bors_global_event(event, ctx, &gh_client, &team_api) + if let Err(error) = handle_bors_global_event(event, ctx) .instrument(span.clone()) .await { diff --git a/src/lib.rs b/src/lib.rs index 5e85551f..ec1f9e9b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,10 +8,14 @@ mod github; mod permissions; mod utils; -pub use bors::{event::BorsGlobalEvent, event::BorsRepositoryEvent, BorsContext, CommandParser}; +pub use bors::{ + event::BorsGlobalEvent, event::BorsRepositoryEvent, BorsContext, BorsContextBuilder, + CommandParser, +}; pub use database::PgDbClient; pub use github::{ api::create_github_client, + api::create_github_client_from_access_token, api::load_repositories, server::{create_app, create_bors_process, ServerState}, WebhookSecret, diff --git a/src/permissions.rs b/src/permissions.rs index 4c7022fa..657dd5ba 100644 --- a/src/permissions.rs +++ b/src/permissions.rs @@ -37,6 +37,7 @@ pub(crate) struct UserPermissionsResponse { github_ids: HashSet, } +#[derive(Clone)] pub struct TeamApiClient { base_url: String, } diff --git a/src/tests/mocks/bors.rs b/src/tests/mocks/bors.rs index a0f8ad5b..d410e8a7 100644 --- a/src/tests/mocks/bors.rs +++ b/src/tests/mocks/bors.rs @@ -24,8 +24,8 @@ use crate::tests::mocks::{ }; use crate::tests::webhook::{create_webhook_request, TEST_WEBHOOK_SECRET}; use crate::{ - create_app, create_bors_process, BorsContext, BorsGlobalEvent, CommandParser, PgDbClient, - ServerState, WebhookSecret, + create_app, create_bors_process, BorsContextBuilder, BorsGlobalEvent, CommandParser, + PgDbClient, ServerState, WebhookSecret, }; use super::pull_request::{GitHubPullRequestEventPayload, PullRequestChangeEvent}; @@ -109,10 +109,16 @@ impl BorsTester { repos.insert(name, Arc::new(repo)); } - let ctx = BorsContext::new(CommandParser::new("@bors".to_string()), db.clone(), repos); + let ctx = BorsContextBuilder::default() + .parser(CommandParser::new("@bors".to_string())) + .db(db.clone()) + .repositories(repos) + .gh_app_client(mock.github_client()) + .team_api_client(mock.team_api_client()) + .build() + .unwrap(); - let (repository_tx, global_tx, bors_process) = - create_bors_process(ctx, mock.github_client(), mock.team_api_client()); + let (repository_tx, global_tx, bors_process) = create_bors_process(ctx); let state = ServerState::new( repository_tx,