diff --git a/sqllogictest-bin/src/engines/postgres.rs b/sqllogictest-bin/src/engines/postgres.rs index 513b0c9..99ecf79 100644 --- a/sqllogictest-bin/src/engines/postgres.rs +++ b/sqllogictest-bin/src/engines/postgres.rs @@ -98,8 +98,9 @@ impl sqllogictest::AsyncDB for Postgres { } if output.is_empty() { + let stmt = self.client.prepare(sql).await?; Ok(DBOutput::Rows { - types: vec![], + types: vec![ColumnType::Any; stmt.columns().len()], rows: vec![], }) } else { diff --git a/sqllogictest-bin/src/engines/postgres_extended.rs b/sqllogictest-bin/src/engines/postgres_extended.rs index 36ac386..45645af 100644 --- a/sqllogictest-bin/src/engines/postgres_extended.rs +++ b/sqllogictest-bin/src/engines/postgres_extended.rs @@ -338,8 +338,9 @@ impl sqllogictest::AsyncDB for PostgresExtended { } if output.is_empty() { + let stmt = self.client.prepare(sql).await?; Ok(DBOutput::Rows { - types: vec![], + types: vec![ColumnType::Any; stmt.columns().len()], rows: vec![], }) } else { diff --git a/sqllogictest-bin/src/lib.rs b/sqllogictest-bin/src/lib.rs index fdbc99b..9dfb1fc 100644 --- a/sqllogictest-bin/src/lib.rs +++ b/sqllogictest-bin/src/lib.rs @@ -1,6 +1,7 @@ mod engines; use std::collections::BTreeMap; +use std::fs::File; use std::io::{stdout, Write}; use std::path::{Path, PathBuf}; use std::time::{Duration, Instant}; @@ -14,7 +15,7 @@ use futures::StreamExt; use itertools::Itertools; use quick_junit::{NonSuccessKind, Report, TestCase, TestCaseStatus, TestSuite}; use rand::seq::SliceRandom; -use sqllogictest::{AsyncDB, Control, Record, Runner}; +use sqllogictest::{AsyncDB, Injected, Record, RecordOutput, Runner}; #[derive(Copy, Clone, Debug, PartialEq, Eq, ArgEnum)] #[must_use] @@ -84,6 +85,13 @@ struct Opt { /// The database password. #[clap(short = 'w', long, default_value = "postgres")] pass: String, + + /// Overrides the test files with the actual output of the database. + #[clap(long)] + r#override: bool, + /// Reformats the test files. + #[clap(long)] + format: bool, } /// Connection configuration. @@ -123,6 +131,8 @@ pub async fn main_okk() -> Result<()> { db, user, pass, + r#override, + format, } = Opt::parse(); if host.len() != port.len() { @@ -171,6 +181,10 @@ pub async fn main_okk() -> Result<()> { pass, }; + if r#override || format { + return update_test_files(files, &engine, config, format).await; + } + let mut report = Report::new(junit.clone().unwrap_or_else(|| "sqllogictest".to_string())); report.set_timestamp(Local::now()); @@ -348,6 +362,28 @@ async fn run_serial( } } +/// * `format` - If true, will not run sqls, only formats the file. +async fn update_test_files( + files: Vec<PathBuf>, + engine: &EngineConfig, + config: DBConfig, + format: bool, +) -> Result<()> { + for file in files { + let engine = engines::connect(engine, &config).await?; + let runner = Runner::new(engine); + + if let Err(e) = update_test_file(&mut std::io::stdout(), runner, &file, format).await { + { + println!("{}\n\n{:?}", style("[FAILED]").red().bold(), e); + println!(); + } + }; + } + + Ok(()) +} + async fn flush(out: &mut impl std::io::Write) -> std::io::Result<()> { tokio::task::block_in_place(|| out.flush()) } @@ -384,38 +420,9 @@ async fn run_test_file<T: std::io::Write, D: AsyncDB>( begin_times.push(Instant::now()); - let finish = |out: &mut T, time_stack: &mut Vec<Instant>, did_pop: &mut bool, file: &str| { - let begin_time = time_stack.pop().unwrap(); - - if *did_pop { - // start a new line if the result is not immediately after the item - write!( - out, - "\n{}{} {: <54} .. {} in {} ms", - "| ".repeat(time_stack.len()), - style("[END]").blue().bold(), - file, - style("[OK]").green().bold(), - begin_time.elapsed().as_millis() - )?; - } else { - // otherwise, append time to the previous line - write!( - out, - "{} in {} ms", - style("[OK]").green().bold(), - begin_time.elapsed().as_millis() - )?; - } - - *did_pop = true; - - Ok::<_, anyhow::Error>(()) - }; - for record in records { match &record { - Record::Control(Control::BeginInclude(file)) => { + Record::Injected(Injected::BeginInclude(file)) => { begin_times.push(Instant::now()); if !did_pop { writeln!(out, "{}", style("[BEGIN]").blue().bold())?; @@ -431,11 +438,12 @@ async fn run_test_file<T: std::io::Write, D: AsyncDB>( )?; flush(out).await?; } - Record::Control(Control::EndInclude(file)) => { - finish(out, &mut begin_times, &mut did_pop, file)?; + Record::Injected(Injected::EndInclude(file)) => { + finish_test_file(out, &mut begin_times, &mut did_pop, file)?; } _ => {} } + runner .run_async(record) .await @@ -448,7 +456,7 @@ async fn run_test_file<T: std::io::Write, D: AsyncDB>( let duration = begin_times[0].elapsed(); - finish( + finish_test_file( out, &mut begin_times, &mut did_pop, @@ -459,3 +467,297 @@ async fn run_test_file<T: std::io::Write, D: AsyncDB>( Ok(duration) } + +fn finish_test_file<T: std::io::Write>( + out: &mut T, + time_stack: &mut Vec<Instant>, + did_pop: &mut bool, + file: &str, +) -> Result<()> { + let begin_time = time_stack.pop().unwrap(); + + if *did_pop { + // start a new line if the result is not immediately after the item + write!( + out, + "\n{}{} {: <54} .. {} in {} ms", + "| ".repeat(time_stack.len()), + style("[END]").blue().bold(), + file, + style("[OK]").green().bold(), + begin_time.elapsed().as_millis() + )?; + } else { + // otherwise, append time to the previous line + write!( + out, + "{} in {} ms", + style("[OK]").green().bold(), + begin_time.elapsed().as_millis() + )?; + } + + *did_pop = true; + + Ok::<_, anyhow::Error>(()) +} + +async fn update_test_file<T: std::io::Write, D: AsyncDB>( + out: &mut T, + mut runner: Runner<D>, + filename: impl AsRef<Path>, + format: bool, +) -> Result<()> { + let filename = filename.as_ref(); + let records = tokio::task::block_in_place(|| { + sqllogictest::parse_file(filename).map_err(|e| anyhow!("{:?}", e)) + }) + .context("failed to parse sqllogictest file")?; + + let mut begin_times = vec![]; + let mut did_pop = false; + + write!(out, "{: <60} .. ", filename.to_string_lossy())?; + flush(out).await?; + + begin_times.push(Instant::now()); + + fn create_outfile(filename: impl AsRef<Path>) -> std::io::Result<(PathBuf, File)> { + let filename = filename.as_ref(); + let outfilename = filename.file_name().unwrap().to_str().unwrap().to_owned() + ".temp"; + let outfilename = filename.parent().unwrap().join(&outfilename); + let outfile = File::create(&outfilename)?; + Ok((outfilename, outfile)) + } + + struct Item { + filename: String, + outfilename: PathBuf, + outfile: File, + first_record: bool, + } + let (outfilename, outfile) = create_outfile(filename)?; + let mut stack = vec![Item { + filename: filename.to_string_lossy().to_string(), + outfilename, + outfile, + first_record: true, + }]; + + for record in records { + let Item { + filename, + outfilename, + outfile, + first_record, + } = stack.last_mut().unwrap(); + + match &record { + Record::Injected(Injected::BeginInclude(filename)) => { + let (outfilename, outfile) = create_outfile(filename)?; + stack.push(Item { + filename: filename.clone(), + outfilename, + outfile, + first_record: true, + }); + + begin_times.push(Instant::now()); + if !did_pop { + writeln!(out, "{}", style("[BEGIN]").blue().bold())?; + } else { + writeln!(out)?; + } + did_pop = false; + write!( + out, + "{}{: <60} .. ", + "| ".repeat(begin_times.len() - 1), + filename + )?; + flush(out).await?; + } + Record::Injected(Injected::EndInclude(file)) => { + // override the original file with the updated one + std::fs::rename(outfilename, filename)?; + stack.pop(); + finish_test_file(out, &mut begin_times, &mut did_pop, file)?; + } + _ => { + if !*first_record { + writeln!(outfile)?; + } + + update_record(outfile, &mut runner, record, format) + .await + .context(format!("failed to run `{}`", style(filename).bold()))?; + *first_record = false; + } + } + } + + finish_test_file( + out, + &mut begin_times, + &mut did_pop, + &filename.to_string_lossy(), + )?; + + writeln!(out)?; + + // override the original file with the updated one + let Item { + filename, + outfilename, + outfile: _, + first_record: _, + } = stack.last().unwrap(); + std::fs::rename(outfilename, filename)?; + + Ok(()) +} + +async fn update_record<D: AsyncDB>( + outfile: &mut File, + runner: &mut Runner<D>, + record: Record, + format: bool, +) -> Result<()> { + assert!(!matches!(record, Record::Injected(_))); + + if format { + record.unparse(outfile)?; + writeln!(outfile)?; + return Ok(()); + } + + match (record.clone(), runner.apply_record(record).await) { + (record, RecordOutput::Nothing) => { + record.unparse(outfile)?; + writeln!(outfile)?; + } + ( + Record::Statement { sql, .. }, + RecordOutput::Query { + types, + rows, + error: None, + }, + ) => { + writeln!( + outfile, + "query {}", + types.iter().map(|c| format!("{c}")).join("") + )?; + writeln!(outfile, "{}", sql)?; + writeln!(outfile, "----")?; + for result in rows { + writeln!(outfile, "{}", result.iter().format("\t"))?; + } + } + (Record::Query { sql, .. }, RecordOutput::Statement { error: None, .. }) => { + writeln!(outfile, "statement ok")?; + writeln!(outfile, "{}", sql)?; + } + ( + Record::Statement { + loc: _, + conditions: _, + expected_error, + sql, + expected_count, + }, + RecordOutput::Statement { count, error }, + ) => match (error, expected_error) { + (None, _) => { + if expected_count.is_some() { + writeln!(outfile, "statement count {count}")?; + writeln!(outfile, "{}", sql)?; + } else { + writeln!(outfile, "statement ok")?; + writeln!(outfile, "{}", sql)?; + } + } + (Some(e), Some(expected_error)) if expected_error.is_match(&e.to_string()) => { + writeln!(outfile, "statement error {}", expected_error)?; + writeln!(outfile, "{}", sql)?; + } + (Some(e), _) => { + writeln!(outfile, "statement error {}", e)?; + writeln!(outfile, "{}", sql)?; + } + }, + ( + Record::Query { + loc: _, + conditions: _, + type_string, + sort_mode, + label, + expected_error, + sql, + expected_results, + }, + RecordOutput::Query { + types: _, + rows, + error, + }, + ) => { + match (error, expected_error) { + (None, _) => {} + (Some(e), Some(expected_error)) if expected_error.is_match(&e.to_string()) => { + writeln!(outfile, "query error {}", expected_error)?; + writeln!(outfile, "{}", sql)?; + } + (Some(e), _) => { + writeln!(outfile, "query error {}", e)?; + writeln!(outfile, "{}", sql)?; + } + }; + + // FIXME: use output's types instead of orignal query's types + write!( + outfile, + "query {}", + type_string.iter().map(|c| format!("{c}")).join("") + )?; + if let Some(sort_mode) = sort_mode { + write!(outfile, " {}", sort_mode.as_str())?; + } + if let Some(label) = label { + write!(outfile, " {}", label)?; + } + writeln!(outfile)?; + writeln!(outfile, "{}", sql)?; + + #[allow(clippy::ptr_arg)] + fn normalize_string(s: &String) -> String { + s.trim().split_ascii_whitespace().join(" ") + } + + let normalized_rows = rows + .iter() + .map(|strs| strs.iter().map(normalize_string).join(" ")) + .collect_vec(); + + let normalized_expected = expected_results.iter().map(normalize_string).collect_vec(); + + writeln!(outfile, "----")?; + + if normalized_expected == normalized_rows { + // If the results are correct, do not format them. + for result in expected_results { + writeln!(outfile, "{}", result)?; + } + } else { + for result in rows { + writeln!(outfile, "{}", result.iter().format("\t"))?; + } + }; + } + _ => unreachable!(), + } + + Ok(()) +} diff --git a/sqllogictest/src/parser.rs b/sqllogictest/src/parser.rs index 980f33b..6a25f26 100644 --- a/sqllogictest/src/parser.rs +++ b/sqllogictest/src/parser.rs @@ -70,7 +70,10 @@ impl Location { #[non_exhaustive] pub enum Record { /// An include copies all records from another files. - Include { loc: Location, filename: String }, + Include { + loc: Location, + filename: String, + }, /// A statement is an SQL command that is to be evaluated but from which we do not expect to /// get results (other than success or failure). Statement { @@ -101,12 +104,20 @@ pub enum Record { expected_results: Vec<String>, }, /// A sleep period. - Sleep { loc: Location, duration: Duration }, + Sleep { + loc: Location, + duration: Duration, + }, /// Subtest. - Subtest { loc: Location, name: String }, + Subtest { + loc: Location, + name: String, + }, /// A halt record merely causes sqllogictest to ignore the rest of the test script. /// For debugging use only. - Halt { loc: Location }, + Halt { + loc: Location, + }, /// Control statements. Control(Control), /// Set the maximum number of result values that will be accepted @@ -115,13 +126,129 @@ pub enum Record { /// is the only result. /// /// If the threshold is 0, then hashing is never used. - HashThreshold { loc: Location, threshold: u64 }, + HashThreshold { + loc: Location, + threshold: u64, + }, + Condition(Condition), + Comment(Vec<String>), + /// Internally injected record which should not occur in the test file. + Injected(Injected), +} + +impl Record { + /// Unparses the record to its string representation in the test file. + /// + /// # Panics + /// If the record is an internally injected record which should not occur in the test file. + pub fn unparse(&self, w: &mut impl std::io::Write) -> std::io::Result<()> { + match self { + Record::Include { loc: _, filename } => { + write!(w, "include {}", filename) + } + Record::Statement { + loc: _, + conditions: _, + expected_error, + sql, + expected_count, + } => { + write!(w, "statement ")?; + match (expected_count, expected_error) { + (None, None) => write!(w, "ok")?, + (None, Some(err)) => { + if err.as_str().is_empty() { + write!(w, "error")?; + } else { + write!(w, "error {}", err)?; + } + } + (Some(cnt), None) => write!(w, "count {}", cnt)?, + (Some(_), Some(_)) => unreachable!(), + } + writeln!(w)?; + write!(w, "{}", sql) + } + Record::Query { + loc: _, + conditions: _, + type_string, + sort_mode, + label, + expected_error, + sql, + expected_results, + } => { + write!(w, "query")?; + if let Some(err) = expected_error { + writeln!(w, " error {}", err)?; + return write!(w, "{}", sql); + } + + write!( + w, + " {}", + type_string.iter().map(|c| format!("{c}")).join("") + )?; + if let Some(sort_mode) = sort_mode { + write!(w, " {}", sort_mode.as_str())?; + } + if let Some(label) = label { + write!(w, " {}", label)?; + } + writeln!(w)?; + writeln!(w, "{}", sql)?; + + write!(w, "----")?; + for result in expected_results { + write!(w, "\n{}", result)?; + } + Ok(()) + } + Record::Sleep { loc: _, duration } => { + write!(w, "sleep {}", humantime::format_duration(*duration)) + } + Record::Subtest { loc: _, name } => { + write!(w, "subtest {}", name) + } + Record::Halt { loc: _ } => { + write!(w, "halt") + } + Record::Control(c) => match c { + Control::SortMode(m) => write!(w, "control sortmode {}", m.as_str()), + }, + Record::Condition(cond) => match cond { + Condition::OnlyIf { engine_name } => { + write!(w, "onlyif {}", engine_name) + } + Condition::SkipIf { engine_name } => { + write!(w, "skipif {}", engine_name) + } + }, + Record::HashThreshold { loc: _, threshold } => { + write!(w, "hash-threshold {}", threshold) + } + Record::Comment(comment) => { + let mut iter = comment.iter(); + write!(w, "#{}", iter.next().unwrap().trim_end())?; + for line in iter { + write!(w, "\n#{}", line.trim_end())?; + } + Ok(()) + } + Record::Injected(p) => panic!("unexpected injected record: {:?}", p), + } + } } #[derive(Debug, PartialEq, Eq, Clone)] pub enum Control { /// Control sort mode. SortMode(SortMode), +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum Injected { /// Pseudo control command to indicate the begin of an include statement. Automatically /// injected by sqllogictest parser. BeginInclude(String), @@ -246,12 +373,30 @@ fn parse_inner(loc: &Location, script: &str) -> Result<Vec<Record>, ParseError> let mut lines = script.split('\n').enumerate(); let mut records = vec![]; let mut conditions = vec![]; - while let Some((num, line)) = lines.next() { - if line.is_empty() || line.starts_with('#') { + + while let Some((mut num, mut line)) = lines.next() { + if let Some(text) = line.strip_prefix('#') { + let mut comments = vec![text.to_string()]; + for (num_, line_) in lines.by_ref() { + num = num_; + line = line_; + if let Some(text) = line.strip_prefix('#') { + comments.push(text.to_string()); + } else { + break; + } + } + + records.push(Record::Comment(comments)); + } + + if line.is_empty() { continue; } + let mut loc = loc.clone(); loc.line = num as u32 + 1; + let tokens: Vec<&str> = line.split_whitespace().collect(); match tokens.as_slice() { [] => continue, @@ -261,7 +406,6 @@ fn parse_inner(loc: &Location, script: &str) -> Result<Vec<Record>, ParseError> }), ["halt"] => { records.push(Record::Halt { loc }); - break; } ["subtest", name] => { records.push(Record::Subtest { @@ -278,14 +422,18 @@ fn parse_inner(loc: &Location, script: &str) -> Result<Vec<Record>, ParseError> }); } ["skipif", engine_name] => { - conditions.push(Condition::SkipIf { + let cond = Condition::SkipIf { engine_name: engine_name.to_string(), - }); + }; + conditions.push(cond.clone()); + records.push(Record::Condition(cond)); } ["onlyif", engine_name] => { - conditions.push(Condition::OnlyIf { + let cond = Condition::OnlyIf { engine_name: engine_name.to_string(), - }); + }; + conditions.push(cond.clone()); + records.push(Record::Condition(cond)); } ["statement", res @ ..] => { let mut expected_count = None; @@ -412,7 +560,7 @@ fn parse_inner(loc: &Location, script: &str) -> Result<Vec<Record>, ParseError> Ok(records) } -/// Parse a sqllogictest file and link all included scripts together. +/// Parse a sqllogictest file. The included scripts are inserted after the `include` record. pub fn parse_file(filename: impl AsRef<Path>) -> Result<Vec<Record>, ParseError> { let filename = filename.as_ref().to_str().unwrap(); parse_file_inner(Location::new(filename, 0)) @@ -426,6 +574,8 @@ fn parse_file_inner(loc: Location) -> Result<Vec<Record>, ParseError> { let script = std::fs::read_to_string(path).unwrap(); let mut records = vec![]; for rec in parse_inner(&loc, &script)? { + records.push(rec.clone()); + if let Record::Include { filename, loc } = rec { let complete_filename = { let mut path_buf = path.to_path_buf(); @@ -440,14 +590,12 @@ fn parse_file_inner(loc: Location) -> Result<Vec<Record>, ParseError> { { let included_file = included_file.as_os_str().to_string_lossy().to_string(); - records.push(Record::Control(Control::BeginInclude( + records.push(Record::Injected(Injected::BeginInclude( included_file.clone(), ))); records.extend(parse_file_inner(loc.include(&included_file))?); - records.push(Record::Control(Control::EndInclude(included_file))); + records.push(Record::Injected(Injected::EndInclude(included_file))); } - } else { - records.push(rec); } } Ok(records) @@ -460,6 +608,6 @@ mod tests { #[test] fn test_include_glob() { let records = parse_file("../examples/include/include_1.slt").unwrap(); - assert_eq!(12, records.len()); + assert_eq!(14, records.len()); } } diff --git a/sqllogictest/src/runner.rs b/sqllogictest/src/runner.rs index 2934de7..a0c2461 100644 --- a/sqllogictest/src/runner.rs +++ b/sqllogictest/src/runner.rs @@ -24,6 +24,19 @@ pub enum ColumnType { FloatingPoint, /// Do not check the type of the column. Any, + Unknown(char), +} + +impl Display for ColumnType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ColumnType::Text => write!(f, "T"), + ColumnType::Integer => write!(f, "I"), + ColumnType::FloatingPoint => write!(f, "R"), + ColumnType::Any => write!(f, "?"), + ColumnType::Unknown(c) => write!(f, "{}", c), + } + } } impl TryFrom<char> for ColumnType { @@ -33,14 +46,28 @@ impl TryFrom<char> for ColumnType { match c { 'T' => Ok(Self::Text), 'I' => Ok(Self::Integer), - 'F' => Ok(Self::FloatingPoint), + 'R' => Ok(Self::FloatingPoint), + '?' => Ok(Self::Any), // FIXME: // _ => Err(ParseErrorKind::InvalidType(c)), - _ => Ok(Self::Any), + _ => Ok(Self::Unknown(c)), } } } +pub enum RecordOutput { + Nothing, + Query { + types: Vec<ColumnType>, + rows: Vec<Vec<String>>, + error: Option<Arc<dyn std::error::Error + Send + Sync>>, + }, + Statement { + count: u64, + error: Option<Arc<dyn std::error::Error + Send + Sync>>, + }, +} + #[non_exhaustive] pub enum DBOutput { Rows { @@ -367,16 +394,6 @@ fn format_diff( /// By default, we will use `|x, y| x == y`. pub type Validator = fn(&Vec<String>, &Vec<String>) -> bool; -/// A collection of hook functions. -#[async_trait] -pub trait Hook: Send { - /// Called after each statement completes. - async fn on_stmt_complete(&mut self, _sql: &str) {} - - /// Called after each query completes. - async fn on_query_complete(&mut self, _sql: &str) {} -} - /// Sqllogictest runner. pub struct Runner<D: AsyncDB> { db: D, @@ -384,7 +401,6 @@ pub struct Runner<D: AsyncDB> { validator: Validator, testdir: Option<TempDir>, sort_mode: Option<SortMode>, - hook: Option<Box<dyn Hook>>, /// 0 means never hashing hash_threshold: usize, } @@ -397,7 +413,6 @@ impl<D: AsyncDB> Runner<D> { validator: |x, y| x == y, testdir: None, sort_mode: None, - hook: None, hash_threshold: 0, } } @@ -414,102 +429,224 @@ impl<D: AsyncDB> Runner<D> { self.validator = validator; } - /// Run a single record. - pub async fn run_async(&mut self, record: Record) -> Result<(), TestError> { - tracing::info!(?record, "testing"); + pub async fn apply_record(&mut self, record: Record) -> RecordOutput { match record { - Record::Statement { conditions, .. } if self.should_skip(&conditions) => {} + Record::Statement { conditions, .. } if self.should_skip(&conditions) => { + RecordOutput::Nothing + } Record::Statement { conditions: _, - - expected_error, sql, - loc, - expected_count, + + // compare result in run_async + expected_error: _, + expected_count: _, + loc: _, } => { let sql = self.replace_keywords(sql); let ret = self.db.run(&sql).await; - match (ret, expected_error) { - (Ok(_), Some(_)) => { - return Err(TestErrorKind::Ok { - sql, - kind: RecordKind::Statement, + match ret { + Ok(out) => match out { + DBOutput::Rows { types, rows } => RecordOutput::Query { + types, + rows, + error: None, + }, + DBOutput::StatementComplete(count) => { + RecordOutput::Statement { count, error: None } + } + }, + Err(e) => RecordOutput::Statement { + count: 0, + error: Some(Arc::new(e)), + }, + } + } + Record::Query { conditions, .. } if self.should_skip(&conditions) => { + RecordOutput::Nothing + } + Record::Query { + conditions: _, + sql, + sort_mode, + + // compare result in run_async + type_string: _, + expected_error: _, + expected_results: _, + loc: _, + + // not handle yet, + label: _, + } => { + let sql = self.replace_keywords(sql); + let (types, mut rows) = match self.db.run(&sql).await { + Ok(out) => match out { + DBOutput::Rows { types, rows } => (types, rows), + DBOutput::StatementComplete(count) => { + return RecordOutput::Statement { count, error: None } + } + }, + Err(e) => { + return RecordOutput::Query { + error: Some(Arc::new(e)), + types: vec![], + rows: vec![], } - .at(loc)) } - (Ok(result), None) => { - if let Some(expected_count) = expected_count { - let count = match result { - DBOutput::Rows { types: _, rows } => { - return Err(TestErrorKind::StatementResultMismatch { - sql, - expected: expected_count, - actual: format!("got rows {:?}", rows), - } - .at(loc)); - } - DBOutput::StatementComplete(count) => count, - }; - - if expected_count != count { - return Err(TestErrorKind::StatementResultMismatch { - sql, - expected: expected_count, - actual: format!("affected {count} rows"), - } - .at(loc)); - } + }; + + match sort_mode.as_ref().or(self.sort_mode.as_ref()) { + None | Some(SortMode::NoSort) => {} + Some(SortMode::RowSort) => { + rows.sort_unstable(); + } + Some(SortMode::ValueSort) => todo!("value sort"), + }; + + if self.hash_threshold > 0 && rows.len() > self.hash_threshold { + let mut md5 = md5::Context::new(); + for line in &rows { + for value in line { + md5.consume(value.as_bytes()); + md5.consume(b"\n"); } } - (Err(e), Some(expected_error)) => { - if !expected_error.is_match(&e.to_string()) { - return Err(TestErrorKind::ErrorMismatch { + let hash = md5.compute(); + rows = vec![vec![format!( + "{} values hashing to {:?}", + rows.len() * rows[0].len(), + hash + )]]; + } + + RecordOutput::Query { + error: None, + types, + rows, + } + } + Record::Sleep { duration, .. } => { + D::sleep(duration).await; + RecordOutput::Nothing + } + Record::Control(control) => match control { + Control::SortMode(sort_mode) => { + self.sort_mode = Some(sort_mode); + RecordOutput::Nothing + } + }, + Record::HashThreshold { loc: _, threshold } => { + self.hash_threshold = threshold as usize; + RecordOutput::Nothing + } + Record::Include { .. } + | Record::Comment(_) + | Record::Subtest { .. } + | Record::Halt { .. } + | Record::Injected(_) + | Record::Condition(_) => RecordOutput::Nothing, + } + } + + /// Run a single record. + pub async fn run_async(&mut self, record: Record) -> Result<(), TestError> { + tracing::info!(?record, "testing"); + + match (record.clone(), self.apply_record(record).await) { + (_, RecordOutput::Nothing) => {} + // Tolerate the mismatched return type... + (Record::Statement { .. }, RecordOutput::Query { error: None, .. }) => {} + ( + Record::Query { + expected_results, + loc, + sql, + .. + }, + RecordOutput::Statement { error: None, .. }, + ) => { + if !expected_results.is_empty() { + return Err(TestErrorKind::QueryResultMismatch { + sql, + expected: expected_results.join("\n"), + actual: "".to_string(), + } + .at(loc)); + } + } + ( + Record::Statement { + loc, + conditions: _, + expected_error, + sql, + expected_count, + }, + RecordOutput::Statement { count, error }, + ) => match (error, expected_error) { + (None, Some(_)) => { + return Err(TestErrorKind::Ok { + sql, + kind: RecordKind::Statement, + } + .at(loc)) + } + (None, None) => { + if let Some(expected_count) = expected_count { + if expected_count != count { + return Err(TestErrorKind::StatementResultMismatch { sql, - err: Arc::new(e), - expected_err: expected_error.to_string(), - kind: RecordKind::Statement, + expected: expected_count, + actual: format!("affected {count} rows"), } .at(loc)); } } - (Err(e), None) => { - return Err(TestErrorKind::Fail { + } + (Some(e), Some(expected_error)) => { + if !expected_error.is_match(&e.to_string()) { + return Err(TestErrorKind::ErrorMismatch { sql, err: Arc::new(e), + expected_err: expected_error.to_string(), kind: RecordKind::Statement, } .at(loc)); } } - if let Some(hook) = &mut self.hook { - hook.on_stmt_complete(&sql).await; + (Some(e), None) => { + return Err(TestErrorKind::Fail { + sql, + err: Arc::new(e), + kind: RecordKind::Statement, + } + .at(loc)); } - } - Record::Query { conditions, .. } if self.should_skip(&conditions) => {} - Record::Query { - conditions: _, - - loc, - sql, - expected_error, - expected_results, - sort_mode, - type_string, - - // not handle yet, - label: _, - } => { - let sql = self.replace_keywords(sql); - let output = match (self.db.run(&sql).await, expected_error) { - (Ok(_), Some(_)) => { + }, + ( + Record::Query { + loc, + conditions: _, + type_string, + sort_mode: _, + label: _, + expected_error, + sql, + expected_results, + }, + RecordOutput::Query { types, rows, error }, + ) => { + match (error, expected_error) { + (None, Some(_)) => { return Err(TestErrorKind::Ok { sql, kind: RecordKind::Query, } .at(loc)) } - (Ok(output), None) => output, - (Err(e), Some(expected_error)) => { + (None, None) => {} + (Some(e), Some(expected_error)) => { if !expected_error.is_match(&e.to_string()) { return Err(TestErrorKind::ErrorMismatch { sql, @@ -521,7 +658,7 @@ impl<D: AsyncDB> Runner<D> { } return Ok(()); } - (Err(e), None) => { + (Some(e), None) => { return Err(TestErrorKind::Fail { sql, err: Arc::new(e), @@ -531,18 +668,6 @@ impl<D: AsyncDB> Runner<D> { } }; - let (types, mut output) = match output { - DBOutput::Rows { types, rows } => (types, rows), - DBOutput::StatementComplete(_) => { - return Err(TestErrorKind::QueryResultMismatch { - sql, - expected: expected_results.join("\n"), - actual: "statement complete".to_string(), - } - .at(loc)) - } - }; - // check number of columns if types.len() != type_string.len() { // FIXME: do not validate type-string now @@ -562,63 +687,25 @@ impl<D: AsyncDB> Runner<D> { } } - match sort_mode.as_ref().or(self.sort_mode.as_ref()) { - None | Some(SortMode::NoSort) => {} - Some(SortMode::RowSort) => { - output.sort_unstable(); - } - Some(SortMode::ValueSort) => todo!("value sort"), - }; - - if self.hash_threshold > 0 && output.len() > self.hash_threshold { - let mut md5 = md5::Context::new(); - for line in &output { - for value in line { - md5.consume(value.as_bytes()); - md5.consume(b"\n"); - } - } - let hash = md5.compute(); - output = vec![vec![format!( - "{} values hashing to {:?}", - output.len() * output[0].len(), - hash - )]]; - } - // We compare normalized results. Whitespace characters are ignored. - let output = output + let normalized_rows = rows .into_iter() .map(|strs| strs.iter().map(normalize_string).join(" ")) .collect_vec(); - let expected_results = expected_results.iter().map(normalize_string).collect_vec(); - if !(self.validator)(&output, &expected_results) { + let expected_results = expected_results.iter().map(normalize_string).collect_vec(); + if !(self.validator)(&normalized_rows, &expected_results) { return Err(TestErrorKind::QueryResultMismatch { sql, expected: expected_results.join("\n"), - actual: output.join("\n"), + actual: normalized_rows.join("\n"), } .at(loc)); } - if let Some(hook) = &mut self.hook { - hook.on_query_complete(&sql).await; - } } - Record::Sleep { duration, .. } => D::sleep(duration).await, - Record::Halt { .. } => {} - Record::Subtest { .. } => {} - Record::Include { loc, .. } => { - unreachable!("include should be rewritten during link: at {}", loc) - } - Record::Control(control) => match control { - Control::SortMode(sort_mode) => { - self.sort_mode = Some(sort_mode); - } - Control::BeginInclude(_) | Control::EndInclude(_) => {} - }, - Record::HashThreshold { loc: _, threshold } => self.hash_threshold = threshold as usize, + _ => unreachable!(), } + Ok(()) } @@ -755,11 +842,6 @@ impl<D: AsyncDB> Runner<D> { .iter() .any(|c| c.should_skip(self.db.engine_name())) } - - /// Set hook functions. - pub fn set_hook(&mut self, hook: impl Hook + 'static) { - self.hook = Some(Box::new(hook)); - } } /// Trim and replace multiple whitespaces with one.