Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: drop Transactional and ConnectionOrTransaction #1069

Merged
merged 1 commit into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 33 additions & 96 deletions common/src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,97 +11,13 @@ pub use func::*;
use anyhow::{ensure, Context};
use migration::{Migrator, MigratorTrait};
use sea_orm::{
prelude::async_trait, ConnectOptions, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
DbBackend, DbErr, ExecResult, QueryResult, RuntimeErr, Statement,
prelude::async_trait, ConnectOptions, ConnectionTrait, DatabaseConnection, DbBackend, DbErr,
ExecResult, QueryResult, RuntimeErr, Statement,
};
use sqlx::error::ErrorKind;
use std::ops::{Deref, DerefMut};
use tracing::instrument;

pub enum Transactional {
None,
Some(DatabaseTransaction),
}

impl Transactional {
/// Commit the database transaction.
///
/// If there's no underlying database transaction, then this becomes a no-op.
#[instrument(skip_all, fields(transactional=matches!(self, Transactional::Some(_))), ret)]
pub async fn commit(self) -> Result<(), DbErr> {
match self {
Transactional::None => {}
Transactional::Some(inner) => {
inner.commit().await?;
}
}

Ok(())
}
}

impl AsRef<Transactional> for Transactional {
fn as_ref(&self) -> &Transactional {
self
}
}

impl AsRef<Transactional> for () {
fn as_ref(&self) -> &Transactional {
&Transactional::None
}
}

#[derive(Clone)]
pub enum ConnectionOrTransaction<'db> {
Connection(&'db DatabaseConnection),
Transaction(&'db DatabaseTransaction),
}

impl<'db> From<&'db DatabaseTransaction> for ConnectionOrTransaction<'db> {
fn from(value: &'db DatabaseTransaction) -> Self {
Self::Transaction(value)
}
}

#[async_trait::async_trait]
impl ConnectionTrait for ConnectionOrTransaction<'_> {
fn get_database_backend(&self) -> DbBackend {
match self {
ConnectionOrTransaction::Connection(inner) => inner.get_database_backend(),
ConnectionOrTransaction::Transaction(inner) => inner.get_database_backend(),
}
}

async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
match self {
ConnectionOrTransaction::Connection(inner) => inner.execute(stmt).await,
ConnectionOrTransaction::Transaction(inner) => inner.execute(stmt).await,
}
}

async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
match self {
ConnectionOrTransaction::Connection(inner) => inner.execute_unprepared(sql).await,
ConnectionOrTransaction::Transaction(inner) => inner.execute_unprepared(sql).await,
}
}

async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
match self {
ConnectionOrTransaction::Connection(inner) => inner.query_one(stmt).await,
ConnectionOrTransaction::Transaction(inner) => inner.query_one(stmt).await,
}
}

async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
match self {
ConnectionOrTransaction::Connection(inner) => inner.query_all(stmt).await,
ConnectionOrTransaction::Transaction(inner) => inner.query_all(stmt).await,
}
}
}

#[derive(Clone, Debug)]
pub struct Database {
/// the database connection
Expand All @@ -111,16 +27,6 @@ pub struct Database {
}

impl Database {
pub fn connection<'db, TX: AsRef<Transactional>>(
&'db self,
tx: &'db TX,
) -> ConnectionOrTransaction<'db> {
match tx.as_ref() {
Transactional::None => ConnectionOrTransaction::Connection(&self.db),
Transactional::Some(tx) => ConnectionOrTransaction::Transaction(tx),
}
}

#[instrument(err)]
pub async fn new(database: &crate::config::Database) -> Result<Self, anyhow::Error> {
let url = database.to_url();
Expand Down Expand Up @@ -260,6 +166,37 @@ impl ConnectionTrait for Database {
}
}

/// Implementation of the connection trait for our database struct.
///
/// **NOTE**: We lack the implementations for the `mock` feature. However, the mock feature would
/// require us to have the `Database` struct to be non-clone, which we don't support anyway.
#[async_trait::async_trait]
impl ConnectionTrait for &Database {
fn get_database_backend(&self) -> DbBackend {
self.db.get_database_backend()
}

async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
self.db.execute(stmt).await
}

async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
self.db.execute_unprepared(sql).await
}

async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
self.db.query_one(stmt).await
}

async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
self.db.query_all(stmt).await
}

fn support_returning(&self) -> bool {
self.db.support_returning()
}
}

