From 580e5a6b9e77a284785359557a074d7bb63b9c02 Mon Sep 17 00:00:00 2001 From: Ben Harris Date: Wed, 5 Jul 2023 11:46:17 -0400 Subject: [PATCH 1/5] wip --- src/search.rs | 81 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 src/search.rs diff --git a/src/search.rs b/src/search.rs new file mode 100644 index 0000000..618b15a --- /dev/null +++ b/src/search.rs @@ -0,0 +1,81 @@ +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::{collections::{HashMap}}; + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] + +pub enum Operation { + #[serde(rename = "and")] + And, + #[serde(rename = "or")] + Or, + #[serde(rename = "eq")] + Eq, + #[serde(rename = "ne")] + Ne, + #[serde(rename = "gt")] + Gt, + #[serde(rename = "gte")] + Gte, + #[serde(rename = "lt")] + Lt, + #[serde(rename = "lte")] + Lte, +} + +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] +pub struct Comparator { + /// Dimension of the vectors in the collection + pub attribute: String, + /// Distance metric used for querying + pub op: Operation, + /// Embeddings in the collection + pub value: String, +} + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("Filter incorrectly formatted")] + InvalidFilter, +} + + + +pub fn parse(input: Option>) -> Option { + match input { + Some(input) => match input { + "$and" => {}, + "$or" => {}, + "$eq" => {}, + "$ne" => {}, + "$gt" => {}, + "$gte" => {}, + "$lt" => {}, + "$lte" => {}, + _ => {}, + }, + None => None, + } +} + +fn get_operation_fn(op: Operation) -> impl Fn(String | bool, String | bool) -> bool { + match op { + Operation::And => and, + Operation::Or => or, + Operation::Eq => eq, + Operation::Ne => ne, + Operation::Gt => gt, + Operation::Gte => gte, + Operation::Lt => lt, + Operation::Lte => lte, + } +} + +fn and(lhs: bool, rhs: bool) -> bool { lhs && rhs } +fn or(lhs: bool, rhs: bool) -> bool { lhs || rhs } +fn eq(lhs: String, rhs: String) -> bool { lhs == rhs } +fn ne(lhs: String, rhs: String) -> bool { lhs != rhs } +fn gt(lhs: String, rhs: String) -> bool { lhs > rhs } +fn gte(lhs: String, rhs: String) -> bool { lhs >= rhs } +fn lt(lhs: String, rhs: String) -> bool { lhs < rhs } +fn lte(lhs: String, rhs: String) -> bool { lhs <= rhs } From ae017e3b7116d1daa96fa29ad2d67630f43a2da9 Mon Sep 17 00:00:00 2001 From: Ben Harris Date: Mon, 10 Jul 2023 05:25:27 -0400 Subject: [PATCH 2/5] wip --- .gitignore | 1 + src/db.rs | 3 ++- src/main.rs | 1 + src/routes/collection.rs | 15 +++++++++++++-- src/search.rs | 38 +++++++++++++++++++------------------- 5 files changed, 36 insertions(+), 22 deletions(-) diff --git a/.gitignore b/.gitignore index 401319c..2bc5772 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /storage /target +.DS_Store \ No newline at end of file diff --git a/src/db.rs b/src/db.rs index 423ece3..4a4125d 100644 --- a/src/db.rs +++ b/src/db.rs @@ -55,7 +55,7 @@ pub struct Collection { } impl Collection { - pub fn get_similarity(&self, query: &[f32], k: usize) -> Vec { + pub fn get_similarity(&self, query: &[f32], k: usize) -> Vec { let memo_attr = get_cache_attr(self.distance, query); let distance_fn = get_distance_fn(self.distance); @@ -63,6 +63,7 @@ impl Collection { .embeddings .par_iter() .enumerate() + // .filter() .map(|(index, embedding)| { let score = distance_fn(&embedding.vector, query, memo_attr); ScoreIndex { score, index } diff --git a/src/main.rs b/src/main.rs index 07fd97b..6ba2e21 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,6 +8,7 @@ use tracing_subscriber::{ mod db; mod errors; mod routes; +mod search; mod server; mod shutdown; mod similarity; diff --git a/src/routes/collection.rs b/src/routes/collection.rs index c0f8f47..e9e772b 100644 --- a/src/routes/collection.rs +++ b/src/routes/collection.rs @@ -5,11 +5,13 @@ use aide::axum::{ use axum::{extract::Path, http::StatusCode, Extension}; use axum_jsonschema::Json; use schemars::JsonSchema; +use std::collections::HashMap; use std::time::Instant; use crate::{ db::{self, Collection, DbExtension, Embedding, Error as DbError, SimilarityResult}, errors::HTTPError, + search::Error as SearchError, similarity::Distance, }; @@ -56,6 +58,8 @@ struct QueryCollectionQuery { query: Vec, /// Number of results to return k: Option, + /// Filter results by metadata + filter: Option>, } /// Query a collection @@ -76,8 +80,16 @@ async fn query_collection( return Err(HTTPError::new("Query dimension mismatch").with_status(StatusCode::BAD_REQUEST)); } + // if (req.filter.is_some()) + // match { + // Ok(_) => Ok(Json()), + // Err(_) => Err(HTTPError::new( + // "Metadata filter is incorrectly formatted" + // ).with_status(StatusCode::BAD_REQUEST)), + // } + let instant = Instant::now(); - let results = collection.get_similarity(&req.query, req.k.unwrap_or(1)); + let results = collection.get_similarity(&req.query, req.k.unwrap_or(1), req.filter); drop(db); tracing::trace!("Query to {collection_name} took {:?}", instant.elapsed()); @@ -125,7 +137,6 @@ async fn delete_collection( tracing::trace!("Deleting collection {collection_name}"); let mut db = db.write().await; - let delete_result = db.delete_collection(&collection_name); drop(db); diff --git a/src/search.rs b/src/search.rs index 618b15a..1bc6796 100644 --- a/src/search.rs +++ b/src/search.rs @@ -1,6 +1,6 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use std::{collections::{HashMap}}; +use std::collections::HashMap; #[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] @@ -9,28 +9,30 @@ pub enum Operation { And, #[serde(rename = "or")] Or, - #[serde(rename = "eq")] + #[serde(rename = "eq")] Eq, #[serde(rename = "ne")] Ne, #[serde(rename = "gt")] Gt, - #[serde(rename = "gte")] + #[serde(rename = "gte")] Gte, - #[serde(rename = "lt")] + #[serde(rename = "lt")] Lt, - #[serde(rename = "lte")] + #[serde(rename = "lte")] Lte, } +enum OpSide { + String(String), + Comparator(&Comparator), +} + #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] pub struct Comparator { - /// Dimension of the vectors in the collection - pub attribute: String, - /// Distance metric used for querying + pub lhs: OpSide, pub op: Operation, - /// Embeddings in the collection - pub value: String, + pub rhs: OpSide, } #[derive(Debug, thiserror::Error)] @@ -39,10 +41,8 @@ pub enum Error { InvalidFilter, } - - pub fn parse(input: Option>) -> Option { - match input { + match input { Some(input) => match input { "$and" => {}, "$or" => {}, @@ -58,17 +58,17 @@ pub fn parse(input: Option>) -> Option { } } -fn get_operation_fn(op: Operation) -> impl Fn(String | bool, String | bool) -> bool { +fn get_operation_fn(op: Operation) -> impl Fn(String, String) -> bool { match op { Operation::And => and, Operation::Or => or, Operation::Eq => eq, Operation::Ne => ne, - Operation::Gt => gt, - Operation::Gte => gte, - Operation::Lt => lt, - Operation::Lte => lte, - } + Operation::Gt => gt, + Operation::Gte => gte, + Operation::Lt => lt, + Operation::Lte => lte, + } } fn and(lhs: bool, rhs: bool) -> bool { lhs && rhs } From 2926cd71d7ee7e3ff564b887f2dc843659a2ef82 Mon Sep 17 00:00:00 2001 From: Ben Harris Date: Wed, 12 Jul 2023 01:07:33 -0400 Subject: [PATCH 3/5] wip --- src/db.rs | 15 +++- src/routes/collection.rs | 5 +- src/search.rs | 185 +++++++++++++++++++++++++++++---------- 3 files changed, 153 insertions(+), 52 deletions(-) diff --git a/src/db.rs b/src/db.rs index 4a4125d..7819d0e 100644 --- a/src/db.rs +++ b/src/db.rs @@ -12,6 +12,7 @@ use std::{ use tokio::sync::RwLock; use crate::similarity::{get_cache_attr, get_distance_fn, normalize, Distance, ScoreIndex}; +use crate::search::Filter; lazy_static! { pub static ref STORE_PATH: PathBuf = PathBuf::from("./storage/db"); @@ -55,7 +56,7 @@ pub struct Collection { } impl Collection { - pub fn get_similarity(&self, query: &[f32], k: usize) -> Vec { + pub fn get_similarity(&self, query: &[f32], k: usize, comparate: Option) -> Vec { let memo_attr = get_cache_attr(self.distance, query); let distance_fn = get_distance_fn(self.distance); @@ -63,7 +64,17 @@ impl Collection { .embeddings .par_iter() .enumerate() - // .filter() + .filter(|(_, embedding)| { + if let Some(comparate) = &comparate { + if let Some(metadata) = &embedding.metadata { + if let Some(value) = metadata.get(&comparate.key) { + return comparate.filter(value); + } + } + } + + true + }) .map(|(index, embedding)| { let score = distance_fn(&embedding.vector, query, memo_attr); ScoreIndex { score, index } diff --git a/src/routes/collection.rs b/src/routes/collection.rs index e9e772b..ba44aa9 100644 --- a/src/routes/collection.rs +++ b/src/routes/collection.rs @@ -5,13 +5,12 @@ use aide::axum::{ use axum::{extract::Path, http::StatusCode, Extension}; use axum_jsonschema::Json; use schemars::JsonSchema; -use std::collections::HashMap; use std::time::Instant; use crate::{ db::{self, Collection, DbExtension, Embedding, Error as DbError, SimilarityResult}, errors::HTTPError, - search::Error as SearchError, + search::{Error as SearchError, Comparator}, similarity::Distance, }; @@ -59,7 +58,7 @@ struct QueryCollectionQuery { /// Number of results to return k: Option, /// Filter results by metadata - filter: Option>, + comparator: Option, } /// Query a collection diff --git a/src/search.rs b/src/search.rs index 1bc6796..680db9d 100644 --- a/src/search.rs +++ b/src/search.rs @@ -3,12 +3,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; #[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] - -pub enum Operation { - #[serde(rename = "and")] - And, - #[serde(rename = "or")] - Or, +pub enum EqualityCompOp { #[serde(rename = "eq")] Eq, #[serde(rename = "ne")] @@ -23,59 +18,155 @@ pub enum Operation { Lte, } -enum OpSide { - String(String), - Comparator(&Comparator), +fn eq(lhs: String, rhs: String) -> bool { lhs == rhs } +fn ne(lhs: String, rhs: String) -> bool { lhs != rhs } +fn gt(lhs: String, rhs: String) -> bool { lhs > rhs } +fn gte(lhs: String, rhs: String) -> bool { lhs >= rhs } +fn lt(lhs: String, rhs: String) -> bool { lhs < rhs } +fn lte(lhs: String, rhs: String) -> bool { lhs <= rhs } + +fn get_equality_comp_op_fn(op: EqualityCompOp) -> impl Fn(String, String) -> bool { + match op { + EqualityCompOp::Eq => eq, + EqualityCompOp::Ne => ne, + EqualityCompOp::Gt => gt, + EqualityCompOp::Gte => gte, + EqualityCompOp::Lt => lt, + EqualityCompOp::Lte => lte, + } } #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] + pub struct Comparator { - pub lhs: OpSide, - pub op: Operation, - pub rhs: OpSide, + pub op: EqualityCompOp, + pub val: String, +} + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] + +pub enum LogicalCompOp { + #[serde(rename = "and")] + And, + #[serde(rename = "or")] + Or, +} + +fn and(lhs: bool, rhs: bool) -> bool { lhs && rhs } +fn or(lhs: bool, rhs: bool) -> bool { lhs || rhs } + +fn get_logical_comp_op_fn(op: LogicalCompOp) -> impl Fn(bool, bool) -> bool { + match op { + LogicalCompOp::And => and, + LogicalCompOp::Or => or, + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] +pub struct Logic { + pub lhs: Comparator, + pub op: LogicalCompOp, + pub rhs: Comparator, +} + +pub enum Filter { + Comparator, + Logic, } -#[derive(Debug, thiserror::Error)] +/*** + +{ + "where": { + "$and": [ + ...opt1, + ...opt2, + ], + "$or": [ + ...opt1, + ...opt2, + ], + "$eq": { + "field": "value" + }, + "$ne": { + "field": "value" + }, + "$gt": { + "field": "value" + }, + "$gte": { + "field": "value" + }, + "$lt": { + "field": "value" + }, + "$lte": { + "field": "value" + }, + } +} + + ***/ + + #[derive(Debug, thiserror::Error)] pub enum Error { #[error("Filter incorrectly formatted")] InvalidFilter, } -pub fn parse(input: Option>) -> Option { - match input { - Some(input) => match input { - "$and" => {}, - "$or" => {}, - "$eq" => {}, - "$ne" => {}, - "$gt" => {}, - "$gte" => {}, - "$lt" => {}, - "$lte" => {}, - _ => {}, - }, - None => None, +pub fn parse(input: Option>) -> Result, Error> { + if input.is_none() { + return Ok(None); } -} -fn get_operation_fn(op: Operation) -> impl Fn(String, String) -> bool { - match op { - Operation::And => and, - Operation::Or => or, - Operation::Eq => eq, - Operation::Ne => ne, - Operation::Gt => gt, - Operation::Gte => gte, - Operation::Lt => lt, - Operation::Lte => lte, + if input.expect("").keys().len() != 1 { + return Err(Error::InvalidFilter); } -} -fn and(lhs: bool, rhs: bool) -> bool { lhs && rhs } -fn or(lhs: bool, rhs: bool) -> bool { lhs || rhs } -fn eq(lhs: String, rhs: String) -> bool { lhs == rhs } -fn ne(lhs: String, rhs: String) -> bool { lhs != rhs } -fn gt(lhs: String, rhs: String) -> bool { lhs > rhs } -fn gte(lhs: String, rhs: String) -> bool { lhs >= rhs } -fn lt(lhs: String, rhs: String) -> bool { lhs < rhs } -fn lte(lhs: String, rhs: String) -> bool { lhs <= rhs } + match input { + "$and" => { + let lhs = parse(input.expect("").get("lhs")); + let rhs = parse(input.expect("").get("rhs")); + let op = get_logical_comp_op_fn(LogicalCompOp::And); + Some(Filter::Logic(Logic { lhs, op, rhs })) + }, + "$or" => { + let lhs = parse(input.get("lhs")); + let rhs = parse(input.get("rhs")); + let op = get_logical_comp_op_fn(LogicalCompOp::Or); + Some(Filter::Logic(Logic { lhs, op, rhs })) + }, + "$eq" => { + let op = get_equality_comp_op_fn(EqualityCompOp::Eq); + let val = input.get("val").unwrap().to_string(); + Some(Filter::Comparator(Comparator { op, val })) + }, + "$ne" => { + let op = get_equality_comp_op_fn(EqualityCompOp::Ne); + let val = input.get("val").unwrap().to_string(); + Some(Filter::Comparator(Comparator { op, val })) + }, + "$gt" => { + let op = get_equality_comp_op_fn(EqualityCompOp::Gt); + let val = input.get("val").unwrap().to_string(); + Some(Filter::Comparator(Comparator { op, val })) + }, + "$gte" => { + let op = get_equality_comp_op_fn(EqualityCompOp::Gte); + let val = input.get("val").unwrap().to_string(); + Some(Filter::Comparator(Comparator { op, val })) + }, + "$lt" => { + let op = get_equality_comp_op_fn(EqualityCompOp::Lt); + let val = input.get("val").unwrap().to_string(); + Some(Filter::Comparator(Comparator { op, val })) + }, + "$lte" => { + let op = get_equality_comp_op_fn(EqualityCompOp::Lte); + let val = input.get("val").unwrap().to_string(); + Some(Filter::Comparator(Comparator { op, val })) + }, + _ => Err(Error::InvalidFilter), + } +} From 2047c695f4164e428684efc3f2750f120ac601d4 Mon Sep 17 00:00:00 2001 From: Ben Harris Date: Wed, 12 Jul 2023 01:44:02 -0400 Subject: [PATCH 4/5] wip --- src/db.rs | 22 ++++++++++----------- src/routes/collection.rs | 11 ++--------- src/search.rs | 42 ++++++++++++++++++++-------------------- 3 files changed, 34 insertions(+), 41 deletions(-) diff --git a/src/db.rs b/src/db.rs index 7819d0e..8dd5205 100644 --- a/src/db.rs +++ b/src/db.rs @@ -64,17 +64,17 @@ impl Collection { .embeddings .par_iter() .enumerate() - .filter(|(_, embedding)| { - if let Some(comparate) = &comparate { - if let Some(metadata) = &embedding.metadata { - if let Some(value) = metadata.get(&comparate.key) { - return comparate.filter(value); - } - } - } - - true - }) + // .filter(|(_, embedding)| { + // if let Some(comparate) = &comparate { + // if let Some(metadata) = &embedding.metadata { + // if let Some(value) = metadata.get(&comparate.key) { + // return comparate.filter(value); + // } + // } + // } + + // true + // }) .map(|(index, embedding)| { let score = distance_fn(&embedding.vector, query, memo_attr); ScoreIndex { score, index } diff --git a/src/routes/collection.rs b/src/routes/collection.rs index ba44aa9..5cead7e 100644 --- a/src/routes/collection.rs +++ b/src/routes/collection.rs @@ -10,7 +10,7 @@ use std::time::Instant; use crate::{ db::{self, Collection, DbExtension, Embedding, Error as DbError, SimilarityResult}, errors::HTTPError, - search::{Error as SearchError, Comparator}, + search::Filter, similarity::Distance, }; @@ -58,7 +58,7 @@ struct QueryCollectionQuery { /// Number of results to return k: Option, /// Filter results by metadata - comparator: Option, + filter: Option, } /// Query a collection @@ -79,13 +79,6 @@ async fn query_collection( return Err(HTTPError::new("Query dimension mismatch").with_status(StatusCode::BAD_REQUEST)); } - // if (req.filter.is_some()) - // match { - // Ok(_) => Ok(Json()), - // Err(_) => Err(HTTPError::new( - // "Metadata filter is incorrectly formatted" - // ).with_status(StatusCode::BAD_REQUEST)), - // } let instant = Instant::now(); let results = collection.get_similarity(&req.query, req.k.unwrap_or(1), req.filter); diff --git a/src/search.rs b/src/search.rs index 680db9d..cece695 100644 --- a/src/search.rs +++ b/src/search.rs @@ -69,15 +69,20 @@ pub struct Logic { pub rhs: Comparator, } -pub enum Filter { +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] +pub enum FilterOp { Comparator, Logic, } +pub struct Filter { + pub val: FilterType, +} + /*** { - "where": { + "filter": { "$and": [ ...opt1, ...opt2, @@ -115,57 +120,52 @@ pub enum Error { InvalidFilter, } -pub fn parse(input: Option>) -> Result, Error> { - if input.is_none() { - return Ok(None); - } - - if input.expect("").keys().len() != 1 { +pub fn parse(input: HashMap) -> Result { + if input.keys().len() != 1 { return Err(Error::InvalidFilter); } - match input { + match input.keys().next().unwrap().as_str() { "$and" => { - let lhs = parse(input.expect("").get("lhs")); - let rhs = parse(input.expect("").get("rhs")); - let op = get_logical_comp_op_fn(LogicalCompOp::And); - Some(Filter::Logic(Logic { lhs, op, rhs })) + let lhs = parse(input.get("lhs")); + let rhs = parse(input.get("rhs")); + Ok() // Logic { lhs, op: LogicalCompOp::And, rhs } }, "$or" => { let lhs = parse(input.get("lhs")); let rhs = parse(input.get("rhs")); - let op = get_logical_comp_op_fn(LogicalCompOp::Or); - Some(Filter::Logic(Logic { lhs, op, rhs })) + let op = FilterOp::Logic(Logic { lhs, op: LogicalCompOp::Or, rhs }); + Filter { op } }, "$eq" => { let op = get_equality_comp_op_fn(EqualityCompOp::Eq); let val = input.get("val").unwrap().to_string(); - Some(Filter::Comparator(Comparator { op, val })) + Filter::Comparator(Comparator { op: EqualityCompOp::Eq, val }) }, "$ne" => { let op = get_equality_comp_op_fn(EqualityCompOp::Ne); let val = input.get("val").unwrap().to_string(); - Some(Filter::Comparator(Comparator { op, val })) + Filter::Comparator(Comparator { op: EqualityCompOp::Ne, val }) }, "$gt" => { let op = get_equality_comp_op_fn(EqualityCompOp::Gt); let val = input.get("val").unwrap().to_string(); - Some(Filter::Comparator(Comparator { op, val })) + Filter::Comparator(Comparator { op: EqualityCompOp::Gt, val }) }, "$gte" => { let op = get_equality_comp_op_fn(EqualityCompOp::Gte); let val = input.get("val").unwrap().to_string(); - Some(Filter::Comparator(Comparator { op, val })) + Filter::Comparator(Comparator { op: EqualityCompOp::Gte, val }) }, "$lt" => { let op = get_equality_comp_op_fn(EqualityCompOp::Lt); let val = input.get("val").unwrap().to_string(); - Some(Filter::Comparator(Comparator { op, val })) + Filter::Comparator(Comparator { op: EqualityCompOp::Lt, val }) }, "$lte" => { let op = get_equality_comp_op_fn(EqualityCompOp::Lte); let val = input.get("val").unwrap().to_string(); - Some(Filter::Comparator(Comparator { op, val })) + Filter::Comparator(Comparator { op: EqualityCompOp::Lte, val }) }, _ => Err(Error::InvalidFilter), } From 47c9072b320eddec5876dd1d5d7daa26bcce948f Mon Sep 17 00:00:00 2001 From: Ben Harris Date: Fri, 14 Jul 2023 01:32:01 -0400 Subject: [PATCH 5/5] wip --- src/db.rs | 17 +++---- src/search.rs | 127 +++++++++++++++++++++++++++++++------------------- 2 files changed, 84 insertions(+), 60 deletions(-) diff --git a/src/db.rs b/src/db.rs index 8dd5205..9c52ea8 100644 --- a/src/db.rs +++ b/src/db.rs @@ -64,17 +64,12 @@ impl Collection { .embeddings .par_iter() .enumerate() - // .filter(|(_, embedding)| { - // if let Some(comparate) = &comparate { - // if let Some(metadata) = &embedding.metadata { - // if let Some(value) = metadata.get(&comparate.key) { - // return comparate.filter(value); - // } - // } - // } - - // true - // }) + .filter(|(_, embedding)| { + match comparate { + Some(ref comparate) => (*comparate).compare(embedding), + _ => true, + } + }) .map(|(index, embedding)| { let score = distance_fn(&embedding.vector, query, memo_attr); ScoreIndex { score, index } diff --git a/src/search.rs b/src/search.rs index cece695..fd65be6 100644 --- a/src/search.rs +++ b/src/search.rs @@ -2,6 +2,10 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::collections::HashMap; +trait Compare { + fn compare(&self, metadata: &HashMap) -> bool; +} + #[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] pub enum EqualityCompOp { #[serde(rename = "eq")] @@ -39,8 +43,17 @@ fn get_equality_comp_op_fn(op: EqualityCompOp) -> impl Fn(String, String) -> boo #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] pub struct Comparator { + pub metadata_field: String, pub op: EqualityCompOp, - pub val: String, + pub comp_value: String, +} + +impl Compare for Comparator { + fn compare(&self, metadata: &HashMap) -> bool { + let metadata_value = metadata.get(&self.metadata_field).unwrap_or(&"".to_string()); + let op = get_equality_comp_op_fn(self.op); + op(*metadata_value, self.comp_value) + } } #[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] @@ -64,19 +77,33 @@ fn get_logical_comp_op_fn(op: LogicalCompOp) -> impl Fn(bool, bool) -> bool { #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] pub struct Logic { - pub lhs: Comparator, + pub lhs: Box, pub op: LogicalCompOp, - pub rhs: Comparator, + pub rhs: Box, +} + +impl Compare for Logic { + fn compare(&self, metadata: &HashMap) -> bool { + let lhs = self.lhs.compare(metadata); + let rhs = self.rhs.compare(metadata); + let op = get_logical_comp_op_fn(self.op); + op(lhs, rhs) + } } #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] -pub enum FilterOp { - Comparator, - Logic, +pub enum Filter { + Comparator(Comparator), + Logic(Logic), } -pub struct Filter { - pub val: FilterType, +impl Compare for Filter { + fn compare(&self, metadata: &HashMap) -> bool { + match self { + Filter::Comparator(c) => c.compare(metadata), + Filter::Logic(l) => l.compare(metadata), + } + } } /*** @@ -120,53 +147,55 @@ pub enum Error { InvalidFilter, } +fn parse_logic_helper(input: HashMap, key: &str) -> Result { + match input.get(key) { + Some(s) => { + match parse(s) { + Ok(f) => { Ok(f) }, + Err(_) => { Err(Error::InvalidFilter) } + } + }, + None => { Err(Error::InvalidFilter) }, + } +} + + + +fn parse_logic(input: HashMap, op: LogicalCompOp) -> Result { + let lhs = parse_logic_helper(input, "lhs"); + let rhs = parse_logic_helper(input, "rhs"); + Ok(Filter::Logic(Logic { lhs, op, rhs })) +} + +fn parse_comparator(input: HashMap, op: EqualityCompOp) -> Result { + fn parse_field(input: HashMap, key: &str) -> Result { + match input.get(key) { + Some(s) => Ok(s.to_string()), + None => return Err(Error::InvalidFilter), + } + } + + let metadata_field = match input.keys { + Some(s) => Ok(s.to_string()), + None => return Err(Error::InvalidFilter), + }; + Ok(Filter::Comparator(Comparator { metadata_field, op: EqualityCompOp::Eq, comp_value })) +} + pub fn parse(input: HashMap) -> Result { if input.keys().len() != 1 { return Err(Error::InvalidFilter); } match input.keys().next().unwrap().as_str() { - "$and" => { - let lhs = parse(input.get("lhs")); - let rhs = parse(input.get("rhs")); - Ok() // Logic { lhs, op: LogicalCompOp::And, rhs } - }, - "$or" => { - let lhs = parse(input.get("lhs")); - let rhs = parse(input.get("rhs")); - let op = FilterOp::Logic(Logic { lhs, op: LogicalCompOp::Or, rhs }); - Filter { op } - }, - "$eq" => { - let op = get_equality_comp_op_fn(EqualityCompOp::Eq); - let val = input.get("val").unwrap().to_string(); - Filter::Comparator(Comparator { op: EqualityCompOp::Eq, val }) - }, - "$ne" => { - let op = get_equality_comp_op_fn(EqualityCompOp::Ne); - let val = input.get("val").unwrap().to_string(); - Filter::Comparator(Comparator { op: EqualityCompOp::Ne, val }) - }, - "$gt" => { - let op = get_equality_comp_op_fn(EqualityCompOp::Gt); - let val = input.get("val").unwrap().to_string(); - Filter::Comparator(Comparator { op: EqualityCompOp::Gt, val }) - }, - "$gte" => { - let op = get_equality_comp_op_fn(EqualityCompOp::Gte); - let val = input.get("val").unwrap().to_string(); - Filter::Comparator(Comparator { op: EqualityCompOp::Gte, val }) - }, - "$lt" => { - let op = get_equality_comp_op_fn(EqualityCompOp::Lt); - let val = input.get("val").unwrap().to_string(); - Filter::Comparator(Comparator { op: EqualityCompOp::Lt, val }) - }, - "$lte" => { - let op = get_equality_comp_op_fn(EqualityCompOp::Lte); - let val = input.get("val").unwrap().to_string(); - Filter::Comparator(Comparator { op: EqualityCompOp::Lte, val }) - }, + "$and" => parse_logic(input.get("$and").unwrap().to_string(), LogicalCompOp::And), + "$or" => parse_logic(input.get("$or").unwrap().to_string(), LogicalCompOp::Or), + "$eq" => parse_comparator(input.get("$eq").unwrap().to_string(), EqualityCompOp::Eq), + "$ne" => parse_comparator(input.get("$ne").unwrap().to_string(), EqualityCompOp::Ne), + "$gt" => parse_comparator(input.get("$gt").unwrap().to_string(), EqualityCompOp::Gt), + "$gte" => parse_comparator(input.get("$gte").unwrap().to_string(), EqualityCompOp::Gte), + "$lt" => parse_comparator(input.get("$lt").unwrap().to_string(), EqualityCompOp::Lt), + "$lte" => parse_comparator(input.get("$lte").unwrap().to_string(), EqualityCompOp::Lte), _ => Err(Error::InvalidFilter), } }