From ac66d6b33034626ecdb651ce1aabb0d3b7fb0f42 Mon Sep 17 00:00:00 2001 From: Krzysztof Bogacki Date: Wed, 2 Aug 2023 15:32:19 +0200 Subject: [PATCH] feat: redirect to OIDC providers only once in registration flows MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit test(e2e): ensure there is only one OIDC redirect Co-authored-by: Jakub FijaƂkowski --- .../strategy/oidc/strategy_registration.go | 39 ++++++++++++++++ .../oidc/registration/success.spec.ts | 46 +++++++++++++++++++ 2 files changed, 85 insertions(+) diff --git a/selfservice/strategy/oidc/strategy_registration.go b/selfservice/strategy/oidc/strategy_registration.go index 82737df36a9d..293767759fce 100644 --- a/selfservice/strategy/oidc/strategy_registration.go +++ b/selfservice/strategy/oidc/strategy_registration.go @@ -43,6 +43,12 @@ var jsonnetCache, _ = ristretto.NewCache(&ristretto.Config[[]byte, []byte]{ type MetadataType string +type OIDCProviderData struct { + Provider string `json:"provider"` + Tokens *identity.CredentialsOIDCEncryptedTokens `json:"tokens"` + Claims Claims `json:"claims"` +} + type VerifiedAddress struct { Value string `json:"value"` Via identity.VerifiableAddressType `json:"via"` @@ -53,6 +59,8 @@ const ( PublicMetadata MetadataType = "identity.metadata_public" AdminMetadata MetadataType = "identity.metadata_admin" + + InternalContextKeyProviderData = "provider_data" ) func (s *Strategy) RegisterRegistrationRoutes(r *x.RouterPublic) { @@ -216,6 +224,26 @@ func (s *Strategy) Register(w http.ResponseWriter, r *http.Request, f *registrat return errors.WithStack(flow.ErrCompletedByStrategy) } + providerDataKey := flow.PrefixInternalContextKey(s.ID(), InternalContextKeyProviderData) + if oidcProviderData := gjson.GetBytes(f.InternalContext, providerDataKey); oidcProviderData.IsObject() { + var providerData OIDCProviderData + if err = json.Unmarshal([]byte(oidcProviderData.Raw), &providerData); err != nil { + return s.handleError(ctx, w, r, f, pid, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected OIDC provider data in internal context to be an object but got: %w", err))) + } + if pid != providerData.Provider { + return s.handleError(ctx, w, r, f, pid, nil, errors.WithStack(herodot.ErrInternalServerError.WithReasonf("Expected OIDC provider data in internal context to have matching provider but got: %s", providerData.Provider))) + } + _, err = s.processRegistration(ctx, w, r, f, providerData.Tokens, &providerData.Claims, provider, &AuthCodeContainer{ + FlowID: f.ID.String(), + Traits: p.Traits, + TransientPayload: f.TransientPayload, + }) + if err != nil { + return s.handleError(ctx, w, r, f, pid, nil, err) + } + return errors.WithStack(flow.ErrCompletedByStrategy) + } + state, pkce, err := s.GenerateState(ctx, provider, f.ID) if err != nil { return s.handleError(ctx, w, r, f, pid, nil, err) @@ -313,6 +341,13 @@ func (s *Strategy) processRegistration(ctx context.Context, w http.ResponseWrite return nil, nil } + providerDataKey := flow.PrefixInternalContextKey(s.ID(), InternalContextKeyProviderData) + if hasOIDCProviderData := gjson.GetBytes(rf.InternalContext, providerDataKey).IsObject(); !hasOIDCProviderData { + if internalContext, err := sjson.SetBytes(rf.InternalContext, providerDataKey, &OIDCProviderData{Provider: provider.Config().ID, Tokens: token, Claims: *claims}); err == nil { + rf.InternalContext = internalContext + } + } + fetch := fetcher.NewFetcher(fetcher.WithClient(s.d.HTTPClient(ctx)), fetcher.WithCache(jsonnetCache, 60*time.Minute)) jsonnetMapperSnippet, err := fetch.FetchContext(ctx, provider.Config().Mapper) if err != nil { @@ -351,6 +386,10 @@ func (s *Strategy) processRegistration(ctx context.Context, w http.ResponseWrite return nil, s.handleError(ctx, w, r, rf, provider.Config().ID, i.Traits, err) } + if internalContext, err := sjson.DeleteBytes(rf.InternalContext, providerDataKey); err == nil { + rf.InternalContext = internalContext + } + return nil, nil } diff --git a/test/e2e/cypress/integration/profiles/oidc/registration/success.spec.ts b/test/e2e/cypress/integration/profiles/oidc/registration/success.spec.ts index 132845623f31..6fa26dba2779 100644 --- a/test/e2e/cypress/integration/profiles/oidc/registration/success.spec.ts +++ b/test/e2e/cypress/integration/profiles/oidc/registration/success.spec.ts @@ -103,6 +103,52 @@ context("Social Sign Up Successes", () => { }) }) + it("should redirect to oidc provider only once", () => { + const email = gen.email() + + cy.registerOidc({ + app, + email, + expectSession: false, + route: registration, + }) + + cy.get(appPrefix(app) + '[name="traits.email"]').should( + "have.value", + email, + ) + + cy.get('[name="traits.consent"][type="checkbox"]') + .siblings("label") + .click() + cy.get('[name="traits.newsletter"][type="checkbox"]') + .siblings("label") + .click() + cy.get('[name="traits.website"]').type(website) + + cy.intercept("GET", "http://*/oauth2/auth*", { + forceNetworkError: true, + }).as("additionalRedirect") + + cy.triggerOidc(app) + + cy.get("@additionalRedirect").should("not.exist") + + cy.location("pathname").should((loc) => { + expect(loc).to.be.oneOf([ + "/welcome", + "/", + "/sessions", + "/verification", + ]) + }) + + cy.getSession().should((session) => { + shouldSession(email)(session) + expect(session.identity.traits.consent).to.equal(true) + }) + }) + it("should pass transient_payload to webhook", () => { testFlowWebhook( (hooks) =>