From 3d6656f9c941d1b80150e613fb389c10500e47c3 Mon Sep 17 00:00:00 2001 From: Tom Fenech Date: Fri, 14 Jun 2024 06:31:14 +0200 Subject: [PATCH] refactor: use functional options pattern to reduce passing nil --- selfservice/flow/login/handler.go | 2 +- selfservice/flow/login/hook.go | 17 +++++++++++++++-- selfservice/flow/login/hook_test.go | 2 +- selfservice/strategy/oidc/strategy_login.go | 2 +- 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/selfservice/flow/login/handler.go b/selfservice/flow/login/handler.go index db1645ec5ddf..fe76240b5221 100644 --- a/selfservice/flow/login/handler.go +++ b/selfservice/flow/login/handler.go @@ -888,7 +888,7 @@ continueLogin: return } - if err := h.d.LoginHookExecutor().PostLoginHook(w, r, group, f, i, sess, nil, ""); err != nil { + if err := h.d.LoginHookExecutor().PostLoginHook(w, r, group, f, i, sess, ""); err != nil { if errors.Is(err, ErrAddressNotVerified) { h.d.LoginFlowErrorHandler().WriteFlowError(w, r, f, node.DefaultGroup, errors.WithStack(schema.NewAddressNotVerifiedError())) return diff --git a/selfservice/flow/login/hook.go b/selfservice/flow/login/hook.go index 73b713979fc2..f3cc43c1941a 100644 --- a/selfservice/flow/login/hook.go +++ b/selfservice/flow/login/hook.go @@ -64,6 +64,7 @@ type ( } HookExecutor struct { d executorDependencies + c *claims.Claims } HookExecutorProvider interface { LoginHookExecutor() *HookExecutor @@ -120,6 +121,14 @@ func (e *HookExecutor) handleLoginError(_ http.ResponseWriter, r *http.Request, return flowError } +type PostLoginHookOpt func(*HookExecutor) + +func WithClaims(c *claims.Claims) PostLoginHookOpt { + return func(h *HookExecutor) { + h.c = c + } +} + func (e *HookExecutor) PostLoginHook( w http.ResponseWriter, r *http.Request, @@ -127,8 +136,8 @@ func (e *HookExecutor) PostLoginHook( f *Flow, i *identity.Identity, s *session.Session, - c *claims.Claims, provider string, + opts ...PostLoginHookOpt, ) (err error) { ctx := r.Context() ctx, span := e.d.Tracer(ctx).Tracer().Start(ctx, "HookExecutor.PostLoginHook") @@ -169,13 +178,17 @@ func (e *HookExecutor) PostLoginHook( classified := s s = s.Declassified() + for _, o := range opts { + o(e) + } + e.d.Logger(). WithRequest(r). WithField("identity_id", i.ID). WithField("flow_method", f.Active). Debug("Running ExecuteLoginPostHook.") for k, executor := range e.d.PostLoginHooks(r.Context(), f.Active) { - if err := executor.ExecuteLoginPostHook(w, r, g, f, s, c); err != nil { + if err := executor.ExecuteLoginPostHook(w, r, g, f, s, e.c); err != nil { if errors.Is(err, ErrHookAbortFlow) { e.d.Logger(). WithRequest(r). diff --git a/selfservice/flow/login/hook_test.go b/selfservice/flow/login/hook_test.go index 351f7bd217fe..7d7f8e158174 100644 --- a/selfservice/flow/login/hook_test.go +++ b/selfservice/flow/login/hook_test.go @@ -72,7 +72,7 @@ func TestLoginExecutor(t *testing.T) { } testhelpers.SelfServiceHookLoginErrorHandler(t, w, r, - reg.LoginHookExecutor().PostLoginHook(w, r, strategy.ToUiNodeGroup(), loginFlow, useIdentity, sess, nil, "")) + reg.LoginHookExecutor().PostLoginHook(w, r, strategy.ToUiNodeGroup(), loginFlow, useIdentity, sess, "")) }) ts := httptest.NewServer(router) diff --git a/selfservice/strategy/oidc/strategy_login.go b/selfservice/strategy/oidc/strategy_login.go index 4e50d8cbd061..1fe7de03291e 100644 --- a/selfservice/strategy/oidc/strategy_login.go +++ b/selfservice/strategy/oidc/strategy_login.go @@ -172,7 +172,7 @@ func (s *Strategy) processLogin(ctx context.Context, w http.ResponseWriter, r *h sess.CompletedLoginForWithProvider(s.ID(), identity.AuthenticatorAssuranceLevel1, provider.Config().ID, provider.Config().OrganizationID) for _, c := range oidcCredentials.Providers { if c.Subject == claims.Subject && c.Provider == provider.Config().ID { - if err = s.d.LoginHookExecutor().PostLoginHook(w, r, node.OpenIDConnectGroup, loginFlow, i, sess, claims, provider.Config().ID); err != nil { + if err = s.d.LoginHookExecutor().PostLoginHook(w, r, node.OpenIDConnectGroup, loginFlow, i, sess, provider.Config().ID, login.WithClaims(claims)); err != nil { return nil, s.handleError(ctx, w, r, loginFlow, provider.Config().ID, nil, err) } return nil, nil