@@ -451,72 +451,94 @@ func StdCreateSchema(ctx context.Context, db *sql.DB, dialect Dialect, schemaNam
451
451
return err
452
452
}
453
453
454
- type ColumnMigrationStep func (dialect Dialect , table Table , migration ColumnTypeMigration , tempColumnIdentifier string ) (string , error )
454
+ type MigrationInstruction struct {
455
+ TypeMigration ColumnTypeMigration
456
+ TempColumnIdentifier string
457
+ }
458
+
459
+ type ColumnMigrationStep func (dialect Dialect , table Table , instructions ... MigrationInstruction ) ([]string , error )
455
460
456
461
var StdMigrationSteps = []ColumnMigrationStep {
457
- func (dialect Dialect , table Table , migration ColumnTypeMigration , tempColumnIdentifier string ) (string , error ) {
458
- return fmt .Sprintf ("ALTER TABLE %s ADD COLUMN %s %s;" ,
459
- table .Identifier ,
460
- tempColumnIdentifier ,
461
- // Always create these new columns as nullable
462
- migration .NullableDDL ,
463
- ), nil
462
+ func (dialect Dialect , table Table , instructions ... MigrationInstruction ) ([]string , error ) {
463
+ var queries []string
464
+ for _ , ins := range instructions {
465
+ queries = append (
466
+ queries ,
467
+ fmt .Sprintf ("ALTER TABLE %s ADD COLUMN %s %s " ,
468
+ table .Identifier ,
469
+ ins .TempColumnIdentifier ,
470
+ ins .TypeMigration .NullableDDL ,
471
+ ),
472
+ )
473
+ }
474
+
475
+ return queries , nil
464
476
},
465
- func (dialect Dialect , table Table , migration ColumnTypeMigration , tempColumnIdentifier string ) (string , error ) {
466
- return fmt .Sprintf (
467
- // The WHERE filter is required by some warehouses (bigquery)
468
- "UPDATE %s SET %s = %s WHERE true;" ,
469
- table .Identifier ,
470
- tempColumnIdentifier ,
471
- migration .CastSQL (migration ),
472
- ), nil
477
+ func (dialect Dialect , table Table , instructions ... MigrationInstruction ) ([]string , error ) {
478
+ var query strings.Builder
479
+ query .WriteString (fmt .Sprintf ("UPDATE %s SET " , table .Identifier ))
480
+
481
+ for i , ins := range instructions {
482
+ if i > 0 {
483
+ query .WriteString (", " )
484
+ }
485
+ query .WriteString (fmt .Sprintf ("%s = %s" , ins .TempColumnIdentifier , ins .TypeMigration .CastSQL (ins .TypeMigration )))
486
+ }
487
+
488
+ // The WHERE filter is required by some warehouses (bigquery)
489
+ query .WriteString (" WHERE true;" )
490
+
491
+ return []string {query .String ()}, nil
473
492
},
474
- func (dialect Dialect , table Table , migration ColumnTypeMigration , _ string ) (string , error ) {
475
- return fmt .Sprintf (
476
- "ALTER TABLE %s DROP COLUMN %s;" ,
477
- table .Identifier ,
478
- migration .Identifier ,
479
- ), nil
493
+ func (dialect Dialect , table Table , instructions ... MigrationInstruction ) ([]string , error ) {
494
+ var queries []string
495
+ for _ , ins := range instructions {
496
+ queries = append (
497
+ queries ,
498
+ fmt .Sprintf ("ALTER TABLE %s DROP COLUMN %s" ,
499
+ table .Identifier ,
500
+ ins .TypeMigration .Identifier ,
501
+ ),
502
+ )
503
+ }
504
+
505
+ return queries , nil
480
506
},
481
- func (dialect Dialect , table Table , migration ColumnTypeMigration , tempColumnIdentifier string ) (string , error ) {
482
- return fmt .Sprintf (
483
- "ALTER TABLE %s RENAME COLUMN %s TO %s;" ,
484
- table .Identifier ,
485
- tempColumnIdentifier ,
486
- migration .Identifier ,
487
- ), nil
507
+ func (dialect Dialect , table Table , instructions ... MigrationInstruction ) ([]string , error ) {
508
+ var queries []string
509
+ for _ , ins := range instructions {
510
+ queries = append (queries ,
511
+ fmt .Sprintf ("ALTER TABLE %s RENAME COLUMN %s TO %s" ,
512
+ table .Identifier ,
513
+ ins .TempColumnIdentifier ,
514
+ ins .TypeMigration .Identifier ,
515
+ ),
516
+ )
517
+ }
518
+
519
+ return queries , nil
488
520
},
489
- func (dialect Dialect , table Table , migration ColumnTypeMigration , _ string ) (string , error ) {
490
- // If column was originally not nullable, we fix its DDL
491
- if migration .NullableDDL == migration .DDL {
492
- return "" , nil
521
+ func (dialect Dialect , table Table , instructions ... MigrationInstruction ) ([]string , error ) {
522
+ var queries []string
523
+
524
+ for _ , ins := range instructions {
525
+ if ins .TypeMigration .NullableDDL != ins .TypeMigration .DDL {
526
+ queries = append (
527
+ queries ,
528
+ fmt .Sprintf ("ALTER TABLE %s ALTER COLUMN %s SET NOT NULL" ,
529
+ table .Identifier ,
530
+ ins .TypeMigration .Identifier ,
531
+ ),
532
+ )
533
+ }
493
534
}
494
535
495
- return fmt .Sprintf (
496
- "ALTER TABLE %s ALTER COLUMN %s SET NOT NULL;" ,
497
- table .Identifier ,
498
- migration .Identifier ,
499
- ), nil
536
+ return queries , nil
500
537
},
501
538
}
502
539
503
- func StdColumnTypeMigration (ctx context.Context , dialect Dialect , table Table , migration ColumnTypeMigration , steps ... ColumnMigrationStep ) ([]string , error ) {
504
- var step = 0
505
- if migration .ProgressColumnExists && migration .OriginalColumnExists {
506
- step = 1
507
- } else if migration .ProgressColumnExists && ! migration .OriginalColumnExists {
508
- step = 3
509
- }
510
-
511
- log .WithFields (log.Fields {
512
- "table" : table .Identifier ,
513
- "ddl" : migration .DDL ,
514
- "field" : migration .Field ,
515
- "originalColumnExists" : migration .OriginalColumnExists ,
516
- "progressColumnExists" : migration .ProgressColumnExists ,
517
- "step" : step ,
518
- }).Info ("rendered queries for column migration using renaming" )
519
-
540
+ func StdColumnTypeMigrations (ctx context.Context , dialect Dialect , table Table , migrations []ColumnTypeMigration , steps ... ColumnMigrationStep ) ([]string , error ) {
541
+ // Connectors can provide custom steps, if they don't, we default to std steps
520
542
if len (steps ) == 0 {
521
543
steps = StdMigrationSteps
522
544
}
@@ -525,19 +547,45 @@ func StdColumnTypeMigration(ctx context.Context, dialect Dialect, table Table, m
525
547
return nil , fmt .Errorf ("must have at least %d steps" , len (StdMigrationSteps ))
526
548
}
527
549
528
- var tempColumnIdentifier = dialect . Identifier ( migration . Field + ColumnMigrationTemporarySuffix )
550
+ var stepInstructions = make ( map [ int ][] MigrationInstruction )
529
551
530
- var renderedSteps []string
531
- for i , s := range steps [step :] {
532
- newStep , err := s (dialect , table , migration , tempColumnIdentifier )
533
- if err != nil {
534
- return nil , fmt .Errorf ("rendering step %d: %w" , i , err )
552
+ for _ , migration := range migrations {
553
+ var step = 0
554
+ if migration .ProgressColumnExists && migration .OriginalColumnExists {
555
+ step = 1
556
+ } else if migration .ProgressColumnExists && ! migration .OriginalColumnExists {
557
+ step = 3
535
558
}
536
- if newStep == "" {
537
- continue
559
+
560
+ log .WithFields (log.Fields {
561
+ "table" : table .Identifier ,
562
+ "ddl" : migration .DDL ,
563
+ "field" : migration .Field ,
564
+ "originalColumnExists" : migration .OriginalColumnExists ,
565
+ "progressColumnExists" : migration .ProgressColumnExists ,
566
+ "step" : step ,
567
+ }).Info ("rendering queries for column migration using renaming" )
568
+
569
+ var tempColumnIdentifier = dialect .Identifier (migration .Field + ColumnMigrationTemporarySuffix )
570
+
571
+ var instruction = MigrationInstruction {
572
+ TypeMigration : migration ,
573
+ TempColumnIdentifier : tempColumnIdentifier ,
538
574
}
539
575
540
- renderedSteps = append (renderedSteps , newStep )
576
+ stepInstructions [step ] = append (stepInstructions [step ], instruction )
577
+ }
578
+
579
+ var renderedSteps []string
580
+ for step , instructions := range stepInstructions {
581
+ for i , s := range steps [step :] {
582
+ newStep , err := s (dialect , table , instructions ... )
583
+ if err != nil {
584
+ return nil , fmt .Errorf ("rendering step %d: %w" , i , err )
585
+ }
586
+
587
+ renderedSteps = append (renderedSteps , newStep ... )
588
+ }
541
589
}
542
590
543
591
return renderedSteps , nil
0 commit comments