From 03d7cadc91579e4e460eeb5a6364d44e268f9959 Mon Sep 17 00:00:00 2001 From: xxchan Date: Tue, 6 Dec 2022 20:30:47 +0100 Subject: [PATCH 1/7] feat: support formatting test files Signed-off-by: xxchan --- sqllogictest-bin/src/lib.rs | 242 +++++++++++++++++++++++++++++++----- sqllogictest/src/parser.rs | 182 ++++++++++++++++++++++++--- sqllogictest/src/runner.rs | 26 ++-- 3 files changed, 392 insertions(+), 58 deletions(-) diff --git a/sqllogictest-bin/src/lib.rs b/sqllogictest-bin/src/lib.rs index fdbc99b..7ccc029 100644 --- a/sqllogictest-bin/src/lib.rs +++ b/sqllogictest-bin/src/lib.rs @@ -1,6 +1,7 @@ mod engines; use std::collections::BTreeMap; +use std::fs::File; use std::io::{stdout, Write}; use std::path::{Path, PathBuf}; use std::time::{Duration, Instant}; @@ -14,7 +15,7 @@ use futures::StreamExt; use itertools::Itertools; use quick_junit::{NonSuccessKind, Report, TestCase, TestCaseStatus, TestSuite}; use rand::seq::SliceRandom; -use sqllogictest::{AsyncDB, Control, Record, Runner}; +use sqllogictest::{AsyncDB, Injected, Record, Runner}; #[derive(Copy, Clone, Debug, PartialEq, Eq, ArgEnum)] #[must_use] @@ -84,6 +85,13 @@ struct Opt { /// The database password. #[clap(short = 'w', long, default_value = "postgres")] pass: String, + + /// Overrides the test files with the actual output of the database. + #[clap(long)] + r#override: bool, + /// Reformats the test files. + #[clap(long)] + format: bool, } /// Connection configuration. @@ -123,6 +131,8 @@ pub async fn main_okk() -> Result<()> { db, user, pass, + r#override, + format, } = Opt::parse(); if host.len() != port.len() { @@ -171,6 +181,10 @@ pub async fn main_okk() -> Result<()> { pass, }; + if r#override || format { + return update_test_files(files, &engine, config, format).await; + } + let mut report = Report::new(junit.clone().unwrap_or_else(|| "sqllogictest".to_string())); report.set_timestamp(Local::now()); @@ -348,6 +362,28 @@ async fn run_serial( } } +/// * `format` - If true, will not run sqls, only formats the file. +async fn update_test_files( + files: Vec, + engine: &EngineConfig, + config: DBConfig, + format: bool, +) -> Result<()> { + for file in files { + let engine = engines::connect(engine, &config).await?; + let runner = Runner::new(engine); + + if let Err(e) = update_test_file(&mut std::io::stdout(), runner, &file, format).await { + { + println!("{}\n\n{:?}", style("[FAILED]").red().bold(), e); + println!(); + } + }; + } + + Ok(()) +} + async fn flush(out: &mut impl std::io::Write) -> std::io::Result<()> { tokio::task::block_in_place(|| out.flush()) } @@ -384,38 +420,9 @@ async fn run_test_file( begin_times.push(Instant::now()); - let finish = |out: &mut T, time_stack: &mut Vec, did_pop: &mut bool, file: &str| { - let begin_time = time_stack.pop().unwrap(); - - if *did_pop { - // start a new line if the result is not immediately after the item - write!( - out, - "\n{}{} {: <54} .. {} in {} ms", - "| ".repeat(time_stack.len()), - style("[END]").blue().bold(), - file, - style("[OK]").green().bold(), - begin_time.elapsed().as_millis() - )?; - } else { - // otherwise, append time to the previous line - write!( - out, - "{} in {} ms", - style("[OK]").green().bold(), - begin_time.elapsed().as_millis() - )?; - } - - *did_pop = true; - - Ok::<_, anyhow::Error>(()) - }; - for record in records { match &record { - Record::Control(Control::BeginInclude(file)) => { + Record::Injected(Injected::BeginInclude(file)) => { begin_times.push(Instant::now()); if !did_pop { writeln!(out, "{}", style("[BEGIN]").blue().bold())?; @@ -431,11 +438,12 @@ async fn run_test_file( )?; flush(out).await?; } - Record::Control(Control::EndInclude(file)) => { - finish(out, &mut begin_times, &mut did_pop, file)?; + Record::Injected(Injected::EndInclude(file)) => { + finish_test_file(out, &mut begin_times, &mut did_pop, file)?; } _ => {} } + runner .run_async(record) .await @@ -448,7 +456,7 @@ async fn run_test_file( let duration = begin_times[0].elapsed(); - finish( + finish_test_file( out, &mut begin_times, &mut did_pop, @@ -459,3 +467,169 @@ async fn run_test_file( Ok(duration) } + +fn finish_test_file( + out: &mut T, + time_stack: &mut Vec, + did_pop: &mut bool, + file: &str, +) -> Result<()> { + let begin_time = time_stack.pop().unwrap(); + + if *did_pop { + // start a new line if the result is not immediately after the item + write!( + out, + "\n{}{} {: <54} .. {} in {} ms", + "| ".repeat(time_stack.len()), + style("[END]").blue().bold(), + file, + style("[OK]").green().bold(), + begin_time.elapsed().as_millis() + )?; + } else { + // otherwise, append time to the previous line + write!( + out, + "{} in {} ms", + style("[OK]").green().bold(), + begin_time.elapsed().as_millis() + )?; + } + + *did_pop = true; + + Ok::<_, anyhow::Error>(()) +} + +async fn update_test_file( + out: &mut T, + _runner: Runner, + filename: impl AsRef, + format: bool, +) -> Result<()> { + let filename = filename.as_ref(); + let records = tokio::task::block_in_place(|| { + sqllogictest::parse_file(filename).map_err(|e| anyhow!("{:?}", e)) + }) + .context("failed to parse sqllogictest file")?; + + let mut begin_times = vec![]; + let mut did_pop = false; + + write!(out, "{: <60} .. ", filename.to_string_lossy())?; + flush(out).await?; + + begin_times.push(Instant::now()); + + fn create_outfile(filename: impl AsRef) -> std::io::Result<(PathBuf, File)> { + let filename = filename.as_ref(); + let outfilename = filename.file_name().unwrap().to_str().unwrap().to_owned() + ".temp"; + let outfilename = filename.parent().unwrap().join(&outfilename); + let outfile = File::create(&outfilename)?; + Ok((outfilename, outfile)) + } + + struct Item { + filename: String, + outfilename: PathBuf, + outfile: File, + first_record: bool, + } + let (outfilename, outfile) = create_outfile(filename)?; + let mut stack = vec![Item { + filename: filename.to_string_lossy().to_string(), + outfilename, + outfile, + first_record: true, + }]; + + for record in records { + let Item { + filename, + outfilename, + outfile, + first_record, + } = stack.last_mut().unwrap(); + + match &record { + Record::Injected(Injected::BeginInclude(filename)) => { + let (outfilename, outfile) = create_outfile(filename)?; + stack.push(Item { + filename: filename.clone(), + outfilename, + outfile, + first_record: true, + }); + + begin_times.push(Instant::now()); + if !did_pop { + writeln!(out, "{}", style("[BEGIN]").blue().bold())?; + } else { + writeln!(out)?; + } + did_pop = false; + write!( + out, + "{}{: <60} .. ", + "| ".repeat(begin_times.len() - 1), + filename + )?; + flush(out).await?; + } + Record::Injected(Injected::EndInclude(file)) => { + // override the original file with the updated one + std::fs::rename(outfilename, filename)?; + stack.pop(); + finish_test_file(out, &mut begin_times, &mut did_pop, file)?; + } + _ => { + if !*first_record { + writeln!(outfile)?; + } + + update_record(outfile, record, format) + .context(format!("failed to run `{}`", style(filename).bold()))?; + *first_record = false; + } + } + } + + finish_test_file( + out, + &mut begin_times, + &mut did_pop, + &filename.to_string_lossy(), + )?; + + writeln!(out)?; + + // override the original file with the updated one + let Item { + filename, + outfilename, + outfile: _, + first_record: _, + } = stack.last().unwrap(); + std::fs::rename(outfilename, filename)?; + + Ok(()) +} + +fn update_record(outfile: &mut File, record: Record, format: bool) -> Result<()> { + if format { + todo!() + } + + match record { + Record::Injected(_) => { + unreachable!() + } + _ => { + record.unparse(outfile)?; + writeln!(outfile)?; + } + } + + Ok(()) +} diff --git a/sqllogictest/src/parser.rs b/sqllogictest/src/parser.rs index 980f33b..1b528e7 100644 --- a/sqllogictest/src/parser.rs +++ b/sqllogictest/src/parser.rs @@ -70,7 +70,10 @@ impl Location { #[non_exhaustive] pub enum Record { /// An include copies all records from another files. - Include { loc: Location, filename: String }, + Include { + loc: Location, + filename: String, + }, /// A statement is an SQL command that is to be evaluated but from which we do not expect to /// get results (other than success or failure). Statement { @@ -101,12 +104,20 @@ pub enum Record { expected_results: Vec, }, /// A sleep period. - Sleep { loc: Location, duration: Duration }, + Sleep { + loc: Location, + duration: Duration, + }, /// Subtest. - Subtest { loc: Location, name: String }, + Subtest { + loc: Location, + name: String, + }, /// A halt record merely causes sqllogictest to ignore the rest of the test script. /// For debugging use only. - Halt { loc: Location }, + Halt { + loc: Location, + }, /// Control statements. Control(Control), /// Set the maximum number of result values that will be accepted @@ -115,13 +126,129 @@ pub enum Record { /// is the only result. /// /// If the threshold is 0, then hashing is never used. - HashThreshold { loc: Location, threshold: u64 }, + HashThreshold { + loc: Location, + threshold: u64, + }, + Condition(Condition), + Comment(Vec), + /// Internally injected record which should not occur in the test file. + Injected(Injected), +} + +impl Record { + /// Unparses the record to its string representation in the test file. + /// + /// # Panics + /// If the record is an internally injected record which should not occur in the test file. + pub fn unparse(&self, w: &mut impl std::io::Write) -> std::io::Result<()> { + match self { + Record::Include { loc: _, filename } => { + write!(w, "include {}", filename) + } + Record::Statement { + loc: _, + conditions: _, + expected_error, + sql, + expected_count, + } => { + write!(w, "statement ")?; + match (expected_count, expected_error) { + (None, None) => write!(w, "ok")?, + (None, Some(err)) => { + if err.as_str().is_empty() { + write!(w, "error")?; + } else { + write!(w, "error {}", err)?; + } + } + (Some(cnt), None) => write!(w, "count {}", cnt)?, + (Some(_), Some(_)) => unreachable!(), + } + writeln!(w)?; + write!(w, "{}", sql) + } + Record::Query { + loc: _, + conditions: _, + type_string, + sort_mode, + label, + expected_error, + sql, + expected_results, + } => { + write!(w, "query",)?; + if let Some(err) = expected_error { + writeln!(w, " error {}", err)?; + return write!(w, "{}", sql); + } + + write!( + w, + " {}", + type_string.iter().map(|c| format!("{c}")).join("") + )?; + if let Some(sort_mode) = sort_mode { + write!(w, " {}", sort_mode.as_str())?; + } + if let Some(label) = label { + write!(w, " {}", label)?; + } + writeln!(w)?; + writeln!(w, "{}", sql)?; + + write!(w, "----")?; + for result in expected_results { + write!(w, "\n{}", result)?; + } + Ok(()) + } + Record::Sleep { loc: _, duration } => { + write!(w, "sleep {}", humantime::format_duration(*duration)) + } + Record::Subtest { loc: _, name } => { + write!(w, "subtest {}", name) + } + Record::Halt { loc: _ } => { + write!(w, "halt") + } + Record::Control(c) => match c { + Control::SortMode(m) => write!(w, "control sortmode {}", m.as_str()), + }, + Record::Condition(cond) => match cond { + Condition::OnlyIf { engine_name } => { + write!(w, "onlyif {}", engine_name) + } + Condition::SkipIf { engine_name } => { + write!(w, "skipif {}", engine_name) + } + }, + Record::HashThreshold { loc: _, threshold } => { + write!(w, "hash-threshold {}", threshold) + } + Record::Comment(comment) => { + let mut iter = comment.iter(); + write!(w, "#{}", iter.next().unwrap().trim_end())?; + for line in iter { + write!(w, "\n#{}", line.trim_end())?; + } + Ok(()) + } + Record::Injected(p) => panic!("unexpected injected record: {:?}", p), + } + } } #[derive(Debug, PartialEq, Eq, Clone)] pub enum Control { /// Control sort mode. SortMode(SortMode), +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum Injected { /// Pseudo control command to indicate the begin of an include statement. Automatically /// injected by sqllogictest parser. BeginInclude(String), @@ -246,12 +373,30 @@ fn parse_inner(loc: &Location, script: &str) -> Result, ParseError> let mut lines = script.split('\n').enumerate(); let mut records = vec![]; let mut conditions = vec![]; - while let Some((num, line)) = lines.next() { - if line.is_empty() || line.starts_with('#') { + + while let Some((mut num, mut line)) = lines.next() { + if let Some(text) = line.strip_prefix('#') { + let mut comments = vec![text.to_string()]; + for (num_, line_) in lines.by_ref() { + num = num_; + line = line_; + if let Some(text) = line.strip_prefix('#') { + comments.push(text.to_string()); + } else { + break; + } + } + + records.push(Record::Comment(comments)); + } + + if line.is_empty() { continue; } + let mut loc = loc.clone(); loc.line = num as u32 + 1; + let tokens: Vec<&str> = line.split_whitespace().collect(); match tokens.as_slice() { [] => continue, @@ -261,7 +406,6 @@ fn parse_inner(loc: &Location, script: &str) -> Result, ParseError> }), ["halt"] => { records.push(Record::Halt { loc }); - break; } ["subtest", name] => { records.push(Record::Subtest { @@ -278,14 +422,18 @@ fn parse_inner(loc: &Location, script: &str) -> Result, ParseError> }); } ["skipif", engine_name] => { - conditions.push(Condition::SkipIf { + let cond = Condition::SkipIf { engine_name: engine_name.to_string(), - }); + }; + conditions.push(cond.clone()); + records.push(Record::Condition(cond)); } ["onlyif", engine_name] => { - conditions.push(Condition::OnlyIf { + let cond = Condition::OnlyIf { engine_name: engine_name.to_string(), - }); + }; + conditions.push(cond.clone()); + records.push(Record::Condition(cond)); } ["statement", res @ ..] => { let mut expected_count = None; @@ -412,7 +560,7 @@ fn parse_inner(loc: &Location, script: &str) -> Result, ParseError> Ok(records) } -/// Parse a sqllogictest file and link all included scripts together. +/// Parse a sqllogictest file. The included scripts are inserted after the `include` record. pub fn parse_file(filename: impl AsRef) -> Result, ParseError> { let filename = filename.as_ref().to_str().unwrap(); parse_file_inner(Location::new(filename, 0)) @@ -426,6 +574,8 @@ fn parse_file_inner(loc: Location) -> Result, ParseError> { let script = std::fs::read_to_string(path).unwrap(); let mut records = vec![]; for rec in parse_inner(&loc, &script)? { + records.push(rec.clone()); + if let Record::Include { filename, loc } = rec { let complete_filename = { let mut path_buf = path.to_path_buf(); @@ -440,14 +590,12 @@ fn parse_file_inner(loc: Location) -> Result, ParseError> { { let included_file = included_file.as_os_str().to_string_lossy().to_string(); - records.push(Record::Control(Control::BeginInclude( + records.push(Record::Injected(Injected::BeginInclude( included_file.clone(), ))); records.extend(parse_file_inner(loc.include(&included_file))?); - records.push(Record::Control(Control::EndInclude(included_file))); + records.push(Record::Injected(Injected::EndInclude(included_file))); } - } else { - records.push(rec); } } Ok(records) diff --git a/sqllogictest/src/runner.rs b/sqllogictest/src/runner.rs index 2934de7..c820318 100644 --- a/sqllogictest/src/runner.rs +++ b/sqllogictest/src/runner.rs @@ -26,6 +26,17 @@ pub enum ColumnType { Any, } +impl Display for ColumnType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ColumnType::Text => write!(f, "T"), + ColumnType::Integer => write!(f, "I"), + ColumnType::FloatingPoint => write!(f, "R"), + ColumnType::Any => write!(f, "?"), + } + } +} + impl TryFrom for ColumnType { type Error = ParseErrorKind; @@ -33,7 +44,8 @@ impl TryFrom for ColumnType { match c { 'T' => Ok(Self::Text), 'I' => Ok(Self::Integer), - 'F' => Ok(Self::FloatingPoint), + 'R' => Ok(Self::FloatingPoint), + '?' => Ok(Self::Any), // FIXME: // _ => Err(ParseErrorKind::InvalidType(c)), _ => Ok(Self::Any), @@ -606,18 +618,18 @@ impl Runner { } } Record::Sleep { duration, .. } => D::sleep(duration).await, - Record::Halt { .. } => {} - Record::Subtest { .. } => {} - Record::Include { loc, .. } => { - unreachable!("include should be rewritten during link: at {}", loc) - } + Record::Include { .. } => {} Record::Control(control) => match control { Control::SortMode(sort_mode) => { self.sort_mode = Some(sort_mode); } - Control::BeginInclude(_) | Control::EndInclude(_) => {} }, Record::HashThreshold { loc: _, threshold } => self.hash_threshold = threshold as usize, + Record::Comment(_) + | Record::Subtest { .. } + | Record::Halt { .. } + | Record::Injected(_) + | Record::Condition(_) => {} } Ok(()) } From 97c6f2fe619f46ec258238752320239c98d37d73 Mon Sep 17 00:00:00 2001 From: xxchan Date: Tue, 6 Dec 2022 22:52:07 +0100 Subject: [PATCH 2/7] feat: support generate test outputs Signed-off-by: xxchan --- sqllogictest-bin/src/lib.rs | 124 ++++++++++++-- sqllogictest/src/parser.rs | 4 +- sqllogictest/src/runner.rs | 313 ++++++++++++++++++++++-------------- 3 files changed, 306 insertions(+), 135 deletions(-) diff --git a/sqllogictest-bin/src/lib.rs b/sqllogictest-bin/src/lib.rs index 7ccc029..9cd89c7 100644 --- a/sqllogictest-bin/src/lib.rs +++ b/sqllogictest-bin/src/lib.rs @@ -15,7 +15,7 @@ use futures::StreamExt; use itertools::Itertools; use quick_junit::{NonSuccessKind, Report, TestCase, TestCaseStatus, TestSuite}; use rand::seq::SliceRandom; -use sqllogictest::{AsyncDB, Injected, Record, Runner}; +use sqllogictest::{AsyncDB, Injected, Record, RecordOutput, Runner}; #[derive(Copy, Clone, Debug, PartialEq, Eq, ArgEnum)] #[must_use] @@ -504,7 +504,7 @@ fn finish_test_file( async fn update_test_file( out: &mut T, - _runner: Runner, + mut runner: Runner, filename: impl AsRef, format: bool, ) -> Result<()> { @@ -588,7 +588,8 @@ async fn update_test_file( writeln!(outfile)?; } - update_record(outfile, record, format) + update_record(outfile, &mut runner, record, format) + .await .context(format!("failed to run `{}`", style(filename).bold()))?; *first_record = false; } @@ -616,19 +617,122 @@ async fn update_test_file( Ok(()) } -fn update_record(outfile: &mut File, record: Record, format: bool) -> Result<()> { +async fn update_record( + outfile: &mut File, + runner: &mut Runner, + record: Record, + format: bool, +) -> Result<()> { + assert!(!matches!(record, Record::Injected(_))); + if format { - todo!() + record.unparse(outfile)?; + writeln!(outfile)?; + return Ok(()); } - match record { - Record::Injected(_) => { - unreachable!() - } - _ => { + match (record.clone(), runner.apply_record(record).await) { + (record, RecordOutput::Nothing) => { record.unparse(outfile)?; writeln!(outfile)?; } + ( + Record::Statement { + loc: _, + conditions: _, + expected_error, + sql, + expected_count, + }, + RecordOutput::Statement { count, error }, + ) => match (error, expected_error) { + (None, _) => { + if expected_count.is_some() { + writeln!(outfile, "statement count {count}")?; + writeln!(outfile, "{}", sql)?; + } else { + writeln!(outfile, "statement ok")?; + writeln!(outfile, "{}", sql)?; + } + } + (Some(e), Some(expected_error)) if expected_error.is_match(&e.to_string()) => { + writeln!(outfile, "statement error {}", expected_error)?; + writeln!(outfile, "{}", sql)?; + } + (Some(e), _) => { + writeln!(outfile, "statement error {}", e)?; + writeln!(outfile, "{}", sql)?; + } + }, + ( + Record::Query { + loc: _, + conditions: _, + type_string, + sort_mode, + label, + expected_error, + sql, + expected_results, + }, + RecordOutput::Query { + types: _, + rows, + error, + }, + ) => { + match (error, expected_error) { + (None, _) => {} + (Some(e), Some(expected_error)) if expected_error.is_match(&e.to_string()) => { + writeln!(outfile, "query error {}", expected_error)?; + writeln!(outfile, "{}", sql)?; + } + (Some(e), _) => { + writeln!(outfile, "query error {}", e)?; + writeln!(outfile, "{}", sql)?; + } + }; + + write!( + outfile, + "query {}", + type_string.iter().map(|c| format!("{c}")).join("") + )?; + if let Some(sort_mode) = sort_mode { + write!(outfile, " {}", sort_mode.as_str())?; + } + if let Some(label) = label { + write!(outfile, " {}", label)?; + } + writeln!(outfile)?; + writeln!(outfile, "{}", sql)?; + + #[allow(clippy::ptr_arg)] + fn normalize_string(s: &String) -> String { + s.trim().split_ascii_whitespace().join(" ") + } + + let normalized_rows = rows + .into_iter() + .map(|strs| strs.iter().map(normalize_string).join(" ")) + .collect_vec(); + + let normalized_expected = expected_results.iter().map(normalize_string).collect_vec(); + + writeln!(outfile, "----")?; + + if normalized_expected == normalized_rows { + // If the results are correct, do not format them. + for result in expected_results { + writeln!(outfile, "{}", result)?; + } + } else { + for result in normalized_rows { + writeln!(outfile, "{}", result)?; + } + }; + } + _ => unreachable!(), } Ok(()) diff --git a/sqllogictest/src/parser.rs b/sqllogictest/src/parser.rs index 1b528e7..6a25f26 100644 --- a/sqllogictest/src/parser.rs +++ b/sqllogictest/src/parser.rs @@ -179,7 +179,7 @@ impl Record { sql, expected_results, } => { - write!(w, "query",)?; + write!(w, "query")?; if let Some(err) = expected_error { writeln!(w, " error {}", err)?; return write!(w, "{}", sql); @@ -608,6 +608,6 @@ mod tests { #[test] fn test_include_glob() { let records = parse_file("../examples/include/include_1.slt").unwrap(); - assert_eq!(12, records.len()); + assert_eq!(14, records.len()); } } diff --git a/sqllogictest/src/runner.rs b/sqllogictest/src/runner.rs index c820318..44c3fc7 100644 --- a/sqllogictest/src/runner.rs +++ b/sqllogictest/src/runner.rs @@ -53,6 +53,19 @@ impl TryFrom for ColumnType { } } +pub enum RecordOutput { + Nothing, + Query { + types: Vec, + rows: Vec>, + error: Option>, + }, + Statement { + count: u64, + error: Option>, + }, +} + #[non_exhaustive] pub enum DBOutput { Rows { @@ -426,102 +439,206 @@ impl Runner { self.validator = validator; } - /// Run a single record. - pub async fn run_async(&mut self, record: Record) -> Result<(), TestError> { - tracing::info!(?record, "testing"); + pub async fn apply_record(&mut self, record: Record) -> RecordOutput { match record { - Record::Statement { conditions, .. } if self.should_skip(&conditions) => {} + Record::Statement { conditions, .. } if self.should_skip(&conditions) => { + RecordOutput::Nothing + } Record::Statement { conditions: _, - - expected_error, sql, - loc, - expected_count, + + // compare result in run_async + expected_error: _, + expected_count: _, + loc: _, } => { let sql = self.replace_keywords(sql); let ret = self.db.run(&sql).await; - match (ret, expected_error) { - (Ok(_), Some(_)) => { - return Err(TestErrorKind::Ok { - sql, - kind: RecordKind::Statement, + let ret = match ret { + Ok(out) => match out { + DBOutput::Rows { .. } => panic!("DB should not return rows for statement"), + DBOutput::StatementComplete(count) => { + RecordOutput::Statement { count, error: None } + } + }, + Err(e) => RecordOutput::Statement { + count: 0, + error: Some(Arc::new(e)), + }, + }; + if let Some(hook) = &mut self.hook { + hook.on_stmt_complete(&sql).await; + } + ret + } + Record::Query { conditions, .. } if self.should_skip(&conditions) => { + RecordOutput::Nothing + } + Record::Query { + conditions: _, + sql, + sort_mode, + + // compare result in run_async + type_string: _, + expected_error: _, + expected_results: _, + loc: _, + + // not handle yet, + label: _, + } => { + let sql = self.replace_keywords(sql); + let (types, mut rows) = match self.db.run(&sql).await { + Ok(out) => match out { + DBOutput::Rows { types, rows } => (types, rows), + DBOutput::StatementComplete(_) => panic!("DB should return rows for query"), + }, + Err(e) => { + return RecordOutput::Query { + error: Some(Arc::new(e)), + types: vec![], + rows: vec![], } - .at(loc)) } - (Ok(result), None) => { - if let Some(expected_count) = expected_count { - let count = match result { - DBOutput::Rows { types: _, rows } => { - return Err(TestErrorKind::StatementResultMismatch { - sql, - expected: expected_count, - actual: format!("got rows {:?}", rows), - } - .at(loc)); - } - DBOutput::StatementComplete(count) => count, - }; - - if expected_count != count { - return Err(TestErrorKind::StatementResultMismatch { - sql, - expected: expected_count, - actual: format!("affected {count} rows"), - } - .at(loc)); - } + }; + + match sort_mode.as_ref().or(self.sort_mode.as_ref()) { + None | Some(SortMode::NoSort) => {} + Some(SortMode::RowSort) => { + rows.sort_unstable(); + } + Some(SortMode::ValueSort) => todo!("value sort"), + }; + + if self.hash_threshold > 0 && rows.len() > self.hash_threshold { + let mut md5 = md5::Context::new(); + for line in &rows { + for value in line { + md5.consume(value.as_bytes()); + md5.consume(b"\n"); } } - (Err(e), Some(expected_error)) => { - if !expected_error.is_match(&e.to_string()) { - return Err(TestErrorKind::ErrorMismatch { + let hash = md5.compute(); + rows = vec![vec![format!( + "{} values hashing to {:?}", + rows.len() * rows[0].len(), + hash + )]]; + } + + if let Some(hook) = &mut self.hook { + hook.on_query_complete(&sql).await; + } + + RecordOutput::Query { + error: None, + types, + rows, + } + } + Record::Sleep { duration, .. } => { + D::sleep(duration).await; + RecordOutput::Nothing + } + Record::Control(control) => match control { + Control::SortMode(sort_mode) => { + self.sort_mode = Some(sort_mode); + RecordOutput::Nothing + } + }, + Record::HashThreshold { loc: _, threshold } => { + self.hash_threshold = threshold as usize; + RecordOutput::Nothing + } + Record::Include { .. } + | Record::Comment(_) + | Record::Subtest { .. } + | Record::Halt { .. } + | Record::Injected(_) + | Record::Condition(_) => RecordOutput::Nothing, + } + } + + /// Run a single record. + pub async fn run_async(&mut self, record: Record) -> Result<(), TestError> { + tracing::info!(?record, "testing"); + + match (record.clone(), self.apply_record(record).await) { + (_, RecordOutput::Nothing) => {} + ( + Record::Statement { + loc, + conditions: _, + expected_error, + sql, + expected_count, + }, + RecordOutput::Statement { count, error }, + ) => match (error, expected_error) { + (None, Some(_)) => { + return Err(TestErrorKind::Ok { + sql, + kind: RecordKind::Statement, + } + .at(loc)) + } + (None, None) => { + if let Some(expected_count) = expected_count { + if expected_count != count { + return Err(TestErrorKind::StatementResultMismatch { sql, - err: Arc::new(e), - expected_err: expected_error.to_string(), - kind: RecordKind::Statement, + expected: expected_count, + actual: format!("affected {count} rows"), } .at(loc)); } } - (Err(e), None) => { - return Err(TestErrorKind::Fail { + } + (Some(e), Some(expected_error)) => { + if !expected_error.is_match(&e.to_string()) { + return Err(TestErrorKind::ErrorMismatch { sql, err: Arc::new(e), + expected_err: expected_error.to_string(), kind: RecordKind::Statement, } .at(loc)); } } - if let Some(hook) = &mut self.hook { - hook.on_stmt_complete(&sql).await; + (Some(e), None) => { + return Err(TestErrorKind::Fail { + sql, + err: Arc::new(e), + kind: RecordKind::Statement, + } + .at(loc)); } - } - Record::Query { conditions, .. } if self.should_skip(&conditions) => {} - Record::Query { - conditions: _, - - loc, - sql, - expected_error, - expected_results, - sort_mode, - type_string, - - // not handle yet, - label: _, - } => { - let sql = self.replace_keywords(sql); - let output = match (self.db.run(&sql).await, expected_error) { - (Ok(_), Some(_)) => { + }, + ( + Record::Query { + loc, + conditions: _, + type_string, + sort_mode: _, + label: _, + expected_error, + sql, + expected_results, + }, + RecordOutput::Query { types, rows, error }, + ) => { + match (error, expected_error) { + (None, Some(_)) => { return Err(TestErrorKind::Ok { sql, kind: RecordKind::Query, } .at(loc)) } - (Ok(output), None) => output, - (Err(e), Some(expected_error)) => { + (None, None) => {} + (Some(e), Some(expected_error)) => { if !expected_error.is_match(&e.to_string()) { return Err(TestErrorKind::ErrorMismatch { sql, @@ -533,7 +650,7 @@ impl Runner { } return Ok(()); } - (Err(e), None) => { + (Some(e), None) => { return Err(TestErrorKind::Fail { sql, err: Arc::new(e), @@ -543,18 +660,6 @@ impl Runner { } }; - let (types, mut output) = match output { - DBOutput::Rows { types, rows } => (types, rows), - DBOutput::StatementComplete(_) => { - return Err(TestErrorKind::QueryResultMismatch { - sql, - expected: expected_results.join("\n"), - actual: "statement complete".to_string(), - } - .at(loc)) - } - }; - // check number of columns if types.len() != type_string.len() { // FIXME: do not validate type-string now @@ -574,63 +679,25 @@ impl Runner { } } - match sort_mode.as_ref().or(self.sort_mode.as_ref()) { - None | Some(SortMode::NoSort) => {} - Some(SortMode::RowSort) => { - output.sort_unstable(); - } - Some(SortMode::ValueSort) => todo!("value sort"), - }; - - if self.hash_threshold > 0 && output.len() > self.hash_threshold { - let mut md5 = md5::Context::new(); - for line in &output { - for value in line { - md5.consume(value.as_bytes()); - md5.consume(b"\n"); - } - } - let hash = md5.compute(); - output = vec![vec![format!( - "{} values hashing to {:?}", - output.len() * output[0].len(), - hash - )]]; - } - // We compare normalized results. Whitespace characters are ignored. - let output = output + let normalized_rows = rows .into_iter() .map(|strs| strs.iter().map(normalize_string).join(" ")) .collect_vec(); - let expected_results = expected_results.iter().map(normalize_string).collect_vec(); - if !(self.validator)(&output, &expected_results) { + let expected_results = expected_results.iter().map(normalize_string).collect_vec(); + if !(self.validator)(&normalized_rows, &expected_results) { return Err(TestErrorKind::QueryResultMismatch { sql, expected: expected_results.join("\n"), - actual: output.join("\n"), + actual: normalized_rows.join("\n"), } .at(loc)); } - if let Some(hook) = &mut self.hook { - hook.on_query_complete(&sql).await; - } } - Record::Sleep { duration, .. } => D::sleep(duration).await, - Record::Include { .. } => {} - Record::Control(control) => match control { - Control::SortMode(sort_mode) => { - self.sort_mode = Some(sort_mode); - } - }, - Record::HashThreshold { loc: _, threshold } => self.hash_threshold = threshold as usize, - Record::Comment(_) - | Record::Subtest { .. } - | Record::Halt { .. } - | Record::Injected(_) - | Record::Condition(_) => {} + _ => unreachable!(), } + Ok(()) } From a2d79c23d2460e91a22b721bfb056b42d6d90408 Mon Sep 17 00:00:00 2001 From: xxchan Date: Tue, 6 Dec 2022 22:58:52 +0100 Subject: [PATCH 3/7] use tab to format query results Signed-off-by: xxchan --- sqllogictest-bin/src/lib.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sqllogictest-bin/src/lib.rs b/sqllogictest-bin/src/lib.rs index 9cd89c7..91e9a8e 100644 --- a/sqllogictest-bin/src/lib.rs +++ b/sqllogictest-bin/src/lib.rs @@ -713,7 +713,7 @@ async fn update_record( } let normalized_rows = rows - .into_iter() + .iter() .map(|strs| strs.iter().map(normalize_string).join(" ")) .collect_vec(); @@ -727,8 +727,8 @@ async fn update_record( writeln!(outfile, "{}", result)?; } } else { - for result in normalized_rows { - writeln!(outfile, "{}", result)?; + for result in rows { + writeln!(outfile, "{}", result.iter().format("\t"))?; } }; } From a813bf3ea37a624ebc1764b502fc1e5ed4117229 Mon Sep 17 00:00:00 2001 From: xxchan Date: Tue, 6 Dec 2022 23:05:15 +0100 Subject: [PATCH 4/7] do not update type-string now Signed-off-by: xxchan --- sqllogictest-bin/src/lib.rs | 1 + sqllogictest/src/runner.rs | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/sqllogictest-bin/src/lib.rs b/sqllogictest-bin/src/lib.rs index 91e9a8e..c538fe3 100644 --- a/sqllogictest-bin/src/lib.rs +++ b/sqllogictest-bin/src/lib.rs @@ -693,6 +693,7 @@ async fn update_record( } }; + // FIXME: use output's types instead of orignal query's types write!( outfile, "query {}", diff --git a/sqllogictest/src/runner.rs b/sqllogictest/src/runner.rs index 44c3fc7..0e8cd6f 100644 --- a/sqllogictest/src/runner.rs +++ b/sqllogictest/src/runner.rs @@ -24,6 +24,7 @@ pub enum ColumnType { FloatingPoint, /// Do not check the type of the column. Any, + Unknown(char), } impl Display for ColumnType { @@ -33,6 +34,7 @@ impl Display for ColumnType { ColumnType::Integer => write!(f, "I"), ColumnType::FloatingPoint => write!(f, "R"), ColumnType::Any => write!(f, "?"), + ColumnType::Unknown(c) => write!(f, "{}", c), } } } @@ -48,7 +50,7 @@ impl TryFrom for ColumnType { '?' => Ok(Self::Any), // FIXME: // _ => Err(ParseErrorKind::InvalidType(c)), - _ => Ok(Self::Any), + _ => Ok(Self::Unknown(c)), } } } From c0b75c492938022c2a64c461e828b80fcf8ea2b2 Mon Sep 17 00:00:00 2001 From: xxchan Date: Tue, 6 Dec 2022 23:34:16 +0100 Subject: [PATCH 5/7] tolerate result type mismatch Signed-off-by: xxchan --- sqllogictest-bin/src/lib.rs | 23 +++++++++++++++++++++++ sqllogictest/src/runner.rs | 30 ++++++++++++++++++++++++++++-- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/sqllogictest-bin/src/lib.rs b/sqllogictest-bin/src/lib.rs index c538fe3..9dfb1fc 100644 --- a/sqllogictest-bin/src/lib.rs +++ b/sqllogictest-bin/src/lib.rs @@ -636,6 +636,29 @@ async fn update_record( record.unparse(outfile)?; writeln!(outfile)?; } + ( + Record::Statement { sql, .. }, + RecordOutput::Query { + types, + rows, + error: None, + }, + ) => { + writeln!( + outfile, + "query {}", + types.iter().map(|c| format!("{c}")).join("") + )?; + writeln!(outfile, "{}", sql)?; + writeln!(outfile, "----")?; + for result in rows { + writeln!(outfile, "{}", result.iter().format("\t"))?; + } + } + (Record::Query { sql, .. }, RecordOutput::Statement { error: None, .. }) => { + writeln!(outfile, "statement ok")?; + writeln!(outfile, "{}", sql)?; + } ( Record::Statement { loc: _, diff --git a/sqllogictest/src/runner.rs b/sqllogictest/src/runner.rs index 0e8cd6f..340ed9f 100644 --- a/sqllogictest/src/runner.rs +++ b/sqllogictest/src/runner.rs @@ -459,7 +459,11 @@ impl Runner { let ret = self.db.run(&sql).await; let ret = match ret { Ok(out) => match out { - DBOutput::Rows { .. } => panic!("DB should not return rows for statement"), + DBOutput::Rows { types, rows } => RecordOutput::Query { + types, + rows, + error: None, + }, DBOutput::StatementComplete(count) => { RecordOutput::Statement { count, error: None } } @@ -495,7 +499,9 @@ impl Runner { let (types, mut rows) = match self.db.run(&sql).await { Ok(out) => match out { DBOutput::Rows { types, rows } => (types, rows), - DBOutput::StatementComplete(_) => panic!("DB should return rows for query"), + DBOutput::StatementComplete(count) => { + return RecordOutput::Statement { count, error: None } + } }, Err(e) => { return RecordOutput::Query { @@ -569,6 +575,26 @@ impl Runner { match (record.clone(), self.apply_record(record).await) { (_, RecordOutput::Nothing) => {} + // Tolerate the mismatched return type... + (Record::Statement { .. }, RecordOutput::Query { error: None, .. }) => {} + ( + Record::Query { + expected_results, + loc, + sql, + .. + }, + RecordOutput::Statement { error: None, .. }, + ) => { + if !expected_results.is_empty() { + return Err(TestErrorKind::QueryResultMismatch { + sql, + expected: expected_results.join("\n"), + actual: "".to_string(), + } + .at(loc)); + } + } ( Record::Statement { loc, From 0302c8bfafac1cf1c06e6be74cc9795f9f3d8f10 Mon Sep 17 00:00:00 2001 From: xxchan Date: Wed, 7 Dec 2022 13:33:53 +0100 Subject: [PATCH 6/7] pgengine: retrieve column info for empty query result Signed-off-by: xxchan --- sqllogictest-bin/src/engines/postgres.rs | 3 ++- sqllogictest-bin/src/engines/postgres_extended.rs | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sqllogictest-bin/src/engines/postgres.rs b/sqllogictest-bin/src/engines/postgres.rs index 513b0c9..99ecf79 100644 --- a/sqllogictest-bin/src/engines/postgres.rs +++ b/sqllogictest-bin/src/engines/postgres.rs @@ -98,8 +98,9 @@ impl sqllogictest::AsyncDB for Postgres { } if output.is_empty() { + let stmt = self.client.prepare(sql).await?; Ok(DBOutput::Rows { - types: vec![], + types: vec![ColumnType::Any; stmt.columns().len()], rows: vec![], }) } else { diff --git a/sqllogictest-bin/src/engines/postgres_extended.rs b/sqllogictest-bin/src/engines/postgres_extended.rs index 36ac386..45645af 100644 --- a/sqllogictest-bin/src/engines/postgres_extended.rs +++ b/sqllogictest-bin/src/engines/postgres_extended.rs @@ -338,8 +338,9 @@ impl sqllogictest::AsyncDB for PostgresExtended { } if output.is_empty() { + let stmt = self.client.prepare(sql).await?; Ok(DBOutput::Rows { - types: vec![], + types: vec![ColumnType::Any; stmt.columns().len()], rows: vec![], }) } else { From d78b302f721fab77377d1a6de44960704426f46b Mon Sep 17 00:00:00 2001 From: xxchan Date: Wed, 7 Dec 2022 15:10:45 +0100 Subject: [PATCH 7/7] remove hook Signed-off-by: xxchan --- sqllogictest/src/runner.rs | 27 +-------------------------- 1 file changed, 1 insertion(+), 26 deletions(-) diff --git a/sqllogictest/src/runner.rs b/sqllogictest/src/runner.rs index 340ed9f..a0c2461 100644 --- a/sqllogictest/src/runner.rs +++ b/sqllogictest/src/runner.rs @@ -394,16 +394,6 @@ fn format_diff( /// By default, we will use `|x, y| x == y`. pub type Validator = fn(&Vec, &Vec) -> bool; -/// A collection of hook functions. -#[async_trait] -pub trait Hook: Send { - /// Called after each statement completes. - async fn on_stmt_complete(&mut self, _sql: &str) {} - - /// Called after each query completes. - async fn on_query_complete(&mut self, _sql: &str) {} -} - /// Sqllogictest runner. pub struct Runner { db: D, @@ -411,7 +401,6 @@ pub struct Runner { validator: Validator, testdir: Option, sort_mode: Option, - hook: Option>, /// 0 means never hashing hash_threshold: usize, } @@ -424,7 +413,6 @@ impl Runner { validator: |x, y| x == y, testdir: None, sort_mode: None, - hook: None, hash_threshold: 0, } } @@ -457,7 +445,7 @@ impl Runner { } => { let sql = self.replace_keywords(sql); let ret = self.db.run(&sql).await; - let ret = match ret { + match ret { Ok(out) => match out { DBOutput::Rows { types, rows } => RecordOutput::Query { types, @@ -472,11 +460,7 @@ impl Runner { count: 0, error: Some(Arc::new(e)), }, - }; - if let Some(hook) = &mut self.hook { - hook.on_stmt_complete(&sql).await; } - ret } Record::Query { conditions, .. } if self.should_skip(&conditions) => { RecordOutput::Nothing @@ -536,10 +520,6 @@ impl Runner { )]]; } - if let Some(hook) = &mut self.hook { - hook.on_query_complete(&sql).await; - } - RecordOutput::Query { error: None, types, @@ -862,11 +842,6 @@ impl Runner { .iter() .any(|c| c.should_skip(self.db.engine_name())) } - - /// Set hook functions. - pub fn set_hook(&mut self, hook: impl Hook + 'static) { - self.hook = Some(Box::new(hook)); - } } /// Trim and replace multiple whitespaces with one.