From 3757619da03ae28f0cb8abe42a6d8f9cbc41b56f Mon Sep 17 00:00:00 2001 From: Jesse Geens Date: Thu, 9 Jan 2025 14:33:38 +0100 Subject: [PATCH] Created worker pool and migrate states as well --- cmd/migrator.go | 3 +- share/sql/migrate.go | 197 ++++++++++++++++++++++++++++++++----------- 2 files changed, 149 insertions(+), 51 deletions(-) diff --git a/cmd/migrator.go b/cmd/migrator.go index 14d312f..f7c1973 100644 --- a/cmd/migrator.go +++ b/cmd/migrator.go @@ -15,9 +15,10 @@ func main() { name := flag.String("name", "cernboxngcopy", "Database name") gatewaysvc := flag.String("gatewaysvc", "localhost:9142", "Gateway service location") token := flag.String("token", "", "JWT token for gateway svc") + dryRun := flag.Bool("dryrun", true, "Use dry run?") flag.Parse() fmt.Printf("Connecting to %s@%s:%d\n", *username, *host, *port) - sql.RunMigration(*username, *password, *host, *name, *gatewaysvc, *token, *port) + sql.RunMigration(*username, *password, *host, *name, *gatewaysvc, *token, *port, *dryRun) } diff --git a/share/sql/migrate.go b/share/sql/migrate.go index 21a8bea..16186cd 100644 --- a/share/sql/migrate.go +++ b/share/sql/migrate.go @@ -29,7 +29,42 @@ type ShareOrLink struct { Link *model.PublicLink } -func RunMigration(username, password, host, name, gatewaysvc, token string, port int) { +type OldShareEntry struct { + ID int + UIDOwner string + UIDInitiator string + Prefix string + ItemSource string + ItemType string + ShareWith string + Token string + Expiration string + Permissions int + ShareType int + ShareName string + STime int + FileTarget string + State int + Quicklink bool + Description string + NotifyUploads bool + NotifyUploadsExtraRecipients sql.NullString + Orphan bool +} + +type OldShareState struct { + id int + recipient string + state int +} + +const ( + bufferSize = 10 + numWorkers = 10 +) + +func RunMigration(username, password, host, name, gatewaysvc, token string, port int, dryRun bool) { + // Config config := map[string]interface{}{ "engine": "mysql", "db_username": username, @@ -38,13 +73,15 @@ func RunMigration(username, password, host, name, gatewaysvc, token string, port "db_port": port, "db_name": name, "gatewaysvc": gatewaysvc, - "dry_run": false, + "dry_run": dryRun, } + // Authenticate to gateway service tokenlessCtx, cancel := context.WithCancel(context.Background()) ctx := appctx.ContextSetToken(tokenlessCtx, token) ctx = metadata.AppendToOutgoingContext(ctx, appctx.TokenHeader, token) defer cancel() + // Set up migrator shareManager, err := New(ctx, config) if err != nil { fmt.Println("Failed to create shareManager: " + err.Error()) @@ -62,31 +99,26 @@ func RunMigration(username, password, host, name, gatewaysvc, token string, port ShareMgr: sharemgr, } - ch := make(chan *ShareOrLink, 100) - go getAllShares(ctx, migrator, ch) - for share := range ch { - // TODO error handling - if share.IsShare { - fmt.Printf("Creating share %d\n", share.Share.ID) - migrator.NewDb.Create(&share.Share) - } else { - fmt.Printf("Creating share %d\n", share.Link.ID) - migrator.NewDb.Create(&share.Link) - } + if dryRun { + migrator.NewDb = migrator.NewDb.Debug() } + migrateShares(ctx, migrator) + fmt.Println("---------------------------------") + migrateShareStatuses(ctx, migrator) + } -func getAllShares(ctx context.Context, migrator Migrator, ch chan *ShareOrLink) { - // First we find out what the highest ID is - count, err := getCount(migrator) +func migrateShares(ctx context.Context, migrator Migrator) { + // Check how many shares are to be migrated + count, err := getCount(migrator, "oc_share") if err != nil { - fmt.Println("Error getting highest id: " + err.Error()) - close(ch) + fmt.Println("Error getting count: " + err.Error()) return } fmt.Printf("Migrating %d shares\n", count) + // Get all old shares query := "select id, coalesce(uid_owner, '') as uid_owner, coalesce(uid_initiator, '') as uid_initiator, lower(coalesce(share_with, '')) as share_with, coalesce(fileid_prefix, '') as fileid_prefix, coalesce(item_source, '') as item_source, coalesce(item_type, '') as item_type, stime, permissions, share_type, orphan FROM oc_share order by id desc" // AND id=?" params := []interface{}{} @@ -94,48 +126,113 @@ func getAllShares(ctx context.Context, migrator Migrator, ch chan *ShareOrLink) if err != nil { fmt.Printf("Fatal error: %s", err.Error()) - close(ch) - return + os.Exit(1) + } + + // Create channel for workers + ch := make(chan *OldShareEntry, bufferSize) + defer close(ch) + + // Start all workers + for range numWorkers { + go workerShare(ctx, migrator, ch) } for res.Next() { var s OldShareEntry res.Scan(&s.ID, &s.UIDOwner, &s.UIDInitiator, &s.ShareWith, &s.Prefix, &s.ItemSource, &s.ItemType, &s.STime, &s.Permissions, &s.ShareType, &s.Orphan) - newShare, err := oldShareToNewShare(ctx, migrator, s) if err == nil { - ch <- newShare + ch <- &s } else { - fmt.Printf("Error occured for share %s: %s\n", s.ID, err.Error()) + fmt.Printf("Error occured for share %d: %s\n", s.ID, err.Error()) } } +} + +func migrateShareStatuses(ctx context.Context, migrator Migrator) { + // Check how many shares are to be migrated + count, err := getCount(migrator, "oc_share_status") + if err != nil { + fmt.Println("Error getting count: " + err.Error()) + return + } + fmt.Printf("Migrating %d share statuses\n", count) + + // Get all old shares + query := "select id, coalesce(recipient, '') as recipient, state FROM oc_share_status order by id desc" + params := []interface{}{} - close(ch) + res, err := migrator.OldDb.Query(query, params...) + + if err != nil { + fmt.Printf("Fatal error: %s", err.Error()) + os.Exit(1) + } + + // Create channel for workers + ch := make(chan *OldShareState, bufferSize) + defer close(ch) + + // Start all workers + for range numWorkers { + go workerState(ctx, migrator, ch) + } + + for res.Next() { + var s OldShareState + res.Scan(&s.id, &s.recipient, &s.state) + if err == nil { + ch <- &s + } else { + fmt.Printf("Error occured for share status%d: %s\n", s.id, err.Error()) + } + } } -type OldShareEntry struct { - ID int - UIDOwner string - UIDInitiator string - Prefix string - ItemSource string - ItemType string - ShareWith string - Token string - Expiration string - Permissions int - ShareType int - ShareName string - STime int - FileTarget string - State int - Quicklink bool - Description string - NotifyUploads bool - NotifyUploadsExtraRecipients sql.NullString - Orphan bool +func workerShare(ctx context.Context, migrator Migrator, ch chan *OldShareEntry) { + for share := range ch { + handleSingleShare(ctx, migrator, share) + } +} + +func workerState(ctx context.Context, migrator Migrator, ch chan *OldShareState) { + for state := range ch { + handleSingleState(ctx, migrator, state) + } +} + +func handleSingleShare(ctx context.Context, migrator Migrator, s *OldShareEntry) { + share, err := oldShareToNewShare(ctx, migrator, s) + if err != nil { + return + } + // TODO error handling + if share.IsShare { + migrator.NewDb.Create(&share.Share) + } else { + migrator.NewDb.Create(&share.Link) + } +} + +func handleSingleState(ctx context.Context, migrator Migrator, s *OldShareState) { + // case collaboration.ShareState_SHARE_STATE_REJECTED: + // state = -1 + // case collaboration.ShareState_SHARE_STATE_ACCEPTED: + // state = 1 + + newShareState := &model.ShareState{ + ShareID: uint(s.id), + Model: gorm.Model{ + ID: uint(s.id), + }, + User: s.recipient, + Hidden: s.state == -1, // Hidden if REJECTED + Synced: true, // for now, we always sync? or not? TODO + } + migrator.NewDb.Create(&newShareState) } -func oldShareToNewShare(ctx context.Context, migrator Migrator, s OldShareEntry) (*ShareOrLink, error) { +func oldShareToNewShare(ctx context.Context, migrator Migrator, s *OldShareEntry) (*ShareOrLink, error) { expirationDate, expirationError := time.Parse("2006-01-02 15:04:05", s.Expiration) protoShare := model.ProtoShare{ @@ -146,7 +243,6 @@ func oldShareToNewShare(ctx context.Context, migrator Migrator, s OldShareEntry) }, UIDOwner: s.UIDOwner, UIDInitiator: s.UIDInitiator, - Description: s.Description, Permissions: uint8(s.Permissions), Orphan: s.Orphan, // will be re-checked later Expiration: datatypes.Null[time.Time]{ @@ -171,7 +267,7 @@ func oldShareToNewShare(ctx context.Context, migrator Migrator, s OldShareEntry) protoShare.Orphan = true } else { // We do not set, because of a general error - fmt.Printf("An error occured while statting (%s, %s): %s\n", protoShare.Instance, protoShare.Inode, err.Error()) + fmt.Printf("An error occured for share %d while statting (%s, %s): %s\n", s.ID, protoShare.Instance, protoShare.Inode, err.Error()) } } @@ -187,6 +283,7 @@ func oldShareToNewShare(ctx context.Context, migrator Migrator, s OldShareEntry) ProtoShare: protoShare, ShareWith: s.ShareWith, SharedWithIsGroup: s.ShareType == 1, + Description: s.Description, }, }, nil } else if s.ShareType == 3 { @@ -211,9 +308,9 @@ func oldShareToNewShare(ctx context.Context, migrator Migrator, s OldShareEntry) } } -func getCount(migrator Migrator) (int, error) { +func getCount(migrator Migrator, table string) (int, error) { res := 0 - query := "select count(*) from oc_share" + query := "select count(*) from " + table params := []interface{}{} if err := migrator.OldDb.QueryRow(query, params...).Scan(&res); err != nil {