Skip to content

Commit

Permalink
refactor: drop Transactional and ConnectionOrTransaction
Browse files Browse the repository at this point in the history
Directly use the `ConnectionTrait` from sea_orm.
  • Loading branch information
ctron committed Dec 4, 2024
1 parent 88e4a52 commit 398b96c
Show file tree
Hide file tree
Showing 122 changed files with 2,090 additions and 2,215 deletions.
129 changes: 33 additions & 96 deletions common/src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,97 +11,13 @@ pub use func::*;
use anyhow::{ensure, Context};
use migration::{Migrator, MigratorTrait};
use sea_orm::{
prelude::async_trait, ConnectOptions, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
DbBackend, DbErr, ExecResult, QueryResult, RuntimeErr, Statement,
prelude::async_trait, ConnectOptions, ConnectionTrait, DatabaseConnection, DbBackend, DbErr,
ExecResult, QueryResult, RuntimeErr, Statement,
};
use sqlx::error::ErrorKind;
use std::ops::{Deref, DerefMut};
use tracing::instrument;

pub enum Transactional {
None,
Some(DatabaseTransaction),
}

impl Transactional {
/// Commit the database transaction.
///
/// If there's no underlying database transaction, then this becomes a no-op.
#[instrument(skip_all, fields(transactional=matches!(self, Transactional::Some(_))), ret)]
pub async fn commit(self) -> Result<(), DbErr> {
match self {
Transactional::None => {}
Transactional::Some(inner) => {
inner.commit().await?;
}
}

Ok(())
}
}

impl AsRef<Transactional> for Transactional {
fn as_ref(&self) -> &Transactional {
self
}
}

impl AsRef<Transactional> for () {
fn as_ref(&self) -> &Transactional {
&Transactional::None
}
}

#[derive(Clone)]
pub enum ConnectionOrTransaction<'db> {
Connection(&'db DatabaseConnection),
Transaction(&'db DatabaseTransaction),
}

impl<'db> From<&'db DatabaseTransaction> for ConnectionOrTransaction<'db> {
fn from(value: &'db DatabaseTransaction) -> Self {
Self::Transaction(value)
}
}

#[async_trait::async_trait]
impl ConnectionTrait for ConnectionOrTransaction<'_> {
fn get_database_backend(&self) -> DbBackend {
match self {
ConnectionOrTransaction::Connection(inner) => inner.get_database_backend(),
ConnectionOrTransaction::Transaction(inner) => inner.get_database_backend(),
}
}

async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
match self {
ConnectionOrTransaction::Connection(inner) => inner.execute(stmt).await,
ConnectionOrTransaction::Transaction(inner) => inner.execute(stmt).await,
}
}

async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
match self {
ConnectionOrTransaction::Connection(inner) => inner.execute_unprepared(sql).await,
ConnectionOrTransaction::Transaction(inner) => inner.execute_unprepared(sql).await,
}
}

async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
match self {
ConnectionOrTransaction::Connection(inner) => inner.query_one(stmt).await,
ConnectionOrTransaction::Transaction(inner) => inner.query_one(stmt).await,
}
}

async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
match self {
ConnectionOrTransaction::Connection(inner) => inner.query_all(stmt).await,
ConnectionOrTransaction::Transaction(inner) => inner.query_all(stmt).await,
}
}
}

#[derive(Clone, Debug)]
pub struct Database {
/// the database connection
Expand All @@ -111,16 +27,6 @@ pub struct Database {
}

impl Database {
pub fn connection<'db, TX: AsRef<Transactional>>(
&'db self,
tx: &'db TX,
) -> ConnectionOrTransaction<'db> {
match tx.as_ref() {
Transactional::None => ConnectionOrTransaction::Connection(&self.db),
Transactional::Some(tx) => ConnectionOrTransaction::Transaction(tx),
}
}

#[instrument(err)]
pub async fn new(database: &crate::config::Database) -> Result<Self, anyhow::Error> {
let url = database.to_url();
Expand Down Expand Up @@ -260,6 +166,37 @@ impl ConnectionTrait for Database {
}
}

/// Implementation of the connection trait for our database struct.
///
/// **NOTE**: We lack the implementations for the `mock` feature. However, the mock feature would
/// require us to have the `Database` struct to be non-clone, which we don't support anyway.
#[async_trait::async_trait]
impl ConnectionTrait for &Database {
fn get_database_backend(&self) -> DbBackend {
self.db.get_database_backend()
}

async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
self.db.execute(stmt).await
}

async fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
self.db.execute_unprepared(sql).await
}

async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
self.db.query_one(stmt).await
}

async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
self.db.query_all(stmt).await
}

fn support_returning(&self) -> bool {
self.db.support_returning()
}
}

