Skip to content

Commit 7030a63

Browse files
alambxxchan
andauthoredDec 13, 2022
Add update_record_with_output function and tests (#130)
Co-authored-by: xxchan <xxchan22f@gmail.com>
1 parent 65b122f commit 7030a63

File tree

2 files changed

+322
-141
lines changed

2 files changed

+322
-141
lines changed
 

‎sqllogictest-bin/src/lib.rs

+9-130
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ use futures::StreamExt;
1515
use itertools::Itertools;
1616
use quick_junit::{NonSuccessKind, Report, TestCase, TestCaseStatus, TestSuite};
1717
use rand::seq::SliceRandom;
18-
use sqllogictest::{AsyncDB, Injected, Record, RecordOutput, Runner};
18+
use sqllogictest::{
19+
default_validator, update_record_with_output, AsyncDB, Injected, Record, Runner,
20+
};
1921

2022
#[derive(Copy, Clone, Debug, PartialEq, Eq, ArgEnum)]
2123
#[must_use]
@@ -656,137 +658,14 @@ async fn update_record<D: AsyncDB>(
656658
return Ok(());
657659
}
658660

659-
match (record.clone(), runner.apply_record(record).await) {
660-
(record, RecordOutput::Nothing) => {
661-
writeln!(outfile, "{record}")?;
662-
}
663-
(Record::Statement { sql, .. }, RecordOutput::Query { error: None, .. }) => {
664-
// statement ok
665-
// SELECT ...
666-
//
667-
// This case can be used when we want to only ensure the query succeeds,
668-
// but don't care about the output.
669-
// DuckDB has a few of these.
670-
671-
writeln!(outfile, "statement ok")?;
672-
writeln!(outfile, "{}", sql)?;
673-
writeln!(outfile)?;
674-
}
675-
(Record::Query { sql, .. }, RecordOutput::Statement { error: None, .. }) => {
676-
writeln!(outfile, "statement ok")?;
677-
writeln!(outfile, "{}", sql)?;
678-
writeln!(outfile)?;
661+
let record_output = runner.apply_record(record.clone()).await;
662+
match update_record_with_output(&record, &record_output, "\t", default_validator) {
663+
Some(new_record) => {
664+
writeln!(outfile, "{new_record}")?;
679665
}
680-
(
681-
Record::Statement {
682-
loc: _,
683-
conditions: _,
684-
expected_error,
685-
sql,
686-
expected_count,
687-
},
688-
RecordOutput::Statement { count, error },
689-
) => match (error, expected_error) {
690-
(None, _) => {
691-
if expected_count.is_some() {
692-
writeln!(outfile, "statement count {count}")?;
693-
writeln!(outfile, "{}", sql)?;
694-
} else {
695-
writeln!(outfile, "statement ok")?;
696-
writeln!(outfile, "{}", sql)?;
697-
}
698-
writeln!(outfile)?;
699-
}
700-
(Some(e), Some(expected_error)) if expected_error.is_match(&e.to_string()) => {
701-
if expected_error.as_str().is_empty() {
702-
writeln!(outfile, "statement error")?;
703-
} else {
704-
writeln!(outfile, "statement error {}", expected_error)?;
705-
}
706-
writeln!(outfile, "{}", sql)?;
707-
writeln!(outfile)?;
708-
}
709-
(Some(e), _) => {
710-
writeln!(outfile, "statement error {}", e)?;
711-
writeln!(outfile, "{}", sql)?;
712-
writeln!(outfile)?;
713-
}
714-
},
715-
(
716-
Record::Query {
717-
loc: _,
718-
conditions: _,
719-
type_string,
720-
sort_mode,
721-
label,
722-
expected_error,
723-
sql,
724-
expected_results,
725-
},
726-
RecordOutput::Query {
727-
types: _,
728-
rows,
729-
error,
730-
},
731-
) => {
732-
match (error, expected_error) {
733-
(None, _) => {}
734-
(Some(e), Some(expected_error)) if expected_error.is_match(&e.to_string()) => {
735-
writeln!(outfile, "query error {}", expected_error)?;
736-
writeln!(outfile, "{}", sql)?;
737-
writeln!(outfile)?;
738-
return Ok(());
739-
}
740-
(Some(e), _) => {
741-
writeln!(outfile, "query error {}", e)?;
742-
writeln!(outfile, "{}", sql)?;
743-
writeln!(outfile)?;
744-
return Ok(());
745-
}
746-
};
747-
748-
// FIXME: use output's types instead of orignal query's types
749-
write!(
750-
outfile,
751-
"query {}",
752-
type_string.iter().map(|c| format!("{c}")).join("")
753-
)?;
754-
if let Some(sort_mode) = sort_mode {
755-
write!(outfile, " {}", sort_mode.as_str())?;
756-
}
757-
if let Some(label) = label {
758-
write!(outfile, " {}", label)?;
759-
}
760-
writeln!(outfile)?;
761-
writeln!(outfile, "{}", sql)?;
762-
763-
#[allow(clippy::ptr_arg)]
764-
fn normalize_string(s: &String) -> String {
765-
s.trim().split_ascii_whitespace().join(" ")
766-
}
767-
768-
let normalized_rows = rows
769-
.iter()
770-
.map(|strs| strs.iter().map(normalize_string).join(" "))
771-
.collect_vec();
772-
773-
let normalized_expected = expected_results.iter().map(normalize_string).collect_vec();
774-
775-
writeln!(outfile, "----")?;
776-
777-
if normalized_expected == normalized_rows {
778-
// If the results are correct, do not format them.
779-
for result in expected_results {
780-
writeln!(outfile, "{}", result)?;
781-
}
782-
} else {
783-
for result in rows {
784-
writeln!(outfile, "{}", result.iter().format("\t"))?;
785-
}
786-
};
787-
writeln!(outfile)?;
666+
None => {
667+
writeln!(outfile, "{record}")?;
788668
}
789-
_ => unreachable!(),
790669
}
791670

792671
Ok(())

‎sqllogictest/src/runner.rs

+313-11
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use futures::executor::block_on;
1212
use futures::{stream, Future, StreamExt};
1313
use itertools::Itertools;
1414
use owo_colors::OwoColorize;
15+
use regex::Regex;
1516
use tempfile::{tempdir, TempDir};
1617

1718
use crate::parser::*;
@@ -55,6 +56,7 @@ impl TryFrom<char> for ColumnType {
5556
}
5657
}
5758

59+
#[derive(Debug, Clone)]
5860
pub enum RecordOutput {
5961
Nothing,
6062
Query {
@@ -391,8 +393,18 @@ fn format_diff(
391393
///
392394
/// # Default
393395
///
394-
/// By default, we will use compare normalized results.
395-
pub type Validator = fn(&Vec<Vec<String>>, &Vec<String>) -> bool;
396+
/// By default ([`default_validator`]), we will use compare normalized results.
397+
pub type Validator = fn(actual: &[Vec<String>], expected: &[String]) -> bool;
398+
399+
pub fn default_validator(actual: &[Vec<String>], expected: &[String]) -> bool {
400+
let expected_results = expected.iter().map(normalize_string).collect_vec();
401+
// Default, we compare normalized results. Whitespace characters are ignored.
402+
let normalized_rows = actual
403+
.iter()
404+
.map(|strs| strs.iter().map(normalize_string).join(" "))
405+
.collect_vec();
406+
normalized_rows == expected_results
407+
}
396408

397409
/// Sqllogictest runner.
398410
pub struct Runner<D: AsyncDB> {
@@ -410,15 +422,7 @@ impl<D: AsyncDB> Runner<D> {
410422
pub fn new(db: D) -> Self {
411423
Runner {
412424
db,
413-
validator: |x, y| {
414-
let expected_results = y.iter().map(normalize_string).collect_vec();
415-
// Default, we compare normalized results. Whitespace characters are ignored.
416-
let normalized_rows = x
417-
.iter()
418-
.map(|strs| strs.iter().map(normalize_string).join(" "))
419-
.collect_vec();
420-
normalized_rows == expected_results
421-
},
425+
validator: default_validator,
422426
testdir: None,
423427
sort_mode: None,
424428
hash_threshold: 0,
@@ -859,3 +863,301 @@ impl<D: AsyncDB> Runner<D> {
859863
fn normalize_string(s: &String) -> String {
860864
s.trim().split_ascii_whitespace().join(" ")
861865
}
866+
867+
/// Updates the specified [`Record`] with the [`QueryOutput`] produced
868+
/// by a Database, returning `Some(new_record)`.
869+
///
870+
/// If an update is not supported, returns `None`
871+
pub fn update_record_with_output(
872+
record: &Record,
873+
record_output: &RecordOutput,
874+
col_separator: &str,
875+
validator: Validator,
876+
) -> Option<Record> {
877+
match (record.clone(), record_output) {
878+
(_, RecordOutput::Nothing) => None,
879+
// statement, query
880+
(
881+
Record::Statement {
882+
sql,
883+
loc,
884+
conditions,
885+
expected_error: None,
886+
expected_count,
887+
},
888+
RecordOutput::Query { error: None, .. },
889+
) => {
890+
// statement ok
891+
// SELECT ...
892+
//
893+
// This case can be used when we want to only ensure the query succeeds,
894+
// but don't care about the output.
895+
// DuckDB has a few of these.
896+
897+
Some(Record::Statement {
898+
sql,
899+
expected_error: None,
900+
loc,
901+
conditions,
902+
expected_count,
903+
})
904+
}
905+
// query, statement
906+
(
907+
Record::Query {
908+
sql,
909+
loc,
910+
conditions,
911+
..
912+
},
913+
RecordOutput::Statement { error: None, .. },
914+
) => Some(Record::Statement {
915+
sql,
916+
expected_error: None,
917+
loc,
918+
conditions,
919+
expected_count: None,
920+
}),
921+
// statement, statement
922+
(
923+
Record::Statement {
924+
loc,
925+
conditions,
926+
expected_error,
927+
sql,
928+
expected_count,
929+
},
930+
RecordOutput::Statement { count, error },
931+
) => match (error, expected_error) {
932+
// Ok
933+
(None, _) => Some(Record::Statement {
934+
sql,
935+
expected_error: None,
936+
loc,
937+
conditions,
938+
expected_count: expected_count.map(|_| *count),
939+
}),
940+
// Error match
941+
(Some(e), Some(expected_error)) if expected_error.is_match(&e.to_string()) => {
942+
Some(Record::Statement {
943+
sql,
944+
expected_error: Some(expected_error),
945+
loc,
946+
conditions,
947+
expected_count: None,
948+
})
949+
}
950+
// Error mismatch
951+
(Some(e), _) => Some(Record::Statement {
952+
sql,
953+
expected_error: Some(Regex::new(&e.to_string()).unwrap()),
954+
loc,
955+
conditions,
956+
expected_count: None,
957+
}),
958+
},
959+
// query, query
960+
(
961+
Record::Query {
962+
loc,
963+
conditions,
964+
type_string,
965+
sort_mode,
966+
label,
967+
expected_error,
968+
sql,
969+
expected_results,
970+
},
971+
RecordOutput::Query {
972+
// FIXME: maybe we should use output's types instead of orignal query's types
973+
// Fix it after https://github.com/risinglightdb/sqllogictest-rs/issues/36 is resolved.
974+
types: _,
975+
rows,
976+
error,
977+
},
978+
) => {
979+
match (error, expected_error) {
980+
(None, _) => {}
981+
// Error match
982+
(Some(e), Some(expected_error)) if expected_error.is_match(&e.to_string()) => {
983+
return Some(Record::Query {
984+
sql,
985+
expected_error: Some(expected_error),
986+
loc,
987+
conditions,
988+
type_string: vec![],
989+
sort_mode,
990+
label,
991+
expected_results: vec![],
992+
})
993+
}
994+
// Error mismatch
995+
(Some(e), _) => {
996+
return Some(Record::Query {
997+
sql,
998+
expected_error: Some(Regex::new(&e.to_string()).unwrap()),
999+
loc,
1000+
conditions,
1001+
type_string: vec![],
1002+
sort_mode,
1003+
label,
1004+
expected_results: vec![],
1005+
})
1006+
}
1007+
};
1008+
1009+
let results = if validator(rows, &expected_results) {
1010+
// If validation is successful, we respect the original file's expected results.
1011+
expected_results
1012+
} else {
1013+
rows.iter().map(|cols| cols.join(col_separator)).collect()
1014+
};
1015+
1016+
Some(Record::Query {
1017+
sql,
1018+
expected_error: None,
1019+
loc,
1020+
conditions,
1021+
type_string,
1022+
sort_mode,
1023+
label,
1024+
expected_results: results,
1025+
})
1026+
}
1027+
1028+
// No update possible, return the original record
1029+
_ => None,
1030+
}
1031+
}
1032+
1033+
#[cfg(test)]
1034+
mod tests {
1035+
use super::*;
1036+
1037+
#[test]
1038+
fn test_query_replacement() {
1039+
TestCase {
1040+
// input should be ignored
1041+
input: "query III\n\
1042+
select * from foo;\n\
1043+
----\n\
1044+
1 2",
1045+
1046+
// Model a run that produced a 3,4 as output
1047+
record_output: query_output(&[&["3", "4"]]),
1048+
1049+
expected: Some(
1050+
"query III\n\
1051+
select * from foo;\n\
1052+
----\n\
1053+
3 4",
1054+
),
1055+
}
1056+
.run()
1057+
}
1058+
1059+
#[test]
1060+
fn test_query_replacement_no_input() {
1061+
TestCase {
1062+
// input has no query results
1063+
input: "query III\n\
1064+
select * from foo;\n\
1065+
----",
1066+
1067+
// Model a run that produced a 3,4 as output
1068+
record_output: query_output(&[&["3", "4"]]),
1069+
1070+
expected: Some(
1071+
"query III\n\
1072+
select * from foo;\n\
1073+
----\n\
1074+
3 4",
1075+
),
1076+
}
1077+
.run()
1078+
}
1079+
1080+
#[test]
1081+
fn test_query_replacement_error() {
1082+
TestCase {
1083+
// input has no query results
1084+
input: "query III\n\
1085+
select * from foo;\n\
1086+
----",
1087+
1088+
// Model a run that produced a "MyAwesomeDB Error"
1089+
record_output: query_error("MyAwesomeDB Error"),
1090+
1091+
expected: Some(
1092+
"query error TestError: MyAwesomeDB Error\n\
1093+
select * from foo;\n",
1094+
),
1095+
}
1096+
.run()
1097+
}
1098+
1099+
#[derive(Debug)]
1100+
struct TestCase {
1101+
input: &'static str,
1102+
record_output: RecordOutput,
1103+
expected: Option<&'static str>,
1104+
}
1105+
1106+
impl TestCase {
1107+
fn run(self) {
1108+
let Self {
1109+
input,
1110+
record_output,
1111+
expected,
1112+
} = self;
1113+
println!("TestCase");
1114+
println!("**input:\n{input}\n");
1115+
println!("**record_output:\n{record_output:#?}\n");
1116+
println!("**expected:\n{}\n", expected.unwrap_or(""));
1117+
let input = parse_to_record(input);
1118+
let expected = expected.map(parse_to_record);
1119+
let output = update_record_with_output(&input, &record_output, " ", default_validator);
1120+
assert_eq!(output, expected);
1121+
}
1122+
}
1123+
1124+
fn parse_to_record(s: &str) -> Record {
1125+
let mut records = parse(s).unwrap();
1126+
assert_eq!(records.len(), 1);
1127+
records.pop().unwrap()
1128+
}
1129+
1130+
/// Returns a RecordOutput that models the successful execution of a query
1131+
fn query_output(rows: &[&[&str]]) -> RecordOutput {
1132+
let rows = rows
1133+
.iter()
1134+
.map(|cols| cols.iter().map(|c| c.to_string()).collect::<Vec<_>>())
1135+
.collect::<Vec<_>>();
1136+
1137+
let types = rows.iter().map(|_| ColumnType::Any).collect();
1138+
1139+
RecordOutput::Query {
1140+
types,
1141+
rows,
1142+
error: None,
1143+
}
1144+
}
1145+
1146+
/// Returns a RecordOutput that models the error of a query
1147+
fn query_error(error_message: &str) -> RecordOutput {
1148+
RecordOutput::Query {
1149+
types: vec![],
1150+
rows: vec![],
1151+
error: Some(Arc::new(TestError(error_message.to_string()))),
1152+
}
1153+
}
1154+
1155+
#[derive(Debug)]
1156+
struct TestError(String);
1157+
impl std::error::Error for TestError {}
1158+
impl std::fmt::Display for TestError {
1159+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1160+
write!(f, "TestError: {}", self.0)
1161+
}
1162+
}
1163+
}

0 commit comments

Comments
 (0)
Please sign in to comment.