From 7f98ea498b308aa900f6b94422fcddaeb365f799 Mon Sep 17 00:00:00 2001 From: NHAS Date: Mon, 22 Jan 2024 23:04:21 +1300 Subject: [PATCH] Start removing config.Values() --- internal/data/config.go | 24 ++++++++++++++-- internal/router/bpf.go | 34 ++++++++++++++++++----- internal/webserver/authenticators/init.go | 11 ++++++-- internal/webserver/authenticators/oidc.go | 17 ++++++++---- internal/webserver/authenticators/pam.go | 14 ++++++---- internal/webserver/web.go | 15 ++++++++-- ui/check_updates.go | 4 ++- 7 files changed, 92 insertions(+), 27 deletions(-) diff --git a/internal/data/config.go b/internal/data/config.go index fdde324c..2c64c8b4 100644 --- a/internal/data/config.go +++ b/internal/data/config.go @@ -16,7 +16,7 @@ type OIDC struct { IssuerURL string ClientSecret string ClientID string - GroupsClaimName string `json:",omitempty"` + GroupsClaimName string } type PAM struct { @@ -140,8 +140,26 @@ func SetWireguardConfigName(wgConfig string) error { return err } -func GetWireguardConfigName() (string, error) { - return getGeneric(defaultWGFileNameKey) +func GetWireguardConfigName() string { + k, err := getGeneric(defaultWGFileNameKey) + if err != nil { + return "wg0.conf" + } + + if k == "" { + return "wg0.conf" + } + + return k +} + +func SetDefaultMfaMethod(method string) error { + _, err := etcd.Put(context.Background(), defaultMFAMethodKey, method) + return err +} + +func GetDefaultMfaMethod() (string, error) { + return getGeneric(defaultMFAMethodKey) } func SetAuthenticationMethods(methods []string) error { diff --git a/internal/router/bpf.go b/internal/router/bpf.go index 28d5f85a..332784b9 100644 --- a/internal/router/bpf.go +++ b/internal/router/bpf.go @@ -102,8 +102,13 @@ func loadXDP() error { return fmt.Errorf("loading objects: %s", err) } - value := uint64(config.Values().SessionInactivityTimeoutMinutes) * 60000000000 - if config.Values().SessionInactivityTimeoutMinutes < 0 { + sessionInactivityTimeoutMinutes, err := data.GetSessionInactivityTimeoutMinutes() + if err != nil { + return err + } + + value := uint64(sessionInactivityTimeoutMinutes) * 60000000000 + if sessionInactivityTimeoutMinutes < 0 { value = math.MaxUint64 } @@ -218,11 +223,16 @@ func isAuthed(address string) bool { return false } + inactivityTimeoutMinutes, err := data.GetSessionInactivityTimeoutMinutes() + if err != nil { + return false + } + currentTime := GetTimeStamp() sessionValid := (deviceStruct.sessionExpiry > currentTime || deviceStruct.sessionExpiry == math.MaxUint64) - sessionActive := ((currentTime-deviceStruct.lastPacketTime) < uint64(config.Values().SessionInactivityTimeoutMinutes)*60000000000 || config.Values().SessionInactivityTimeoutMinutes < 0) + sessionActive := ((currentTime-deviceStruct.lastPacketTime) < uint64(inactivityTimeoutMinutes)*60000000000 || inactivityTimeoutMinutes < 0) return isAccountLocked == 0 && sessionValid && sessionActive } @@ -536,8 +546,13 @@ func RefreshConfiguration() []error { return []error{err} } - value := uint64(config.Values().SessionInactivityTimeoutMinutes) * 60000000000 - if config.Values().SessionInactivityTimeoutMinutes < 0 { + inactivityTimeoutMinutes, err := data.GetSessionInactivityTimeoutMinutes() + if err != nil { + return []error{err} + } + + value := uint64(inactivityTimeoutMinutes) * 60000000000 + if inactivityTimeoutMinutes < 0 { value = math.MaxUint64 } @@ -575,8 +590,13 @@ func SetAuthorized(internalAddress, username string) error { var deviceStruct fwentry deviceStruct.lastPacketTime = GetTimeStamp() - deviceStruct.sessionExpiry = GetTimeStamp() + uint64(config.Values().MaxSessionLifetimeMinutes)*60000000000 - if config.Values().MaxSessionLifetimeMinutes < 0 { + maxSession, err := data.GetSessionLifetimeMinutes() + if err != nil { + return err + } + + deviceStruct.sessionExpiry = GetTimeStamp() + uint64(maxSession)*60000000000 + if maxSession < 0 { deviceStruct.sessionExpiry = math.MaxUint64 // If the session timeout is disabled, (<0) then we set to max value } diff --git a/internal/webserver/authenticators/init.go b/internal/webserver/authenticators/init.go index 4298a79f..a31bfe1d 100644 --- a/internal/webserver/authenticators/init.go +++ b/internal/webserver/authenticators/init.go @@ -6,7 +6,7 @@ import ( "net/http" "strings" - "github.com/NHAS/wag/internal/config" + "github.com/NHAS/wag/internal/data" ) // from: https://github.com/duo-labs/webauthn.io/blob/3f03b482d21476f6b9fb82b2bf1458ff61a61d41/server/response.go#L15 @@ -25,11 +25,16 @@ func resultMessage(err error) (string, int) { return "Success", http.StatusOK } + mail, err := data.GetHelpMail() + if err != nil { + mail = "Server Error" + } + msg := "Validation failed" if strings.Contains(err.Error(), "account is locked") { - msg = "Account is locked contact: " + config.Values().HelpMail + msg = "Account is locked contact: " + mail } else if strings.Contains(err.Error(), "device is locked") { - msg = "Device is locked contact: " + config.Values().HelpMail + msg = "Device is locked contact: " + mail } return msg, http.StatusBadRequest } diff --git a/internal/webserver/authenticators/oidc.go b/internal/webserver/authenticators/oidc.go index 848f6f80..1e1ba5c8 100644 --- a/internal/webserver/authenticators/oidc.go +++ b/internal/webserver/authenticators/oidc.go @@ -31,6 +31,7 @@ type issuer struct { type Oidc struct { provider rp.RelyingParty + details data.OIDC } func (o Oidc) state() string { @@ -72,12 +73,12 @@ func (o *Oidc) Init() error { log.Println("Connecting to OIDC provider") - oidc, err := data.GetOidc() + o.details, err = data.GetOidc() if err != nil { return err } - o.provider, err = rp.NewRelyingPartyOIDC(oidc.IssuerURL, oidc.ClientID, oidc.ClientSecret, u.String(), []string{"openid"}, options...) + o.provider, err = rp.NewRelyingPartyOIDC(o.details.IssuerURL, o.details.ClientID, o.details.ClientSecret, u.String(), []string{"openid"}, options...) if err != nil { return err } @@ -156,7 +157,7 @@ func (o *Oidc) AuthorisationAPI(w http.ResponseWriter, r *http.Request) { marshalUserinfo := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty, info oidc.UserInfo) { - groupsIntf, ok := tokens.IDTokenClaims.GetClaim(config.Values().Authenticators.OIDC.GroupsClaimName).([]interface{}) + groupsIntf, ok := tokens.IDTokenClaims.GetClaim(o.details.GroupsClaimName).([]interface{}) if !ok { log.Println("Error, could not convert group claim to []string, probably error in oidc idP configuration") @@ -209,8 +210,14 @@ func (o *Oidc) AuthorisationAPI(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) - err := resources.Render("oidc_error.html", w, &resources.Msg{ - HelpMail: config.Values().HelpMail, + mail, err := data.GetHelpMail() + if err != nil { + log.Println("Error getting help mail: ", err) + http.Error(w, "Server Error", http.StatusInternalServerError) + return + } + err = resources.Render("oidc_error.html", w, &resources.Msg{ + HelpMail: mail, NumMethods: NumberOfMethods(), Message: msg, URL: rp.GetEndSessionEndpoint(), diff --git a/internal/webserver/authenticators/pam.go b/internal/webserver/authenticators/pam.go index 179e203e..763475fd 100644 --- a/internal/webserver/authenticators/pam.go +++ b/internal/webserver/authenticators/pam.go @@ -135,16 +135,20 @@ func (t *Pam) AuthoriseFunc(w http.ResponseWriter, r *http.Request) types.Authen passwd := r.FormValue("password") - serviceName := config.Values().Authenticators.PAM.ServiceName + pamDetails, err := data.GetPAM() + if err != nil { + http.Error(w, "Unable to get pam details: "+err.Error(), 500) + return err + } - pamRulesFile := "config /etc/pam.d/" + serviceName - if serviceName == "" { - serviceName = "login" + pamRulesFile := "config /etc/pam.d/" + pamDetails.ServiceName + if pamDetails.ServiceName == "" { + pamDetails.ServiceName = "login" pamRulesFile = "default PAM /etc/pam.d/login" } log.Println(username, "attempting to authorise with PAM (using ", pamRulesFile, ")") - t, err := pam.StartFunc(serviceName, username, func(s pam.Style, msg string) (string, error) { + t, err := pam.StartFunc(pamDetails.ServiceName, username, func(s pam.Style, msg string) (string, error) { switch s { case pam.PromptEchoOff: diff --git a/internal/webserver/web.go b/internal/webserver/web.go index ba435447..4ce88bca 100644 --- a/internal/webserver/web.go +++ b/internal/webserver/web.go @@ -262,7 +262,10 @@ func registerMFA(w http.ResponseWriter, r *http.Request) { method := r.URL.Query().Get("method") if method == "" { - method = config.Values().Authenticators.DefaultMethod + method, err = data.GetDefaultMfaMethod() + if err != nil { + method = "" + } } if method == "" || method == "select" { @@ -465,7 +468,12 @@ func registerDevice(w http.ResponseWriter, r *http.Request) { return } - dnsWithOutSubnet := config.Values().Wireguard.DNS + dnsWithOutSubnet, err := data.GetDNS() + if err != nil { + log.Println(username, remoteAddr, "unable get dns: ", err) + http.Error(w, "Server Error", 500) + return + } for i := 0; i < len(dnsWithOutSubnet); i++ { dnsWithOutSubnet[i] = strings.TrimSuffix(dnsWithOutSubnet[i], "/32") @@ -543,7 +551,8 @@ func registerDevice(w http.ResponseWriter, r *http.Request) { } } else { - w.Header().Set("Content-Disposition", "attachment; filename="+config.Values().DownloadConfigFileName) + + w.Header().Set("Content-Disposition", "attachment; filename="+data.GetWireguardConfigName()) err = resources.RenderWithFuncs("interface.tmpl", w, &wireguardInterface, template.FuncMap{ "StringsJoin": strings.Join, diff --git a/ui/check_updates.go b/ui/check_updates.go index b2abdfc5..c1615c2a 100644 --- a/ui/check_updates.go +++ b/ui/check_updates.go @@ -7,6 +7,7 @@ import ( "time" "github.com/NHAS/wag/internal/config" + "github.com/NHAS/wag/internal/data" ) type githubResponse struct { @@ -32,7 +33,8 @@ var ( func getUpdate() Update { - if !config.Values().CheckUpdates { + should, err := data.CheckUpdates() + if err != nil || !should { return Update{} }