Skip to content

Commit

Permalink
Support adding expressions for ON CONFLICT targets (#692)
Browse files Browse the repository at this point in the history
* Support adding expressions for  targets

* Fix formatting

* Format comment; Make the fmt check happy
  • Loading branch information
sumeetattree authored Dec 14, 2023
1 parent b143d21 commit c21a8e1
Show file tree
Hide file tree
Showing 5 changed files with 262 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/backend/mysql/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ impl QueryBuilder for MysqlQueryBuilder {
sql.push_param(value.clone(), self as _);
}

fn prepare_on_conflict_target(&self, _: &Option<OnConflictTarget>, _: &mut dyn SqlWriter) {
fn prepare_on_conflict_target(&self, _: &[OnConflictTarget], _: &mut dyn SqlWriter) {
// MySQL doesn't support declaring ON CONFLICT target.
}

Expand Down
34 changes: 20 additions & 14 deletions src/backend/query_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1137,7 +1137,7 @@ pub trait QueryBuilder:
fn prepare_on_conflict(&self, on_conflict: &Option<OnConflict>, sql: &mut dyn SqlWriter) {
if let Some(on_conflict) = on_conflict {
self.prepare_on_conflict_keywords(sql);
self.prepare_on_conflict_target(&on_conflict.target, sql);
self.prepare_on_conflict_target(&on_conflict.targets, sql);
self.prepare_on_conflict_condition(&on_conflict.target_where, sql);
self.prepare_on_conflict_action(&on_conflict.action, sql);
self.prepare_on_conflict_condition(&on_conflict.action_where, sql);
Expand All @@ -1148,24 +1148,30 @@ pub trait QueryBuilder:
/// Write ON CONFLICT target
fn prepare_on_conflict_target(
&self,
on_conflict_target: &Option<OnConflictTarget>,
on_conflict_targets: &[OnConflictTarget],
sql: &mut dyn SqlWriter,
) {
if let Some(target) = on_conflict_target {
if on_conflict_targets.is_empty() {
return;
}

write!(sql, "(").unwrap();
on_conflict_targets.iter().fold(true, |first, target| {
if !first {
write!(sql, ", ").unwrap()
}
match target {
OnConflictTarget::ConflictColumns(columns) => {
write!(sql, "(").unwrap();
columns.iter().fold(true, |first, col| {
if !first {
write!(sql, ", ").unwrap()
}
col.prepare(sql.as_writer(), self.quote());
false
});
write!(sql, ")").unwrap();
OnConflictTarget::ConflictColumn(col) => {
col.prepare(sql.as_writer(), self.quote());
}

OnConflictTarget::ConflictExpr(expr) => {
self.prepare_simple_expr(expr, sql);
}
}
}
false
});
write!(sql, ")").unwrap();
}

#[doc(hidden)]
Expand Down
85 changes: 79 additions & 6 deletions src/query/on_conflict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{ConditionHolder, DynIden, IntoCondition, IntoIden, SimpleExpr};

#[derive(Debug, Clone, Default, PartialEq)]
pub struct OnConflict {
pub(crate) target: Option<OnConflictTarget>,
pub(crate) targets: Vec<OnConflictTarget>,
pub(crate) target_where: ConditionHolder,
pub(crate) action: Option<OnConflictAction>,
pub(crate) action_where: ConditionHolder,
Expand All @@ -11,8 +11,10 @@ pub struct OnConflict {
/// Represents ON CONFLICT (upsert) targets
#[derive(Debug, Clone, PartialEq)]
pub enum OnConflictTarget {
/// A list of columns with unique constraint
ConflictColumns(Vec<DynIden>),
/// A column
ConflictColumn(DynIden),
/// An expression `(LOWER(column), ...)`
ConflictExpr(SimpleExpr),
}

/// Represents ON CONFLICT (upsert) actions
Expand Down Expand Up @@ -55,15 +57,86 @@ impl OnConflict {
I: IntoIterator<Item = C>,
{
Self {
target: Some(OnConflictTarget::ConflictColumns(
columns.into_iter().map(IntoIden::into_iden).collect(),
)),
targets: columns
.into_iter()
.map(|c| OnConflictTarget::ConflictColumn(c.into_iden()))
.collect(),
target_where: ConditionHolder::new(),
action: None,
action_where: ConditionHolder::new(),
}
}

/// Set ON CONFLICT target expression
///
/// # Examples
///
/// ```
/// use sea_query::{tests_cfg::*, *};
///
/// let query = Query::insert()
/// .into_table(Glyph::Table)
/// .columns([Glyph::Aspect, Glyph::Image])
/// .values_panic(["abcd".into(), 3.1415.into()])
/// .on_conflict(
/// OnConflict::new()
/// .expr(Expr::col(Glyph::Id))
/// .update_column(Glyph::Aspect)
/// .value(Glyph::Image, Expr::val(1).add(2))
/// .to_owned(),
/// )
/// .to_owned();
///
/// assert_eq!(
/// query.to_string(MysqlQueryBuilder),
/// [
/// r#"INSERT INTO `glyph` (`aspect`, `image`)"#,
/// r#"VALUES ('abcd', 3.1415)"#,
/// r#"ON DUPLICATE KEY UPDATE `aspect` = VALUES(`aspect`), `image` = 1 + 2"#,
/// ]
/// .join(" ")
/// );
/// assert_eq!(
/// query.to_string(PostgresQueryBuilder),
/// [
/// r#"INSERT INTO "glyph" ("aspect", "image")"#,
/// r#"VALUES ('abcd', 3.1415)"#,
/// r#"ON CONFLICT ("id") DO UPDATE SET "aspect" = "excluded"."aspect", "image" = 1 + 2"#,
/// ]
/// .join(" ")
/// );
/// assert_eq!(
/// query.to_string(SqliteQueryBuilder),
/// [
/// r#"INSERT INTO "glyph" ("aspect", "image")"#,
/// r#"VALUES ('abcd', 3.1415)"#,
/// r#"ON CONFLICT ("id") DO UPDATE SET "aspect" = "excluded"."aspect", "image" = 1 + 2"#,
/// ]
/// .join(" ")
/// );
/// ```
pub fn expr<T>(&mut self, expr: T) -> &mut Self
where
T: Into<SimpleExpr>,
{
Self::exprs(self, [expr])
}

/// Set multiple target expressions for ON CONFLICT. See [`OnConflict::expr`]
pub fn exprs<I, T>(&mut self, exprs: I) -> &mut Self
where
T: Into<SimpleExpr>,
I: IntoIterator<Item = T>,
{
self.targets.append(
&mut exprs
.into_iter()
.map(|e: T| OnConflictTarget::ConflictExpr(e.into()))
.collect(),
);
self
}

pub fn do_nothing(&mut self) -> &mut Self {
self.action = Some(OnConflictAction::DoNothing);
self
Expand Down
81 changes: 81 additions & 0 deletions tests/postgres/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,87 @@ fn insert_on_conflict_6() {
);
}

#[test]
#[allow(clippy::approx_constant)]
fn insert_on_conflict_7() {
assert_eq!(
Query::insert()
.into_table(Glyph::Table)
.columns([Glyph::Aspect, Glyph::Image])
.values_panic([
"04108048005887010020060000204E0180400400".into(),
3.1415.into(),
])
.on_conflict(
OnConflict::new()
.expr(Expr::col(Glyph::Id))
.update_column(Glyph::Aspect)
.to_owned()
)
.to_string(PostgresQueryBuilder),
[
r#"INSERT INTO "glyph" ("aspect", "image")"#,
r#"VALUES ('04108048005887010020060000204E0180400400', 3.1415)"#,
r#"ON CONFLICT ("id") DO UPDATE SET "aspect" = "excluded"."aspect""#,
]
.join(" ")
);
}

#[test]
#[allow(clippy::approx_constant)]
fn insert_on_conflict_8() {
assert_eq!(
Query::insert()
.into_table(Glyph::Table)
.columns([Glyph::Aspect, Glyph::Image])
.values_panic([
"04108048005887010020060000204E0180400400".into(),
3.1415.into(),
])
.on_conflict(
OnConflict::new()
.exprs([Expr::col(Glyph::Id), Expr::col(Glyph::Aspect)])
.update_column(Glyph::Aspect)
.to_owned()
)
.to_string(PostgresQueryBuilder),
[
r#"INSERT INTO "glyph" ("aspect", "image")"#,
r#"VALUES ('04108048005887010020060000204E0180400400', 3.1415)"#,
r#"ON CONFLICT ("id", "aspect") DO UPDATE SET "aspect" = "excluded"."aspect""#,
]
.join(" ")
);
}

#[test]
#[allow(clippy::approx_constant)]
fn insert_on_conflict_9() {
assert_eq!(
Query::insert()
.into_table(Glyph::Table)
.columns([Glyph::Aspect, Glyph::Image])
.values_panic([
"04108048005887010020060000204E0180400400".into(),
3.1415.into(),
])
.on_conflict(
OnConflict::column(Glyph::Id)
.expr(Func::lower(Expr::col(Glyph::Tokens)))
.update_column(Glyph::Aspect)
.to_owned()
)
.to_string(PostgresQueryBuilder),
[
r#"INSERT INTO "glyph" ("aspect", "image")"#,
r#"VALUES ('04108048005887010020060000204E0180400400', 3.1415)"#,
r#"ON CONFLICT ("id", LOWER("tokens")) DO UPDATE SET "aspect" = "excluded"."aspect""#,
]
.join(" ")
);
}

#[test]
#[allow(clippy::approx_constant)]
fn insert_returning_all_columns() {
Expand Down
81 changes: 81 additions & 0 deletions tests/sqlite/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1390,6 +1390,87 @@ fn insert_on_conflict_6() {
);
}

#[test]
#[allow(clippy::approx_constant)]
fn insert_on_conflict_7() {
assert_eq!(
Query::insert()
.into_table(Glyph::Table)
.columns([Glyph::Aspect, Glyph::Image])
.values_panic([
"04108048005887010020060000204E0180400400".into(),
3.1415.into(),
])
.on_conflict(
OnConflict::new()
.expr(Expr::col(Glyph::Id))
.update_column(Glyph::Aspect)
.to_owned()
)
.to_string(PostgresQueryBuilder),
[
r#"INSERT INTO "glyph" ("aspect", "image")"#,
r#"VALUES ('04108048005887010020060000204E0180400400', 3.1415)"#,
r#"ON CONFLICT ("id") DO UPDATE SET "aspect" = "excluded"."aspect""#,
]
.join(" ")
);
}

#[test]
#[allow(clippy::approx_constant)]
fn insert_on_conflict_8() {
assert_eq!(
Query::insert()
.into_table(Glyph::Table)
.columns([Glyph::Aspect, Glyph::Image])
.values_panic([
"04108048005887010020060000204E0180400400".into(),
3.1415.into(),
])
.on_conflict(
OnConflict::new()
.exprs([Expr::col(Glyph::Id), Expr::col(Glyph::Aspect)])
.update_column(Glyph::Aspect)
.to_owned()
)
.to_string(PostgresQueryBuilder),
[
r#"INSERT INTO "glyph" ("aspect", "image")"#,
r#"VALUES ('04108048005887010020060000204E0180400400', 3.1415)"#,
r#"ON CONFLICT ("id", "aspect") DO UPDATE SET "aspect" = "excluded"."aspect""#,
]
.join(" ")
);
}

#[test]
#[allow(clippy::approx_constant)]
fn insert_on_conflict_9() {
assert_eq!(
Query::insert()
.into_table(Glyph::Table)
.columns([Glyph::Aspect, Glyph::Image])
.values_panic([
"04108048005887010020060000204E0180400400".into(),
3.1415.into(),
])
.on_conflict(
OnConflict::column(Glyph::Id)
.expr(Func::lower(Expr::col(Glyph::Tokens)))
.update_column(Glyph::Aspect)
.to_owned()
)
.to_string(PostgresQueryBuilder),
[
r#"INSERT INTO "glyph" ("aspect", "image")"#,
r#"VALUES ('04108048005887010020060000204E0180400400', 3.1415)"#,
r#"ON CONFLICT ("id", LOWER("tokens")) DO UPDATE SET "aspect" = "excluded"."aspect""#,
]
.join(" ")
);
}

#[test]
#[allow(clippy::approx_constant)]
fn insert_returning_all_columns() {
Expand Down

0 comments on commit c21a8e1

Please sign in to comment.