diff --git a/src/backend/mysql/query.rs b/src/backend/mysql/query.rs index 58126014..31e3d922 100644 --- a/src/backend/mysql/query.rs +++ b/src/backend/mysql/query.rs @@ -147,6 +147,11 @@ impl QueryBuilder for MysqlQueryBuilder { fn insert_default_keyword(&self) -> &str { "()" } + + /// Prefix of the ELSEIF (MySQL) + fn elseif_keyword_prefix(&self) -> &str { + "ELSE" + } } impl MysqlQueryBuilder { diff --git a/src/backend/postgres/query.rs b/src/backend/postgres/query.rs index f0650110..629ea510 100644 --- a/src/backend/postgres/query.rs +++ b/src/backend/postgres/query.rs @@ -174,6 +174,11 @@ impl QueryBuilder for PostgresQueryBuilder { fn if_null_function(&self) -> &str { "COALESCE" } + + /// Prefix of the ELSIF (Postgres) + fn elseif_keyword_prefix(&self) -> &str { + "ELS" + } } fn is_pg_comparison(b: &BinOper) -> bool { diff --git a/src/backend/query_builder.rs b/src/backend/query_builder.rs index 785971b4..6676493b 100644 --- a/src/backend/query_builder.rs +++ b/src/backend/query_builder.rs @@ -393,16 +393,28 @@ pub trait QueryBuilder: } } + /// Prefix of the ELSEIF (MySQL) vs ELSIF (Postgres) keyword + fn elseif_keyword_prefix(&self) -> &str { + panic!("ELSEIF/ELSIF keyword prefix not implemented for this backend"); + } + fn prepare_if_else_statement(&self, val: &Box, sql: &mut dyn SqlWriter) { write!(sql, "IF ").unwrap(); self.prepare_simple_expr(&val.when, sql); write!(sql, " THEN\n").unwrap(); self.prepare_simple_expr(&val.then, sql); - if let Some(otherwise) = &val.otherwise { - write!(sql, "\nELSE\n").unwrap(); - self.prepare_simple_expr(otherwise, sql); + match &val.otherwise { + Some(SimpleExpr::IfElse(value)) => { + write!(sql, "\n{}", self.elseif_keyword_prefix()).unwrap(); + self.prepare_if_else_statement(value, sql); + } + Some(otherwise) => { + write!(sql, "\nELSE\n").unwrap(); + self.prepare_simple_expr(otherwise, sql); + write!(sql, "\nEND IF").unwrap(); + } + None => write!(sql, "\nEND IF").unwrap(), }; - write!(sql, "\nEND IF").unwrap(); } /// Translate [`CaseStatement`] into SQL statement. diff --git a/src/backend/sqlite/query.rs b/src/backend/sqlite/query.rs index ece86a79..60f67ddd 100644 --- a/src/backend/sqlite/query.rs +++ b/src/backend/sqlite/query.rs @@ -84,4 +84,8 @@ impl QueryBuilder for SqliteQueryBuilder { // SQLite doesn't support inserting multiple rows with default values write!(sql, "DEFAULT VALUES").unwrap() } + + fn prepare_if_else_statement(&self, _val: &Box, _sql: &mut dyn SqlWriter) { + panic!("Sqlite doesn't support if-else statements") + } } diff --git a/tests/mysql/if_else.rs b/tests/mysql/if_else.rs index bd5e94b7..a5e64034 100644 --- a/tests/mysql/if_else.rs +++ b/tests/mysql/if_else.rs @@ -43,3 +43,57 @@ fn if_with_else() { .join("\n") ) } + +#[test] +fn if_with_elseif() { + let query = Query::select().column(Asterisk).from(Glyph::Table).take(); + let then = SimpleExpr::SubQuery(None, Box::new(query.into_sub_query_statement())); + let if_statement = IfElseStatement::new( + Expr::col(Glyph::Id).eq(1), + then, + Some(SimpleExpr::IfElse(Box::new(IfElseStatement::new( + Expr::col(Glyph::Id).eq(2), + Expr::val("42").into(), + None, + )))), + ); + assert_eq!( + if_statement.to_string(MysqlQueryBuilder), + [ + "IF `id` = 1 THEN", + "(SELECT * FROM `glyph`)", + "ELSEIF `id` = 2 THEN", + "'42'", + "END IF" + ] + .join("\n") + ) +} + +#[test] +fn if_with_elseif_and_else() { + let query = Query::select().column(Asterisk).from(Glyph::Table).take(); + let then = SimpleExpr::SubQuery(None, Box::new(query.into_sub_query_statement())); + let if_statement = IfElseStatement::new( + Expr::col(Glyph::Id).eq(1), + then, + Some(SimpleExpr::IfElse(Box::new(IfElseStatement::new( + Expr::col(Glyph::Id).eq(2), + Expr::val("42").into(), + Some(Expr::val("9000").into()), + )))), + ); + assert_eq!( + if_statement.to_string(MysqlQueryBuilder), + [ + "IF `id` = 1 THEN", + "(SELECT * FROM `glyph`)", + "ELSEIF `id` = 2 THEN", + "'42'", + "ELSE", + "'9000'", + "END IF" + ] + .join("\n") + ); +} diff --git a/tests/postgres/if_else.rs b/tests/postgres/if_else.rs new file mode 100644 index 00000000..9f9098bd --- /dev/null +++ b/tests/postgres/if_else.rs @@ -0,0 +1,70 @@ +use super::*; +use pretty_assertions::assert_eq; + +#[test] +#[rustfmt::skip] +fn if_without_else() { + let query = Query::select().column(Asterisk).from(Glyph::Table).take(); + let then = SimpleExpr::SubQuery(None, Box::new(query.into_sub_query_statement())); + let if_statement = IfElseStatement::new( + Expr::col(Glyph::Id).eq(1), + then, + None + ); + assert_eq!( + if_statement.to_string(MysqlQueryBuilder), + [ + "IF `id` = 1 THEN", + "(SELECT * FROM `glyph`)", + "END IF" + ].join("\n") + ) +} + +#[test] +#[rustfmt::skip] +fn if_with_else() { + let query = Query::select().column(Asterisk).from(Glyph::Table).take(); + let then = SimpleExpr::SubQuery(None, Box::new(query.into_sub_query_statement())); + let if_statement = IfElseStatement::new( + Expr::col(Glyph::Id).eq(1), + then, + Some(Expr::val("23").into()) + ); + assert_eq!( + if_statement.to_string(PostgresQueryBuilder), + [ + "IF \"id\" = 1 THEN", + "(SELECT * FROM \"glyph\")", + "ELSE", + "'23'", + "END IF" + ].join("\n") + ) +} + +#[test] +#[rustfmt::skip] +fn if_with_elseif() { + let query = Query::select().column(Asterisk).from(Glyph::Table).take(); + let then = SimpleExpr::SubQuery(None, Box::new(query.into_sub_query_statement())); + let if_statement = IfElseStatement::new( + Expr::col(Glyph::Id).eq(1), + then, + Some(SimpleExpr::IfElse(Box::new(IfElseStatement::new( + Expr::col(Glyph::Id).eq(2), + Expr::val("123").into(), + None + )))) + ); + assert_eq!( + if_statement.to_string(PostgresQueryBuilder), + [ + "IF \"id\" = 1 THEN", + "(SELECT * FROM \"glyph\")", + "ELSIF \"id\" = 2 THEN", + "'123'", + "END IF" + ].join("\n") + ) +} diff --git a/tests/postgres/mod.rs b/tests/postgres/mod.rs index 82b85df3..76390ae4 100644 --- a/tests/postgres/mod.rs +++ b/tests/postgres/mod.rs @@ -1,6 +1,7 @@ use sea_query::{tests_cfg::*, *}; mod foreign_key; +mod if_else; mod index; mod query; mod table; diff --git a/tests/sqlite/mod.rs b/tests/sqlite/mod.rs index fc7388cd..2c90d514 100644 --- a/tests/sqlite/mod.rs +++ b/tests/sqlite/mod.rs @@ -4,6 +4,7 @@ mod foreign_key; mod index; mod query; mod table; +mod unsupported; #[path = "../common.rs"] mod common; diff --git a/tests/sqlite/unsupported.rs b/tests/sqlite/unsupported.rs new file mode 100644 index 00000000..5bf704b9 --- /dev/null +++ b/tests/sqlite/unsupported.rs @@ -0,0 +1,13 @@ +use super::*; + +#[test] +#[should_panic] +#[rustfmt::skip] +fn if_else_statement_is_unsupported() { + let if_statement = IfElseStatement::new( + Expr::col(Glyph::Id).eq(1), + Expr::val("23").into(), + None + ); + if_statement.to_string(SqliteQueryBuilder); +}