Skip to content

Commit

Permalink
Some backend refactoring
Browse files Browse the repository at this point in the history
Signed-off-by: Sergio Castaño Arteaga <[email protected]>
  • Loading branch information
tegioz committed Jan 25, 2022
1 parent 655524e commit f4635e2
Show file tree
Hide file tree
Showing 7 changed files with 537 additions and 508 deletions.
76 changes: 76 additions & 0 deletions remonitor-apiserver/src/handlers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
use axum::{
extract,
extract::Extension,
http::{
header::{HeaderMap, HeaderName, HeaderValue},
StatusCode,
},
response,
};
use deadpool_postgres::Pool;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio_postgres::types::Json;
use uuid::Uuid;

/// Header that indicates the number of items available for pagination purposes.
const PAGINATION_TOTAL_COUNT: &str = "pagination-total-count";

/// Query input used when searching for projects.
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct SearchProjectsInput {
limit: Option<usize>,
offset: Option<usize>,
text: Option<String>,
category: Option<Vec<usize>>,
maturity: Option<Vec<usize>>,
rating: Option<Vec<char>>,
}

/// Handler that allows searching for projects.
pub(crate) async fn search_projects(
Extension(db_pool): Extension<Pool>,
extract::Json(input): extract::Json<SearchProjectsInput>,
) -> Result<(HeaderMap, response::Json<Value>), (StatusCode, String)> {
// Search projects in database
let db = db_pool.get().await.map_err(internal_error)?;
let row = db
.query_one("select * from search_projects($1::jsonb)", &[&Json(input)])
.await
.map_err(internal_error)?;
let Json(projects): Json<Value> = row.get("projects");
let total_count: i64 = row.get("total_count");

// Prepare response headers
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static(PAGINATION_TOTAL_COUNT),
HeaderValue::from_str(&total_count.to_string()).unwrap(),
);

Ok((headers, response::Json(projects)))
}

/// Handler that returns the requested project.
pub(crate) async fn get_project(
Extension(db_pool): Extension<Pool>,
extract::Path(project_id): extract::Path<Uuid>,
) -> Result<response::Json<Value>, (StatusCode, String)> {
// Get project from database
let db = db_pool.get().await.map_err(internal_error)?;
let row = db
.query_one("select get_project($1::uuid)", &[&project_id])
.await
.map_err(internal_error)?;
let Json(project): Json<Value> = row.get(0);

Ok(response::Json(project))
}

/// Helper for mapping any error into a `500 Internal Server Error` response.
fn internal_error<E>(err: E) -> (StatusCode, String)
where
E: std::error::Error,
{
(StatusCode::INTERNAL_SERVER_ERROR, err.to_string())
}
125 changes: 5 additions & 120 deletions remonitor-apiserver/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,12 @@
mod handlers;
mod router;

use anyhow::Error;
use axum::{
extract,
extract::Extension,
http::{
header::{HeaderMap, HeaderName, HeaderValue},
StatusCode,
},
response,
routing::{get, get_service, post},
AddExtensionLayer, Router,
};
use config::{Config, File};
use deadpool_postgres::{Config as DbConfig, Pool, Runtime};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use deadpool_postgres::{Config as DbConfig, Runtime};
use std::net::SocketAddr;
use std::path::Path;
use tokio_postgres::types::Json;
use tokio_postgres::NoTls;
use tower::ServiceBuilder;
use tower_http::{
services::{ServeDir, ServeFile},
trace::TraceLayer,
};
use tracing::info;
use uuid::Uuid;

/// Header that indicates the number of items available for pagination purposes.
const PAGINATION_TOTAL_COUNT: &str = "pagination-total-count";