/// A trait to help working with database errors
pub trait DatabaseErrors {
/// return `true` if the error is a duplicate key error
Expand Down
4 changes: 2 additions & 2 deletions entity/src/advisory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl Model {
let db = ctx.data::<Arc<db::Database>>()?;
if let Some(found) = self
.find_related(organization::Entity)
.one(&db.connection(&db::Transactional::None))
.one(db.as_ref())
.await?
{
Ok(found)
Expand All @@ -51,7 +51,7 @@ impl Model {
let db = ctx.data::<Arc<db::Database>>()?;
Ok(self
.find_related(vulnerability::Entity)
.all(&db.connection(&db::Transactional::None))
.all(db.as_ref())
.await?)
}
}
Expand Down
30 changes: 22 additions & 8 deletions modules/analysis/src/endpoints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ use trustify_auth::{
use trustify_common::{db::query::Query, db::Database, model::Paginated, purl::Purl};

pub fn configure(config: &mut utoipa_actix_web::service_config::ServiceConfig, db: Database) {
let analysis = AnalysisService::new(db);
let analysis = AnalysisService::new();

config
.app_data(web::Data::new(analysis))
.app_data(web::Data::new(db))
.service(search_component_root_components)
.service(get_component_root_components)
.service(analysis_status)
Expand All @@ -34,12 +35,13 @@ pub fn configure(config: &mut utoipa_actix_web::service_config::ServiceConfig, d
#[get("/v1/analysis/status")]
pub async fn analysis_status(
service: web::Data<AnalysisService>,
db: web::Data<Database>,
user: UserInformation,
authorizer: web::Data<Authorizer>,
_: Require<ReadSbom>,
) -> actix_web::Result<impl Responder> {
authorizer.require(&user, Permission::ReadSbom)?;
Ok(HttpResponse::Ok().json(service.status(()).await?))
Ok(HttpResponse::Ok().json(service.status(db.as_ref()).await?))
}

#[utoipa::path(
Expand All @@ -56,13 +58,14 @@ pub async fn analysis_status(
#[get("/v1/analysis/root-component")]
pub async fn search_component_root_components(
service: web::Data<AnalysisService>,
db: web::Data<Database>,
web::Query(search): web::Query<Query>,
web::Query(paginated): web::Query<Paginated>,
_: Require<ReadSbom>,
) -> actix_web::Result<impl Responder> {
Ok(HttpResponse::Ok().json(
service
.retrieve_root_components(search, paginated, ())
.retrieve_root_components(search, paginated, db.as_ref())
.await?,
))
}
Expand All @@ -80,6 +83,7 @@ pub async fn search_component_root_components(
#[get("/v1/analysis/root-component/{key}")]
pub async fn get_component_root_components(
service: web::Data<AnalysisService>,
db: web::Data<Database>,
key: web::Path<String>,
web::Query(paginated): web::Query<Paginated>,
_: Require<ReadSbom>,
Expand All @@ -88,13 +92,13 @@ pub async fn get_component_root_components(
let purl: Purl = Purl::from_str(&key).map_err(Error::Purl)?;
Ok(HttpResponse::Ok().json(
service
.retrieve_root_components_by_purl(purl, paginated, ())
.retrieve_root_components_by_purl(purl, paginated, db.as_ref())
.await?,
))
} else {
Ok(HttpResponse::Ok().json(
service
.retrieve_root_components_by_name(key.to_string(), paginated, ())
.retrieve_root_components_by_name(key.to_string(), paginated, db.as_ref())
.await?,
))
}
Expand All @@ -114,11 +118,16 @@ pub async fn get_component_root_components(
#[get("/v1/analysis/dep")]
pub async fn search_component_deps(
service: web::Data<AnalysisService>,
db: web::Data<Database>,
web::Query(search): web::Query<Query>,
web::Query(paginated): web::Query<Paginated>,
_: Require<ReadSbom>,
) -> actix_web::Result<impl Responder> {
Ok(HttpResponse::Ok().json(service.retrieve_deps(search, paginated, ()).await?))
Ok(HttpResponse::Ok().json(
service
.retrieve_deps(search, paginated, db.as_ref())
.await?,
))
}

#[utoipa::path(
Expand All @@ -134,17 +143,22 @@ pub async fn search_component_deps(
#[get("/v1/analysis/dep/{key}")]
pub async fn get_component_deps(
service: web::Data<AnalysisService>,
db: web::Data<Database>,
key: web::Path<String>,
web::Query(paginated): web::Query<Paginated>,
_: Require<ReadSbom>,
) -> actix_web::Result<impl Responder> {
if key.starts_with("pkg:") {
let purl: Purl = Purl::from_str(&key).map_err(Error::Purl)?;
Ok(HttpResponse::Ok().json(service.retrieve_deps_by_purl(purl, paginated, ()).await?))
Ok(HttpResponse::Ok().json(
service
.retrieve_deps_by_purl(purl, paginated, db.as_ref())
.await?,
))
} else {
Ok(HttpResponse::Ok().json(
service
.retrieve_deps_by_name(key.to_string(), paginated, ())
.retrieve_deps_by_name(key.to_string(), paginated, db.as_ref())
.await?,
))
}
Expand Down
Loading
Loading