Skip to content

Commit

Permalink
refactor(Backend): 🔥 rewrite request duplication to use Any and unite…
Browse files Browse the repository at this point in the history
…d store
  • Loading branch information
Eason0729 committed Feb 8, 2024
1 parent 4675b99 commit 52a3318
Show file tree
Hide file tree
Showing 17 changed files with 151 additions and 154 deletions.
38 changes: 11 additions & 27 deletions backend/src/controller/crypto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tracing::Span;

use crate::init::config::GlobalConfig;
use base64::{engine::general_purpose::URL_SAFE, Engine};
use blake2::{Blake2b512, Digest};

type Result<T> = std::result::Result<T, Error>;
Expand All @@ -28,21 +29,7 @@ impl From<Error> for tonic::Status {
}
}

#[derive(PartialEq, Eq)]
pub struct HashValue(Vec<u8>);

impl From<Vec<u8>> for HashValue {
fn from(v: Vec<u8>) -> Self {
Self(v)
}
}

impl From<HashValue> for Vec<u8> {
fn from(v: HashValue) -> Self {
v.0
}
}

/// signed object
#[derive(Serialize, Deserialize)]
struct Signed {
data: Vec<u8>,
Expand All @@ -58,9 +45,7 @@ impl CryptoController {
#[tracing::instrument(parent=span,name="crypto_construct",level = "info",skip_all)]
pub fn new(config: &GlobalConfig, span: &Span) -> Self {
let salt = config.database.salt.as_bytes().to_vec();

let signing_key = SigningKey::random(&mut OsRng);

let verifying_key = *signing_key.verifying_key();

Self {
Expand All @@ -69,24 +54,26 @@ impl CryptoController {
verifying_key,
}
}
/// hash `src` and compare hash value with `hashed`
#[tracing::instrument(name = "crypto_hasheq_controller", level = "debug", skip_all)]
pub fn hash_eq(&self, src: &str, tar: &[u8]) -> bool {
let hashed: Vec<u8> = self.hash(src).into();
pub fn hash_eq(&self, src: &str, hashed: &[u8]) -> bool {
let src_hashed: Vec<u8> = self.hash(src);
let mut result = true;
for (a, b) in hashed.iter().zip(tar.iter()) {
for (a, b) in src_hashed.iter().zip(hashed.iter()) {
if *a != *b {
result = false;
}
}
result
}
/// get BLAKE2b-512 hashed bytes with salt
#[tracing::instrument(name = "crypto_hash_controller", level = "debug", skip_all)]
pub fn hash(&self, src: &str) -> HashValue {
pub fn hash(&self, src: &str) -> Vec<u8> {
let mut hasher = Blake2b512::new();
hasher.update(&[src.as_bytes(), self.salt.as_slice()].concat());

let hashed = hasher.finalize();
HashValue(hashed.to_vec())
hashed.to_vec()
}
/// serialize and sign the object with blake2b512, append the signature and return
#[tracing::instrument(level = "debug", skip_all)]
Expand All @@ -99,10 +86,7 @@ impl CryptoController {
data: raw,
signature,
};
Ok(base64::Engine::encode(
&base64::engine::general_purpose::STANDARD_NO_PAD,
bincode::serialize(&signed)?,
))
Ok(URL_SAFE.encode(bincode::serialize(&signed)?))
}
/// extract signature and object of encoded bytes(serde will handle it)
///
Expand All @@ -111,7 +95,7 @@ impl CryptoController {
/// Error if signature invaild
#[tracing::instrument(level = "debug", skip_all)]
pub fn decode<M: DeserializeOwned>(&self, raw: String) -> Result<M> {
let raw = base64::Engine::decode(&base64::engine::general_purpose::STANDARD_NO_PAD, raw)?;
let raw = URL_SAFE.decode(raw)?;
let raw: Signed = bincode::deserialize(&raw)?;
let signature = raw.signature;

Expand Down
71 changes: 38 additions & 33 deletions backend/src/controller/duplicate.rs
Original file line number Diff line number Diff line change
@@ -1,48 +1,53 @@
use std::{
any::{Any, TypeId},
ops::Deref,
};

use quick_cache::sync::Cache;
use std::sync::Arc;
use tracing::Span;
use uuid::Uuid;

#[derive(Eq, Hash, PartialEq)]
struct DupKey {
user_id: i32,
request_id: Uuid,
type_id: TypeId,
}

pub struct DupController {
dup_i32: Cache<(i32, Uuid), i32>,
dup_str: Cache<(i32, Uuid), String>,
store: Cache<DupKey, Arc<dyn Any + 'static + Send + Sync>>,
}

impl DupController {
#[tracing::instrument(parent=span, name="duplicate_construct",level = "info",skip_all)]
pub fn new(span: &Span) -> Self {
Self {
dup_i32: Cache::new(150),
dup_str: Cache::new(150),
store: Cache::new(50),
}
}
/// store a request_id with result i32
pub fn store_i32(&self, spliter: i32, uuid: Uuid, result: i32) {
tracing::trace!(request_id=?uuid);
self.dup_i32.insert((spliter, uuid), result);
}
/// store a request_id with result String
pub fn store_str(&self, spliter: i32, uuid: Uuid, result: String) {
tracing::trace!(request_id=?uuid);
self.dup_str.insert((spliter, uuid), result);
pub fn store<T>(&self, user_id: i32, request_id: Uuid, result: T)
where
T: 'static + Send + Sync + Clone,
{
let key = DupKey {
user_id,
request_id,
type_id: result.type_id(),
};
self.store.insert(key, Arc::new(result));
}
/// attempt to get result of i32
#[tracing::instrument(level = "debug", skip(self))]
pub fn check_i32(&self, spliter: i32, uuid: &Uuid) -> Option<i32> {
tracing::trace!(request_id=?uuid);
if let Some(x) = self.dup_i32.get(&(spliter, *uuid)) {
log::debug!("duplicated request_id: {}, result: {}", uuid, x);
return Some(x);
}
None
}
/// attempt to get result of String
#[tracing::instrument(level = "debug", skip(self))]
pub fn check_str(&self, spliter: i32, uuid: &Uuid) -> Option<String> {
tracing::trace!(request_id=?uuid);
if let Some(x) = self.dup_str.get(&(spliter, *uuid)) {
log::debug!("duplicated request_id: {}, result: {}", uuid, x);
return Some(x);
}
None
pub fn check<T>(&self, user_id: i32, request_id: Uuid) -> Option<T>
where
T: 'static + Send + Sync + Clone,
{
let key = DupKey {
user_id,
request_id,
type_id: TypeId::of::<T>(),
};
self.store
.peek(&key)
.map(|x| x.deref().downcast_ref::<T>().unwrap().clone())
}
}
}
12 changes: 6 additions & 6 deletions backend/src/controller/rate_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,16 @@ trait LimitPolicy {
struct LoginPolicy;

impl LimitPolicy for LoginPolicy {
const BURST: NonZeroU32 = NonZeroU32!(200);
const RATE: NonZeroU32 = NonZeroU32!(55);
const BURST: NonZeroU32 = NonZeroU32!(400);
const RATE: NonZeroU32 = NonZeroU32!(150);
}

/// policy for [`TrafficType::Guest`]
struct GuestPolicy;

impl LimitPolicy for GuestPolicy {
const BURST: NonZeroU32 = NonZeroU32!(80);
const RATE: NonZeroU32 = NonZeroU32!(35);
const BURST: NonZeroU32 = NonZeroU32!(150);
const RATE: NonZeroU32 = NonZeroU32!(80);
}

/// policy for [`TrafficType::Blacklist`]
Expand All @@ -55,8 +55,8 @@ impl LimitPolicy for GuestPolicy {
struct BlacklistPolicy;

impl LimitPolicy for BlacklistPolicy {
const BURST: NonZeroU32 = NonZeroU32!(30);
const RATE: NonZeroU32 = NonZeroU32!(10);
const BURST: NonZeroU32 = NonZeroU32!(60);
const RATE: NonZeroU32 = NonZeroU32!(30);
}

pub struct RateLimitController {
Expand Down
19 changes: 10 additions & 9 deletions backend/src/endpoint/announcement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ impl AnnouncementSet for Arc<Server> {
check_length!(LONG_ART_SIZE, req.info, content);

let uuid = Uuid::parse_str(&req.request_id).map_err(Error::InvaildUUID)?;
if let Some(x) = self.dup.check_i32(user_id, &uuid) {
return Ok(Response::new(x.into()));
if let Some(x) = self.dup.check::<AnnouncementId>(user_id, uuid) {
return Ok(Response::new(x));
};

if perm.super_user() {
Expand All @@ -172,11 +172,12 @@ impl AnnouncementSet for Arc<Server> {
.await
.map_err(Into::<Error>::into)?;

self.dup.store_i32(user_id, uuid, model.id.clone().unwrap());
let id: AnnouncementId = model.id.clone().unwrap().into();

tracing::debug!(id = model.id.clone().unwrap(), "announcement_created");
self.dup.store(user_id, uuid, id.clone());
tracing::debug!(id = id.id, "announcement_created");

Ok(Response::new(model.id.unwrap().into()))
Ok(Response::new(id))
}
#[instrument(skip_all, level = "debug")]
async fn update(
Expand All @@ -190,8 +191,8 @@ impl AnnouncementSet for Arc<Server> {
check_exist_length!(LONG_ART_SIZE, req.info, content);

let uuid = Uuid::parse_str(&req.request_id).map_err(Error::InvaildUUID)?;
if self.dup.check_i32(user_id, &uuid).is_some() {
return Ok(Response::new(()));
if let Some(x) = self.dup.check::<()>(user_id, uuid) {
return Ok(Response::new(x));
};

tracing::trace!(id = req.id.id);
Expand All @@ -205,12 +206,12 @@ impl AnnouncementSet for Arc<Server> {

fill_exist_active_model!(model, req.info, title, content);

let model = model
model
.update(self.db.deref())
.await
.map_err(Into::<Error>::into)?;

self.dup.store_i32(user_id, uuid, model.id);
self.dup.store(user_id, uuid, ());

Ok(Response::new(()))
}
Expand Down
11 changes: 6 additions & 5 deletions backend/src/endpoint/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ impl ChatSet for Arc<Server> {
check_length!(LONG_ART_SIZE, req, message);

let uuid = Uuid::parse_str(&req.request_id).map_err(Error::InvaildUUID)?;
if let Some(x) = self.dup.check_i32(user_id, &uuid) {
return Ok(Response::new(x.into()));
if let Some(x) = self.dup.check::<ChatId>(user_id, uuid) {
return Ok(Response::new(x));
};

let mut model: ActiveModel = Default::default();
Expand All @@ -53,12 +53,13 @@ impl ChatSet for Arc<Server> {
.await
.map_err(Into::<Error>::into)?;

self.dup.store_i32(user_id, uuid, model.id.clone().unwrap());
let id: ChatId = model.id.clone().unwrap().into();
self.dup.store(user_id, uuid, id.clone());

tracing::debug!(id = model.id.clone().unwrap());
tracing::debug!(id = id.id, "chat_created");
self.metrics.chat.add(1, &[]);

Ok(Response::new(model.id.unwrap().into()))
Ok(Response::new(id))
}

async fn remove(&self, req: Request<ChatId>) -> Result<Response<()>, Status> {
Expand Down
22 changes: 11 additions & 11 deletions backend/src/endpoint/contest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ impl ContestSet for Arc<Server> {
check_length!(LONG_ART_SIZE, req.info, content);

let uuid = Uuid::parse_str(&req.request_id).map_err(Error::InvaildUUID)?;
if let Some(x) = self.dup.check_i32(user_id, &uuid) {
return Ok(Response::new(x.into()));
if let Some(x) = self.dup.check::<ContestId>(user_id, uuid) {
return Ok(Response::new(x));
};

if !perm.super_user() {
Expand Down Expand Up @@ -179,12 +179,13 @@ impl ContestSet for Arc<Server> {
.await
.map_err(Into::<Error>::into)?;

self.dup.store_i32(user_id, uuid, model.id.clone().unwrap());
let id: ContestId = model.id.clone().unwrap().into();
self.dup.store(user_id, uuid, id.clone());

tracing::debug!(id = id.id, "contest_created");
self.metrics.contest.add(1, &[]);
tracing::debug!(id = model.id.clone().unwrap());

Ok(Response::new(model.id.unwrap().into()))
Ok(Response::new(id))
}
#[instrument(skip_all, level = "debug")]
async fn update(&self, req: Request<UpdateContestRequest>) -> Result<Response<()>, Status> {
Expand All @@ -195,10 +196,9 @@ impl ContestSet for Arc<Server> {
check_exist_length!(LONG_ART_SIZE, req.info, content);

let uuid = Uuid::parse_str(&req.request_id).map_err(Error::InvaildUUID)?;
if self.dup.check_i32(user_id, &uuid).is_some() {
return Ok(Response::new(()));
if let Some(x) = self.dup.check::<()>(user_id, uuid) {
return Ok(Response::new(x));
};

if !perm.super_user() {
return Err(Error::RequirePermission(RoleLv::Super).into());
}
Expand All @@ -213,7 +213,7 @@ impl ContestSet for Arc<Server> {
if let Some(src) = req.info.password {
if let Some(tar) = model.password.as_ref() {
if perm.root() || self.crypto.hash_eq(&src, tar) {
let hash = self.crypto.hash(&src).into();
let hash = self.crypto.hash(&src);
model.password = Some(hash);
} else {
return Err(Error::PermissionDeny(
Expand All @@ -234,12 +234,12 @@ impl ContestSet for Arc<Server> {
model.end = ActiveValue::Set(into_chrono(x));
}

let model = model
model
.update(self.db.deref())
.await
.map_err(Into::<Error>::into)?;

self.dup.store_i32(user_id, uuid, model.id);
self.dup.store(user_id, uuid, ());

Ok(Response::new(()))
}
Expand Down
Loading

0 comments on commit 52a3318

Please sign in to comment.