/// A trait to help working with database errors
pub trait DatabaseErrors {
/// return `true` if the error is a duplicate key error
Expand Down
4 changes: 2 additions & 2 deletions entity/src/advisory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl Model {
let db = ctx.data::<Arc<db::Database>>()?;
if let Some(found) = self
.find_related(organization::Entity)
.one(&db.connection(&db::Transactional::None))
.one(db.as_ref())
.await?
{
Ok(found)
Expand All @@ -51,7 +51,7 @@ impl Model {
let db = ctx.data::<Arc<db::Database>>()?;
Ok(self
.find_related(vulnerability::Entity)
.all(&db.connection(&db::Transactional::None))
.all(db.as_ref())
.await?)
}
}
Expand Down
30 changes: 22 additions & 8 deletions modules/analysis/src/endpoints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ use trustify_auth::{
use trustify_common::{db::query::Query, db::Database, model::Paginated, purl::Purl};

pub fn configure(config: &mut utoipa_actix_web::service_config::ServiceConfig, db: Database) {
let analysis = AnalysisService::new(db);
let analysis = AnalysisService::new();

config
.app_data(web::Data::new(analysis))
.app_data(web::Data::new(db))
.service(search_component_root_components)
.service(get_component_root_components)
.service(analysis_status)
Expand All @@ -34,12 +35,13 @@ pub fn configure(config: &mut utoipa_actix_web::service_config::ServiceConfig, d
#[get("/v1/analysis/status")]
pub async fn analysis_status(
service: web::Data<AnalysisService>,
db: web::Data<Database>,
user: UserInformation,
authorizer: web::Data<Authorizer>,
_: Require<ReadSbom>,
) -> actix_web::Result<impl Responder> {
authorizer.require(&user, Permission::ReadSbom)?;
Ok(HttpResponse::Ok().json(service.status(()).await?))
Ok(HttpResponse::Ok().json(service.status(db.as_ref()).await?))
}

#[utoipa::path(
Expand All @@ -56,13 +58,14 @@ pub async fn analysis_status(
#[get("/v1/analysis/root-component")]
pub async fn search_component_root_components(
service: web::Data<AnalysisService>,
db: web::Data<Database>,
web::Query(search): web::Query<Query>,
web::Query(paginated): web::Query<Paginated>,
_: Require<ReadSbom>,
) -> actix_web::Result<impl Responder> {
Ok(HttpResponse::Ok().json(
service
.retrieve_root_components(search, paginated, ())
.retrieve_root_components(search, paginated, db.as_ref())
.await?,
))
}
Expand All @@ -80,6 +83,7 @@ pub async fn search_component_root_components(
#[get("/v1/analysis/root-component/{key}")]
pub async fn get_component_root_components(
service: web::Data<AnalysisService>,
db: web::Data<Database>,
key: web::Path<String>,
web::Query(paginated): web::Query<Paginated>,
_: Require<ReadSbom>,
Expand All @@ -88,13 +92,13 @@ pub async fn get_component_root_components(
let purl: Purl = Purl::from_str(&key).map_err(Error::Purl)?;
Ok(HttpResponse::Ok().json(
service
.retrieve_root_components_by_purl(purl, paginated, ())
.retrieve_root_components_by_purl(purl, paginated, db.as_ref())
.await?,
))
} else {
Ok(HttpResponse::Ok().json(
service
.retrieve_root_components_by_name(key.to_string(), paginated, ())
.retrieve_root_components_by_name(key.to_string(), paginated, db.as_ref())
.await?,
))
}
Expand All @@ -114,11 +118,16 @@ pub async fn get_component_root_components(
#[get("/v1/analysis/dep")]
pub async fn search_component_deps(
service: web::Data<AnalysisService>,
db: web::Data<Database>,
web::Query(search): web::Query<Query>,
web::Query(paginated): web::Query<Paginated>,
_: Require<ReadSbom>,
) -> actix_web::Result<impl Responder> {
Ok(HttpResponse::Ok().json(service.retrieve_deps(search, paginated, ()).await?))
Ok(HttpResponse::Ok().json(
service
.retrieve_deps(search, paginated, db.as_ref())
.await?,
))
}

#[utoipa::path(
Expand All @@ -134,17 +143,22 @@ pub async fn search_component_deps(
#[get("/v1/analysis/dep/{key}")]
pub async fn get_component_deps(
service: web::Data<AnalysisService>,
db: web::Data<Database>,
key: web::Path<String>,
web::Query(paginated): web::Query<Paginated>,
_: Require<ReadSbom>,
) -> actix_web::Result<impl Responder> {
if key.starts_with("pkg:") {
let purl: Purl = Purl::from_str(&key).map_err(Error::Purl)?;
Ok(HttpResponse::Ok().json(service.retrieve_deps_by_purl(purl, paginated, ()).await?))
Ok(HttpResponse::Ok().json(
service
.retrieve_deps_by_purl(purl, paginated, db.as_ref())
.await?,
))
} else {
Ok(HttpResponse::Ok().json(
service
.retrieve_deps_by_name(key.to_string(), paginated, ())
.retrieve_deps_by_name(key.to_string(), paginated, db.as_ref())
.await?,
))
}
Expand Down
Loading

0 comments on commit 398b96c

Please sign in to comment.