#[tokio::main]
async fn main() -> Result<(), Error> {
Expand All @@ -51,7 +30,7 @@ async fn main() -> Result<(), Error> {
let db_pool = db_cfg.create_pool(Some(Runtime::Tokio1), NoTls)?;

// Setup and launch HTTP server
let router = setup_router(&cfg, db_pool)?;
let router = router::setup(&cfg, db_pool)?;
let addr: SocketAddr = cfg.get_str("apiserver.addr")?.parse()?;
info!("listening on {}", addr);
axum::Server::bind(&addr)
Expand All @@ -61,97 +40,3 @@ async fn main() -> Result<(), Error> {

Ok(())
}

/// Setup API server router.
fn setup_router(cfg: &Config, db_pool: Pool) -> Result<Router, Error> {
// Setup some paths
let static_path = cfg.get_str("apiserver.staticPath")?;
let index_path = Path::new(&static_path).join("index.html");

// Setup error handler
let error_handler = |err: std::io::Error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("internal error: {}", err),
)
};

// Setup router
let router = Router::new()
.route("/api/projects/search", post(search_projects))
.route("/api/projects/:project_id", get(get_project))
.route(
"/",
get_service(ServeFile::new(index_path)).handle_error(error_handler),
)
.nest(
"/static",
get_service(ServeDir::new(static_path)).handle_error(error_handler),
)
.layer(
ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.layer(AddExtensionLayer::new(db_pool)),
);

Ok(router)
}

/// Handler that allows searching for projects.
async fn search_projects(
Extension(db_pool): Extension<Pool>,
extract::Json(input): extract::Json<SearchProjectInput>,
) -> Result<(HeaderMap, response::Json<Value>), (StatusCode, String)> {
// Search projects in database
let db = db_pool.get().await.map_err(internal_error)?;
let row = db
.query_one("select * from search_projects($1::jsonb)", &[&Json(input)])
.await
.map_err(internal_error)?;
let Json(projects): Json<Value> = row.get("projects");
let total_count: i64 = row.get("total_count");

// Prepare response headers
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static(PAGINATION_TOTAL_COUNT),
HeaderValue::from_str(&total_count.to_string()).unwrap(),
);

Ok((headers, response::Json(projects)))
}

/// Handler that returns the requested project.
async fn get_project(
Extension(db_pool): Extension<Pool>,
extract::Path(project_id): extract::Path<Uuid>,
) -> Result<response::Json<Value>, (StatusCode, String)> {
// Get project from database
let db = db_pool.get().await.map_err(internal_error)?;
let row = db
.query_one("select get_project($1::uuid)", &[&project_id])
.await
.map_err(internal_error)?;
let Json(project): Json<Value> = row.get(0);

Ok(response::Json(project))
}

/// Helper for mapping any error into a `500 Internal Server Error` response.
fn internal_error<E>(err: E) -> (StatusCode, String)
where
E: std::error::Error,
{
(StatusCode::INTERNAL_SERVER_ERROR, err.to_string())
}

/// Query input used when searching for projects.
#[derive(Debug, Serialize, Deserialize)]
struct SearchProjectInput {
limit: Option<usize>,
offset: Option<usize>,
text: Option<String>,
category: Option<Vec<usize>>,
maturity: Option<Vec<usize>>,
rating: Option<Vec<char>>,
}
50 changes: 50 additions & 0 deletions remonitor-apiserver/src/router.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use crate::handlers::*;
use anyhow::Error;
use axum::{
http::StatusCode,
routing::{get, get_service, post},
AddExtensionLayer, Router,
};
use config::Config;
use deadpool_postgres::Pool;
use std::path::Path;
use tower::ServiceBuilder;
use tower_http::{
services::{ServeDir, ServeFile},
trace::TraceLayer,
};

/// Setup API server router.
pub(crate) fn setup(cfg: &Config, db_pool: Pool) -> Result<Router, Error> {
// Setup some paths
let static_path = cfg.get_str("apiserver.staticPath")?;
let index_path = Path::new(&static_path).join("index.html");

// Setup error handler
let error_handler = |err: std::io::Error| async move {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("internal error: {}", err),
)
};

// Setup router
let router = Router::new()
.route("/api/projects/search", post(search_projects))
.route("/api/projects/:project_id", get(get_project))
.route(
"/",
get_service(ServeFile::new(index_path)).handle_error(error_handler),
)
.nest(
"/static",
get_service(ServeDir::new(static_path)).handle_error(error_handler),
)
.layer(
ServiceBuilder::new()
.layer(TraceLayer::new_for_http())
.layer(AddExtensionLayer::new(db_pool)),
);

Ok(router)
}
4 changes: 2 additions & 2 deletions remonitor-linter/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
mod check;

use crate::check::Globs;
use anyhow::Error;
use serde::{Deserialize, Serialize};
use std::path::Path;

mod check;

/// A linter report.
#[derive(Debug, Serialize, Deserialize)]
pub struct Report {
Expand Down
Loading

0 comments on commit f4635e2

Please sign in to comment.