@@ -98,6 +98,21 @@ static const char *db_sqlite3_fmt_error(struct db_stmt *stmt)
9898 sqlite3_errmsg (conn2sql (stmt -> db -> conn )));
9999}
100100
101+ static bool is_strict_constraint_error (struct db_stmt * stmt )
102+ {
103+ sqlite3 * sql = conn2sql (stmt -> db -> conn );
104+ const char * errmsg = sqlite3_errmsg (sql );
105+ int errcode = sqlite3_errcode (sql );
106+
107+ if (errcode != SQLITE_CONSTRAINT || !stmt -> db -> use_strict_tables )
108+ return false;
109+
110+ return (strstr (errmsg , "CHECK constraint failed" ) ||
111+ strstr (errmsg , "datatype mismatch" ) ||
112+ strstr (errmsg , "cannot store" ) ||
113+ strstr (errmsg , "NOT NULL constraint failed" ));
114+ }
115+
101116static bool db_sqlite3_setup (struct db * db , bool create )
102117{
103118 char * filename ;
@@ -205,16 +220,183 @@ static bool db_sqlite3_setup(struct db *db, bool create)
205220 "PRAGMA foreign_keys = ON;" , -1 , & stmt , NULL );
206221 err = sqlite3_step (stmt );
207222 sqlite3_finalize (stmt );
208- return err == SQLITE_DONE ;
223+
224+ if (err != SQLITE_DONE )
225+ return false;
226+
227+ bool is_testing = (getenv ("TEST_DB_PROVIDER" ) ||
228+ getenv ("PYTEST_PAR" ) ||
229+ getenv ("TEST_DEBUG" ) ||
230+ getenv ("VALGRIND" ));
231+
232+ /* SQLite 3.37.0 introduced STRICT table support */
233+ if ((db -> developer || is_testing ) && sqlite3_libversion_number () >= 3037000 )
234+ db -> use_strict_tables = true;
235+
236+ {
237+ static const char * security_pragmas [] = {
238+ "PRAGMA trusted_schema = OFF;" ,
239+ "PRAGMA cell_size_check = ON;" ,
240+ "PRAGMA secure_delete = ON;" ,
241+ NULL
242+ };
243+
244+ for (int i = 0 ; security_pragmas [i ]; i ++ ) {
245+ err = sqlite3_prepare_v2 (conn2sql (db -> conn ),
246+ security_pragmas [i ], -1 , & stmt , NULL );
247+ if (err == SQLITE_OK ) {
248+ err = sqlite3_step (stmt );
249+ sqlite3_finalize (stmt );
250+ }
251+ }
252+ }
253+
254+ return true;
255+ }
256+
257+ static bool is_standalone_type_keyword (const char * query , const char * pos ,
258+ const char * keyword , size_t keyword_len ,
259+ size_t query_len )
260+ {
261+ bool prefix_ok = (pos == query || (!isalnum (pos [-1 ]) && pos [-1 ] != '_' ));
262+ const char * after = pos + keyword_len ;
263+ bool suffix_ok = (after >= query + query_len ||
264+ (!isalnum (after [0 ]) && after [0 ] != '_' ));
265+
266+ return prefix_ok && suffix_ok ;
267+ }
268+
269+ static char * normalize_varchar_to_text (const tal_t * ctx , const char * query )
270+ {
271+ char * result ;
272+ const char * src ;
273+ char * dst ;
274+ size_t query_len ;
275+
276+ if (!query )
277+ return NULL ;
278+
279+ query_len = strlen (query );
280+
281+ #define MAX_SQL_STATEMENT_LENGTH 1048576 /* 1MB limit */
282+ if (query_len > MAX_SQL_STATEMENT_LENGTH )
283+ return NULL ;
284+
285+ /* INT(3) -> INTEGER(7) worst case: +4 bytes per conversion */
286+ size_t max_expansions = (query_len / 3 ) * 4 ;
287+ size_t buffer_size = query_len + max_expansions + 64 ;
288+
289+ if (buffer_size < query_len )
290+ return NULL ;
291+
292+ result = tal_arr (ctx , char , buffer_size );
293+ src = query ;
294+ dst = result ;
295+
296+ while (* src ) {
297+ if (strncasecmp (src , "BIGSERIAL" , 9 ) == 0 &&
298+ is_standalone_type_keyword (query , src , "BIGSERIAL" , 9 , query_len )) {
299+ strcpy (dst , "INTEGER" );
300+ dst += 7 ;
301+ src += 9 ;
302+ } else if (strncasecmp (src , "VARCHAR" , 7 ) == 0 &&
303+ is_standalone_type_keyword (query , src , "VARCHAR" , 7 , query_len )) {
304+ strcpy (dst , "TEXT" );
305+ dst += 4 ;
306+ src += 7 ;
307+
308+ if (* src == '(' ) {
309+ const char * paren_start = src ;
310+ while (* src && * src != ')' ) {
311+ src ++ ;
312+ /* Prevent runaway on malformed SQL */
313+ if (src - paren_start > 1000 )
314+ return NULL ;
315+ }
316+ if (* src == ')' ) src ++ ;
317+ }
318+ } else if (strncasecmp (src , "BIGINT" , 6 ) == 0 &&
319+ is_standalone_type_keyword (query , src , "BIGINT" , 6 , query_len )) {
320+ strcpy (dst , "INTEGER" );
321+ dst += 7 ;
322+ src += 6 ;
323+ } else if (strncasecmp (src , "INT" , 3 ) == 0 &&
324+ is_standalone_type_keyword (query , src , "INT" , 3 , query_len )) {
325+ strcpy (dst , "INTEGER" );
326+ dst += 7 ;
327+ src += 3 ;
328+ } else {
329+ * dst ++ = * src ++ ;
330+ }
331+ }
332+
333+ * dst = '\0' ;
334+ return result ;
335+ }
336+
337+ static char * add_strict_to_create_table (const tal_t * ctx , const char * query )
338+ {
339+ char * semicolon_pos ;
340+ ptrdiff_t prefix_len ;
341+
342+ if (!strcasestr (query , "CREATE TABLE" ))
343+ return tal_strdup (ctx , query );
344+
345+ if (strcasestr (query , "STRICT" ))
346+ return tal_strdup (ctx , query );
347+
348+ semicolon_pos = strrchr (query , ';' );
349+ if (!semicolon_pos )
350+ semicolon_pos = (char * )query + strlen (query );
351+
352+ prefix_len = semicolon_pos - query ;
353+ return tal_fmt (ctx , "%.*s STRICT%s" , (int )prefix_len ,
354+ query , semicolon_pos );
355+ }
356+
357+ static char * prepare_query_for_execution (const tal_t * ctx , struct db * db ,
358+ const char * query )
359+ {
360+ char * normalized_query ;
361+
362+ normalized_query = normalize_varchar_to_text (ctx , query );
363+ if (!normalized_query )
364+ return NULL ;
365+
366+ if (db -> use_strict_tables )
367+ return add_strict_to_create_table (ctx , normalized_query );
368+ else
369+ return normalized_query ;
209370}
210371
211372static bool db_sqlite3_query (struct db_stmt * stmt )
212373{
213374 sqlite3_stmt * s ;
214375 sqlite3 * conn = conn2sql (stmt -> db -> conn );
215376 int err ;
377+ char * query_to_execute ;
216378
217- err = sqlite3_prepare_v2 (conn , stmt -> query -> query , -1 , & s , NULL );
379+ query_to_execute = prepare_query_for_execution (stmt , stmt -> db ,
380+ stmt -> query -> query );
381+ bool should_free_query = (query_to_execute != stmt -> query -> query );
382+
383+ err = sqlite3_prepare_v2 (conn , query_to_execute , -1 , & s , NULL );
384+
385+ if (err != SQLITE_OK ) {
386+ if (should_free_query )
387+ tal_free (query_to_execute );
388+ tal_free (stmt -> error );
389+ if (is_strict_constraint_error (stmt )) {
390+ stmt -> error = tal_fmt (stmt , "%s (Note: STRICT tables are enabled)" ,
391+ db_sqlite3_fmt_error (stmt ));
392+ } else {
393+ stmt -> error = db_sqlite3_fmt_error (stmt );
394+ }
395+ return false;
396+ }
397+
398+ if (should_free_query )
399+ tal_free (query_to_execute );
218400
219401 for (size_t i = 0 ; i < stmt -> query -> placeholders ; i ++ ) {
220402 struct db_binding * b = & stmt -> bindings [i ];
@@ -246,12 +428,6 @@ static bool db_sqlite3_query(struct db_stmt *stmt)
246428 }
247429 }
248430
249- if (err != SQLITE_OK ) {
250- tal_free (stmt -> error );
251- stmt -> error = db_sqlite3_fmt_error (stmt );
252- return false;
253- }
254-
255431 stmt -> inner_stmt = s ;
256432 return true;
257433}
@@ -270,7 +446,12 @@ static bool db_sqlite3_exec(struct db_stmt *stmt)
270446 err = sqlite3_step (stmt -> inner_stmt );
271447 if (err != SQLITE_DONE ) {
272448 tal_free (stmt -> error );
273- stmt -> error = db_sqlite3_fmt_error (stmt );
449+ if (is_strict_constraint_error (stmt )) {
450+ stmt -> error = tal_fmt (stmt , "%s (Note: STRICT tables are enabled)" ,
451+ db_sqlite3_fmt_error (stmt ));
452+ } else {
453+ stmt -> error = db_sqlite3_fmt_error (stmt );
454+ }
274455 return false;
275456 }
276457
0 commit comments