From e94435967a1439c4e7533e53923f1b19c66fddad Mon Sep 17 00:00:00 2001 From: nhas Date: Tue, 26 Nov 2024 21:58:51 +1300 Subject: [PATCH] Begin work on dynamic https/http --- adminui/ui_webserver.go | 271 ++++++++++++----------------- commands/start.go | 7 +- go.mod | 6 + go.sum | 13 ++ internal/autotls/certmagic.go | 318 ++++++++++++++++++++++++++++++++++ internal/autotls/dns-01.go | 106 ------------ internal/autotls/http-01.go | 73 -------- internal/autotls/static.go | 1 - internal/config/config.go | 16 +- internal/data/config.go | 76 ++++++-- internal/data/init.go | 55 +++++- internal/data/tls.go | 241 ++++++++++++++++++-------- internal/enrolment/web.go | 97 +---------- internal/mfaportal/web.go | 96 +--------- internal/router/firewall.go | 2 - internal/router/iptables.go | 23 +-- 16 files changed, 752 insertions(+), 649 deletions(-) create mode 100644 internal/autotls/certmagic.go delete mode 100644 internal/autotls/dns-01.go delete mode 100644 internal/autotls/http-01.go delete mode 100644 internal/autotls/static.go diff --git a/adminui/ui_webserver.go b/adminui/ui_webserver.go index 87496435..0810665b 100644 --- a/adminui/ui_webserver.go +++ b/adminui/ui_webserver.go @@ -2,7 +2,6 @@ package adminui import ( "context" - "crypto/tls" "encoding/json" "errors" "fmt" @@ -16,6 +15,7 @@ import ( "github.com/NHAS/session" "github.com/NHAS/wag/adminui/frontend" + "github.com/NHAS/wag/internal/autotls" "github.com/NHAS/wag/internal/config" "github.com/NHAS/wag/internal/data" "github.com/NHAS/wag/internal/router" @@ -37,8 +37,6 @@ type AdminUI struct { logQueue *queue.Queue[[]byte] - https, http *http.Server - listenerEvents struct { clusterHealth string } @@ -160,176 +158,125 @@ func New(firewall *router.Firewall, errs chan<- error) (ui *AdminUI, err error) log.SetOutput(io.MultiWriter(os.Stdout, adminUI.logQueue)) - //https://blog.cloudflare.com/exposing-go-on-the-internet/ - tlsConfig := &tls.Config{ - // Only use curves which have assembly implementations - CurvePreferences: []tls.CurveID{ - tls.CurveP256, - tls.X25519, // Go 1.8 only - }, - MinVersion: tls.VersionTLS12, - CipherSuites: []uint16{ - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, // Go 1.8 only - tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, // Go 1.8 only - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - }, - } - - go func() { - - protectedRoutes := http.NewServeMux() - allRoutes := http.NewServeMux() + protectedRoutes := http.NewServeMux() + allRoutes := http.NewServeMux() - allRoutes.HandleFunc("/", frontend.Index) - allRoutes.HandleFunc("GET /index.html", frontend.Index) + allRoutes.HandleFunc("/", frontend.Index) + allRoutes.HandleFunc("GET /index.html", frontend.Index) - allRoutes.HandleFunc("GET /favicon.ico", frontend.Favicon) - allRoutes.HandleFunc("GET /logo.png", frontend.Logo) - allRoutes.HandleFunc("GET /assets/", frontend.Assets) + allRoutes.HandleFunc("GET /favicon.ico", frontend.Favicon) + allRoutes.HandleFunc("GET /logo.png", frontend.Logo) + allRoutes.HandleFunc("GET /assets/", frontend.Assets) - allRoutes.HandleFunc("POST /api/login", adminUI.doLogin) - allRoutes.HandleFunc("GET /api/config", adminUI.uiConfig) - allRoutes.HandleFunc("POST /api/refresh", adminUI.doAuthRefresh) + allRoutes.HandleFunc("POST /api/login", adminUI.doLogin) + allRoutes.HandleFunc("GET /api/config", adminUI.uiConfig) + allRoutes.HandleFunc("POST /api/refresh", adminUI.doAuthRefresh) - if config.Values.ManagementUI.OIDC.Enabled { - allRoutes.HandleFunc("GET /login/oidc", func(w http.ResponseWriter, r *http.Request) { - rp.AuthURLHandler(func() string { - r, _ := utils.GenerateRandomHex(32) - return r - }, adminUI.oidcProvider)(w, r) - }) + if config.Values.ManagementUI.OIDC.Enabled { + allRoutes.HandleFunc("GET /login/oidc", func(w http.ResponseWriter, r *http.Request) { + rp.AuthURLHandler(func() string { + r, _ := utils.GenerateRandomHex(32) + return r + }, adminUI.oidcProvider)(w, r) + }) - allRoutes.HandleFunc("GET /login/oidc/callback", adminUI.oidcCallback) - } + allRoutes.HandleFunc("GET /login/oidc/callback", adminUI.oidcCallback) + } - allRoutes.Handle("/api/", adminUI.sessionManager.AuthorisationChecks(protectedRoutes, - func(w http.ResponseWriter, r *http.Request) { - http.Error(w, "Unauthorized", http.StatusUnauthorized) - }, - func(w http.ResponseWriter, r *http.Request, dAdmin data.AdminUserDTO) bool { - - key, adminDetails := adminUI.sessionManager.GetSessionFromRequest(r) - if adminDetails != nil { - if adminDetails.Type == "" || adminDetails.Type == data.LocalUser { - d, err := data.GetAdminUser(dAdmin.Username) - if err != nil { - adminUI.sessionManager.DeleteSession(w, r) - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return false - } - - adminUI.sessionManager.UpdateSession(key, d) + allRoutes.Handle("/api/", adminUI.sessionManager.AuthorisationChecks(protectedRoutes, + func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + }, + func(w http.ResponseWriter, r *http.Request, dAdmin data.AdminUserDTO) bool { + + key, adminDetails := adminUI.sessionManager.GetSessionFromRequest(r) + if adminDetails != nil { + if adminDetails.Type == "" || adminDetails.Type == data.LocalUser { + d, err := data.GetAdminUser(dAdmin.Username) + if err != nil { + adminUI.sessionManager.DeleteSession(w, r) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return false } - // Otherwise the admin type is OIDC, and will no be in the local db + adminUI.sessionManager.UpdateSession(key, d) } - return true - })) - - protectedRoutes.HandleFunc("GET /api/info", adminUI.serverInfo) - protectedRoutes.HandleFunc("GET /api/console_log", adminUI.consoleLog) - - protectedRoutes.HandleFunc("GET /api/cluster/members", adminUI.members) - protectedRoutes.HandleFunc("POST /api/cluster/members", adminUI.newNode) - protectedRoutes.HandleFunc("PUT /api/cluster/members", adminUI.nodeControl) - - protectedRoutes.HandleFunc("GET /api/cluster/events", adminUI.getClusterEvents) - protectedRoutes.HandleFunc("PUT /api/cluster/events", adminUI.clusterEventsAcknowledge) - - protectedRoutes.HandleFunc("GET /api/diag/wg", adminUI.wgDiagnositicsData) - protectedRoutes.HandleFunc("GET /api/diag/firewall", adminUI.getFirewallState) - protectedRoutes.HandleFunc("POST /api/diag/check", adminUI.firewallCheckTest) - protectedRoutes.HandleFunc("POST /api/diag/acls", adminUI.aclsTest) - protectedRoutes.HandleFunc("POST /api/diag/notifications", adminUI.testNotifications) - - protectedRoutes.HandleFunc("GET /api/management/users", adminUI.getUsers) - protectedRoutes.HandleFunc("PUT /api/management/users", adminUI.editUser) - protectedRoutes.HandleFunc("DELETE /api/management/users", adminUI.removeUsers) - protectedRoutes.HandleFunc("GET /api/management/admin_users", adminUI.adminUsersData) - - protectedRoutes.HandleFunc("GET /api/management/devices", adminUI.getAllDevices) - protectedRoutes.HandleFunc("PUT /api/management/devices", adminUI.editDevice) - protectedRoutes.HandleFunc("DELETE /api/management/devices", adminUI.deleteDevice) - - protectedRoutes.HandleFunc("GET /api/management/registration_tokens", adminUI.getAllRegistrationTokens) - protectedRoutes.HandleFunc("POST /api/management/registration_tokens", adminUI.createRegistrationToken) - protectedRoutes.HandleFunc("DELETE /api/management/registration_tokens", adminUI.deleteRegistrationTokens) - - protectedRoutes.HandleFunc("GET /api/policy/rules", adminUI.getAllPolicies) - protectedRoutes.HandleFunc("PUT /api/policy/rules", adminUI.editPolicy) - protectedRoutes.HandleFunc("POST /api/policy/rules", adminUI.createPolicy) - protectedRoutes.HandleFunc("DELETE /api/policy/rules", adminUI.deletePolices) - - protectedRoutes.HandleFunc("GET /api/policy/groups", adminUI.getAllGroups) - protectedRoutes.HandleFunc("PUT /api/policy/groups", adminUI.editGroup) - protectedRoutes.HandleFunc("POST /api/policy/groups", adminUI.createGroup) - protectedRoutes.HandleFunc("DELETE /api/policy/groups", adminUI.deleteGroups) - - protectedRoutes.HandleFunc("PUT /api/settings/general", adminUI.updateGeneralSettings) - protectedRoutes.HandleFunc("PUT /api/settings/login", adminUI.updateLoginSettings) - protectedRoutes.HandleFunc("GET /api/settings/general", adminUI.getGeneralSettings) - protectedRoutes.HandleFunc("GET /api/settings/login", adminUI.getLoginSettings) - protectedRoutes.HandleFunc("GET /api/settings/all_mfa_methods", adminUI.getAllMfaMethods) - - notifications := make(chan NotificationDTO, 1) - protectedRoutes.HandleFunc("GET /api/notifications", adminUI.notificationsWS(notifications)) - data.RegisterEventListener(data.NodeErrors, true, adminUI.receiveErrorNotifications(notifications)) - go adminUI.monitorClusterMembers(notifications) - - should, err := data.ShouldCheckUpdates() - if err == nil && should { - adminUI.startUpdateChecker(notifications) - } - - protectedRoutes.HandleFunc("PUT /api/change_password", adminUI.changePassword) - - protectedRoutes.HandleFunc("GET /api/logout", func(w http.ResponseWriter, r *http.Request) { - adminUI.sessionManager.DeleteSession(w, r) - w.WriteHeader(http.StatusNoContent) - }) - - protectedRoutes.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - http.NotFound(w, r) - }) - - if data.SupportsTLS(data.ManagementUI) { + // Otherwise the admin type is OIDC, and will no be in the local db + } - go func() { + return true + })) + + protectedRoutes.HandleFunc("GET /api/info", adminUI.serverInfo) + protectedRoutes.HandleFunc("GET /api/console_log", adminUI.consoleLog) + + protectedRoutes.HandleFunc("GET /api/cluster/members", adminUI.members) + protectedRoutes.HandleFunc("POST /api/cluster/members", adminUI.newNode) + protectedRoutes.HandleFunc("PUT /api/cluster/members", adminUI.nodeControl) + + protectedRoutes.HandleFunc("GET /api/cluster/events", adminUI.getClusterEvents) + protectedRoutes.HandleFunc("PUT /api/cluster/events", adminUI.clusterEventsAcknowledge) + + protectedRoutes.HandleFunc("GET /api/diag/wg", adminUI.wgDiagnositicsData) + protectedRoutes.HandleFunc("GET /api/diag/firewall", adminUI.getFirewallState) + protectedRoutes.HandleFunc("POST /api/diag/check", adminUI.firewallCheckTest) + protectedRoutes.HandleFunc("POST /api/diag/acls", adminUI.aclsTest) + protectedRoutes.HandleFunc("POST /api/diag/notifications", adminUI.testNotifications) + + protectedRoutes.HandleFunc("GET /api/management/users", adminUI.getUsers) + protectedRoutes.HandleFunc("PUT /api/management/users", adminUI.editUser) + protectedRoutes.HandleFunc("DELETE /api/management/users", adminUI.removeUsers) + protectedRoutes.HandleFunc("GET /api/management/admin_users", adminUI.adminUsersData) + + protectedRoutes.HandleFunc("GET /api/management/devices", adminUI.getAllDevices) + protectedRoutes.HandleFunc("PUT /api/management/devices", adminUI.editDevice) + protectedRoutes.HandleFunc("DELETE /api/management/devices", adminUI.deleteDevice) + + protectedRoutes.HandleFunc("GET /api/management/registration_tokens", adminUI.getAllRegistrationTokens) + protectedRoutes.HandleFunc("POST /api/management/registration_tokens", adminUI.createRegistrationToken) + protectedRoutes.HandleFunc("DELETE /api/management/registration_tokens", adminUI.deleteRegistrationTokens) + + protectedRoutes.HandleFunc("GET /api/policy/rules", adminUI.getAllPolicies) + protectedRoutes.HandleFunc("PUT /api/policy/rules", adminUI.editPolicy) + protectedRoutes.HandleFunc("POST /api/policy/rules", adminUI.createPolicy) + protectedRoutes.HandleFunc("DELETE /api/policy/rules", adminUI.deletePolices) + + protectedRoutes.HandleFunc("GET /api/policy/groups", adminUI.getAllGroups) + protectedRoutes.HandleFunc("PUT /api/policy/groups", adminUI.editGroup) + protectedRoutes.HandleFunc("POST /api/policy/groups", adminUI.createGroup) + protectedRoutes.HandleFunc("DELETE /api/policy/groups", adminUI.deleteGroups) + + protectedRoutes.HandleFunc("PUT /api/settings/general", adminUI.updateGeneralSettings) + protectedRoutes.HandleFunc("PUT /api/settings/login", adminUI.updateLoginSettings) + protectedRoutes.HandleFunc("GET /api/settings/general", adminUI.getGeneralSettings) + protectedRoutes.HandleFunc("GET /api/settings/login", adminUI.getLoginSettings) + protectedRoutes.HandleFunc("GET /api/settings/all_mfa_methods", adminUI.getAllMfaMethods) + + notifications := make(chan NotificationDTO, 1) + protectedRoutes.HandleFunc("GET /api/notifications", adminUI.notificationsWS(notifications)) + data.RegisterEventListener(data.NodeErrors, true, adminUI.receiveErrorNotifications(notifications)) + go adminUI.monitorClusterMembers(notifications) + + should, err := data.ShouldCheckUpdates() + if err == nil && should { + adminUI.startUpdateChecker(notifications) + } - adminUI.https = &http.Server{ - Addr: config.Values.ManagementUI.ListenAddress, - ReadTimeout: 5 * time.Second, - WriteTimeout: 10 * time.Second, - IdleTimeout: 120 * time.Second, - TLSConfig: tlsConfig, - Handler: utils.SetSecurityHeaders(allRoutes), - } + protectedRoutes.HandleFunc("PUT /api/change_password", adminUI.changePassword) - if err := adminUI.https.ListenAndServeTLS("", ""); err != nil && !errors.Is(err, http.ErrServerClosed) { - errs <- fmt.Errorf("TLS management listener failed: %v", err) - } + protectedRoutes.HandleFunc("GET /api/logout", func(w http.ResponseWriter, r *http.Request) { + adminUI.sessionManager.DeleteSession(w, r) + w.WriteHeader(http.StatusNoContent) + }) - }() - } else { - go func() { - adminUI.http = &http.Server{ - Addr: config.Values.ManagementUI.ListenAddress, - ReadTimeout: 5 * time.Second, - WriteTimeout: 10 * time.Second, - IdleTimeout: 120 * time.Second, - Handler: utils.SetSecurityHeaders(allRoutes), - } - if err := adminUI.http.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - errs <- fmt.Errorf("webserver management listener failed: %v", adminUI.http.ListenAndServe()) - } + protectedRoutes.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + }) - }() - } - }() + if err := autotls.Do.DynamicListener(data.Management, utils.SetSecurityHeaders(allRoutes)); err != nil { + return nil, err + } log.Println("[ADMINUI] Started Managemnt UI listening:", config.Values.ManagementUI.ListenAddress) @@ -471,13 +418,7 @@ func (au *AdminUI) oidcCallback(w http.ResponseWriter, r *http.Request) { func (au *AdminUI) Close() { - if au.http != nil { - au.http.Close() - } - - if au.https != nil { - au.https.Close() - } + autotls.Do.Close(data.Management) if config.Values.ManagementUI.Enabled { log.Println("Stopped Management UI") diff --git a/commands/start.go b/commands/start.go index 915f50b5..7c78ac90 100644 --- a/commands/start.go +++ b/commands/start.go @@ -12,6 +12,7 @@ import ( "syscall" "github.com/NHAS/wag/adminui" + "github.com/NHAS/wag/internal/autotls" "github.com/NHAS/wag/internal/config" "github.com/NHAS/wag/internal/data" "github.com/NHAS/wag/internal/enrolment" @@ -73,9 +74,13 @@ func (g *start) Check() error { err := data.Load(config.Values.DatabaseLocation, g.clusterJoinToken, false) if err != nil { - return fmt.Errorf("cannot load database: %v", err) + return fmt.Errorf("cannot load database: %w", err) } + err = autotls.Initialise() + if err != nil { + return fmt.Errorf("failed to initialise auto tls module: %w", err) + } return nil } diff --git a/go.mod b/go.mod index dfb3bcdc..30fabb35 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,8 @@ require ( require ( github.com/beorn7/perks v1.0.1 // indirect github.com/bits-and-blooms/bitset v1.15.0 // indirect + github.com/caddyserver/certmagic v0.21.4 // indirect + github.com/caddyserver/zerossl v0.1.3 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudflare/cloudflare-go v0.108.0 // indirect @@ -68,9 +70,12 @@ require ( github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.17.11 // indirect + github.com/klauspost/cpuid/v2 v2.2.8 // indirect github.com/leodido/go-urn v1.4.0 // indirect + github.com/libdns/libdns v0.2.2 // indirect github.com/mdlayher/genetlink v1.3.2 // indirect github.com/mdlayher/socket v0.5.1 // indirect + github.com/mholt/acmez/v2 v2.0.3 // indirect github.com/miekg/dns v1.1.62 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect @@ -87,6 +92,7 @@ require ( github.com/tmc/grpc-websocket-proxy v0.0.0-20220101234140-673ab2c3ae75 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/xiang90/probing v0.0.0-20221125231312-a49e3df8f510 // indirect + github.com/zeebo/blake3 v0.2.4 // indirect github.com/zitadel/logging v0.6.1 // indirect github.com/zitadel/schema v1.3.0 // indirect go.etcd.io/bbolt v1.3.11 // indirect diff --git a/go.sum b/go.sum index 1a245652..a2fb2190 100644 --- a/go.sum +++ b/go.sum @@ -20,6 +20,10 @@ github.com/bmatcuk/doublestar/v4 v4.7.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTS github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/boombuler/barcode v1.0.2 h1:79yrbttoZrLGkL/oOI8hBrUKucwOL0oOjUgEguGMcJ4= github.com/boombuler/barcode v1.0.2/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= +github.com/caddyserver/certmagic v0.21.4 h1:e7VobB8rffHv8ZZpSiZtEwnLDHUwLVYLWzWSa1FfKI0= +github.com/caddyserver/certmagic v0.21.4/go.mod h1:swUXjQ1T9ZtMv95qj7/InJvWLXURU85r+CfG0T+ZbDE= +github.com/caddyserver/zerossl v0.1.3 h1:onS+pxp3M8HnHpN5MMbOMyNjmTheJyWRaZYwn+YTAyA= +github.com/caddyserver/zerossl v0.1.3/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= @@ -141,6 +145,8 @@ github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= +github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM= +github.com/klauspost/cpuid/v2 v2.2.8/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -153,6 +159,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s= +github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ= github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= @@ -161,6 +169,8 @@ github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/ github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= +github.com/mholt/acmez/v2 v2.0.3 h1:CgDBlEwg3QBp6s45tPQmFIBrkRIkBT4rW4orMM6p4sw= +github.com/mholt/acmez/v2 v2.0.3/go.mod h1:pQ1ysaDeGrIMvJ9dfJMk5kJNkn7L2sb3UhyrX6Q91cw= github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ= github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= @@ -225,6 +235,8 @@ github.com/xiang90/probing v0.0.0-20221125231312-a49e3df8f510 h1:S2dVYn90KE98chq github.com/xiang90/probing v0.0.0-20221125231312-a49e3df8f510/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/zeebo/blake3 v0.2.4 h1:KYQPkhpRtcqh0ssGYcKLG1JYvddkEA8QwCM/yBqhaZI= +github.com/zeebo/blake3 v0.2.4/go.mod h1:7eeQ6d2iXWRGF6npfaxl2CU+xy2Fjo2gxeyZGCRUjcE= github.com/zitadel/logging v0.6.1 h1:Vyzk1rl9Kq9RCevcpX6ujUaTYFX43aa4LkvV1TvUk+Y= github.com/zitadel/logging v0.6.1/go.mod h1:Y4CyAXHpl3Mig6JOszcV5Rqqsojj+3n7y2F591Mp/ow= github.com/zitadel/oidc/v3 v3.33.1 h1:e3w9PDV0Mh50/ZiJWtzyT0E4uxJ6RXll+hqVDnqGbTU= @@ -330,6 +342,7 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= diff --git a/internal/autotls/certmagic.go b/internal/autotls/certmagic.go new file mode 100644 index 00000000..1d8dc80b --- /dev/null +++ b/internal/autotls/certmagic.go @@ -0,0 +1,318 @@ +package autotls + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "log" + "net" + "net/http" + "strings" + "sync" + "time" + + "github.com/NHAS/wag/internal/data" + "github.com/caddyserver/certmagic" +) + +type webserver struct { + listeners []*http.Server + mux http.Handler + close chan interface{} + isClosed bool + details data.WebserverConfiguration +} + +type AutoTLS struct { + *certmagic.Config + + sync.RWMutex + + webServers map[data.Webserver]*webserver + + issuer *certmagic.ACMEIssuer +} + +var Do *AutoTLS + +func Initialise() error { + + email, err := data.GetAcmeEmail() + if err != nil { + email = "" + } + + // Defaults to lets encrypt production if nothing is set + provider, err := data.GetAcmeProvider() + if err != nil { + provider = "" + } + + config := certmagic.NewDefault() + config.Storage = data.NewCertStore("wag-certificates") + + issuer := certmagic.NewACMEIssuer(&certmagic.Default, certmagic.ACMEIssuer{ + CA: provider, + Email: email, + Agreed: true, + }) + + if provider != "" && email != "" { + config.Issuers = []certmagic.Issuer{issuer} + } + + ret := &AutoTLS{ + Config: config, + webServers: make(map[data.Webserver]*webserver), + issuer: issuer} + ret.registerEventListeners() + + if Do != nil { + panic("should not occur") + } + + Do = ret + return nil +} + +func (a *AutoTLS) DynamicListener(forWhat data.Webserver, mux http.Handler) error { + + if mux == nil { + panic("no handler provided") + } + + initialDetails, err := data.GetWebserverConfig(forWhat) + if err != nil { + return err + } + + return a.refreshListeners(forWhat, mux, initialDetails) +} + +func (a *AutoTLS) Close(what data.Webserver) { + a.Lock() + defer a.Unlock() + w, ok := a.webServers[what] + if !ok { + return + } + + for _, s := range w.listeners { + s.Close() + } + + w.isClosed = true + + delete(a.webServers, what) + + close(w.close) +} + +func (a *AutoTLS) registerEventListeners() error { + + _, err := data.RegisterEventListener(data.AcmeEmailKey, false, func(_, current, previous string, ev data.EventType) error { + + a.issuer.Email = current + if ev == data.DELETED { + a.issuer.Email = "" + } + + if a.issuer.CA == "" || a.issuer.Email == "" { + a.Config.Issuers = []certmagic.Issuer{} + } else { + a.Config.Issuers = []certmagic.Issuer{a.issuer} + } + + // todo refesh with stored details & mux + + return nil + }) + if err != nil { + return err + } + + _, err = data.RegisterEventListener(data.AcmeProviderKey, false, func(_, current, previous string, ev data.EventType) error { + + a.issuer.CA = current + if ev == data.DELETED { + a.issuer.CA = "" + } + + if a.issuer.CA == "" || a.issuer.Email == "" { + a.Config.Issuers = []certmagic.Issuer{} + } else { + a.Config.Issuers = []certmagic.Issuer{a.issuer} + } + + // todo refesh with stored details & mux + + return nil + }) + if err != nil { + return err + } + + webserverEventsFunc := func(key string, current, _ data.WebserverConfiguration, ev data.EventType) error { + + webserverTarget := data.Webserver(strings.TrimPrefix(key, data.WebServerConfigKey)) + a.RLock() + _, ok := a.webServers[webserverTarget] + a.RUnlock() + + if !ok { + return nil + } + + if ev == data.DELETED { + a.Close(webserverTarget) + return nil + } + + // todo reopen after close, thus rethink about how close fully works + + // nil means we keep the established mux + return a.refreshListeners(webserverTarget, nil, current) + } + + _, err = data.RegisterEventListener(data.TunnelWebServerConfigKey, false, webserverEventsFunc) + if err != nil { + return err + } + + _, err = data.RegisterEventListener(data.PublicWebServerConfigKey, false, webserverEventsFunc) + if err != nil { + return err + } + + _, err = data.RegisterEventListener(data.ManagementWebServerConfigKey, false, webserverEventsFunc) + if err != nil { + return err + } + + return nil +} + +func (a *AutoTLS) refreshListeners(forWhat data.Webserver, mux http.Handler, details data.WebserverConfiguration) error { + ctx := context.Background() + + a.Lock() + defer a.Unlock() + + w, ok := a.webServers[forWhat] + if !ok { + if mux == nil { + return errors.New("refresh called from events while web server doesnt exist") + } + w = &webserver{ + mux: mux, + isClosed: false, + close: make(chan interface{}), + } + a.webServers[forWhat] = w + } + w.details = details + + if w.details.Domain == "" || !w.details.TLS || len(a.Issuers) == 0 { + httpListener, err := net.Listen("tcp", w.details.ListenAddress) + if err != nil { + return err + } + + httpServer := &http.Server{ + ReadHeaderTimeout: 10 * time.Second, + ReadTimeout: 30 * time.Second, + WriteTimeout: 2 * time.Minute, + IdleTimeout: 5 * time.Minute, + Handler: w.mux, + BaseContext: func(listener net.Listener) context.Context { return ctx }, + } + + for _, s := range w.listeners { + s.Close() + } + w.listeners = []*http.Server{httpServer} + + go httpServer.Serve(httpListener) + } else { + err := a.Config.ManageSync(ctx, []string{w.details.Domain}) + if err != nil { + return err + } + + tlsConfig := a.Config.TLSConfig() + tlsConfig.NextProtos = append([]string{"h2", "http/1.1"}, tlsConfig.NextProtos...) + + httpsLn, err := tls.Listen("tcp", fmt.Sprintf(w.details.ListenAddress), tlsConfig) + if err != nil { + return err + } + + httpsServer := &http.Server{ + ReadHeaderTimeout: 10 * time.Second, + ReadTimeout: 30 * time.Second, + WriteTimeout: 2 * time.Minute, + IdleTimeout: 5 * time.Minute, + Handler: w.mux, + BaseContext: func(listener net.Listener) context.Context { return ctx }, + } + for _, s := range w.listeners { + s.Close() + } + w.listeners = []*http.Server{} + + httpRedirectServer, err := a.autoRedirector(w.details.ListenAddress, w.details.Domain) + if err != nil { + log.Println("WARNING could start acme tls listener on", w.details.ListenAddress, err, " auto provisioning certificate may fail") + } else { + w.listeners = append(w.listeners, httpRedirectServer) + } + + w.listeners = append(w.listeners, httpsServer) + + go httpsServer.Serve(httpsLn) + + } + + return nil +} + +func (a *AutoTLS) autoRedirector(httpsServerListenAddr, domain string) (*http.Server, error) { + ctx := context.Background() + + host, port, err := net.SplitHostPort(httpsServerListenAddr) + if err != nil { + host = httpsServerListenAddr + port = "443" + } + + httpRedirectListener, err := net.Listen("tcp", fmt.Sprintf("%s:80", host)) + if err != nil { + return nil, err + } + + httpServer := &http.Server{ + ReadHeaderTimeout: 5 * time.Second, + ReadTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, + IdleTimeout: 5 * time.Second, + BaseContext: func(listener net.Listener) context.Context { return ctx }, + } + + if am, ok := a.Issuers[0].(*certmagic.ACMEIssuer); ok { + // todo dns-01 + httpServer.Handler = am.HTTPChallengeHandler(http.HandlerFunc(a.httpRedirectHandler(domain + ":" + port))) + } + + go httpServer.Serve(httpRedirectListener) + + return httpServer, nil +} + +func (a *AutoTLS) httpRedirectHandler(redirectTo string) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + // get rid of this disgusting unencrypted HTTP connection 🤢 + w.Header().Set("Connection", "close") + http.Redirect(w, r, "https://"+redirectTo, http.StatusMovedPermanently) + } +} diff --git a/internal/autotls/dns-01.go b/internal/autotls/dns-01.go deleted file mode 100644 index ef949963..00000000 --- a/internal/autotls/dns-01.go +++ /dev/null @@ -1,106 +0,0 @@ -package autotls - -import ( - "crypto" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "fmt" - "log" - "os" - - "github.com/go-acme/lego/v4/certificate" - "github.com/go-acme/lego/v4/challenge/dns01" - "github.com/go-acme/lego/v4/lego" - "github.com/go-acme/lego/v4/providers/dns/cloudflare" - "github.com/go-acme/lego/v4/registration" -) - -// You need to implement this interface for the ACME client -type MyUser struct { - Email string - Registration *registration.Resource - key *ecdsa.PrivateKey -} - -func (u *MyUser) GetEmail() string { - return u.Email -} -func (u MyUser) GetRegistration() *registration.Resource { - return u.Registration -} -func (u *MyUser) GetPrivateKey() crypto.PrivateKey { - return u.key -} - -func main() { - // Replace these with your values - acmeEmail := "your-email@example.com" - - cfEmail := "some@email.com" - - domain := "your-domain.com" - cfAPIToken := os.Getenv("CF_API_TOKEN") - - // Create a user - privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - log.Fatal(err) - } - - myUser := MyUser{ - Email: acmeEmail, - key: privateKey, - } - - acmeConfig := lego.NewConfig(&myUser) - - // This will help you get staging certificates while testing - acmeConfig.CADirURL = lego.LEDirectoryStaging - // Use this for production - // config.CADirURL = lego.LEDirectoryProduction - - // Create a new ACME client - client, err := lego.NewClient(acmeConfig) - if err != nil { - log.Fatal(err) - } - - cfConfig := cloudflare.NewDefaultConfig() - cfConfig.AuthEmail = cfEmail - cfConfig.AuthKey = cfAPIToken - - // Configure Cloudflare provider - cfProvider, err := cloudflare.NewDNSProviderConfig(cfConfig) - if err != nil { - log.Fatal(err) - } - - // Set Cloudflare as the DNS provider - err = client.Challenge.SetDNS01Provider(cfProvider, - dns01.AddRecursiveNameservers([]string{"1.1.1.1:53", "8.8.8.8:53"}), - ) - if err != nil { - log.Fatal(err) - } - - // Register user - reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}) - if err != nil { - log.Fatal(err) - } - myUser.Registration = reg - - // Request certificate - request := certificate.ObtainRequest{ - Domains: []string{domain}, - Bundle: true, - } - - _, err = client.Certificate.Obtain(request) - if err != nil { - log.Fatal(err) - } - - fmt.Println("Successfully obtained certificates!") -} diff --git a/internal/autotls/http-01.go b/internal/autotls/http-01.go deleted file mode 100644 index ebbf2eab..00000000 --- a/internal/autotls/http-01.go +++ /dev/null @@ -1,73 +0,0 @@ -package autotls - -import ( - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "fmt" - "log" - - "github.com/go-acme/lego/v4/certificate" - "github.com/go-acme/lego/v4/challenge/http01" - "github.com/go-acme/lego/v4/lego" - "github.com/go-acme/lego/v4/registration" -) - -func test() { - // Replace these with your values - email := "your-email@example.com" - domain := "your-domain.com" - - // Create a user - privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - log.Fatal(err) - } - - myUser := MyUser{ - Email: email, - key: privateKey, - } - - config := lego.NewConfig(&myUser) - - // This will help you get staging certificates while testing - config.CADirURL = lego.LEDirectoryStaging - // Use this for production - // config.CADirURL = lego.LEDirectoryProduction - - // Create a new ACME client - client, err := lego.NewClient(config) - if err != nil { - log.Fatal(err) - } - - // Create HTTP-01 provider - httpProvider := http01.NewProviderServer("", "80") - - // Set HTTP-01 as the challenge provider - err = client.Challenge.SetHTTP01Provider(httpProvider) - if err != nil { - log.Fatal(err) - } - - // Register user - reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}) - if err != nil { - log.Fatal(err) - } - myUser.Registration = reg - - // Request certificate - request := certificate.ObtainRequest{ - Domains: []string{domain}, - Bundle: true, - } - - _, err = client.Certificate.Obtain(request) - if err != nil { - log.Fatal(err) - } - - fmt.Println("Successfully obtained certificates!") -} diff --git a/internal/autotls/static.go b/internal/autotls/static.go deleted file mode 100644 index 12f22d86..00000000 --- a/internal/autotls/static.go +++ /dev/null @@ -1 +0,0 @@ -package autotls diff --git a/internal/config/config.go b/internal/config/config.go index 76759335..2d31bb39 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -23,6 +23,7 @@ var Version string type webserverDetails struct { ListenAddress string Domain string + TLS bool } type Acls struct { @@ -64,15 +65,9 @@ type Config struct { DownloadConfigFileName string `json:",omitempty"` - TLSStrategy struct { - DefaultDomain string `json:",omitempty"` - DefaultStrat string - - Tunnel string `json:",omitempty"` - Public string `json:",omitempty"` - ManagementUI string `json:",omitempty"` - - StaticCertsDirectory string `json:",omitempty"` + Acme struct { + Email string + CAProvider string } ManagementUI struct { @@ -94,9 +89,11 @@ type Config struct { Webserver struct { Public webserverDetails + Tunnel struct { Port string Domain string + TLS bool } } @@ -106,7 +103,6 @@ type Config struct { DefaultMethod string `json:",omitempty"` Issuer string Methods []string `json:",omitempty"` - DomainURL string OIDC struct { IssuerURL string diff --git a/internal/data/config.go b/internal/data/config.go index 356ab714..0008f33e 100644 --- a/internal/data/config.go +++ b/internal/data/config.go @@ -30,33 +30,79 @@ type Webauthn struct { Origin string } +type Webserver string + +const ( + Tunnel = Webserver("tunnel") + Management = Webserver("management") + Public = Webserver("public") +) + const ( - fullJsonConfigKey = "wag-config-full" + ConfigKey = "wag-config-" - helpMailKey = "wag-config-general-help-mail" - defaultWGFileNameKey = "wag-config-general-wg-filename" - checkUpdatesKey = "wag-config-general-check-updates" + fullJsonConfigKey = ConfigKey + "full" - InactivityTimeoutKey = "wag-config-authentication-inactivity-timeout" - SessionLifetimeKey = "wag-config-authentication-max-session-lifetime" + helpMailKey = ConfigKey + "general-help-mail" + defaultWGFileNameKey = ConfigKey + "general-wg-filename" + checkUpdatesKey = ConfigKey + "general-check-updates" - LockoutKey = "wag-config-authentication-lockout" - IssuerKey = "wag-config-authentication-issuer" - DomainKey = "wag-config-authentication-domain" - MFAMethodsEnabledKey = "wag-config-authentication-methods" - DefaultMFAMethodKey = "wag-config-authentication-default-method" + InactivityTimeoutKey = ConfigKey + "authentication-inactivity-timeout" + SessionLifetimeKey = ConfigKey + "authentication-max-session-lifetime" - OidcDetailsKey = "wag-config-authentication-oidc" - PamDetailsKey = "wag-config-authentication-pam" + LockoutKey = ConfigKey + "authentication-lockout" + IssuerKey = ConfigKey + "authentication-issuer" + DomainKey = ConfigKey + "authentication-domain" + MFAMethodsEnabledKey = ConfigKey + "authentication-methods" + DefaultMFAMethodKey = ConfigKey + "authentication-default-method" - externalAddressKey = "wag-config-network-external-address" - dnsKey = "wag-config-network-dns" + OidcDetailsKey = ConfigKey + "authentication-oidc" + PamDetailsKey = ConfigKey + "authentication-pam" + + externalAddressKey = ConfigKey + "network-external-address" + dnsKey = ConfigKey + "network-dns" MembershipKey = "wag-membership" deviceRef = "deviceref-" + + WebServerConfigKey = ConfigKey + "webserver-" + TunnelWebServerConfigKey = WebServerConfigKey + string(Tunnel) + PublicWebServerConfigKey = WebServerConfigKey + string(Public) + ManagementWebServerConfigKey = WebServerConfigKey + string(Management) ) +type WebserverConfiguration struct { + ListenAddress string `json:"listen_address"` + Domain string `json:"domain"` + TLS bool `json:"tls"` +} + +func GetWebserverConfig(forWhat Webserver) (details WebserverConfiguration, err error) { + + response, err := etcd.Get(context.Background(), WebServerConfigKey+string(forWhat)) + if err != nil { + return WebserverConfiguration{}, err + } + + if len(response.Kvs) == 0 { + return WebserverConfiguration{}, errors.New("no web server settings found") + } + + err = json.Unmarshal(response.Kvs[0].Value, &details) + return +} + +func SetWebserverConfig(forWhat Webserver, details WebserverConfiguration) (err error) { + + b, err := json.Marshal(details) + if err != nil { + return err + } + _, err = etcd.Put(context.Background(), WebServerConfigKey+string(forWhat), string(b)) + return err +} + func getString(key string) (ret string, err error) { resp, err := etcd.Get(context.Background(), key) if err != nil { diff --git a/internal/data/init.go b/internal/data/init.go index 37ada75d..01fbfe37 100644 --- a/internal/data/init.go +++ b/internal/data/init.go @@ -274,7 +274,7 @@ func loadInitialSettings() error { return err } - err = putIfNotFound(DomainKey, config.Values.Authenticators.DomainURL, "domain url") + err = putIfNotFound(DomainKey, config.Values.Webserver.Tunnel.Domain, "domain url") if err != nil { return err } @@ -309,6 +309,59 @@ func loadInitialSettings() error { return err } + err = putIfNotFound(PamDetailsKey, config.Values.Authenticators.PAM, "pam settings") + if err != nil { + return err + } + + err = putIfNotFound(AcmeEmailKey, config.Values.Acme.Email, "acme email") + if err != nil { + return err + } + + err = putIfNotFound(AcmeEmailKey, config.Values.Acme.CAProvider, "acme provider") + if err != nil { + return err + } + + tunnelWebserverConfig := WebserverConfiguration{ + Domain: config.Values.Webserver.Tunnel.Domain, + TLS: config.Values.Webserver.Tunnel.TLS, + } + + tunnelWebserverConfig.ListenAddress = config.Values.Wireguard.ServerAddress.String() + if config.Values.Wireguard.ServerAddress.To4() == nil && config.Values.Wireguard.ServerAddress.To16() != nil { + tunnelWebserverConfig.ListenAddress = "[" + tunnelWebserverConfig.ListenAddress + "]" + } + tunnelWebserverConfig.ListenAddress += ":" + config.Values.Webserver.Tunnel.Port + + err = putIfNotFound(TunnelWebServerConfigKey, tunnelWebserverConfig, "tunnel web server config") + if err != nil { + return err + } + + publicWebserverConfig := WebserverConfiguration{ + Domain: config.Values.Webserver.Public.Domain, + TLS: config.Values.Webserver.Public.TLS, + ListenAddress: config.Values.Webserver.Public.ListenAddress, + } + + err = putIfNotFound(PublicWebServerConfigKey, publicWebserverConfig, "public/enrolment web server config") + if err != nil { + return err + } + + managementWebserverConfig := WebserverConfiguration{ + Domain: config.Values.ManagementUI.Domain, + TLS: config.Values.ManagementUI.TLS, + ListenAddress: config.Values.ManagementUI.ListenAddress, + } + + err = putIfNotFound(ManagementWebServerConfigKey, managementWebserverConfig, "management web server config") + if err != nil { + return err + } + return nil } diff --git a/internal/data/tls.go b/internal/data/tls.go index a33f9743..0d57854e 100644 --- a/internal/data/tls.go +++ b/internal/data/tls.go @@ -2,139 +2,226 @@ package data import ( "context" - "encoding/base64" "encoding/json" - "fmt" + "errors" + "io/fs" + "path" + "strings" - "go.etcd.io/etcd/client/pkg/v3/types" + "github.com/caddyserver/certmagic" clientv3 "go.etcd.io/etcd/client/v3" - "go.etcd.io/etcd/client/v3/clientv3util" + "go.etcd.io/etcd/client/v3/concurrency" ) -type WebServer string - const ( - Tunnel = WebServer("tunnel") - ManagementUI = WebServer("tunnel") - Public = WebServer("tunnel") - - TLSPrefix = "wag-tls-" - CertificatesKey = TLSPrefix + "certificates-" - UpdateCertHoldKey = TLSPrefix + "hold" - PinAcmeQuerierKey = TLSPrefix + "force-acme-from-node" - - AcmeTime = 2 * 60 + AcmeEmailKey = "wag-acme-email" + AcmeProviderKey = "wag-acme-provider" ) -func PinNodeToAcmeDuties(node types.ID) error { - _, err := etcd.Put(context.Background(), PinAcmeQuerierKey, node.String()) - return err +func GetAcmeEmail() (string, error) { + return getString(AcmeEmailKey) } -func UnpinAcmeDuties() error { - _, err := etcd.Delete(context.Background(), PinAcmeQuerierKey) +func SetAcmeProvider(providerURL string) error { + if !strings.HasPrefix(providerURL, "https://") { + return errors.New("acme provider must start with https://") + } + + data, _ := json.Marshal(providerURL) + + _, err := etcd.Put(context.Background(), AcmeProviderKey, string(data)) return err } -type Certificate struct { - Certificate []byte - PrivateKey []byte `sensitive:"true"` +func GetAcmeProvider() (string, error) { + return getString(AcmeProviderKey) } -type certificateJSON struct { - Certificate string `json:"certificate"` - PrivateKey string `json:"private_key"` +type CertMagicStore struct { + basePath string } -// MarshalJSON implements the json.Marshaler interface -func (c Certificate) MarshalJSON() ([]byte, error) { - return json.Marshal(certificateJSON{ - Certificate: base64.StdEncoding.EncodeToString(c.Certificate), - PrivateKey: base64.StdEncoding.EncodeToString(c.PrivateKey), - }) +func NewCertStore(basePath string) *CertMagicStore { + return &CertMagicStore{ + basePath: basePath, + } } -// UnmarshalJSON implements the json.Unmarshaler interface -func (c *Certificate) UnmarshalJSON(data []byte) error { - var jsonData certificateJSON - if err := json.Unmarshal(data, &jsonData); err != nil { - return err +func (cms *CertMagicStore) Exists(ctx context.Context, key string) bool { + + res, err := etcd.Get(ctx, cms.basePath+"/"+key, clientv3.WithCountOnly()) + if err != nil { + return false } - cert, err := base64.StdEncoding.DecodeString(jsonData.Certificate) + return res.Count > 1 +} + +func (cms *CertMagicStore) Lock(ctx context.Context, name string) error { + session, err := concurrency.NewSession(etcd, concurrency.WithContext(ctx)) if err != nil { return err } - key, err := base64.StdEncoding.DecodeString(jsonData.PrivateKey) + return concurrency.NewMutex(session, name).Lock(ctx) +} + +func (cms *CertMagicStore) Unlock(ctx context.Context, name string) error { + session, err := concurrency.NewSession(etcd, concurrency.WithContext(ctx)) if err != nil { return err } - c.Certificate = cert - c.PrivateKey = key - return nil + return concurrency.NewMutex(session, name).Unlock(ctx) + } -func AllowToRenew() (bool, error) { - lease, err := clientv3.NewLease(etcd).Grant(context.Background(), AcmeTime) +func (cms *CertMagicStore) Store(ctx context.Context, key string, value []byte) error { + keyPath := cms.basePath + "/" + key + + _, err := etcd.Put(ctx, keyPath, string(value)) + return err +} + +func (cms *CertMagicStore) Load(ctx context.Context, key string) ([]byte, error) { + + keyPath := cms.basePath + "/" + key + + res, err := etcd.Get(ctx, keyPath) if err != nil { - return false, err + return nil, err } - txn := etcd.Txn(context.Background()) - txn.If( - clientv3util.KeyMissing(UpdateCertHoldKey), - ).Then( - clientv3.OpPut(UpdateCertHoldKey, GetServerID().String(), clientv3.WithLease(lease.ID)), - ) + if res.Count == 0 { + return nil, fs.ErrNotExist + } - resp, err := txn.Commit() - if err != nil { - return false, err + if len(res.Kvs) == 0 { + return nil, fs.ErrNotExist } - // This node won the race, so now it can do acme (and will not be stomped on for seconds) - return resp.Succeeded, nil + return res.Kvs[0].Value, nil } -func SetCertificate(forWhat WebServer, certificate, privateKey []byte) error { +func (cms *CertMagicStore) Delete(ctx context.Context, key string) error { + + keyPath := cms.basePath + "/" + key - newCert := Certificate{ - Certificate: certificate, - PrivateKey: privateKey, + opts := []clientv3.OpOption{} + + res, err := etcd.Get(ctx, keyPath, clientv3.WithCountOnly()) + if err != nil { + return err } - data, err := json.Marshal(newCert) + if res.Count == 0 { + + if !strings.HasSuffix(keyPath, "/") { + keyPath = keyPath + "/" + } + + res, err = etcd.Get(ctx, keyPath, clientv3.WithCountOnly(), clientv3.WithPrefix()) + if err != nil { + return err + } + + if res.Count == 0 { + return fs.ErrNotExist + } + + // intentional fall through + } + + //A "directory" is a key with no value, but which may be the prefix of other keys. + if res.Count > 1 { + opts = append(opts, clientv3.WithPrefix()) + } + + delRes, err := etcd.Delete(ctx, key, opts...) if err != nil { - return fmt.Errorf("failed to marshal new certificate: %w", err) + return err } - _, err = etcd.Put(context.Background(), CertificatesKey+string(forWhat), string(data)) + if delRes.Deleted != res.Count { + return errors.New("short delete") + } - return err + return nil } -// deliberately no getter, so that we force the user to use the events system to update their certificates +func (cms *CertMagicStore) List(ctx context.Context, pathPrefix string, recursive bool) ([]string, error) { -// SupportsTLS Should only be used on startup, everywhere else use the events system to watch if certificates have been updated/created -func SupportsTLS(web WebServer) bool { + keyPath := cms.basePath + "/" + pathPrefix - certificates, err := etcd.Get(context.Background(), CertificatesKey+string(web)) + response, err := etcd.Get(context.Background(), keyPath, clientv3.WithPrefix(), clientv3.WithKeysOnly()) if err != nil { - return false + return nil, err } - if certificates.Count == 0 { - return false + if response.Count == 0 { + return nil, fs.ErrNotExist } - if len(certificates.Kvs) != 1 { - return false + var keys []string + for _, res := range response.Kvs { + + key := strings.TrimPrefix(string(res.Key), cms.basePath+"/") + keys = append(keys, key) + } + + if recursive { + return keys, nil + } + + // stolen from: https://github.com/SUNET/knubbis-fleetlock/blob/main/certmagic/etcd3storage/etcd3storage.go + + combinedKeys := map[string]struct{}{} + for _, key := range keys { + // prefix/dir1/file1 -> dir1/file1 + noPrefixKey := strings.TrimPrefix(key, pathPrefix+"/") + // dir1/file1 -> dir1 + part := strings.Split(noPrefixKey, "/")[0] + + combinedKeys[part] = struct{}{} } - var jsonCert certificateJSON - err = json.Unmarshal(certificates.Kvs[0].Value, &jsonCert) + cKeys := []string{} + for key := range combinedKeys { + cKeys = append(cKeys, path.Join(pathPrefix, key)) + } + + return cKeys, nil +} + +func (cms *CertMagicStore) Stat(ctx context.Context, key string) (certmagic.KeyInfo, error) { + res, err := etcd.Get(ctx, key) + if err != nil { + return certmagic.KeyInfo{}, err + } + + r := certmagic.KeyInfo{ + Key: key, + } + + if len(res.Kvs) > 1 { + + r.Size = int64(len(res.Kvs[0].Value)) + r.IsTerminal = true + + return r, nil + } + + // look for directory + res, err = etcd.Get(ctx, key+"/", clientv3.WithPrefix(), clientv3.WithCountOnly(), clientv3.WithKeysOnly()) + if err != nil { + return certmagic.KeyInfo{}, err + } + + if res.Count > 0 { + r.IsTerminal = false + } else { + return certmagic.KeyInfo{}, fs.ErrNotExist + } - return err == nil + return r, nil } diff --git a/internal/enrolment/web.go b/internal/enrolment/web.go index 327ae3a2..7bec076e 100644 --- a/internal/enrolment/web.go +++ b/internal/enrolment/web.go @@ -2,9 +2,7 @@ package enrolment import ( "bytes" - "crypto/tls" "encoding/base64" - "errors" "fmt" "html/template" "image/png" @@ -13,8 +11,8 @@ import ( "net/http" "net/url" "strings" - "time" + "github.com/NHAS/wag/internal/autotls" "github.com/NHAS/wag/internal/config" "github.com/NHAS/wag/internal/data" "github.com/NHAS/wag/internal/enrolment/resources" @@ -30,20 +28,11 @@ import ( ) type EnrolmentServer struct { - publicHTTPServ *http.Server - publicTLSServ *http.Server - firewall *router.Firewall + firewall *router.Firewall } func (es *EnrolmentServer) Close() { - - if es.publicHTTPServ != nil { - es.publicHTTPServ.Close() - } - - if es.publicTLSServ != nil { - es.publicTLSServ.Close() - } + autotls.Do.Close(data.Public) } func (es *EnrolmentServer) registerDevice(w http.ResponseWriter, r *http.Request) { @@ -321,89 +310,13 @@ func New(firewall *router.Firewall, errChan chan<- error) (*EnrolmentServer, err var es EnrolmentServer es.firewall = firewall - //https://blog.cloudflare.com/exposing-go-on-the-internet/ - tlsConfig := &tls.Config{ - // Only use curves which have assembly implementations - CurvePreferences: []tls.CurveID{ - tls.CurveP256, - tls.X25519, // Go 1.8 only - }, - MinVersion: tls.VersionTLS12, - CipherSuites: []uint16{ - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, // Go 1.8 only - tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, // Go 1.8 only - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - }, - } - public := http.NewServeMux() public.HandleFunc("GET /static/", utils.EmbeddedStatic(styling.Static)) public.HandleFunc("GET /reachability", es.reachability) public.HandleFunc("GET /register_device", es.registerDevice) - if data.SupportsTLS(data.Public) { - - go func() { - - es.publicTLSServ = &http.Server{ - Addr: config.Values.Webserver.Public.ListenAddress, - ReadTimeout: 5 * time.Second, - WriteTimeout: 10 * time.Second, - IdleTimeout: 120 * time.Second, - TLSConfig: tlsConfig, - Handler: utils.SetSecurityHeaders(public), - } - - if err := es.publicTLSServ.ListenAndServeTLS("", ""); err != nil && !errors.Is(err, http.ErrServerClosed) { - errChan <- fmt.Errorf("TLS webserver enrolment listener failed: %v", err) - } - }() - - if config.Values.NumberProxies == 0 { - go func() { - - address, port, err := net.SplitHostPort(config.Values.Webserver.Public.ListenAddress) - - if err != nil { - errChan <- fmt.Errorf("malformed listen address for enrolment listener: %v", err) - return - } - - // If we're supporting tls, add a redirection handler from 80 -> tls - port += ":" + port - if port == "443" { - port = "" - } - - es.publicHTTPServ = &http.Server{ - Addr: address + ":80", - ReadTimeout: 5 * time.Second, - WriteTimeout: 10 * time.Second, - IdleTimeout: 120 * time.Second, - Handler: utils.SetSecurityHeaders(utils.SetRedirectHandler(port)), - } - - log.Printf("Creating redirection from 80/tcp to TLS webserver enrolment listener failed: %v", es.publicHTTPServ.ListenAndServe()) - }() - } - - } else { - go func() { - es.publicHTTPServ = &http.Server{ - Addr: config.Values.Webserver.Public.ListenAddress, - ReadTimeout: 5 * time.Second, - WriteTimeout: 10 * time.Second, - IdleTimeout: 120 * time.Second, - Handler: utils.SetSecurityHeaders(public), - } - - if err := es.publicHTTPServ.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - errChan <- fmt.Errorf("HTTP webserver enrolment listener failed: %v", err) - } - }() + if err := autotls.Do.DynamicListener(data.Public, public); err != nil { + return nil, err } log.Println("[ENROLMENT] Public enrolment listening: ", config.Values.Webserver.Public.ListenAddress) diff --git a/internal/mfaportal/web.go b/internal/mfaportal/web.go index d565ccb9..2044684c 100644 --- a/internal/mfaportal/web.go +++ b/internal/mfaportal/web.go @@ -1,16 +1,14 @@ package mfaportal import ( - "crypto/tls" "encoding/json" - "errors" "fmt" "log" "net/http" "path" "strings" - "time" + "github.com/NHAS/wag/internal/autotls" "github.com/NHAS/wag/internal/config" "github.com/NHAS/wag/internal/data" "github.com/NHAS/wag/internal/mfaportal/authenticators" @@ -21,9 +19,6 @@ import ( ) type MfaPortal struct { - tunnelHTTPServ *http.Server - tunnelTLSServ *http.Server - firewall *router.Firewall listenerKeys struct { @@ -36,13 +31,7 @@ type MfaPortal struct { func (mp *MfaPortal) Close() { - if mp.tunnelHTTPServ != nil { - mp.tunnelHTTPServ.Close() - } - - if mp.tunnelTLSServ != nil { - mp.tunnelTLSServ.Close() - } + autotls.Do.Close(data.Public) mp.deregisterListeners() @@ -57,24 +46,6 @@ func New(firewall *router.Firewall, errChan chan<- error) (m *MfaPortal, err err var mfaPortal MfaPortal mfaPortal.firewall = firewall - //https://blog.cloudflare.com/exposing-go-on-the-internet/ - tlsConfig := &tls.Config{ - // Only use curves which have assembly implementations - CurvePreferences: []tls.CurveID{ - tls.CurveP256, - tls.X25519, // Go 1.8 only - }, - MinVersion: tls.VersionTLS12, - CipherSuites: []uint16{ - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, // Go 1.8 only - tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, // Go 1.8 only - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - }, - } - tunnel := http.NewServeMux() tunnel.HandleFunc("GET /status/", mfaPortal.status) @@ -113,68 +84,11 @@ func New(firewall *router.Firewall, errChan chan<- error) (m *MfaPortal, err err tunnel.HandleFunc("/", mfaPortal.index) - address := config.Values.Wireguard.ServerAddress.String() - if config.Values.Wireguard.ServerAddress.To4() == nil && config.Values.Wireguard.ServerAddress.To16() != nil { - address = "[" + address + "]" - } - - tunnelListenAddress := address + ":" + config.Values.Webserver.Tunnel.Port - if data.SupportsTLS(data.Tunnel) { - - go func() { - - mfaPortal.tunnelTLSServ = &http.Server{ - Addr: tunnelListenAddress, - ReadTimeout: 5 * time.Second, - WriteTimeout: 10 * time.Second, - IdleTimeout: 120 * time.Second, - TLSConfig: tlsConfig, - Handler: utils.SetSecurityHeaders(tunnel), - } - if err := mfaPortal.tunnelTLSServ.ListenAndServeTLS("", ""); err != nil && !errors.Is(err, http.ErrServerClosed) { - errChan <- fmt.Errorf("TLS webserver tunnel listener failed: %v", err) - } - - }() - - if config.Values.NumberProxies == 0 { - go func() { - - port := ":" + config.Values.Webserver.Tunnel.Port - if port == "443" { - port = "" - } - - mfaPortal.tunnelHTTPServ = &http.Server{ - Addr: address + ":80", - ReadTimeout: 5 * time.Second, - WriteTimeout: 10 * time.Second, - IdleTimeout: 120 * time.Second, - Handler: utils.SetSecurityHeaders(utils.SetRedirectHandler(port)), - } - - log.Printf("HTTP redirect to TLS webserver tunnel listener failed: %v", mfaPortal.tunnelHTTPServ.ListenAndServe()) - }() - } - } else { - go func() { - mfaPortal.tunnelHTTPServ = &http.Server{ - Addr: tunnelListenAddress, - ReadTimeout: 5 * time.Second, - WriteTimeout: 10 * time.Second, - IdleTimeout: 120 * time.Second, - Handler: utils.SetSecurityHeaders(tunnel), - } - - if err := mfaPortal.tunnelHTTPServ.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - errChan <- fmt.Errorf("webserver tunnel listener failed: %v", err) - } - - }() + if err := autotls.Do.DynamicListener(data.Tunnel, utils.SetSecurityHeaders(tunnel)); err != nil { + return nil, err } - //Group the print statement so that multithreading won't disorder them - log.Println("[PORTAL] Captive portal started listening: ", tunnelListenAddress) + log.Println("[PORTAL] Captive portal started listening") return m, nil } diff --git a/internal/router/firewall.go b/internal/router/firewall.go index 3d135836..98e5f51b 100644 --- a/internal/router/firewall.go +++ b/internal/router/firewall.go @@ -62,8 +62,6 @@ type Firewall struct { connectedPeersLck sync.RWMutex currentlyConnectedPeers map[string]string - - tunnelInitallySupportedTLS bool } func (f *Firewall) GetRoutes(username string) ([]string, error) { diff --git a/internal/router/iptables.go b/internal/router/iptables.go index b59217c0..63d0f9a6 100644 --- a/internal/router/iptables.go +++ b/internal/router/iptables.go @@ -6,7 +6,6 @@ import ( "strings" "github.com/NHAS/wag/internal/config" - "github.com/NHAS/wag/internal/data" "github.com/coreos/go-iptables/iptables" ) @@ -70,14 +69,10 @@ func (f *Firewall) setupIptables() error { return err } - // Open port 80 to allow http redirection - if data.SupportsTLS(data.Tunnel) { - f.tunnelInitallySupportedTLS = true - //Allow input to authorize web server on the tunnel (http -> https redirect), if we're not behind a proxy - err = ipt.Append("filter", "INPUT", "-m", "tcp", "-p", "tcp", "-i", devName, "--dport", "80", "-j", "ACCEPT") - if err != nil { - return err - } + //Allow input to authorize web server on the tunnel (http -> https redirect), if we're not behind a proxy + err = ipt.Append("filter", "INPUT", "-m", "tcp", "-p", "tcp", "-i", devName, "--dport", "80", "-j", "ACCEPT") + if err != nil { + return err } } @@ -167,12 +162,10 @@ func (f *Firewall) teardownIptables() { } // Open port 80 to allow http redirection - if f.tunnelInitallySupportedTLS { - //Allow input to authorize web server on the tunnel (http -> https redirect), if we're not behind a proxy - err = ipt.Delete("filter", "INPUT", "-m", "tcp", "-p", "tcp", "-i", config.Values.Wireguard.DevName, "--dport", "80", "-j", "ACCEPT") - if err != nil { - log.Println("Unable to clean up firewall rules: ", err) - } + //Allow input to authorize web server on the tunnel (http -> https redirect), if we're not behind a proxy + err = ipt.Delete("filter", "INPUT", "-m", "tcp", "-p", "tcp", "-i", config.Values.Wireguard.DevName, "--dport", "80", "-j", "ACCEPT") + if err != nil { + log.Println("Unable to clean up firewall rules: ", err) } }