From c21a8e1ca56940a8fba6ec0998826460e6ea5263 Mon Sep 17 00:00:00 2001 From: Sumeet Attree Date: Thu, 14 Dec 2023 18:07:00 +0530 Subject: [PATCH] Support adding expressions for ON CONFLICT targets (#692) * Support adding expressions for targets * Fix formatting * Format comment; Make the fmt check happy --- src/backend/mysql/query.rs | 2 +- src/backend/query_builder.rs | 34 +++++++++------ src/query/on_conflict.rs | 85 +++++++++++++++++++++++++++++++++--- tests/postgres/query.rs | 81 ++++++++++++++++++++++++++++++++++ tests/sqlite/query.rs | 81 ++++++++++++++++++++++++++++++++++ 5 files changed, 262 insertions(+), 21 deletions(-) diff --git a/src/backend/mysql/query.rs b/src/backend/mysql/query.rs index 75e71bd92..975ae06ff 100644 --- a/src/backend/mysql/query.rs +++ b/src/backend/mysql/query.rs @@ -104,7 +104,7 @@ impl QueryBuilder for MysqlQueryBuilder { sql.push_param(value.clone(), self as _); } - fn prepare_on_conflict_target(&self, _: &Option, _: &mut dyn SqlWriter) { + fn prepare_on_conflict_target(&self, _: &[OnConflictTarget], _: &mut dyn SqlWriter) { // MySQL doesn't support declaring ON CONFLICT target. } diff --git a/src/backend/query_builder.rs b/src/backend/query_builder.rs index a2a8caabe..fe8ddfbc0 100644 --- a/src/backend/query_builder.rs +++ b/src/backend/query_builder.rs @@ -1137,7 +1137,7 @@ pub trait QueryBuilder: fn prepare_on_conflict(&self, on_conflict: &Option, sql: &mut dyn SqlWriter) { if let Some(on_conflict) = on_conflict { self.prepare_on_conflict_keywords(sql); - self.prepare_on_conflict_target(&on_conflict.target, sql); + self.prepare_on_conflict_target(&on_conflict.targets, sql); self.prepare_on_conflict_condition(&on_conflict.target_where, sql); self.prepare_on_conflict_action(&on_conflict.action, sql); self.prepare_on_conflict_condition(&on_conflict.action_where, sql); @@ -1148,24 +1148,30 @@ pub trait QueryBuilder: /// Write ON CONFLICT target fn prepare_on_conflict_target( &self, - on_conflict_target: &Option, + on_conflict_targets: &[OnConflictTarget], sql: &mut dyn SqlWriter, ) { - if let Some(target) = on_conflict_target { + if on_conflict_targets.is_empty() { + return; + } + + write!(sql, "(").unwrap(); + on_conflict_targets.iter().fold(true, |first, target| { + if !first { + write!(sql, ", ").unwrap() + } match target { - OnConflictTarget::ConflictColumns(columns) => { - write!(sql, "(").unwrap(); - columns.iter().fold(true, |first, col| { - if !first { - write!(sql, ", ").unwrap() - } - col.prepare(sql.as_writer(), self.quote()); - false - }); - write!(sql, ")").unwrap(); + OnConflictTarget::ConflictColumn(col) => { + col.prepare(sql.as_writer(), self.quote()); + } + + OnConflictTarget::ConflictExpr(expr) => { + self.prepare_simple_expr(expr, sql); } } - } + false + }); + write!(sql, ")").unwrap(); } #[doc(hidden)] diff --git a/src/query/on_conflict.rs b/src/query/on_conflict.rs index 280a8491f..47d45c0f6 100644 --- a/src/query/on_conflict.rs +++ b/src/query/on_conflict.rs @@ -2,7 +2,7 @@ use crate::{ConditionHolder, DynIden, IntoCondition, IntoIden, SimpleExpr}; #[derive(Debug, Clone, Default, PartialEq)] pub struct OnConflict { - pub(crate) target: Option, + pub(crate) targets: Vec, pub(crate) target_where: ConditionHolder, pub(crate) action: Option, pub(crate) action_where: ConditionHolder, @@ -11,8 +11,10 @@ pub struct OnConflict { /// Represents ON CONFLICT (upsert) targets #[derive(Debug, Clone, PartialEq)] pub enum OnConflictTarget { - /// A list of columns with unique constraint - ConflictColumns(Vec), + /// A column + ConflictColumn(DynIden), + /// An expression `(LOWER(column), ...)` + ConflictExpr(SimpleExpr), } /// Represents ON CONFLICT (upsert) actions @@ -55,15 +57,86 @@ impl OnConflict { I: IntoIterator, { Self { - target: Some(OnConflictTarget::ConflictColumns( - columns.into_iter().map(IntoIden::into_iden).collect(), - )), + targets: columns + .into_iter() + .map(|c| OnConflictTarget::ConflictColumn(c.into_iden())) + .collect(), target_where: ConditionHolder::new(), action: None, action_where: ConditionHolder::new(), } } + /// Set ON CONFLICT target expression + /// + /// # Examples + /// + /// ``` + /// use sea_query::{tests_cfg::*, *}; + /// + /// let query = Query::insert() + /// .into_table(Glyph::Table) + /// .columns([Glyph::Aspect, Glyph::Image]) + /// .values_panic(["abcd".into(), 3.1415.into()]) + /// .on_conflict( + /// OnConflict::new() + /// .expr(Expr::col(Glyph::Id)) + /// .update_column(Glyph::Aspect) + /// .value(Glyph::Image, Expr::val(1).add(2)) + /// .to_owned(), + /// ) + /// .to_owned(); + /// + /// assert_eq!( + /// query.to_string(MysqlQueryBuilder), + /// [ + /// r#"INSERT INTO `glyph` (`aspect`, `image`)"#, + /// r#"VALUES ('abcd', 3.1415)"#, + /// r#"ON DUPLICATE KEY UPDATE `aspect` = VALUES(`aspect`), `image` = 1 + 2"#, + /// ] + /// .join(" ") + /// ); + /// assert_eq!( + /// query.to_string(PostgresQueryBuilder), + /// [ + /// r#"INSERT INTO "glyph" ("aspect", "image")"#, + /// r#"VALUES ('abcd', 3.1415)"#, + /// r#"ON CONFLICT ("id") DO UPDATE SET "aspect" = "excluded"."aspect", "image" = 1 + 2"#, + /// ] + /// .join(" ") + /// ); + /// assert_eq!( + /// query.to_string(SqliteQueryBuilder), + /// [ + /// r#"INSERT INTO "glyph" ("aspect", "image")"#, + /// r#"VALUES ('abcd', 3.1415)"#, + /// r#"ON CONFLICT ("id") DO UPDATE SET "aspect" = "excluded"."aspect", "image" = 1 + 2"#, + /// ] + /// .join(" ") + /// ); + /// ``` + pub fn expr(&mut self, expr: T) -> &mut Self + where + T: Into, + { + Self::exprs(self, [expr]) + } + + /// Set multiple target expressions for ON CONFLICT. See [`OnConflict::expr`] + pub fn exprs(&mut self, exprs: I) -> &mut Self + where + T: Into, + I: IntoIterator, + { + self.targets.append( + &mut exprs + .into_iter() + .map(|e: T| OnConflictTarget::ConflictExpr(e.into())) + .collect(), + ); + self + } + pub fn do_nothing(&mut self) -> &mut Self { self.action = Some(OnConflictAction::DoNothing); self diff --git a/tests/postgres/query.rs b/tests/postgres/query.rs index 1a0d87425..5df9ed281 100644 --- a/tests/postgres/query.rs +++ b/tests/postgres/query.rs @@ -1436,6 +1436,87 @@ fn insert_on_conflict_6() { ); } +#[test] +#[allow(clippy::approx_constant)] +fn insert_on_conflict_7() { + assert_eq!( + Query::insert() + .into_table(Glyph::Table) + .columns([Glyph::Aspect, Glyph::Image]) + .values_panic([ + "04108048005887010020060000204E0180400400".into(), + 3.1415.into(), + ]) + .on_conflict( + OnConflict::new() + .expr(Expr::col(Glyph::Id)) + .update_column(Glyph::Aspect) + .to_owned() + ) + .to_string(PostgresQueryBuilder), + [ + r#"INSERT INTO "glyph" ("aspect", "image")"#, + r#"VALUES ('04108048005887010020060000204E0180400400', 3.1415)"#, + r#"ON CONFLICT ("id") DO UPDATE SET "aspect" = "excluded"."aspect""#, + ] + .join(" ") + ); +} + +#[test] +#[allow(clippy::approx_constant)] +fn insert_on_conflict_8() { + assert_eq!( + Query::insert() + .into_table(Glyph::Table) + .columns([Glyph::Aspect, Glyph::Image]) + .values_panic([ + "04108048005887010020060000204E0180400400".into(), + 3.1415.into(), + ]) + .on_conflict( + OnConflict::new() + .exprs([Expr::col(Glyph::Id), Expr::col(Glyph::Aspect)]) + .update_column(Glyph::Aspect) + .to_owned() + ) + .to_string(PostgresQueryBuilder), + [ + r#"INSERT INTO "glyph" ("aspect", "image")"#, + r#"VALUES ('04108048005887010020060000204E0180400400', 3.1415)"#, + r#"ON CONFLICT ("id", "aspect") DO UPDATE SET "aspect" = "excluded"."aspect""#, + ] + .join(" ") + ); +} + +#[test] +#[allow(clippy::approx_constant)] +fn insert_on_conflict_9() { + assert_eq!( + Query::insert() + .into_table(Glyph::Table) + .columns([Glyph::Aspect, Glyph::Image]) + .values_panic([ + "04108048005887010020060000204E0180400400".into(), + 3.1415.into(), + ]) + .on_conflict( + OnConflict::column(Glyph::Id) + .expr(Func::lower(Expr::col(Glyph::Tokens))) + .update_column(Glyph::Aspect) + .to_owned() + ) + .to_string(PostgresQueryBuilder), + [ + r#"INSERT INTO "glyph" ("aspect", "image")"#, + r#"VALUES ('04108048005887010020060000204E0180400400', 3.1415)"#, + r#"ON CONFLICT ("id", LOWER("tokens")) DO UPDATE SET "aspect" = "excluded"."aspect""#, + ] + .join(" ") + ); +} + #[test] #[allow(clippy::approx_constant)] fn insert_returning_all_columns() { diff --git a/tests/sqlite/query.rs b/tests/sqlite/query.rs index 7b1aa409b..132ff98e1 100644 --- a/tests/sqlite/query.rs +++ b/tests/sqlite/query.rs @@ -1390,6 +1390,87 @@ fn insert_on_conflict_6() { ); } +#[test] +#[allow(clippy::approx_constant)] +fn insert_on_conflict_7() { + assert_eq!( + Query::insert() + .into_table(Glyph::Table) + .columns([Glyph::Aspect, Glyph::Image]) + .values_panic([ + "04108048005887010020060000204E0180400400".into(), + 3.1415.into(), + ]) + .on_conflict( + OnConflict::new() + .expr(Expr::col(Glyph::Id)) + .update_column(Glyph::Aspect) + .to_owned() + ) + .to_string(PostgresQueryBuilder), + [ + r#"INSERT INTO "glyph" ("aspect", "image")"#, + r#"VALUES ('04108048005887010020060000204E0180400400', 3.1415)"#, + r#"ON CONFLICT ("id") DO UPDATE SET "aspect" = "excluded"."aspect""#, + ] + .join(" ") + ); +} + +#[test] +#[allow(clippy::approx_constant)] +fn insert_on_conflict_8() { + assert_eq!( + Query::insert() + .into_table(Glyph::Table) + .columns([Glyph::Aspect, Glyph::Image]) + .values_panic([ + "04108048005887010020060000204E0180400400".into(), + 3.1415.into(), + ]) + .on_conflict( + OnConflict::new() + .exprs([Expr::col(Glyph::Id), Expr::col(Glyph::Aspect)]) + .update_column(Glyph::Aspect) + .to_owned() + ) + .to_string(PostgresQueryBuilder), + [ + r#"INSERT INTO "glyph" ("aspect", "image")"#, + r#"VALUES ('04108048005887010020060000204E0180400400', 3.1415)"#, + r#"ON CONFLICT ("id", "aspect") DO UPDATE SET "aspect" = "excluded"."aspect""#, + ] + .join(" ") + ); +} + +#[test] +#[allow(clippy::approx_constant)] +fn insert_on_conflict_9() { + assert_eq!( + Query::insert() + .into_table(Glyph::Table) + .columns([Glyph::Aspect, Glyph::Image]) + .values_panic([ + "04108048005887010020060000204E0180400400".into(), + 3.1415.into(), + ]) + .on_conflict( + OnConflict::column(Glyph::Id) + .expr(Func::lower(Expr::col(Glyph::Tokens))) + .update_column(Glyph::Aspect) + .to_owned() + ) + .to_string(PostgresQueryBuilder), + [ + r#"INSERT INTO "glyph" ("aspect", "image")"#, + r#"VALUES ('04108048005887010020060000204E0180400400', 3.1415)"#, + r#"ON CONFLICT ("id", LOWER("tokens")) DO UPDATE SET "aspect" = "excluded"."aspect""#, + ] + .join(" ") + ); +} + #[test] #[allow(clippy::approx_constant)] fn insert_returning_all_columns() {