@@ -12,6 +12,7 @@ use futures::executor::block_on;
12
12
use futures:: { stream, Future , StreamExt } ;
13
13
use itertools:: Itertools ;
14
14
use owo_colors:: OwoColorize ;
15
+ use regex:: Regex ;
15
16
use tempfile:: { tempdir, TempDir } ;
16
17
17
18
use crate :: parser:: * ;
@@ -55,6 +56,7 @@ impl TryFrom<char> for ColumnType {
55
56
}
56
57
}
57
58
59
+ #[ derive( Debug , Clone ) ]
58
60
pub enum RecordOutput {
59
61
Nothing ,
60
62
Query {
@@ -391,8 +393,18 @@ fn format_diff(
391
393
///
392
394
/// # Default
393
395
///
394
- /// By default, we will use compare normalized results.
395
- pub type Validator = fn ( & Vec < Vec < String > > , & Vec < String > ) -> bool ;
396
+ /// By default ([`default_validator`]), we will use compare normalized results.
397
+ pub type Validator = fn ( actual : & [ Vec < String > ] , expected : & [ String ] ) -> bool ;
398
+
399
+ pub fn default_validator ( actual : & [ Vec < String > ] , expected : & [ String ] ) -> bool {
400
+ let expected_results = expected. iter ( ) . map ( normalize_string) . collect_vec ( ) ;
401
+ // Default, we compare normalized results. Whitespace characters are ignored.
402
+ let normalized_rows = actual
403
+ . iter ( )
404
+ . map ( |strs| strs. iter ( ) . map ( normalize_string) . join ( " " ) )
405
+ . collect_vec ( ) ;
406
+ normalized_rows == expected_results
407
+ }
396
408
397
409
/// Sqllogictest runner.
398
410
pub struct Runner < D : AsyncDB > {
@@ -410,15 +422,7 @@ impl<D: AsyncDB> Runner<D> {
410
422
pub fn new ( db : D ) -> Self {
411
423
Runner {
412
424
db,
413
- validator : |x, y| {
414
- let expected_results = y. iter ( ) . map ( normalize_string) . collect_vec ( ) ;
415
- // Default, we compare normalized results. Whitespace characters are ignored.
416
- let normalized_rows = x
417
- . iter ( )
418
- . map ( |strs| strs. iter ( ) . map ( normalize_string) . join ( " " ) )
419
- . collect_vec ( ) ;
420
- normalized_rows == expected_results
421
- } ,
425
+ validator : default_validator,
422
426
testdir : None ,
423
427
sort_mode : None ,
424
428
hash_threshold : 0 ,
@@ -859,3 +863,301 @@ impl<D: AsyncDB> Runner<D> {
859
863
fn normalize_string ( s : & String ) -> String {
860
864
s. trim ( ) . split_ascii_whitespace ( ) . join ( " " )
861
865
}
866
+
867
+ /// Updates the specified [`Record`] with the [`QueryOutput`] produced
868
+ /// by a Database, returning `Some(new_record)`.
869
+ ///
870
+ /// If an update is not supported, returns `None`
871
+ pub fn update_record_with_output (
872
+ record : & Record ,
873
+ record_output : & RecordOutput ,
874
+ col_separator : & str ,
875
+ validator : Validator ,
876
+ ) -> Option < Record > {
877
+ match ( record. clone ( ) , record_output) {
878
+ ( _, RecordOutput :: Nothing ) => None ,
879
+ // statement, query
880
+ (
881
+ Record :: Statement {
882
+ sql,
883
+ loc,
884
+ conditions,
885
+ expected_error : None ,
886
+ expected_count,
887
+ } ,
888
+ RecordOutput :: Query { error : None , .. } ,
889
+ ) => {
890
+ // statement ok
891
+ // SELECT ...
892
+ //
893
+ // This case can be used when we want to only ensure the query succeeds,
894
+ // but don't care about the output.
895
+ // DuckDB has a few of these.
896
+
897
+ Some ( Record :: Statement {
898
+ sql,
899
+ expected_error : None ,
900
+ loc,
901
+ conditions,
902
+ expected_count,
903
+ } )
904
+ }
905
+ // query, statement
906
+ (
907
+ Record :: Query {
908
+ sql,
909
+ loc,
910
+ conditions,
911
+ ..
912
+ } ,
913
+ RecordOutput :: Statement { error : None , .. } ,
914
+ ) => Some ( Record :: Statement {
915
+ sql,
916
+ expected_error : None ,
917
+ loc,
918
+ conditions,
919
+ expected_count : None ,
920
+ } ) ,
921
+ // statement, statement
922
+ (
923
+ Record :: Statement {
924
+ loc,
925
+ conditions,
926
+ expected_error,
927
+ sql,
928
+ expected_count,
929
+ } ,
930
+ RecordOutput :: Statement { count, error } ,
931
+ ) => match ( error, expected_error) {
932
+ // Ok
933
+ ( None , _) => Some ( Record :: Statement {
934
+ sql,
935
+ expected_error : None ,
936
+ loc,
937
+ conditions,
938
+ expected_count : expected_count. map ( |_| * count) ,
939
+ } ) ,
940
+ // Error match
941
+ ( Some ( e) , Some ( expected_error) ) if expected_error. is_match ( & e. to_string ( ) ) => {
942
+ Some ( Record :: Statement {
943
+ sql,
944
+ expected_error : Some ( expected_error) ,
945
+ loc,
946
+ conditions,
947
+ expected_count : None ,
948
+ } )
949
+ }
950
+ // Error mismatch
951
+ ( Some ( e) , _) => Some ( Record :: Statement {
952
+ sql,
953
+ expected_error : Some ( Regex :: new ( & e. to_string ( ) ) . unwrap ( ) ) ,
954
+ loc,
955
+ conditions,
956
+ expected_count : None ,
957
+ } ) ,
958
+ } ,
959
+ // query, query
960
+ (
961
+ Record :: Query {
962
+ loc,
963
+ conditions,
964
+ type_string,
965
+ sort_mode,
966
+ label,
967
+ expected_error,
968
+ sql,
969
+ expected_results,
970
+ } ,
971
+ RecordOutput :: Query {
972
+ // FIXME: maybe we should use output's types instead of orignal query's types
973
+ // Fix it after https://github.com/risinglightdb/sqllogictest-rs/issues/36 is resolved.
974
+ types : _,
975
+ rows,
976
+ error,
977
+ } ,
978
+ ) => {
979
+ match ( error, expected_error) {
980
+ ( None , _) => { }
981
+ // Error match
982
+ ( Some ( e) , Some ( expected_error) ) if expected_error. is_match ( & e. to_string ( ) ) => {
983
+ return Some ( Record :: Query {
984
+ sql,
985
+ expected_error : Some ( expected_error) ,
986
+ loc,
987
+ conditions,
988
+ type_string : vec ! [ ] ,
989
+ sort_mode,
990
+ label,
991
+ expected_results : vec ! [ ] ,
992
+ } )
993
+ }
994
+ // Error mismatch
995
+ ( Some ( e) , _) => {
996
+ return Some ( Record :: Query {
997
+ sql,
998
+ expected_error : Some ( Regex :: new ( & e. to_string ( ) ) . unwrap ( ) ) ,
999
+ loc,
1000
+ conditions,
1001
+ type_string : vec ! [ ] ,
1002
+ sort_mode,
1003
+ label,
1004
+ expected_results : vec ! [ ] ,
1005
+ } )
1006
+ }
1007
+ } ;
1008
+
1009
+ let results = if validator ( rows, & expected_results) {
1010
+ // If validation is successful, we respect the original file's expected results.
1011
+ expected_results
1012
+ } else {
1013
+ rows. iter ( ) . map ( |cols| cols. join ( col_separator) ) . collect ( )
1014
+ } ;
1015
+
1016
+ Some ( Record :: Query {
1017
+ sql,
1018
+ expected_error : None ,
1019
+ loc,
1020
+ conditions,
1021
+ type_string,
1022
+ sort_mode,
1023
+ label,
1024
+ expected_results : results,
1025
+ } )
1026
+ }
1027
+
1028
+ // No update possible, return the original record
1029
+ _ => None ,
1030
+ }
1031
+ }
1032
+
1033
+ #[ cfg( test) ]
1034
+ mod tests {
1035
+ use super :: * ;
1036
+
1037
+ #[ test]
1038
+ fn test_query_replacement ( ) {
1039
+ TestCase {
1040
+ // input should be ignored
1041
+ input : "query III\n \
1042
+ select * from foo;\n \
1043
+ ----\n \
1044
+ 1 2",
1045
+
1046
+ // Model a run that produced a 3,4 as output
1047
+ record_output : query_output ( & [ & [ "3" , "4" ] ] ) ,
1048
+
1049
+ expected : Some (
1050
+ "query III\n \
1051
+ select * from foo;\n \
1052
+ ----\n \
1053
+ 3 4",
1054
+ ) ,
1055
+ }
1056
+ . run ( )
1057
+ }
1058
+
1059
+ #[ test]
1060
+ fn test_query_replacement_no_input ( ) {
1061
+ TestCase {
1062
+ // input has no query results
1063
+ input : "query III\n \
1064
+ select * from foo;\n \
1065
+ ----",
1066
+
1067
+ // Model a run that produced a 3,4 as output
1068
+ record_output : query_output ( & [ & [ "3" , "4" ] ] ) ,
1069
+
1070
+ expected : Some (
1071
+ "query III\n \
1072
+ select * from foo;\n \
1073
+ ----\n \
1074
+ 3 4",
1075
+ ) ,
1076
+ }
1077
+ . run ( )
1078
+ }
1079
+
1080
+ #[ test]
1081
+ fn test_query_replacement_error ( ) {
1082
+ TestCase {
1083
+ // input has no query results
1084
+ input : "query III\n \
1085
+ select * from foo;\n \
1086
+ ----",
1087
+
1088
+ // Model a run that produced a "MyAwesomeDB Error"
1089
+ record_output : query_error ( "MyAwesomeDB Error" ) ,
1090
+
1091
+ expected : Some (
1092
+ "query error TestError: MyAwesomeDB Error\n \
1093
+ select * from foo;\n ",
1094
+ ) ,
1095
+ }
1096
+ . run ( )
1097
+ }
1098
+
1099
+ #[ derive( Debug ) ]
1100
+ struct TestCase {
1101
+ input : & ' static str ,
1102
+ record_output : RecordOutput ,
1103
+ expected : Option < & ' static str > ,
1104
+ }
1105
+
1106
+ impl TestCase {
1107
+ fn run ( self ) {
1108
+ let Self {
1109
+ input,
1110
+ record_output,
1111
+ expected,
1112
+ } = self ;
1113
+ println ! ( "TestCase" ) ;
1114
+ println ! ( "**input:\n {input}\n " ) ;
1115
+ println ! ( "**record_output:\n {record_output:#?}\n " ) ;
1116
+ println ! ( "**expected:\n {}\n " , expected. unwrap_or( "" ) ) ;
1117
+ let input = parse_to_record ( input) ;
1118
+ let expected = expected. map ( parse_to_record) ;
1119
+ let output = update_record_with_output ( & input, & record_output, " " , default_validator) ;
1120
+ assert_eq ! ( output, expected) ;
1121
+ }
1122
+ }
1123
+
1124
+ fn parse_to_record ( s : & str ) -> Record {
1125
+ let mut records = parse ( s) . unwrap ( ) ;
1126
+ assert_eq ! ( records. len( ) , 1 ) ;
1127
+ records. pop ( ) . unwrap ( )
1128
+ }
1129
+
1130
+ /// Returns a RecordOutput that models the successful execution of a query
1131
+ fn query_output ( rows : & [ & [ & str ] ] ) -> RecordOutput {
1132
+ let rows = rows
1133
+ . iter ( )
1134
+ . map ( |cols| cols. iter ( ) . map ( |c| c. to_string ( ) ) . collect :: < Vec < _ > > ( ) )
1135
+ . collect :: < Vec < _ > > ( ) ;
1136
+
1137
+ let types = rows. iter ( ) . map ( |_| ColumnType :: Any ) . collect ( ) ;
1138
+
1139
+ RecordOutput :: Query {
1140
+ types,
1141
+ rows,
1142
+ error : None ,
1143
+ }
1144
+ }
1145
+
1146
+ /// Returns a RecordOutput that models the error of a query
1147
+ fn query_error ( error_message : & str ) -> RecordOutput {
1148
+ RecordOutput :: Query {
1149
+ types : vec ! [ ] ,
1150
+ rows : vec ! [ ] ,
1151
+ error : Some ( Arc :: new ( TestError ( error_message. to_string ( ) ) ) ) ,
1152
+ }
1153
+ }
1154
+
1155
+ #[ derive( Debug ) ]
1156
+ struct TestError ( String ) ;
1157
+ impl std:: error:: Error for TestError { }
1158
+ impl std:: fmt:: Display for TestError {
1159
+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
1160
+ write ! ( f, "TestError: {}" , self . 0 )
1161
+ }
1162
+ }
1163
+ }
0 commit comments