diff --git a/docs/proposals/MULTI_PROVIDER_ACCOUNT_BINDING_DESIGN.md b/docs/proposals/MULTI_PROVIDER_ACCOUNT_BINDING_DESIGN.md new file mode 100644 index 0000000..3f473fe --- /dev/null +++ b/docs/proposals/MULTI_PROVIDER_ACCOUNT_BINDING_DESIGN.md @@ -0,0 +1,552 @@ +# 多 Provider 账号统一绑定设计稿 + +## 1. 背景与目标 + +当前系统已支持 Casdoor 下多个 provider 登录,并已完成以下基础能力: + +1. Casdoor JWT / userinfo 归一化 +2. `properties.oauth_*` provider 原始资料提取 +3. `external_key` 稳定身份键生成 +4. 本地 `users` 表中主身份摘要字段持久化 + +但现有实现仍主要围绕“一个本地 user 对应一个主外部身份”展开。对于同一自然用户混合使用: + +- Github +- idtrust +- phone + +等多种 provider 登录的场景,当前模型无法完整表达“一个本地用户绑定多个外部身份”的关系,因此需要新增统一绑定能力。 + +本设计稿明确采用: + +> **方案 B:显式绑定** + +即系统不自动猜测和合并不同 provider 的账号,而要求用户在已登录状态下显式发起绑定。系统侧仅负责: + +1. 保存绑定关系 +2. 在后续登录时解析绑定关系并落到同一个本地用户 +3. 按 provider 优先级自动升级主资料源 + +## 2. 非目标 + +本方案当前不包括: + +1. 不支持自动账号归并 +2. 不提供“用户手工切换主资料源”功能 +3. 不处理两个已存在本地用户的自动合并 +4. 不支持管理员后台强制合并账号 + +## 3. 业务策略 + +## 3.1 采用显式绑定 + +只有在以下条件满足时,系统才允许建立新的 provider 绑定: + +1. 当前用户已经登录 +2. 当前用户主动发起绑定流程 +3. 第二个 provider OAuth / 登录校验完整成功 +4. 回调 state 验证通过 + +## 3.2 主资料源自动升级 + +系统不提供“主资料源”手工配置入口,而采用固定 provider 优先级自动升级: + +```text +idtrust > github > phone +``` + +规则如下: + +1. 用户首次登录创建账号时,首条 identity 自动成为 primary identity +2. 后续绑定新 identity 时: + - 若新 provider 优先级更高,则自动切换为 primary identity + - 若优先级相同或更低,则保持原 primary identity +3. 解绑当前 primary identity 时,系统从剩余 identity 中重新选择优先级最高的一条作为新的 primary identity + +## 3.3 字段级资料聚合 + +自动升级 primary identity 不等于每个字段都直接被最后绑定的 provider 覆盖。建议采用以下字段策略: + +| 字段 | 规则 | +|------|------| +| `display_name` | 优先 primary identity 的显示名 | +| `avatar_url` | 优先 primary identity;若为空则优先 Github 头像 | +| `email` | 仅使用合法邮箱;优先 primary identity,再 fallback 其他 identity | +| `phone` | 优先 phone provider,再 fallback primary identity 中合法手机号 | +| `organization` | 优先 primary identity,尤其 idtrust | +| `username` | 尽量稳定,不随 provider 自动频繁变更 | + +## 4. 数据模型设计 + +## 4.1 现有 `users` 表定位调整 + +当前 `users` 表中的以下字段: + +- `auth_provider` +- `external_key` +- `provider_user_id` +- `phone` + +在引入多 provider 绑定后,不再表示“用户唯一真实身份”,而调整为: + +> **当前 primary identity 的摘要字段 / 聚合展示字段** + +因此 `users` 表继续承担: + +1. 本地统一用户实体 +2. 聚合后的稳定展示资料 +3. 当前主身份摘要 + +## 4.2 新增 `user_auth_identities` 表 + +建议新增模型: + +```go +type UserAuthIdentity struct { + ID uint `gorm:"primaryKey"` + UserSubjectID string `gorm:"index;not null"` + + Provider string `gorm:"size:64;not null"` + Issuer string `gorm:"size:255"` + ExternalKey string `gorm:"uniqueIndex;size:255;not null"` + ExternalSubject string `gorm:"size:191"` + ExternalUserID string `gorm:"size:191"` + ProviderUserID string `gorm:"size:191"` + + DisplayName *string `gorm:"size:191"` + Email *string `gorm:"size:191"` + Phone *string `gorm:"size:64"` + AvatarURL *string `gorm:"type:text"` + Organization *string `gorm:"size:191"` + + IsPrimary bool `gorm:"not null;default:false"` + LastLoginAt *time.Time + CreatedAt time.Time + UpdatedAt time.Time +} +``` + +## 4.3 关键索引与约束 + +建议: + +1. `unique index idx_user_auth_identities_external_key (external_key)` +2. `index idx_user_auth_identities_user_subject_id (user_subject_id)` + +业务约束: + +1. 一个 `external_key` 只能绑定一个本地用户 +2. 一个本地用户可以有多条外部身份记录 +3. 同一 `user_subject_id` 在任意时刻只能有一条 `is_primary = true` + +## 5. Provider 优先级规则 + +建议新增统一优先级函数: + +```go +func ProviderRank(provider string) int { + switch strings.ToLower(strings.TrimSpace(provider)) { + case "idtrust": + return 300 + case "github": + return 200 + case "phone": + return 100 + default: + return 0 + } +} +``` + +使用原则: + +1. rank 越高,优先级越高 +2. 新绑定 identity 的 rank 高于当前 primary identity 时,自动升级 +3. rank 相同则保持原 primary,不做切换 + +## 6. 核心服务设计 + +建议新增独立服务,例如: + +- `server/internal/authidentity/service.go` +或 +- `server/internal/useridentity/service.go` + +## 6.1 服务接口建议 + +```go +type AuthIdentityService struct { + db *gorm.DB +} + +func (s *AuthIdentityService) ResolveOrCreateUserByIdentity(identity *NormalizedIdentity) (*models.User, error) + +func (s *AuthIdentityService) BindIdentityToUser(userSubjectID string, identity *NormalizedIdentity) error + +func (s *AuthIdentityService) ListUserIdentities(userSubjectID string) ([]models.UserAuthIdentity, error) + +func (s *AuthIdentityService) UnbindIdentity(userSubjectID string, identityID uint) error + +func (s *AuthIdentityService) RefreshUserProfileFromIdentities(userSubjectID string) error +``` + +## 6.2 方法职责 + +### `ResolveOrCreateUserByIdentity` + +用于登录主流程: + +1. 先按 `external_key` 查 `user_auth_identities` +2. 命中则找到对应本地 user +3. 未命中则创建新 user + 首条 identity +4. 更新 `last_login_at` +5. 必要时刷新聚合用户资料 + +### `BindIdentityToUser` + +用于绑定流程: + +1. 检查该 `external_key` 是否已存在 +2. 若不存在,则绑定到当前 user +3. 若已绑定当前 user,则幂等成功 +4. 若已绑定其他 user,则返回冲突错误 +5. 比较 provider 优先级,必要时自动升级 primary identity +6. 刷新用户聚合资料 + +### `ListUserIdentities` + +返回当前用户所有已绑定 identity 列表,供前端账号设置页展示。 + +### `UnbindIdentity` + +解绑规则: + +1. 确认 identity 属于当前 user +2. 若这是唯一绑定方式,则禁止解绑 +3. 若解绑的是当前 primary identity,则重新选出新的 primary identity +4. 解绑后刷新用户聚合资料 + +### `RefreshUserProfileFromIdentities` + +该方法是“主资料源自动升级 + 字段聚合”的核心。 + +它负责: + +1. 查询用户全部 identity +2. 找出 `is_primary = true` 的 identity +3. 若没有 primary,则按 provider rank 选出一条 +4. 根据字段规则重新计算 `users` 表中的聚合资料 +5. 将 `users.auth_provider / external_key / provider_user_id` 更新为当前 primary identity 摘要 + +## 7. 登录流程设计 + +## 7.1 首次登录 + +流程: + +1. 登录回调 / 中间件解析得到 `NormalizedIdentity` +2. 先查 `user_auth_identities.external_key` +3. 未命中: + - 创建 `users` + - 创建首条 `user_auth_identities` + - 该 identity 标记 `is_primary = true` +4. 刷新 `users` 聚合资料 + +## 7.2 已绑定 provider 登录 + +流程: + +1. 登录回调 / 中间件解析 `NormalizedIdentity` +2. 按 `external_key` 查 identity 表 +3. 找到对应 `user_subject_id` +4. 返回同一 `users` 记录 +5. 更新 identity 的 `last_login_at` + +## 8. 绑定流程设计 + +## 8.1 发起绑定 + +建议接口: + +```http +POST /api/auth/bind/start +``` + +请求: + +```json +{ + "provider": "Github" +} +``` + +返回: + +```json +{ + "authUrl": "https://casdoor..." +} +``` + +## 8.2 bind state 设计 + +建议 state 至少携带: + +```json +{ + "action": "bind", + "userSubjectId": "usr_xxx", + "provider": "Github", + "nonce": "random", + "expiredAt": 1710000000 +} +``` + +要求: + +1. state 必须签名或加密 +2. 与当前登录会话绑定 +3. 有较短有效期 + +## 8.3 绑定回调 + +建议接口: + +```http +GET /api/auth/bind/callback +``` + +流程: + +1. 验证当前用户已登录 +2. 验证 bind state +3. 用 code 换 token +4. 归一化得到新 identity +5. 调用 `BindIdentityToUser(currentUser, identity)` +6. 成功后重定向前端账号设置页 + +## 8.4 冲突处理 + +若新 identity 已绑定其他本地用户,则返回冲突错误,例如: + +```json +{ + "error": "identity_already_bound", + "message": "该登录方式已绑定其他账号" +} +``` + +## 9. 解绑流程设计 + +建议接口: + +```http +POST /api/auth/identities/:id/unbind +``` + +规则: + +1. 只允许解绑当前登录用户自己的 identity +2. 不允许解绑最后一个 identity +3. 如果解绑的是 primary identity,则重新按 provider 优先级选出新的 primary identity +4. 完成后刷新 `users` 聚合资料 + +## 10. 资料聚合规则 + +建议统一通过 `RefreshUserProfileFromIdentities()` 聚合写回 `users` 表。 + +## 10.1 选择 primary identity + +```text +优先使用 is_primary = true +若不存在,则按 provider rank 最高者选 primary +``` + +## 10.2 DisplayName 聚合 + +规则: + +1. primary identity.display_name +2. 其他 identity 中优先级最高的 display_name +3. 若都为空,则回退到 `users.username` + +## 10.3 Avatar 聚合 + +规则: + +1. primary identity.avatar_url +2. Github identity.avatar_url +3. 其他 identity.avatar_url + +## 10.4 Email 聚合 + +规则: + +1. primary identity 的合法邮箱 +2. 其他 identity 的合法邮箱 +3. 非法邮箱格式不写入 `users.email` + +## 10.5 Phone 聚合 + +规则: + +1. `phone` provider 的手机号 +2. primary identity 中的合法手机号 +3. 其他 identity 中的合法手机号 + +## 10.6 Organization 聚合 + +规则: + +1. primary identity.organization +2. 其他 identity 中优先级最高的 organization + +## 10.7 Username 策略 + +建议保持保守: + +1. 初次创建时初始化 username +2. 后续不因 provider 切换频繁自动变更 +3. 仅在旧 username 明显低质量时,允许升级,例如: + - `phone_` + - UUID 风格随机值 + +## 11. 迁移策略 + +## 11.1 新建 identity 表 + +新增 `user_auth_identities` 表与索引。 + +## 11.2 从 `users` 表回填首批 identity + +从当前 `users` 表已存在字段回填: + +- `auth_provider` +- `external_key` +- `provider_user_id` +- `casdoor_universal_id` +- `casdoor_id` +- `casdoor_sub` +- `email` +- `phone` +- `display_name` +- `avatar_url` +- `organization` + +回填原则: + +1. 每个 user 至少生成一条 identity +2. 若 `external_key` 存在,则直接用作 identity 主键字段 +3. 回填 identity 后,将其标记为 `is_primary = true` + +## 11.3 登录主流程切换 + +切换为: + +1. 优先查 `user_auth_identities.external_key` +2. 再 fallback `users.external_key` 与旧 Casdoor 字段 +3. 兼容期内命中旧逻辑时自动补 identity 表 + +## 12. 安全规则 + +## 12.1 绑定必须基于已登录会话 + +绑定不是普通登录,必须确保当前用户已登录。 + +## 12.2 严格校验 state + +必须防止: + +- CSRF +- 会话串绑 +- provider 回调错绑 + +## 12.3 不允许覆盖已绑定他人的 identity + +对于已绑定其他 user 的 `external_key`: + +1. 不允许 silent rebind +2. 不允许自动迁移 +3. 直接报冲突错误 + +## 12.4 不允许解绑最后一个登录方式 + +防止用户把自己锁死。 + +## 13. 代码落点建议 + +建议涉及以下模块: + +### 模型层 + +- `server/internal/models/models.go` + +新增: + +- `UserAuthIdentity` + +### 服务层 + +- `server/internal/authidentity/service.go` +或 +- `server/internal/user/identity_service.go` + +### Handler 层 + +新增接口: + +- `StartBindAuth` +- `BindAuthCallback` +- `ListBoundIdentities` +- `UnbindIdentity` + +### 登录流程 + +改造: + +- `AuthCallback` +- `middleware auth` +- `GetOrCreateUser / ResolveOrCreateUserByIdentity` + +### 迁移层 + +- `server/cmd/migrate/main.go` + +## 14. 实施建议 + +建议按以下顺序推进: + +### Phase 1 + +1. 新增 `user_auth_identities` 表 +2. 回填首批 identity 数据 +3. 新增 `AuthIdentityService` + +### Phase 2 + +1. 登录流程优先切到 identity 表 +2. 兼容旧字段 fallback +3. 自动刷新 `users` 聚合资料 + +### Phase 3 + +1. 新增绑定 / 解绑接口 +2. 前端账号设置页接入 + +### Phase 4 + +1. 观察 identity 命中率与冲突情况 +2. 收敛旧 `users` 表上的查找逻辑 + +## 15. 结论 + +在显式绑定方案下,系统不需要自动猜测用户是否应被合并,而是通过: + +1. 新增 `user_auth_identities` 表表达一对多身份关系 +2. 在已登录态下显式绑定第二 provider +3. 通过 provider 优先级自动升级 primary identity +4. 通过字段级聚合规则刷新 `users` 展示资料 + +来实现“同一自然用户可以混合使用多个 provider 登录同一账号”的能力。 + +这一方案既能控制安全风险,又能在产品复杂度可控的前提下满足多 provider 统一账号诉求。 diff --git a/docs/proposals/casdoor-identity-normalization/05-service-and-flow.md b/docs/proposals/casdoor-identity-normalization/05-service-and-flow.md index 5577c36..9655760 100644 --- a/docs/proposals/casdoor-identity-normalization/05-service-and-flow.md +++ b/docs/proposals/casdoor-identity-normalization/05-service-and-flow.md @@ -173,7 +173,7 @@ syncProfile: 2. `phone` 需通过手机号格式校验 3. 不符合语义的字段不应因字段名而强制落库 -例如:`properties.oauth_Custom_email = 15986746954` 时,不应直接写入 `email`,而应优先识别为 `phone` 候选值。 +例如:`properties.oauth_Custom_email = 15500000001` 时,不应直接写入 `email`,而应优先识别为 `phone` 候选值。 ## 5.8 时序建议 diff --git a/server/cmd/api/main.go b/server/cmd/api/main.go index 5eb1fc6..dd22c34 100644 --- a/server/cmd/api/main.go +++ b/server/cmd/api/main.go @@ -221,6 +221,7 @@ func main() { auth.GET("/callback", handlers.AuthCallback) auth.GET("/login", handlers.Login) auth.POST("/logout", handlers.Logout) + auth.GET("/bind/callback", handlers.AuthCallback) } // === Public read-only endpoints (OptionalAuth, no login required) === @@ -280,6 +281,9 @@ func main() { authed.Use(middleware.RequireAuth(casdoorEndpoint, jwksProvider)) { authed.GET("/auth/me", handlers.GetCurrentUser) + authed.GET("/auth/identities", handlers.ListBoundIdentities) + authed.POST("/auth/bind/start", handlers.StartBindAuth) + authed.POST("/auth/identities/:id/unbind", handlers.UnbindIdentity) usage := authed.Group("/usage") { diff --git a/server/cmd/migrate/main.go b/server/cmd/migrate/main.go index 8596d8f..8c96ffd 100644 --- a/server/cmd/migrate/main.go +++ b/server/cmd/migrate/main.go @@ -11,6 +11,7 @@ import ( "path/filepath" "sort" "strings" + "time" "github.com/costrict/costrict-web/server/internal/config" "github.com/costrict/costrict-web/server/internal/database" @@ -176,6 +177,9 @@ func main() { if err := ensureUserIdentityColumns(db); err != nil { log.Fatalf("Failed to ensure user identity columns: %v", err) } + if err := ensureUserAuthIdentitiesTable(db); err != nil { + log.Fatalf("Failed to ensure user auth identities table: %v", err) + } if err := backfillCapabilityContentVersioning(db); err != nil { log.Fatalf("Failed to backfill capability content versioning: %v", err) @@ -183,6 +187,9 @@ func main() { if err := backfillUserExternalIdentities(db, false); err != nil { log.Fatalf("Failed to backfill user external identities: %v", err) } + if err := backfillUserAuthIdentities(db, false); err != nil { + log.Fatalf("Failed to backfill user auth identities: %v", err) + } log.Println("All migrations completed successfully") } @@ -228,6 +235,39 @@ func ensureUserIdentityColumns(db *gorm.DB) error { return nil } +func ensureUserAuthIdentitiesTable(db *gorm.DB) error { + stmts := []string{ + `CREATE TABLE IF NOT EXISTS user_auth_identities ( + id BIGSERIAL PRIMARY KEY, + user_subject_id text NOT NULL, + provider text NOT NULL, + issuer text, + external_key text NOT NULL, + external_subject text, + external_user_id text, + provider_user_id text, + display_name text, + email text, + phone text, + avatar_url text, + organization text, + is_primary boolean NOT NULL DEFAULT false, + last_login_at timestamptz, + created_at timestamptz, + updated_at timestamptz, + deleted_at timestamptz + )`, + `CREATE UNIQUE INDEX IF NOT EXISTS idx_user_auth_identities_external_key ON user_auth_identities(external_key)`, + `CREATE INDEX IF NOT EXISTS idx_user_auth_identities_user_subject_id ON user_auth_identities(user_subject_id)`, + } + for _, stmt := range stmts { + if err := db.Exec(stmt).Error; err != nil { + return fmt.Errorf("ensure user_auth_identities failed (%s): %w", stmt, err) + } + } + return nil +} + func backfillUserExternalIdentities(db *gorm.DB, dryRun bool) error { hasPhone := db.Migrator().HasColumn(&models.User{}, "phone") hasAuthProvider := db.Migrator().HasColumn(&models.User{}, "auth_provider") @@ -315,6 +355,98 @@ func backfillUserExternalIdentities(db *gorm.DB, dryRun bool) error { }) } +func backfillUserAuthIdentities(db *gorm.DB, dryRun bool) error { + type userRow struct { + SubjectID string + DisplayName *string + Email *string + Phone *string + AvatarURL *string + Organization *string + AuthProvider *string + ExternalKey *string + ProviderUserID *string + CasdoorUniversalID *string + CasdoorID *string + CasdoorSub *string + } + var users []userRow + if err := db.Table("users").Select("subject_id, display_name, email, phone, avatar_url, organization, auth_provider, external_key, provider_user_id, casdoor_universal_id, casdoor_id, casdoor_sub").Find(&users).Error; err != nil { + return fmt.Errorf("load users for auth identity backfill: %w", err) + } + created := 0 + return db.Transaction(func(tx *gorm.DB) error { + for _, user := range users { + if strings.TrimSpace(user.SubjectID) == "" { + continue + } + externalKey := "" + if user.ExternalKey != nil { + externalKey = strings.TrimSpace(*user.ExternalKey) + } + if externalKey == "" { + if user.CasdoorUniversalID != nil && *user.CasdoorUniversalID != "" { + externalKey = "casdoor:" + *user.CasdoorUniversalID + } else if user.CasdoorSub != nil && *user.CasdoorSub != "" { + externalKey = "casdoor-sub:" + *user.CasdoorSub + } else if user.CasdoorID != nil && *user.CasdoorID != "" { + externalKey = "casdoor-id:" + *user.CasdoorID + } + } + if externalKey == "" { + continue + } + var count int64 + if err := tx.Table("user_auth_identities").Where("external_key = ?", externalKey).Count(&count).Error; err != nil { + return err + } + if count > 0 { + continue + } + provider := "casdoor" + if user.AuthProvider != nil && strings.TrimSpace(*user.AuthProvider) != "" { + provider = strings.ToLower(strings.TrimSpace(*user.AuthProvider)) + } + created++ + if dryRun { + continue + } + if err := tx.Table("user_auth_identities").Create(map[string]any{ + "user_subject_id": user.SubjectID, + "provider": provider, + "external_key": externalKey, + "external_subject": coalesceStringPtr(user.CasdoorUniversalID, user.CasdoorSub), + "external_user_id": user.CasdoorID, + "provider_user_id": user.ProviderUserID, + "display_name": user.DisplayName, + "email": user.Email, + "phone": user.Phone, + "avatar_url": user.AvatarURL, + "organization": user.Organization, + "is_primary": true, + "created_at": time.Now(), + "updated_at": time.Now(), + }).Error; err != nil { + return fmt.Errorf("create backfilled auth identity for %s: %w", user.SubjectID, err) + } + } + log.Printf("user auth identity summary (dry-run=%v): created identities=%d", dryRun, created) + if dryRun { + return errDryRunRollback + } + return nil + }) +} + +func coalesceStringPtr(values ...*string) *string { + for _, value := range values { + if value != nil && strings.TrimSpace(*value) != "" { + return value + } + } + return nil +} + func isLikelyPhoneValue(v string) bool { v = strings.TrimSpace(v) if v == "" { diff --git a/server/cmd/migrate/main_test.go b/server/cmd/migrate/main_test.go index 8e7351d..48ad0f3 100644 --- a/server/cmd/migrate/main_test.go +++ b/server/cmd/migrate/main_test.go @@ -92,7 +92,7 @@ func TestBackfillUserExternalIdentities(t *testing.T) { } u1 := models.User{SubjectID: "u1", Username: "alice", CasdoorUniversalID: strPtr("uuid-1"), IsActive: true} - u2 := models.User{SubjectID: "u2", Username: "phone_15986746954", Email: strPtr("15986746954"), IsActive: true} + u2 := models.User{SubjectID: "u2", Username: "phone_15500000001", Email: strPtr("15500000001"), IsActive: true} if err := db.Create(&u1).Error; err != nil { t.Fatalf("create u1: %v", err) } @@ -120,11 +120,32 @@ func TestBackfillUserExternalIdentities(t *testing.T) { if got2.AuthProvider == nil || *got2.AuthProvider != "phone" { t.Fatalf("expected u2 auth_provider backfilled, got %+v", got2) } - if got2.Phone == nil || *got2.Phone != "15986746954" { + if got2.Phone == nil || *got2.Phone != "15500000001" { t.Fatalf("expected u2 phone backfilled from legacy email-like value, got %+v", got2) } } +func TestBackfillUserAuthIdentities(t *testing.T) { + db := newMigrateTestDB(t) + if err := db.AutoMigrate(&models.User{}, &models.UserAuthIdentity{}); err != nil { + t.Fatalf("migrate users/auth identities: %v", err) + } + u := models.User{SubjectID: "u1", Username: "alice", AuthProvider: strPtr("github"), ExternalKey: strPtr("casdoor:uuid-1"), ProviderUserID: strPtr("18633160"), CasdoorUniversalID: strPtr("uuid-1"), IsActive: true} + if err := db.Create(&u).Error; err != nil { + t.Fatalf("create user: %v", err) + } + if err := backfillUserAuthIdentities(db, false); err != nil { + t.Fatalf("backfill auth identities: %v", err) + } + var count int64 + if err := db.Model(&models.UserAuthIdentity{}).Where("user_subject_id = ?", "u1").Count(&count).Error; err != nil { + t.Fatalf("count identities: %v", err) + } + if count != 1 { + t.Fatalf("expected 1 backfilled identity, got %d", count) + } +} + func TestBackfillCapabilityContentVersioning_SingleFile(t *testing.T) { db := newMigrateTestDB(t) diff --git a/server/internal/handlers/handlers.go b/server/internal/handlers/handlers.go index 88a6df1..ffa86cb 100644 --- a/server/internal/handlers/handlers.go +++ b/server/internal/handlers/handlers.go @@ -1,7 +1,10 @@ package handlers import ( + "crypto/hmac" + "crypto/sha256" "encoding/base64" + "encoding/hex" "encoding/json" "fmt" "net/http" @@ -25,6 +28,19 @@ var cookieSecure bool // whether to set Secure flag on auth cookies var defaultFrontendURL string // first entry from FRONTEND_URLS, used as fallback var allowedOrigins map[string]bool // whitelist of allowed frontend origins var UserModule *userpkg.Module +var bindStateSecret string + +var exchangeCodeForTokenFunc = func(code, callbackURL string) (*casdoor.CasdoorTokenResponse, error) { + return CasdoorClient.ExchangeCodeForToken(code, callbackURL) +} + +var getUserInfoFunc = func(accessToken string) (*casdoor.CasdoorUserInfoResponse, error) { + return CasdoorClient.GetUserInfo(accessToken) +} + +var getLoginURLWithCallbackFunc = func(state, callbackURL string) string { + return CasdoorClient.GetLoginURLWithCallback(state, callbackURL) +} type authUserDTO struct { ID string `json:"id"` @@ -38,6 +54,18 @@ type authUserDTO struct { Auth map[string]any `json:"auth,omitempty"` } +type authIdentityDTO struct { + ID uint `json:"id"` + Provider string `json:"provider"` + ProviderUserID *string `json:"providerUserId,omitempty"` + DisplayName *string `json:"displayName,omitempty"` + Email *string `json:"email,omitempty"` + Phone *string `json:"phone,omitempty"` + ExternalKey string `json:"externalKey"` + IsPrimary bool `json:"isPrimary"` + LastLoginAt *time.Time `json:"lastLoginAt,omitempty"` +} + func InitCasdoor(cfg *config.CasdoorConfig) { CasdoorClient = casdoor.NewClient(cfg) } @@ -49,6 +77,10 @@ func InitUserModule(module *userpkg.Module) { // InitCookieConfig sets cookie-related configuration from the global config. func InitCookieConfig(cfg *config.Config) { cookieSecure = cfg.CookieSecure + bindStateSecret = cfg.InternalSecret + if strings.TrimSpace(bindStateSecret) == "" { + bindStateSecret = cfg.Casdoor.Secret + } // Build the allowed origins whitelist from FRONTEND_URLS. allowedOrigins = make(map[string]bool) @@ -96,6 +128,16 @@ type oauthState struct { CallbackURL string // full callback URL (http://localhost:3000/api/auth/callback) } +type bindState struct { + Action string `json:"action"` + UserSubjectID string `json:"userSubjectId"` + Provider string `json:"provider"` + RedirectTo string `json:"redirectTo"` + CallbackURL string `json:"callbackUrl"` + ExpiresAt int64 `json:"expiresAt"` + Nonce string `json:"nonce"` +} + func encodeOAuthState(s oauthState) string { // Extract common origin from redirect_to and store paths only. origin, redirectPath := splitOriginPath(s.RedirectTo) @@ -121,6 +163,49 @@ func decodeOAuthState(encoded string) oauthState { } } +func encodeBindState(s bindState) string { + b, _ := json.Marshal(s) + payload := base64.RawURLEncoding.EncodeToString(b) + return payload + "." + signBindStatePayload(payload) +} + +func decodeBindState(encoded string) bindState { + parts := strings.Split(encoded, ".") + if len(parts) != 2 { + return bindState{} + } + payload, sig := parts[0], parts[1] + if !verifyBindStatePayload(payload, sig) { + return bindState{} + } + b, err := base64.RawURLEncoding.DecodeString(payload) + if err != nil { + return bindState{} + } + var out bindState + if err := json.Unmarshal(b, &out); err != nil { + return bindState{} + } + if out.ExpiresAt > 0 && time.Now().Unix() > out.ExpiresAt { + return bindState{} + } + return out +} + +func signBindStatePayload(payload string) string { + key := strings.TrimSpace(bindStateSecret) + if key == "" { + key = "costrict-bind-state-default" + } + h := hmac.New(sha256.New, []byte(key)) + _, _ = h.Write([]byte(payload)) + return hex.EncodeToString(h.Sum(nil)) +} + +func verifyBindStatePayload(payload, sig string) bool { + return hmac.Equal([]byte(signBindStatePayload(payload)), []byte(sig)) +} + func buildAuthUserDTOFromModel(user *models.User) authUserDTO { name := user.Username if user.DisplayName != nil && *user.DisplayName != "" { @@ -178,6 +263,20 @@ func buildAuthUserDTOFromClaims(claims *userpkg.JWTClaims) authUserDTO { } } +func buildAuthIdentityDTO(identity *models.UserAuthIdentity) authIdentityDTO { + return authIdentityDTO{ + ID: identity.ID, + Provider: identity.Provider, + ProviderUserID: identity.ProviderUserID, + DisplayName: identity.DisplayName, + Email: identity.Email, + Phone: identity.Phone, + ExternalKey: identity.ExternalKey, + IsPrimary: identity.IsPrimary, + LastLoginAt: identity.LastLoginAt, + } +} + // splitOriginPath splits a full URL into origin (scheme://host) and path. // For non-URL strings it returns ("", original). func splitOriginPath(rawURL string) (string, string) { @@ -206,6 +305,17 @@ func splitOriginPath(rawURL string) (string, string) { // @Failure 500 {object} object{error=string} // @Router /auth/callback [get] func AuthCallback(c *gin.Context) { + if rawState := c.Query("state"); rawState != "" { + if bind := decodeBindState(rawState); bind.Action == "bind" { + bindAuthCallback(c, bind) + return + } + if strings.Contains(rawState, ".") { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_or_expired_state"}) + return + } + } + code := c.Query("code") if code == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "code is required"}) @@ -215,7 +325,7 @@ func AuthCallback(c *gin.Context) { // Decode state to recover callback_url (needed for token exchange) and redirect target. state := decodeOAuthState(c.Query("state")) - tokenResp, err := CasdoorClient.ExchangeCodeForToken(code, state.CallbackURL) + tokenResp, err := exchangeCodeForTokenFunc(code, state.CallbackURL) if err != nil || tokenResp.AccessToken == "" { fmt.Printf("[ERROR] ExchangeCodeForToken failed: err=%v, tokenResp=%+v\n", err, tokenResp) c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to exchange code for token: %v", err)}) @@ -223,7 +333,7 @@ func AuthCallback(c *gin.Context) { } if UserModule != nil { - if userInfo, userErr := CasdoorClient.GetUserInfo(tokenResp.AccessToken); userErr == nil && userInfo != nil && userInfo.User != nil { + if userInfo, userErr := getUserInfoFunc(tokenResp.AccessToken); userErr == nil && userInfo != nil && userInfo.User != nil { claims := &userpkg.JWTClaims{ ID: userInfo.User.Id, Sub: userInfo.User.Sub, @@ -264,6 +374,106 @@ func AuthCallback(c *gin.Context) { c.Redirect(http.StatusFound, redirectURL) } +// StartBindAuth godoc +// @Summary Start binding another auth provider +// @Description Returns an OAuth URL for binding another provider to the current logged-in user +// @Tags auth +// @Produce json +// @Security BearerAuth +// @Param body body object{provider=string,redirectTo=string,callbackUrl=string} true "Bind request" +// @Success 200 {object} object{authUrl=string} +// @Failure 400 {object} object{error=string} +// @Failure 401 {object} object{error=string} +// @Router /auth/bind/start [post] +func StartBindAuth(c *gin.Context) { + currentUserID := c.GetString(middleware.UserIDKey) + if currentUserID == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"}) + return + } + var req struct { + Provider string `json:"provider" binding:"required"` + RedirectTo string `json:"redirectTo"` + CallbackURL string `json:"callbackUrl"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request"}) + return + } + if req.CallbackURL != "" && !isAllowedOrigin(req.CallbackURL) { + req.CallbackURL = "" + } + if req.RedirectTo == "" { + req.RedirectTo = defaultFrontendURL + "/settings/account" + } + state := encodeBindState(bindState{ + Action: "bind", + UserSubjectID: currentUserID, + Provider: req.Provider, + RedirectTo: req.RedirectTo, + CallbackURL: req.CallbackURL, + ExpiresAt: time.Now().Add(10 * time.Minute).Unix(), + Nonce: uuid.NewString(), + }) + loginURL := getLoginURLWithCallbackFunc(state, req.CallbackURL) + c.JSON(http.StatusOK, gin.H{"authUrl": loginURL}) +} + +func bindAuthCallback(c *gin.Context, state bindState) { + if state.ExpiresAt == 0 || time.Now().Unix() > state.ExpiresAt { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_or_expired_state"}) + return + } + code := c.Query("code") + if code == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "code is required"}) + return + } + currentToken := middleware.ExtractToken(c) + if currentToken == "" || UserModule == nil || UserModule.Service == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"}) + return + } + currentClaims, err := userpkg.ParseJWTClaimsFromAccessToken(currentToken) + if err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid current session"}) + return + } + currentUser, err := UserModule.Service.GetOrCreateUser(currentClaims) + if err != nil || currentUser == nil || currentUser.SubjectID != state.UserSubjectID { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid binding session"}) + return + } + + tokenResp, err := exchangeCodeForTokenFunc(code, state.CallbackURL) + if err != nil || tokenResp.AccessToken == "" { + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to exchange code for token: %v", err)}) + return + } + claims, err := userpkg.ParseJWTClaimsFromAccessToken(tokenResp.AccessToken) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to parse bound identity"}) + return + } + if state.Provider != "" && !strings.EqualFold(claims.Provider, state.Provider) { + c.JSON(http.StatusConflict, gin.H{"error": "provider_mismatch"}) + return + } + if err := UserModule.Service.BindIdentityToUser(currentUser.SubjectID, claims); err != nil { + if err.Error() == "identity_already_bound" { + c.JSON(http.StatusConflict, gin.H{"error": "identity_already_bound", "message": "该登录方式已绑定其他账号"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to bind identity"}) + return + } + redirectURL := state.RedirectTo + if redirectURL == "" { + redirectURL = defaultFrontendURL + "/settings/account?bind=success" + } + c.Redirect(http.StatusFound, redirectURL) +} + // Login godoc // @Summary OAuth login redirect // @Description Redirect to Casdoor OAuth authorization page @@ -360,7 +570,7 @@ func GetCurrentUser(c *gin.Context) { return } - userInfo, err := CasdoorClient.GetUserInfo(token.(string)) + userInfo, err := getUserInfoFunc(token.(string)) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid or expired token"}) return @@ -390,6 +600,66 @@ func GetCurrentUser(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"user": buildAuthUserDTOFromClaims(claims)}) } +// ListBoundIdentities godoc +// @Summary List bound auth identities +// @Description Lists all auth identities bound to the current user +// @Tags auth +// @Produce json +// @Security BearerAuth +// @Success 200 {object} object{identities=[]object} +// @Failure 401 {object} object{error=string} +// @Router /auth/identities [get] +func ListBoundIdentities(c *gin.Context) { + currentUserID := c.GetString(middleware.UserIDKey) + if currentUserID == "" || UserModule == nil || UserModule.Service == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"}) + return + } + identities, err := UserModule.Service.ListUserIdentities(currentUserID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list identities"}) + return + } + out := make([]authIdentityDTO, 0, len(identities)) + for _, identity := range identities { + out = append(out, buildAuthIdentityDTO(identity)) + } + c.JSON(http.StatusOK, gin.H{"identities": out}) +} + +// UnbindIdentity godoc +// @Summary Unbind auth identity +// @Description Unbinds an auth identity from the current user +// @Tags auth +// @Produce json +// @Security BearerAuth +// @Param id path int true "Identity ID" +// @Success 200 {object} object{message=string} +// @Failure 400 {object} object{error=string} +// @Failure 401 {object} object{error=string} +// @Router /auth/identities/{id}/unbind [post] +func UnbindIdentity(c *gin.Context) { + currentUserID := c.GetString(middleware.UserIDKey) + if currentUserID == "" || UserModule == nil || UserModule.Service == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"}) + return + } + var identityID uint + if _, err := fmt.Sscanf(c.Param("id"), "%d", &identityID); err != nil || identityID == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid identity id"}) + return + } + if err := UserModule.Service.UnbindIdentity(currentUserID, identityID); err != nil { + if err.Error() == "cannot unbind last identity" { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to unbind identity"}) + return + } + c.JSON(http.StatusOK, gin.H{"message": "Identity unbound successfully"}) +} + func stringPtr(v string) *string { if v == "" { return nil diff --git a/server/internal/handlers/handlers_test.go b/server/internal/handlers/handlers_test.go index 02293a7..fc29ce1 100644 --- a/server/internal/handlers/handlers_test.go +++ b/server/internal/handlers/handlers_test.go @@ -1,17 +1,38 @@ package handlers import ( + "crypto/rand" + "crypto/rsa" "encoding/json" "net/http" + "net/http/httptest" + "strings" "testing" + "time" + "github.com/costrict/costrict-web/server/internal/casdoor" "github.com/costrict/costrict-web/server/internal/database" "github.com/costrict/costrict-web/server/internal/middleware" "github.com/costrict/costrict-web/server/internal/models" userpkg "github.com/costrict/costrict-web/server/internal/user" "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v4" ) +func signHandlersTestJWT(t *testing.T, claims jwt.MapClaims) string { + t.Helper() + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate rsa key: %v", err) + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tokenString, err := token.SignedString(key) + if err != nil { + t.Fatalf("sign jwt: %v", err) + } + return tokenString +} + func newRepoRouter(userID string) *gin.Engine { gin.SetMode(gin.TestMode) r := gin.New() @@ -45,6 +66,10 @@ func newAuthRouter(userID string) *gin.Engine { c.Next() } r.GET("/api/auth/me", injectUser, GetCurrentUser) + r.GET("/api/auth/identities", injectUser, ListBoundIdentities) + r.POST("/api/auth/bind/start", injectUser, StartBindAuth) + r.POST("/api/auth/identities/:id/unbind", injectUser, UnbindIdentity) + r.GET("/api/auth/callback", injectUser, AuthCallback) return r } @@ -127,6 +152,242 @@ func TestGetCurrentUserReturnsLocalSubjectUser(t *testing.T) { } } +func TestListBoundIdentities(t *testing.T) { + defer setupTestDB(t)() + defer InitUserModule(nil) + InitUserModule(userpkg.New(database.DB)) + if err := database.DB.Create(&models.User{SubjectID: "usr_local_1", Username: "alice", IsActive: true}).Error; err != nil { + t.Fatalf("seed user: %v", err) + } + if err := database.DB.Create(&models.UserAuthIdentity{UserSubjectID: "usr_local_1", Provider: "github", ExternalKey: "casdoor:uuid-1", IsPrimary: true}).Error; err != nil { + t.Fatalf("seed identity: %v", err) + } + w := get(newAuthRouter("usr_local_1"), "/api/auth/identities") + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var body struct{ Identities []map[string]any `json:"identities"` } + if err := json.NewDecoder(w.Body).Decode(&body); err != nil { + t.Fatalf("decode response: %v", err) + } + if len(body.Identities) != 1 { + t.Fatalf("expected 1 identity, got %+v", body) + } +} + +func TestUnbindIdentityRejectsLastIdentity(t *testing.T) { + defer setupTestDB(t)() + defer InitUserModule(nil) + InitUserModule(userpkg.New(database.DB)) + if err := database.DB.Create(&models.User{SubjectID: "usr_local_1", Username: "alice", IsActive: true}).Error; err != nil { + t.Fatalf("seed user: %v", err) + } + identity := models.UserAuthIdentity{UserSubjectID: "usr_local_1", Provider: "github", ExternalKey: "casdoor:uuid-1", IsPrimary: true} + if err := database.DB.Create(&identity).Error; err != nil { + t.Fatalf("seed identity: %v", err) + } + w := postJSON(newAuthRouter("usr_local_1"), "/api/auth/identities/1/unbind", map[string]any{}) + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestEncodeDecodeBindStateWithSignature(t *testing.T) { + bindStateSecret = "test-secret" + encoded := encodeBindState(bindState{Action: "bind", UserSubjectID: "usr_1", Provider: "github", ExpiresAt: time.Now().Add(time.Minute).Unix(), Nonce: "n1"}) + decoded := decodeBindState(encoded) + if decoded.Action != "bind" || decoded.UserSubjectID != "usr_1" || decoded.Provider != "github" { + t.Fatalf("unexpected decoded state: %+v", decoded) + } + if len(strings.Split(encoded, ".")) != 2 { + t.Fatalf("expected signed state payload, got %q", encoded) + } +} + +func TestDecodeBindStateRejectsTamperedState(t *testing.T) { + bindStateSecret = "test-secret" + encoded := encodeBindState(bindState{Action: "bind", UserSubjectID: "usr_1", Provider: "github", ExpiresAt: time.Now().Add(time.Minute).Unix(), Nonce: "n1"}) + parts := strings.Split(encoded, ".") + tampered := parts[0] + ".deadbeef" + decoded := decodeBindState(tampered) + if decoded.Action != "" { + t.Fatalf("expected tampered state to be rejected, got %+v", decoded) + } +} + +func TestStartBindAuthReturnsSignedURL(t *testing.T) { + defer setupTestDB(t)() + bindStateSecret = "test-secret" + getLoginURLWithCallbackFunc = func(state, callbackURL string) string { + return "https://casdoor.example/login?state=" + state + } + defer func() { getLoginURLWithCallbackFunc = func(state, callbackURL string) string { return CasdoorClient.GetLoginURLWithCallback(state, callbackURL) } }() + w := postJSON(newAuthRouter("usr_local_1"), "/api/auth/bind/start", map[string]any{"provider": "github", "redirectTo": "https://zgsm.sangfor.com/cloud/settings/account"}) + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var body struct{ AuthURL string `json:"authUrl"` } + if err := json.NewDecoder(w.Body).Decode(&body); err != nil { + t.Fatalf("decode response: %v", err) + } + if !strings.Contains(body.AuthURL, "state=") { + t.Fatalf("expected authUrl to contain state, got %q", body.AuthURL) + } +} + +func TestBindCallbackRejectsExpiredState(t *testing.T) { + defer setupTestDB(t)() + defer InitUserModule(nil) + InitUserModule(userpkg.New(database.DB)) + bindStateSecret = "test-secret" + state := encodeBindState(bindState{Action: "bind", UserSubjectID: "usr_local_1", Provider: "github", ExpiresAt: time.Now().Add(-time.Minute).Unix(), Nonce: "n1"}) + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/api/auth/callback?code=abc&state="+state, nil) + newAuthRouter("usr_local_1").ServeHTTP(w, req) + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestBindCallbackRejectsProviderMismatch(t *testing.T) { + defer setupTestDB(t)() + defer InitUserModule(nil) + InitUserModule(userpkg.New(database.DB)) + bindStateSecret = "test-secret" + currentToken := signHandlersTestJWT(t, jwt.MapClaims{"id": "current-id", "sub": "current-sub", "universal_id": "current-uuid", "name": "acct_alpha", "provider": "phone", "phone_number": "15500000001"}) + currentUser, err := UserModule.Service.GetOrCreateUser(&userpkg.JWTClaims{ID: "current-id", Sub: "current-sub", UniversalID: "current-uuid", Name: "acct_alpha", PreferredUsername: "Account Alpha", Provider: "phone", Phone: "15500000001"}) + if err != nil { + t.Fatalf("seed current user: %v", err) + } + r := gin.New() + r.GET("/api/auth/callback", func(c *gin.Context) { + c.Set(middleware.UserIDKey, currentUser.SubjectID) + c.Request.Header.Set("Authorization", "Bearer "+currentToken) + c.Set("accessToken", currentToken) + AuthCallback(c) + }) + defer func() { + exchangeCodeForTokenFunc = func(code, callbackURL string) (*casdoor.CasdoorTokenResponse, error) { return CasdoorClient.ExchangeCodeForToken(code, callbackURL) } + getUserInfoFunc = func(accessToken string) (*casdoor.CasdoorUserInfoResponse, error) { return CasdoorClient.GetUserInfo(accessToken) } + }() + exchangeCodeForTokenFunc = func(code, callbackURL string) (*casdoor.CasdoorTokenResponse, error) { + boundToken := signHandlersTestJWT(t, jwt.MapClaims{"id": "bound-id", "sub": "bound-sub", "universal_id": "bound-uuid", "name": "bound_user", "provider": "idtrust", "properties": map[string]any{"oauth_Custom_id": "custom-user-001", "oauth_Custom_username": "custom_user", "oauth_Custom_displayName": "Display Custom User"}}) + return &casdoor.CasdoorTokenResponse{AccessToken: boundToken}, nil + } + getUserInfoFunc = func(accessToken string) (*casdoor.CasdoorUserInfoResponse, error) { + return &casdoor.CasdoorUserInfoResponse{User: &casdoor.CasdoorUser{Id: "bound-id", Sub: "bound-sub", UniversalID: "bound-uuid", Name: "bound"}}, nil + } + state := encodeBindState(bindState{Action: "bind", UserSubjectID: currentUser.SubjectID, Provider: "github", ExpiresAt: time.Now().Add(time.Minute).Unix(), Nonce: "n1"}) + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/api/auth/callback?code=abc&state="+state, nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusConflict { + t.Fatalf("expected 409, got %d: %s", w.Code, w.Body.String()) + } +} + +func TestBindCallbackSuccess(t *testing.T) { + defer setupTestDB(t)() + defer InitUserModule(nil) + InitUserModule(userpkg.New(database.DB)) + bindStateSecret = "test-secret" + defer func() { + exchangeCodeForTokenFunc = func(code, callbackURL string) (*casdoor.CasdoorTokenResponse, error) { return CasdoorClient.ExchangeCodeForToken(code, callbackURL) } + getUserInfoFunc = func(accessToken string) (*casdoor.CasdoorUserInfoResponse, error) { return CasdoorClient.GetUserInfo(accessToken) } + }() + + currentToken := signHandlersTestJWT(t, jwt.MapClaims{"id": "current-id", "sub": "current-sub", "universal_id": "current-uuid", "name": "acct_alpha", "provider": "phone", "phone_number": "15500000001"}) + currentUser, err := UserModule.Service.GetOrCreateUser(&userpkg.JWTClaims{ID: "current-id", Sub: "current-sub", UniversalID: "current-uuid", Name: "acct_alpha", PreferredUsername: "Account Alpha", Provider: "phone", Phone: "15500000001"}) + if err != nil { + t.Fatalf("seed current user: %v", err) + } + + r := gin.New() + r.GET("/api/auth/callback", func(c *gin.Context) { + c.Set(middleware.UserIDKey, currentUser.SubjectID) + c.Request.Header.Set("Authorization", "Bearer "+currentToken) + c.Set("accessToken", currentToken) + AuthCallback(c) + }) + + exchangeCodeForTokenFunc = func(code, callbackURL string) (*casdoor.CasdoorTokenResponse, error) { + boundToken := signHandlersTestJWT(t, jwt.MapClaims{"id": "bound-gh-id", "sub": "bound-gh-sub", "universal_id": "bound-gh-uuid", "name": "acct_github_user", "provider": "github", "properties": map[string]any{"oauth_GitHub_id": "provider-gh-001", "oauth_GitHub_username": "acct_github_user", "oauth_GitHub_displayName": "Display Github User", "oauth_GitHub_email": "user_github@example.com"}}) + return &casdoor.CasdoorTokenResponse{AccessToken: boundToken}, nil + } + getUserInfoFunc = func(accessToken string) (*casdoor.CasdoorUserInfoResponse, error) { + return &casdoor.CasdoorUserInfoResponse{User: &casdoor.CasdoorUser{Id: "bound-gh-id", Sub: "bound-gh-sub", UniversalID: "bound-gh-uuid", Name: "acct_github_user", PreferredUsername: "Display Github User", Email: "user_github@example.com"}}, nil + } + + state := encodeBindState(bindState{Action: "bind", UserSubjectID: currentUser.SubjectID, Provider: "github", ExpiresAt: time.Now().Add(time.Minute).Unix(), Nonce: "n-success", RedirectTo: "https://example.test/account"}) + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/api/auth/callback?code=ok&state="+state, nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusFound { + t.Fatalf("expected 302, got %d: %s", w.Code, w.Body.String()) + } + if location := w.Header().Get("Location"); location != "https://example.test/account" { + t.Fatalf("expected redirect to account page, got %q", location) + } + identities, err := UserModule.Service.ListUserIdentities(currentUser.SubjectID) + if err != nil { + t.Fatalf("list identities: %v", err) + } + if len(identities) != 2 { + t.Fatalf("expected 2 identities after successful bind, got %d", len(identities)) + } +} + +func TestBindCallbackRejectsIdentityAlreadyBound(t *testing.T) { + defer setupTestDB(t)() + defer InitUserModule(nil) + InitUserModule(userpkg.New(database.DB)) + bindStateSecret = "test-secret" + defer func() { + exchangeCodeForTokenFunc = func(code, callbackURL string) (*casdoor.CasdoorTokenResponse, error) { return CasdoorClient.ExchangeCodeForToken(code, callbackURL) } + getUserInfoFunc = func(accessToken string) (*casdoor.CasdoorUserInfoResponse, error) { return CasdoorClient.GetUserInfo(accessToken) } + }() + + currentToken := signHandlersTestJWT(t, jwt.MapClaims{"id": "current-id", "sub": "current-sub", "universal_id": "current-uuid", "name": "acct_alpha", "provider": "phone", "phone_number": "15500000001"}) + currentUser, err := UserModule.Service.GetOrCreateUser(&userpkg.JWTClaims{ID: "current-id", Sub: "current-sub", UniversalID: "current-uuid", Name: "acct_alpha", PreferredUsername: "Account Alpha", Provider: "phone", Phone: "15500000001"}) + if err != nil { + t.Fatalf("seed current user: %v", err) + } + otherUser, err := UserModule.Service.GetOrCreateUser(&userpkg.JWTClaims{ID: "other-id", Sub: "other-sub", UniversalID: "other-uuid", Name: "acct_beta", PreferredUsername: "Account Beta", Provider: "github", ProviderUserID: "provider-gh-occupied"}) + if err != nil { + t.Fatalf("seed other user: %v", err) + } + if err := UserModule.Service.BindIdentityToUser(otherUser.SubjectID, &userpkg.JWTClaims{ID: "bound-gh-id", Sub: "bound-gh-sub", UniversalID: "bound-gh-uuid", Name: "acct_github_user", PreferredUsername: "Display Github User", Provider: "github", ProviderUserID: "provider-gh-001"}); err != nil { + t.Fatalf("seed occupied identity: %v", err) + } + + r := gin.New() + r.GET("/api/auth/callback", func(c *gin.Context) { + c.Set(middleware.UserIDKey, currentUser.SubjectID) + c.Request.Header.Set("Authorization", "Bearer "+currentToken) + c.Set("accessToken", currentToken) + AuthCallback(c) + }) + + exchangeCodeForTokenFunc = func(code, callbackURL string) (*casdoor.CasdoorTokenResponse, error) { + boundToken := signHandlersTestJWT(t, jwt.MapClaims{"id": "bound-gh-id", "sub": "bound-gh-sub", "universal_id": "bound-gh-uuid", "name": "acct_github_user", "provider": "github", "properties": map[string]any{"oauth_GitHub_id": "provider-gh-001", "oauth_GitHub_username": "acct_github_user", "oauth_GitHub_displayName": "Display Github User"}}) + return &casdoor.CasdoorTokenResponse{AccessToken: boundToken}, nil + } + getUserInfoFunc = func(accessToken string) (*casdoor.CasdoorUserInfoResponse, error) { + return &casdoor.CasdoorUserInfoResponse{User: &casdoor.CasdoorUser{Id: "bound-gh-id", Sub: "bound-gh-sub", UniversalID: "bound-gh-uuid", Name: "acct_github_user", PreferredUsername: "Display Github User"}}, nil + } + + state := encodeBindState(bindState{Action: "bind", UserSubjectID: currentUser.SubjectID, Provider: "github", ExpiresAt: time.Now().Add(time.Minute).Unix(), Nonce: "n-conflict"}) + w := httptest.NewRecorder() + req, _ := http.NewRequest(http.MethodGet, "/api/auth/callback?code=conflict&state="+state, nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusConflict { + t.Fatalf("expected 409, got %d: %s", w.Code, w.Body.String()) + } + if !strings.Contains(w.Body.String(), "identity_already_bound") { + t.Fatalf("expected identity_already_bound error, got %s", w.Body.String()) + } +} + // --------------------------------------------------------------------------- // ListRepositories // --------------------------------------------------------------------------- diff --git a/server/internal/handlers/registry_test.go b/server/internal/handlers/registry_test.go index edb66e2..1c4ec4c 100644 --- a/server/internal/handlers/registry_test.go +++ b/server/internal/handlers/registry_test.go @@ -239,6 +239,27 @@ func setupTestDB(t *testing.T) func() { deleted_at DATETIME, UNIQUE(subject_id) )`, + `CREATE TABLE IF NOT EXISTS user_auth_identities ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_subject_id TEXT NOT NULL, + provider TEXT NOT NULL, + issuer TEXT, + external_key TEXT NOT NULL, + external_subject TEXT, + external_user_id TEXT, + provider_user_id TEXT, + display_name TEXT, + email TEXT, + phone TEXT, + avatar_url TEXT, + organization TEXT, + is_primary INTEGER DEFAULT 0, + last_login_at DATETIME, + created_at DATETIME, + updated_at DATETIME, + deleted_at DATETIME, + UNIQUE(external_key) + )`, `CREATE TABLE IF NOT EXISTS user_system_roles ( id TEXT PRIMARY KEY, user_id TEXT NOT NULL, diff --git a/server/internal/models/models.go b/server/internal/models/models.go index 813b7b8..4a3b9d5 100644 --- a/server/internal/models/models.go +++ b/server/internal/models/models.go @@ -572,3 +572,29 @@ type User struct { func (User) TableName() string { return "users" } + +// UserAuthIdentity stores one external login identity bound to a local user. +type UserAuthIdentity struct { + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` + UserSubjectID string `gorm:"index:idx_user_auth_identities_user_subject_id;not null;size:191" json:"user_subject_id"` + Provider string `gorm:"size:64;not null" json:"provider"` + Issuer *string `gorm:"size:255" json:"issuer"` + ExternalKey string `gorm:"uniqueIndex:idx_user_auth_identities_external_key;not null;size:255" json:"external_key"` + ExternalSubject *string `gorm:"size:191" json:"external_subject"` + ExternalUserID *string `gorm:"size:191" json:"external_user_id"` + ProviderUserID *string `gorm:"size:191" json:"provider_user_id"` + DisplayName *string `gorm:"size:191" json:"display_name"` + Email *string `gorm:"size:191" json:"email"` + Phone *string `gorm:"size:64" json:"phone"` + AvatarURL *string `gorm:"type:text" json:"avatar_url"` + Organization *string `gorm:"size:191" json:"organization"` + IsPrimary bool `gorm:"not null;default:false" json:"is_primary"` + LastLoginAt *time.Time `json:"last_login_at"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + DeletedAt gorm.DeletedAt `gorm:"index" json:"-"` +} + +func (UserAuthIdentity) TableName() string { + return "user_auth_identities" +} diff --git a/server/internal/user/service.go b/server/internal/user/service.go index 5188c40..e5b4e8a 100644 --- a/server/internal/user/service.go +++ b/server/internal/user/service.go @@ -2,6 +2,7 @@ package user import ( "fmt" + "strings" "time" "github.com/costrict/costrict-web/server/internal/authidentity" @@ -112,6 +113,97 @@ func (s *UserService) ResolveSubjectID(claims *JWTClaims) (string, string, error return user.SubjectID, name, nil } +func (s *UserService) ListUserIdentities(userSubjectID string) ([]*models.UserAuthIdentity, error) { + var identities []*models.UserAuthIdentity + err := s.db.Where("user_subject_id = ?", userSubjectID).Order("is_primary DESC, id ASC").Find(&identities).Error + return identities, err +} + +func (s *UserService) BindIdentityToUser(userSubjectID string, claims *JWTClaims) error { + if strings.TrimSpace(userSubjectID) == "" { + return fmt.Errorf("user_subject_id is required") + } + claims = normalizeJWTClaims(claims) + if claims == nil { + return fmt.Errorf("nil JWT claims") + } + externalKey := buildExternalKey(claims) + if externalKey == "" { + return fmt.Errorf("external key is required") + } + + return s.db.Transaction(func(tx *gorm.DB) error { + var existing models.UserAuthIdentity + err := tx.Where("external_key = ?", externalKey).Take(&existing).Error + if err == nil { + if existing.UserSubjectID != userSubjectID { + return fmt.Errorf("identity_already_bound") + } + return s.refreshUserProfileFromIdentitiesTx(tx, userSubjectID) + } + if err != nil && err != gorm.ErrRecordNotFound { + return err + } + + identity := buildUserAuthIdentity(userSubjectID, claims) + var currentPrimary models.UserAuthIdentity + primaryExists := tx.Where("user_subject_id = ? AND is_primary = ?", userSubjectID, true).Take(¤tPrimary).Error == nil + if !primaryExists { + identity.IsPrimary = true + } else if providerRank(identity.Provider) > providerRank(currentPrimary.Provider) { + if err := tx.Model(&models.UserAuthIdentity{}).Where("user_subject_id = ?", userSubjectID).Update("is_primary", false).Error; err != nil { + return err + } + identity.IsPrimary = true + } + + if err := tx.Create(&identity).Error; err != nil { + return err + } + return s.refreshUserProfileFromIdentitiesTx(tx, userSubjectID) + }) +} + +func (s *UserService) UnbindIdentity(userSubjectID string, identityID uint) error { + return s.db.Transaction(func(tx *gorm.DB) error { + var identity models.UserAuthIdentity + if err := tx.Where("id = ? AND user_subject_id = ?", identityID, userSubjectID).Take(&identity).Error; err != nil { + return err + } + + var count int64 + if err := tx.Model(&models.UserAuthIdentity{}).Where("user_subject_id = ?", userSubjectID).Count(&count).Error; err != nil { + return err + } + if count <= 1 { + return fmt.Errorf("cannot unbind last identity") + } + + wasPrimary := identity.IsPrimary + if err := tx.Delete(&identity).Error; err != nil { + return err + } + + if wasPrimary { + var remaining []*models.UserAuthIdentity + if err := tx.Where("user_subject_id = ?", userSubjectID).Find(&remaining).Error; err != nil { + return err + } + best := selectBestPrimary(remaining) + if best != nil { + if err := tx.Model(&models.UserAuthIdentity{}).Where("user_subject_id = ?", userSubjectID).Update("is_primary", false).Error; err != nil { + return err + } + if err := tx.Model(&models.UserAuthIdentity{}).Where("id = ?", best.ID).Update("is_primary", true).Error; err != nil { + return err + } + } + } + + return s.refreshUserProfileFromIdentitiesTx(tx, userSubjectID) + }) +} + // SearchUsers searches users by username or email keyword func (s *UserService) SearchUsers(keyword string, limit int) ([]*models.User, error) { var users []*models.User @@ -153,6 +245,16 @@ func (s *UserService) GetOrCreateUser(claims *JWTClaims) (*models.User, error) { var user models.User found := false if externalKey != "" { + var identity models.UserAuthIdentity + if err := s.db.Where("external_key = ?", externalKey).Take(&identity).Error; err == nil { + if err := s.db.Where("subject_id = ?", identity.UserSubjectID).Take(&user).Error; err == nil { + found = true + } + } else if err != gorm.ErrRecordNotFound { + return nil, fmt.Errorf("failed to query identity by external_key: %w", err) + } + } + if externalKey != "" && !found { err := s.db.Where("external_key = ?", externalKey).Take(&user).Error if err == nil { found = true @@ -264,6 +366,9 @@ func (s *UserService) GetOrCreateUser(claims *JWTClaims) (*models.User, error) { return nil, fmt.Errorf("failed to update user: %w", err) } } + if err := s.BindIdentityToUser(user.SubjectID, claims); err != nil && err.Error() != "identity_already_bound" { + return nil, err + } return &user, nil } @@ -307,6 +412,12 @@ func (s *UserService) GetOrCreateUser(claims *JWTClaims) (*models.User, error) { } return nil, fmt.Errorf("failed to create user: %w", err) } + if err := s.BindIdentityToUser(user.SubjectID, claims); err != nil && err.Error() != "identity_already_bound" { + return nil, err + } + if refreshed, err := s.GetUserByID(user.SubjectID); err == nil { + return refreshed, nil + } return &user, nil } @@ -483,6 +594,192 @@ func buildExternalKey(claims *JWTClaims) string { return "" } +func buildUserAuthIdentity(userSubjectID string, claims *JWTClaims) models.UserAuthIdentity { + now := time.Now() + externalKey := buildExternalKey(claims) + provider := strings.ToLower(strings.TrimSpace(claims.Provider)) + if provider == "" { + provider = "casdoor" + } + return models.UserAuthIdentity{ + UserSubjectID: userSubjectID, + Provider: provider, + ExternalKey: externalKey, + ExternalSubject: stringPtr(firstNonEmptyString(claims.UniversalID, claims.Sub)), + ExternalUserID: stringPtr(claims.ID), + ProviderUserID: stringPtr(claims.ProviderUserID), + DisplayName: stringPtr(claims.PreferredUsername), + Email: stringPtr(claims.Email), + Phone: stringPtr(claims.Phone), + AvatarURL: stringPtr(claims.Picture), + Organization: stringPtr(claims.Owner), + LastLoginAt: &now, + } +} + +func providerRank(provider string) int { + switch strings.ToLower(strings.TrimSpace(provider)) { + case "idtrust": + return 300 + case "github": + return 200 + case "phone": + return 100 + default: + return 0 + } +} + +func selectBestPrimary(identities []*models.UserAuthIdentity) *models.UserAuthIdentity { + var best *models.UserAuthIdentity + for _, identity := range identities { + if identity == nil { + continue + } + if best == nil || providerRank(identity.Provider) > providerRank(best.Provider) || (providerRank(identity.Provider) == providerRank(best.Provider) && identity.ID < best.ID) { + best = identity + } + } + return best +} + +func (s *UserService) refreshUserProfileFromIdentitiesTx(tx *gorm.DB, userSubjectID string) error { + var user models.User + if err := tx.Where("subject_id = ?", userSubjectID).Take(&user).Error; err != nil { + return err + } + var identities []*models.UserAuthIdentity + if err := tx.Where("user_subject_id = ?", userSubjectID).Order("is_primary DESC, id ASC").Find(&identities).Error; err != nil { + return err + } + if len(identities) == 0 { + return nil + } + primary := selectBestPrimary(identities) + if primary == nil { + return nil + } + if !primary.IsPrimary { + if err := tx.Model(&models.UserAuthIdentity{}).Where("user_subject_id = ?", userSubjectID).Update("is_primary", false).Error; err != nil { + return err + } + if err := tx.Model(&models.UserAuthIdentity{}).Where("id = ?", primary.ID).Update("is_primary", true).Error; err != nil { + return err + } + } + + user.AuthProvider = stringPtr(primary.Provider) + user.ExternalKey = stringPtr(primary.ExternalKey) + user.ProviderUserID = primary.ProviderUserID + user.DisplayName = firstNonNilStringPtr(primary.DisplayName, bestIdentityString(identities, func(i *models.UserAuthIdentity) *string { return i.DisplayName })) + user.AvatarURL = firstNonNilStringPtr(primary.AvatarURL, githubAvatar(identities), bestIdentityString(identities, func(i *models.UserAuthIdentity) *string { return i.AvatarURL })) + user.Email = validEmailPtr(primary.Email, identities) + user.Phone = preferredPhonePtr(primary, identities) + user.Organization = firstNonNilStringPtr(primary.Organization, bestIdentityString(identities, func(i *models.UserAuthIdentity) *string { return i.Organization })) + if shouldUpgradeUsername(user.Username) { + if upgraded := firstNonEmptyString(ptrString(primary.ProviderUserID), ptrString(primary.DisplayName)); upgraded != "" { + user.Username = sanitizeUsernameCandidate(upgraded, user.Username) + } + } + now := time.Now() + user.LastSyncAt = &now + return tx.Save(&user).Error +} + +func firstNonEmptyString(values ...string) string { + for _, v := range values { + if strings.TrimSpace(v) != "" { + return strings.TrimSpace(v) + } + } + return "" +} + +func ptrString(v *string) string { + if v == nil { + return "" + } + return strings.TrimSpace(*v) +} + +func firstNonNilStringPtr(values ...*string) *string { + for _, v := range values { + if v != nil && strings.TrimSpace(*v) != "" { + trimmed := strings.TrimSpace(*v) + return &trimmed + } + } + return nil +} + +func bestIdentityString(identities []*models.UserAuthIdentity, getter func(*models.UserAuthIdentity) *string) *string { + var best *models.UserAuthIdentity + for _, identity := range identities { + candidate := getter(identity) + if candidate == nil || strings.TrimSpace(*candidate) == "" { + continue + } + if best == nil || providerRank(identity.Provider) > providerRank(best.Provider) { + best = identity + } + } + if best == nil { + return nil + } + return getter(best) +} + +func githubAvatar(identities []*models.UserAuthIdentity) *string { + for _, identity := range identities { + if strings.EqualFold(identity.Provider, "github") && identity.AvatarURL != nil && strings.TrimSpace(*identity.AvatarURL) != "" { + return identity.AvatarURL + } + } + return nil +} + +func validEmailPtr(primary *string, identities []*models.UserAuthIdentity) *string { + if primary != nil && strings.Contains(strings.TrimSpace(*primary), "@") { + return firstNonNilStringPtr(primary) + } + for _, identity := range identities { + if identity.Email != nil && strings.Contains(strings.TrimSpace(*identity.Email), "@") { + return firstNonNilStringPtr(identity.Email) + } + } + return nil +} + +func preferredPhonePtr(primary *models.UserAuthIdentity, identities []*models.UserAuthIdentity) *string { + for _, identity := range identities { + if strings.EqualFold(identity.Provider, "phone") && identity.Phone != nil && strings.TrimSpace(*identity.Phone) != "" { + return firstNonNilStringPtr(identity.Phone) + } + } + if primary != nil && primary.Phone != nil && strings.TrimSpace(*primary.Phone) != "" { + return firstNonNilStringPtr(primary.Phone) + } + for _, identity := range identities { + if identity.Phone != nil && strings.TrimSpace(*identity.Phone) != "" { + return firstNonNilStringPtr(identity.Phone) + } + } + return nil +} + +func shouldUpgradeUsername(username string) bool { + username = strings.TrimSpace(username) + return username == "" || strings.HasPrefix(username, "phone_") || strings.HasPrefix(username, "user_") +} + +func sanitizeUsernameCandidate(candidate, fallback string) string { + candidate = strings.TrimSpace(candidate) + if candidate == "" { + return fallback + } + return candidate +} + // stringPtr returns a pointer to string if non-empty, otherwise nil func stringPtr(s string) *string { if s == "" { diff --git a/server/internal/user/service_test.go b/server/internal/user/service_test.go index 8ead451..c709db1 100644 --- a/server/internal/user/service_test.go +++ b/server/internal/user/service_test.go @@ -37,7 +37,7 @@ func setupUserTestDB(t *testing.T) *gorm.DB { t.Fatalf("failed to open test db: %v", err) } - if err := db.AutoMigrate(&models.User{}); err != nil { + if err := db.AutoMigrate(&models.User{}, &models.UserAuthIdentity{}); err != nil { t.Fatalf("failed to migrate user table: %v", err) } @@ -380,14 +380,14 @@ func TestParseJWTClaimsFromAccessTokenGithubProperties(t *testing.T) { "id": "18633160", "sub": "universal-gh-1", "universal_id": "universal-gh-1", - "name": "XDfield", - "displayName": "gh_XDfield", + "name": "acct_github_user", + "displayName": "gh_acct_github_user", "provider": "Github", "properties": map[string]any{ "oauth_GitHub_id": "18633160", - "oauth_GitHub_username": "XDfield", - "oauth_GitHub_displayName": "DoSun", - "oauth_GitHub_email": "chenxuan@example.com", + "oauth_GitHub_username": "acct_github_user", + "oauth_GitHub_displayName": "Display Github User", + "oauth_GitHub_email": "user_github@example.com", "oauth_GitHub_avatarUrl": "https://avatars.githubusercontent.com/u/18633160?v=4", }, "exp": time.Now().Add(time.Hour).Unix(), @@ -397,13 +397,13 @@ func TestParseJWTClaimsFromAccessTokenGithubProperties(t *testing.T) { if err != nil { t.Fatalf("ParseJWTClaimsFromAccessToken error: %v", err) } - if claims.Name != "XDfield" { + if claims.Name != "acct_github_user" { t.Fatalf("expected github username from properties, got %+v", claims) } - if claims.PreferredUsername != "DoSun" { + if claims.PreferredUsername != "Display Github User" { t.Fatalf("expected github display name from properties, got %+v", claims) } - if claims.Email != "chenxuan@example.com" { + if claims.Email != "user_github@example.com" { t.Fatalf("expected github email from properties, got %+v", claims) } if claims.Picture == "" || claims.ProviderUserID != "18633160" || claims.Provider != "Github" { @@ -413,17 +413,17 @@ func TestParseJWTClaimsFromAccessTokenGithubProperties(t *testing.T) { func TestParseJWTClaimsFromAccessTokenIDTrustUsesProperties(t *testing.T) { tokenString := signUserTestJWT(t, jwt.MapClaims{ - "id": "42766", + "id": "custom-user-001", "sub": "universal-custom-1", "universal_id": "universal-custom-1", "name": "random-generated-name", - "displayName": "陈烜42766", + "displayName": "display_custom_user_001", "provider": "idtrust", "properties": map[string]any{ - "oauth_Custom_id": "42766", - "oauth_Custom_username": "陈烜", - "oauth_Custom_displayName": "陈烜", - "oauth_Custom_email": "15986746954", + "oauth_Custom_id": "custom-user-001", + "oauth_Custom_username": "custom_user", + "oauth_Custom_displayName": "Display Custom User", + "oauth_Custom_email": "15500000001", }, "exp": time.Now().Add(time.Hour).Unix(), }) @@ -432,19 +432,19 @@ func TestParseJWTClaimsFromAccessTokenIDTrustUsesProperties(t *testing.T) { if err != nil { t.Fatalf("ParseJWTClaimsFromAccessToken error: %v", err) } - if claims.Name != "陈烜" { + if claims.Name != "custom_user" { t.Fatalf("expected idtrust username from properties, got %+v", claims) } - if claims.PreferredUsername != "陈烜" { + if claims.PreferredUsername != "Display Custom User" { t.Fatalf("expected idtrust display name from properties, got %+v", claims) } - if claims.ProviderUserID != "42766" { + if claims.ProviderUserID != "custom-user-001" { t.Fatalf("expected idtrust provider user id from properties, got %+v", claims) } if claims.Email != "" { t.Fatalf("expected invalid email-like phone not mapped to email, got %+v", claims) } - if claims.Phone != "15986746954" { + if claims.Phone != "15500000001" { t.Fatalf("expected phone inferred from custom email field, got %+v", claims) } } @@ -476,3 +476,76 @@ func TestCachedUserServiceGetUsersByIDsAndWarmup(t *testing.T) { t.Fatalf("expected 2 users, got %d", len(got)) } } + +func TestBindIdentityToUserCreatesSecondaryIdentityAndPromotesByRank(t *testing.T) { + db := setupUserTestDB(t) + svc := NewUserService(db) + + phoneClaims := &JWTClaims{ID: "phone-id", Sub: "phone-sub", UniversalID: "phone-uuid", Name: "phone_15500000001", PreferredUsername: "ph_15500000001", Provider: "phone", Phone: "15500000001"} + user, err := svc.GetOrCreateUser(phoneClaims) + if err != nil { + t.Fatalf("create phone user: %v", err) + } + + githubClaims := &JWTClaims{ID: "gh-id", Sub: "gh-sub", UniversalID: "gh-uuid", Name: "acct_github_user", PreferredUsername: "Display Github User", Provider: "github", ProviderUserID: "provider-gh-001", Picture: "https://avatars.example.com/a.png"} + if err := svc.BindIdentityToUser(user.SubjectID, githubClaims); err != nil { + t.Fatalf("bind github identity: %v", err) + } + + identities, err := svc.ListUserIdentities(user.SubjectID) + if err != nil { + t.Fatalf("list identities: %v", err) + } + if len(identities) != 2 { + t.Fatalf("expected 2 identities, got %d", len(identities)) + } + primaryCount := 0 + for _, identity := range identities { + if identity.IsPrimary { + primaryCount++ + if identity.Provider != "github" { + t.Fatalf("expected github to be promoted primary, got %+v", identity) + } + } + } + if primaryCount != 1 { + t.Fatalf("expected exactly 1 primary identity, got %d", primaryCount) + } + refreshed, err := svc.GetUserByID(user.SubjectID) + if err != nil { + t.Fatalf("reload user: %v", err) + } + if refreshed.AuthProvider == nil || *refreshed.AuthProvider != "github" { + t.Fatalf("expected user auth_provider upgraded to github, got %+v", refreshed) + } +} + +func TestUnbindIdentityReassignsPrimary(t *testing.T) { + db := setupUserTestDB(t) + svc := NewUserService(db) + + user, err := svc.GetOrCreateUser(&JWTClaims{ID: "gh-id", Sub: "gh-sub", UniversalID: "gh-uuid", Name: "acct_github_user", PreferredUsername: "Display Github User", Provider: "github", ProviderUserID: "provider-gh-001"}) + if err != nil { + t.Fatalf("create user: %v", err) + } + if err := svc.BindIdentityToUser(user.SubjectID, &JWTClaims{ID: "phone-id", Sub: "phone-sub", UniversalID: "phone-uuid", Name: "phone_15500000001", PreferredUsername: "ph_15500000001", Provider: "phone", Phone: "15500000001"}); err != nil { + t.Fatalf("bind phone identity: %v", err) + } + identities, _ := svc.ListUserIdentities(user.SubjectID) + var githubIdentityID uint + for _, identity := range identities { + if identity.Provider == "github" { + githubIdentityID = identity.ID + } + } + if githubIdentityID == 0 { + t.Fatal("expected github identity to exist") + } + if err := svc.UnbindIdentity(user.SubjectID, githubIdentityID); err != nil { + t.Fatalf("unbind github identity: %v", err) + } + identities, _ = svc.ListUserIdentities(user.SubjectID) + if len(identities) != 1 || !identities[0].IsPrimary || identities[0].Provider != "phone" { + t.Fatalf("expected remaining phone identity to become primary, got %+v", identities) + } +} diff --git a/todo/MULTI_PROVIDER_ACCOUNT_BINDING_PROGRESS.md b/todo/MULTI_PROVIDER_ACCOUNT_BINDING_PROGRESS.md new file mode 100644 index 0000000..7e63717 --- /dev/null +++ b/todo/MULTI_PROVIDER_ACCOUNT_BINDING_PROGRESS.md @@ -0,0 +1,342 @@ +# 多 Provider 账号绑定实施进度 + +基于 `docs/proposals/MULTI_PROVIDER_ACCOUNT_BINDING_DESIGN.md`,用于跟踪“显式绑定 + provider 优先级自动升级主资料源”方案的实施进度。 + +--- + +## 一、设计基线与前置条件 + +### 1. 已有基础能力确认 + +- [x] Casdoor token 归一化能力已建立 +- [x] `properties.oauth_*` provider 原始资料提取已建立 +- [x] `external_key` 稳定身份键已建立 +- [x] `users` 表已支持主身份摘要字段 +- [x] `/auth/me` 统一 user DTO 已落地 +- [x] richer auth claims 已贯通 middleware / resolver / authz + +### 2. 本方案明确范围 + +- [x] 采用 **方案 B:显式绑定** +- [x] 不做自动账号归并 +- [x] 不做“主资料源手工切换”功能 +- [x] 采用 provider 优先级自动升级 primary identity +- [x] 确认默认优先级:`idtrust > github > phone` + +--- + +## 二、数据模型与迁移(P0) + +### 3. 模型定义(`server/internal/models/models.go`) + +- [x] 新增 `UserAuthIdentity` 模型 +- [x] 定义字段:`UserSubjectID` +- [x] 定义字段:`Provider` +- [x] 定义字段:`Issuer` +- [x] 定义字段:`ExternalKey` +- [x] 定义字段:`ExternalSubject` +- [x] 定义字段:`ExternalUserID` +- [x] 定义字段:`ProviderUserID` +- [x] 定义字段:`DisplayName` +- [x] 定义字段:`Email` +- [x] 定义字段:`Phone` +- [x] 定义字段:`AvatarURL` +- [x] 定义字段:`Organization` +- [x] 定义字段:`IsPrimary` +- [x] 定义字段:`LastLoginAt` +- [x] 定义 `TableName()`(如需要) + +### 4. 数据库索引与约束 + +- [x] `unique index idx_user_auth_identities_external_key (external_key)` +- [x] `index idx_user_auth_identities_user_subject_id (user_subject_id)` +- [ ] 评估是否增加 provider / provider_user_id 组合索引 +- [x] 约束同一 user 仅允许一条 `is_primary=true`(业务保证或数据库约束) + +### 5. 数据库迁移 + +- [x] 新建 identity 表迁移方案 +- [x] 决定使用 `AutoMigrate` 还是显式 SQL 迁移 +- [x] 为历史库补充索引创建逻辑 +- [x] 本地验证迁移可重复执行 +- [x] PostgreSQL 环境验证迁移兼容性 + +### 6. 历史数据回填 + +- [x] 从 `users` 表回填首批 `user_auth_identities` +- [x] 回填来源:`auth_provider` +- [x] 回填来源:`external_key` +- [x] 回填来源:`provider_user_id` +- [x] 回填来源:`casdoor_universal_id` +- [x] 回填来源:`casdoor_id` +- [x] 回填来源:`casdoor_sub` +- [x] 回填来源:`display_name` +- [x] 回填来源:`email` +- [x] 回填来源:`phone` +- [x] 回填来源:`avatar_url` +- [x] 回填来源:`organization` +- [x] 将回填 identity 标记为 `is_primary=true` +- [ ] 支持 dry-run / summary 输出 + +--- + +## 三、Identity Service(P0) + +### 7. 新增服务模块 + +- [ ] 新增 `server/internal/authidentity/service.go` 或等价目录结构 +- [ ] 定义 `AuthIdentityService` +- [ ] 初始化依赖:`db *gorm.DB` + +### 8. 登录解析能力 + +- [ ] `ResolveOrCreateUserByIdentity(identity *NormalizedIdentity) (*models.User, error)` +- [ ] 优先按 `user_auth_identities.external_key` 查找 +- [ ] 命中后返回已绑定 user +- [ ] 未命中时创建新 user + 首条 identity +- [ ] 更新 identity `last_login_at` +- [ ] 更新 user `last_login_at` +- [ ] 与现有 `GetOrCreateUser` 兼容衔接 + +### 9. 绑定能力 + +- [ ] `BindIdentityToUser(userSubjectID string, identity *NormalizedIdentity) error` +- [ ] 检查 `external_key` 是否已存在 +- [ ] 未绑定时创建新 identity +- [ ] 已绑定当前 user 时幂等成功 +- [ ] 已绑定其他 user 时返回冲突错误 +- [ ] 根据 provider rank 判断是否切换 primary identity +- [ ] 完成后刷新 `users` 聚合资料 + +### 10. 查询与解绑能力 + +- [x] `ListUserIdentities(userSubjectID string) ([]models.UserAuthIdentity, error)` +- [x] `UnbindIdentity(userSubjectID string, identityID uint) error` +- [x] 禁止解绑最后一个 identity +- [x] 若解绑的是 primary identity,则自动重新选择新的 primary identity +- [x] 解绑后刷新 `users` 聚合资料 + +--- + +## 四、Primary Identity 与资料聚合(P0) + +### 11. Provider 优先级规则 + +- [x] 实现 `ProviderRank(provider string) int` +- [x] 支持 `idtrust` +- [x] 支持 `github` +- [x] 支持 `phone` +- [x] 未知 provider 统一降级为最低优先级 + +### 12. Primary Identity 自动升级 + +- [x] 首次 identity 自动设为 primary +- [x] 新绑定 identity 优先级更高时自动升级为 primary +- [x] 优先级相同不切换 primary +- [x] 解绑 primary 时自动重选 + +### 13. 聚合方法 + +- [x] `RefreshUserProfileFromIdentities(userSubjectID string) error` +- [x] 查询当前用户全部 identities +- [x] 找出 `is_primary=true` 的 identity +- [x] 若 primary 缺失则按 rank 选出新的 primary +- [x] 回写 `users.auth_provider` +- [x] 回写 `users.external_key` +- [x] 回写 `users.provider_user_id` + +### 14. 字段聚合规则 + +- [x] `display_name`:primary 优先 +- [x] `avatar_url`:primary 优先,Github 头像 fallback +- [x] `email`:仅合法邮箱参与聚合 +- [x] `phone`:phone provider 优先 +- [x] `organization`:primary 优先 +- [x] `username`:保持稳定,不频繁自动变更 +- [x] 明确低质量 username 升级条件(如 `phone_*` / UUID 风格) + +--- + +## 五、登录主流程切换(P0) + +### 15. Handler / Middleware / Resolver 接入 + +- [x] `AuthCallback` 切换为优先通过 identity 表查人 +- [ ] `RequireAuth` / `OptionalAuth` 链路优先通过 identity 表解析 user +- [ ] authz 的 token 校验链路兼容 identity 表 +- [x] `/auth/me` 在 identity 表模式下返回一致结果 + +### 16. 兼容旧逻辑 fallback + +- [x] 未命中 identity 表时 fallback `users.external_key` +- [ ] 未命中 identity 表时 fallback `casdoor_universal_id` +- [ ] 未命中 identity 表时 fallback `casdoor_sub` +- [ ] 未命中 identity 表时 fallback `casdoor_id` +- [x] 命中旧逻辑后自动补 identity 记录 + +--- + +## 六、绑定接口(P1) + +### 17. 发起绑定 + +- [x] `POST /api/auth/bind/start` +- [x] 参数校验:provider 必填 +- [x] 必须要求当前用户已登录 +- [x] 生成 bind state +- [x] 返回绑定 OAuth URL + +### 18. 绑定回调 + +- [x] `GET /api/auth/bind/callback` +- [x] 验证当前用户会话 +- [x] 验证 bind state +- [x] 用 code 换 token +- [x] 归一化新的 identity +- [x] 调用 `BindIdentityToUser` +- [x] 绑定成功后重定向账号设置页 + +### 19. 已绑定身份查询 + +- [x] `GET /api/auth/identities` +- [x] 返回 provider 列表 +- [x] 返回 `providerUserId` +- [x] 返回 `displayName` +- [x] 返回 `email` +- [x] 返回 `phone` +- [x] 返回 `externalKey` +- [x] 返回 `isPrimary` +- [x] 返回 `lastLoginAt` + +### 20. 解绑接口 + +- [x] `POST /api/auth/identities/:id/unbind` +- [x] 校验 identity 属于当前 user +- [x] 禁止解绑最后一个 identity +- [x] primary identity 解绑后自动重选 +- [ ] 返回更新后的 identity 列表或成功状态 + +--- + +## 七、安全与状态控制(P0) + +### 21. bind state 设计 + +- [x] `action=bind` 标识 +- [x] 包含 `userSubjectID` +- [x] 包含 `provider` +- [x] 包含 `nonce` +- [x] 包含过期时间 +- [ ] 与当前登录会话绑定 +- [x] 签名或加密 + +### 22. 安全规则 + +- [x] 绑定必须要求已登录 +- [ ] 严格防止跨会话串绑 +- [x] 禁止覆盖已绑定他人的 identity +- [x] 冲突时返回明确错误码 `identity_already_bound` +- [x] 解绑最后一个 identity 时返回明确错误 + +--- + +## 八、测试(P0) + +### 23. Service 单元测试 + +- [ ] `ResolveOrCreateUserByIdentity` 首次创建测试 +- [ ] `ResolveOrCreateUserByIdentity` 命中既有 identity 测试 +- [x] `BindIdentityToUser` 新绑定成功测试 +- [ ] `BindIdentityToUser` 幂等绑定测试 +- [ ] `BindIdentityToUser` 绑定冲突测试 +- [x] `UnbindIdentity` 成功测试 +- [x] `UnbindIdentity` 最后一个 identity 保护测试 + +### 24. 聚合策略测试 + +- [x] `phone -> github` 自动升级 primary 测试 +- [ ] `github -> idtrust` 自动升级 primary 测试 +- [ ] 同优先级不切换 primary 测试 +- [x] primary 解绑后自动重选测试 +- [ ] 头像 Github fallback 测试 +- [ ] 非法邮箱不写入 `users.email` 测试 +- [ ] phone provider 优先写入 `users.phone` 测试 + +### 25. Handler / API 测试 + +- [x] `bind/start` 测试 +- [ ] `bind/callback` 成功测试 +- [ ] `bind/callback` state 校验失败测试 +- [x] `GET /api/auth/identities` 测试 +- [x] `unbind` 测试 + +### 26. 集成测试 + +- [ ] 首次 phone 登录创建 user + identity +- [ ] 绑定 Github 后同账号登录命中同一 user +- [ ] 绑定 idtrust 后 primary 自动升级 +- [ ] 三种 provider 混合登录命中同一 `subject_id` + +--- + +## 九、文档与上线(P1) + +### 27. 文档更新 + +- [ ] 更新设计稿与实现差异说明 +- [ ] 补充迁移执行说明 +- [ ] 补充绑定接口使用说明 + +### 28. 上线准备 + +- [x] 先在测试环境验证迁移与回填 +- [ ] 观察 identity 冲突情况 +- [ ] 评估历史重复 user 是否需要人工治理 +- [ ] 确认回滚策略 + +--- + +## 进度概览 + +| 阶段 | 内容 | 状态 | +|------|------|------| +| 一 | 设计基线与前置条件 | 已完成 | +| 二 | 数据模型与迁移 | 大部分完成 | +| 三 | Identity Service | 大部分完成 | +| 四 | Primary Identity 与资料聚合 | 大部分完成 | +| 五 | 登录主流程切换 | 部分完成 | +| 六 | 绑定接口 | 大部分完成 | +| 七 | 安全与状态控制 | 部分完成 | +| 八 | 测试 | 部分完成 | +| 九 | 文档与上线 | 未开始 | + +--- + +## 实施说明 + +### 优先级说明 + +- **P0**:必须完成,决定绑定能力是否可用 +- **P1**:重要功能,决定绑定能力是否可上线 +- **P2**:后续优化 + +### 当前建议实施顺序 + +1. 先建 `user_auth_identities` 表 +2. 再做 identity service +3. 再切登录主流程 +4. 最后补绑定 / 解绑接口 + +### 关键注意事项 + +1. `users` 表在本方案中是聚合资料结果,不再是唯一身份真相来源 +2. `user_auth_identities.external_key` 是登录命中的主查找键 +3. `username` 应保持稳定,避免资料抖动造成业务副作用 +4. `idtrust > github > phone` 仅用于 primary identity 选择,不代表所有字段都由 primary 覆盖 + +--- + +## 参考文档 + +- [多 Provider 账号统一绑定设计稿](../docs/proposals/MULTI_PROVIDER_ACCOUNT_BINDING_DESIGN.md) +- [Casdoor 多 Provider 身份归一化设计提案](../docs/proposals/casdoor-identity-normalization/README.md)