diff --git a/adminui/frontend/src/api/settings.ts b/adminui/frontend/src/api/settings.ts
index dd86d063..3106bee0 100644
--- a/adminui/frontend/src/api/settings.ts
+++ b/adminui/frontend/src/api/settings.ts
@@ -1,4 +1,11 @@
-import type { GenericResponseDTO, MFAMethodDTO, LoginSettingsResponseDTO, GeneralSettingsResponseDTO } from './types'
+import type {
+ GenericResponseDTO,
+ MFAMethodDTO,
+ LoginSettingsResponseDTO,
+ GeneralSettingsResponseDTO,
+ AcmeDetailsDTO,
+ WebServerConfigDTO
+} from './types'
import { client } from '.'
@@ -21,3 +28,27 @@ export function updateLoginSettings(settings: LoginSettingsResponseDTO): Promise
export function getMFAMethods(): Promise {
return client.get('/api/settings/all_mfa_methods').then(res => res.data)
}
+
+export function getWebservers(): Promise {
+ return client.get('/api/settings/webservers').then(res => res.data)
+}
+
+export function editWebserver(webserver: WebServerConfigDTO): Promise {
+ return client.put('/api/settings/webserver', webserver).then(res => res.data)
+}
+
+export function getAcmeDetails(): Promise {
+ return client.get('/api/settings/acme').then(res => res.data)
+}
+
+export function setAcmeEmail(email: string): Promise {
+ return client.put('/api/settings/acme/email', { data: email }).then(res => res.data)
+}
+
+export function setAcmeProvider(url: string): Promise {
+ return client.put('/api/settings/acme/provider_url', { data: url }).then(res => res.data)
+}
+
+export function setAcmeCloudflareDNSKey(cloudflare_api_key: string): Promise {
+ return client.put('/api/settings/acme/cloudflare_api_key', { data: cloudflare_api_key }).then(res => res.data)
+}
diff --git a/adminui/frontend/src/api/types.ts b/adminui/frontend/src/api/types.ts
index bc8358e4..f2439628 100644
--- a/adminui/frontend/src/api/types.ts
+++ b/adminui/frontend/src/api/types.ts
@@ -177,6 +177,12 @@ export interface LoginSettingsResponseDTO {
pam: PamResponseDTO
}
+export interface AcmeDetailsDTO {
+ provider_url: string
+ email: string
+ api_token_set: boolean
+}
+
export interface MFAMethodDTO {
friendly_name: string
method: string
@@ -286,3 +292,10 @@ export enum NodeControlActions {
Stepdown = 'stepdown',
Remove = 'remove'
}
+
+export interface WebServerConfigDTO {
+ server_name: string
+ listen_address: string
+ domain: string
+ tls: boolean
+}
diff --git a/adminui/frontend/src/layouts/default.vue b/adminui/frontend/src/layouts/default.vue
index 963f34de..45c84f47 100644
--- a/adminui/frontend/src/layouts/default.vue
+++ b/adminui/frontend/src/layouts/default.vue
@@ -2,8 +2,6 @@
import { useRoute, useRouter } from 'vue-router'
import { storeToRefs } from 'pinia'
-import logo from '../../public/WagLogo.png'
-
import { useAuthStore } from '@/stores/auth'
import { useInstanceDetailsStore } from '@/stores/serverInfo'
@@ -70,7 +68,7 @@ async function logout() {
-
+
diff --git a/adminui/frontend/src/pages/Settings.vue b/adminui/frontend/src/pages/Settings.vue
index 4c963d1a..c5ddb5ee 100644
--- a/adminui/frontend/src/pages/Settings.vue
+++ b/adminui/frontend/src/pages/Settings.vue
@@ -15,19 +15,108 @@ import {
updateLoginSettings,
type GeneralSettingsResponseDTO as GeneralSettingsDTO,
type LoginSettingsResponseDTO as LoginSettingsDTO,
- type MFAMethodDTO
+ type WebServerConfigDTO,
+ type MFAMethodDTO,
+ getAcmeDetails,
+ type AcmeDetailsDTO,
+ setAcmeCloudflareDNSKey,
+ setAcmeEmail,
+ setAcmeProvider,
+ getWebservers,
+ editWebserver
} from '@/api'
const toast = useToast()
const { catcher } = useToastError()
+const apiTokenSetValue = '**********'
+
+const { data: acme, isLoading: isLoadingAcmeSettings, silentlyRefresh: refreshAcme } = useApi(() => getAcmeDetails())
const { data: general, isLoading: isLoadingGeneralSettings, silentlyRefresh: refreshGeneral } = useApi(() => getGeneralSettings())
const { data: loginSettings, isLoading: isLoadingLoginSettings, silentlyRefresh: refreshLoginSettings } = useApi(() => getLoginSettings())
+const { data: webservers, isLoading: isLoadingWebserverSettings, silentlyRefresh: refreshWebservers } = useApi(() => getWebservers())
const { data: mfaTypes, isLoading: isLoadingMFATypes } = useApi(() => getMFAMethods())
-const generalData = computed(() => general.value ?? ({} as GeneralSettingsDTO))
+const originalAcmeStates = ref({} as AcmeDetailsDTO)
+
+watch(
+ acme,
+ newAcme => {
+ if (newAcme) {
+ originalAcmeStates.value = {
+ api_token_set: newAcme.api_token_set,
+ email: newAcme.email,
+ provider_url: newAcme.provider_url
+ }
+ }
+ },
+ { immediate: true }
+)
+
+const originalServerStates = ref>({})
+
+watch(
+ webservers,
+ newServers => {
+ if (newServers) {
+ originalServerStates.value = newServers.reduce(
+ (acc, server) => {
+ acc[server.server_name] = {
+ server_name: server.server_name,
+ domain: server.domain,
+ listen_address: server.listen_address,
+ tls: server.tls
+ }
+ return acc
+ },
+ {} as Record
+ )
+ }
+ },
+ { immediate: true }
+)
+
+const getModifiedServers = () => {
+ if (!webservers.value) return []
+
+ return webservers.value.filter(server => {
+ const original = originalServerStates.value[server.server_name]
+ if (!original) return true // New server
+
+ return original.domain !== server.domain || original.listen_address !== server.listen_address || original.tls !== server.tls
+ })
+}
+
+async function saveServerSettings() {
+ try {
+ const updateResults = await Promise.all(getModifiedServers().map(server => editWebserver(server)))
+
+ const allSuccessful = updateResults.every(result => result.success) // Adjust based on your API response structure
+ if (allSuccessful) {
+ toast.success('updated servers!')
+ } else {
+ const failedServers = updateResults.filter(resp => resp.success)
+ toast.error('failed to save server settings' + failedServers.map(s => s.message))
+ }
+ } catch (e) {
+ catcher(e, 'failed to save acme settings: ')
+ } finally {
+ refreshWebservers()
+ }
+}
+
+const generalData = computed(
+ () =>
+ general.value ??
+ ({
+ dns: [] as string[]
+ } as GeneralSettingsDTO)
+)
const loginSettingsData = computed(() => loginSettings.value ?? ({} as LoginSettingsDTO))
+const acmeSettingsData = computed(() => acme.value ?? ({} as AcmeDetailsDTO))
+
+const webserversSettingsData = computed(() => webservers.value ?? ([] as WebServerConfigDTO[]))
const textValue = ref(general.value?.dns.join('\n') ?? '')
@@ -39,10 +128,57 @@ watch(textValue, newValue => {
watch(general, newValue => {
if (newValue) {
- textValue.value = newValue.dns.join('\n')
+ textValue.value = newValue.dns?.join('\n') ?? ''
+ }
+})
+
+const cloudflareApiTokenRef = ref('')
+watch(acme, newVal => {
+ if (newVal?.api_token_set) {
+ cloudflareApiTokenRef.value = apiTokenSetValue
+ } else {
+ cloudflareApiTokenRef.value = ''
}
})
+async function saveAcmeSettings() {
+ try {
+ let failed = false
+
+ if (cloudflareApiTokenRef.value !== apiTokenSetValue) {
+ const resp = await setAcmeCloudflareDNSKey(cloudflareApiTokenRef.value)
+ if (!resp.success) {
+ toast.error('Failed to save cloudflare api token:' + (resp.message ?? 'Unknown Error'))
+ failed = true
+ }
+ }
+
+ if (acmeSettingsData.value.email != originalAcmeStates.value.email) {
+ const resp = await setAcmeEmail(acmeSettingsData.value.email)
+ if (!resp.success) {
+ toast.error('Failed to save acme email:' + (resp.message ?? 'Unknown Error'))
+ failed = true
+ }
+ }
+
+ if (acmeSettingsData.value.provider_url != originalAcmeStates.value.provider_url) {
+ const resp = await setAcmeProvider(acmeSettingsData.value.provider_url)
+ if (!resp.success) {
+ toast.error('Failed to save acme provider url:' + (resp.message ?? 'Unknown Error'))
+ failed = true
+ }
+ }
+
+ if (!failed) {
+ toast.success('Saved acme settings')
+ }
+ } catch (e) {
+ catcher(e, 'failed to save acme settings: ')
+ } finally {
+ refreshAcme()
+ }
+}
+
async function saveGeneralSettings() {
try {
const resp = await updateGeneralSettings(generalData.value as GeneralSettingsDTO)
@@ -82,7 +218,9 @@ function filterMfaMethods(enabledMethods: string[], allMethods: MFAMethodDTO[]):
-
+
+
+
+
ACME
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Web Servers
+
+
TLS Method:
+
+
+ {{
+ acmeSettingsData.email.length == 0 || acmeSettingsData.provider_url.length == 0
+ ? 'Disabled'
+ : acmeSettingsData.api_token_set
+ ? 'DNS-01'
+ : 'HTTP-01'
+ }}
+
+
+
+
+
+
+
diff --git a/adminui/settings.go b/adminui/settings.go
index 817be56b..4f5abb2f 100644
--- a/adminui/settings.go
+++ b/adminui/settings.go
@@ -4,6 +4,7 @@ import (
"encoding/json"
"log"
"net/http"
+ "sort"
"github.com/NHAS/wag/internal/data"
"github.com/NHAS/wag/internal/mfaportal/authenticators"
@@ -103,3 +104,134 @@ func (au *AdminUI) getAllMfaMethods(w http.ResponseWriter, r *http.Request) {
w.Header().Set("content-type", "application/json")
json.NewEncoder(w).Encode(resp)
}
+
+func (au *AdminUI) getAllWebserverConfigs(w http.ResponseWriter, _ *http.Request) {
+
+ confs, err := data.GetAllWebserverConfigs()
+ if err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ return
+ }
+
+ var results []WebServerConfigDTO
+
+ for name, conf := range confs {
+ results = append(results, WebServerConfigDTO{ServerName: name, WebserverConfiguration: conf})
+ }
+
+ sort.Slice(results, func(i, j int) bool {
+ return results[i].ServerName < results[j].ServerName
+ })
+
+ w.Header().Set("content-type", "application/json")
+ json.NewEncoder(w).Encode(results)
+}
+
+func (au *AdminUI) editWebserverConfig(w http.ResponseWriter, r *http.Request) {
+
+ var (
+ s WebServerConfigDTO
+ err error
+ )
+ defer func() { au.respond(err, w) }()
+
+ err = json.NewDecoder(r.Body).Decode(&s)
+ r.Body.Close()
+ if err != nil {
+ w.WriteHeader(http.StatusBadRequest)
+ return
+ }
+
+ err = data.SetWebserverConfig(data.Webserver(s.ServerName), s.WebserverConfiguration)
+ if err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ return
+ }
+}
+
+func (au *AdminUI) getAcmeDetails(w http.ResponseWriter, _ *http.Request) {
+
+ var (
+ results AcmeDetailsResponseDTO
+ err error
+ )
+
+ cfToken, err := data.GetAcmeDNS01CloudflareToken()
+ results.CloudflareToken = (err == nil && cfToken.APIToken != "")
+
+ results.ProviderURL, _ = data.GetAcmeProvider()
+
+ results.Email, _ = data.GetAcmeEmail()
+
+ w.Header().Set("content-type", "application/json")
+ json.NewEncoder(w).Encode(results)
+}
+
+func (au *AdminUI) editAcmeEmail(w http.ResponseWriter, r *http.Request) {
+
+ var (
+ email StringDTO
+ err error
+ )
+
+ defer func() { au.respond(err, w) }()
+
+ err = json.NewDecoder(r.Body).Decode(&email)
+ if err != nil {
+ w.WriteHeader(http.StatusBadRequest)
+ return
+ }
+
+ err = data.SetAcmeEmail(email.Data)
+ if err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ return
+ }
+
+}
+
+func (au *AdminUI) editAcmeProvider(w http.ResponseWriter, r *http.Request) {
+
+ var (
+ provider StringDTO
+ err error
+ )
+
+ defer func() { au.respond(err, w) }()
+
+ err = json.NewDecoder(r.Body).Decode(&provider)
+ if err != nil {
+ w.WriteHeader(http.StatusBadRequest)
+ return
+ }
+
+ err = data.SetAcmeProvider(provider.Data)
+ if err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ return
+ }
+
+}
+
+func (au *AdminUI) editCloudflareApiToken(w http.ResponseWriter, r *http.Request) {
+
+ var (
+ token StringDTO
+ err error
+ )
+
+ defer func() { au.respond(err, w) }()
+
+ err = json.NewDecoder(r.Body).Decode(&token)
+ if err != nil {
+ w.WriteHeader(http.StatusBadRequest)
+ return
+ }
+
+ err = data.SetAcmeDNS01CloudflareToken(token.Data)
+ if err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ return
+ }
+
+}
diff --git a/adminui/structs.go b/adminui/structs.go
index 66cef7a9..df09364c 100644
--- a/adminui/structs.go
+++ b/adminui/structs.go
@@ -198,3 +198,18 @@ type ConfigResponseDTO struct {
SSO bool `json:"sso"`
Password bool `json:"password"`
}
+
+type WebServerConfigDTO struct {
+ ServerName string `json:"server_name"`
+ data.WebserverConfiguration
+}
+
+type AcmeDetailsResponseDTO struct {
+ ProviderURL string `json:"provider_url"`
+ Email string `json:"email"`
+ CloudflareToken bool `json:"api_token_set"`
+}
+
+type StringDTO struct {
+ Data string `json:"data"`
+}
diff --git a/adminui/ui_webserver.go b/adminui/ui_webserver.go
index 7e1cc41a..6c34e1d2 100644
--- a/adminui/ui_webserver.go
+++ b/adminui/ui_webserver.go
@@ -253,6 +253,14 @@ func New(firewall *router.Firewall, errs chan<- error) (ui *AdminUI, err error)
protectedRoutes.HandleFunc("GET /api/settings/login", adminUI.getLoginSettings)
protectedRoutes.HandleFunc("GET /api/settings/all_mfa_methods", adminUI.getAllMfaMethods)
+ protectedRoutes.HandleFunc("GET /api/settings/webservers", adminUI.getAllWebserverConfigs)
+ protectedRoutes.HandleFunc("PUT /api/settings/webserver", adminUI.editWebserverConfig)
+
+ protectedRoutes.HandleFunc("GET /api/settings/acme", adminUI.getAcmeDetails)
+ protectedRoutes.HandleFunc("PUT /api/settings/acme/email", adminUI.editAcmeEmail)
+ protectedRoutes.HandleFunc("PUT /api/settings/acme/provider_url", adminUI.editAcmeProvider)
+ protectedRoutes.HandleFunc("PUT /api/settings/acme/cloudflare_api_key", adminUI.editCloudflareApiToken)
+
notifications := make(chan NotificationDTO, 1)
protectedRoutes.HandleFunc("GET /api/notifications", adminUI.notificationsWS(notifications))
data.RegisterEventListener(data.NodeErrors, true, adminUI.receiveErrorNotifications(notifications))
diff --git a/docker-test-config.json b/docker-test-config.json
index 40a000d8..9678dcd3 100644
--- a/docker-test-config.json
+++ b/docker-test-config.json
@@ -16,12 +16,10 @@
"ManagementUI": {
"ListenAddress": "127.0.0.1:4433",
"Enabled": true,
- "Debug": false,
"Password": {
"Enabled": true
},
"OIDC": {
- "AdminDomainURL": "",
"IssuerURL": "",
"ClientSecret": "",
"ClientID": "",
@@ -33,6 +31,7 @@
"ListenAddress": ":8081"
},
"Tunnel": {
+ "Domain": "http://vpn.test:8080",
"Port": "8080"
}
},
@@ -42,7 +41,6 @@
"Methods": [
"totp"
],
- "DomainURL": "https://vpn.test:8080",
"OIDC": {
"IssuerURL": "",
"ClientSecret": "",
diff --git a/internal/autotls/certmagic.go b/internal/autotls/certmagic.go
index 647c15e1..759afdb9 100644
--- a/internal/autotls/certmagic.go
+++ b/internal/autotls/certmagic.go
@@ -73,9 +73,7 @@ func Initialise() error {
}
}
- if provider != "" && email != "" {
- config.Issuers = []certmagic.Issuer{issuer}
- }
+ config.Issuers = []certmagic.Issuer{issuer}
ret := &AutoTLS{
Config: config,
@@ -150,11 +148,7 @@ func (a *AutoTLS) registerEventListeners() error {
}
}
- errs := []error{a.refreshListeners(data.Tunnel, nil, nil),
- a.refreshListeners(data.Management, nil, nil),
- a.refreshListeners(data.Public, nil, nil)}
-
- return errors.Join(errs...)
+ return nil
})
if err != nil {
return err
@@ -167,17 +161,7 @@ func (a *AutoTLS) registerEventListeners() error {
a.issuer.Email = ""
}
- if a.issuer.CA == "" || a.issuer.Email == "" {
- a.Config.Issuers = []certmagic.Issuer{}
- } else {
- a.Config.Issuers = []certmagic.Issuer{a.issuer}
- }
-
- errs := []error{a.refreshListeners(data.Tunnel, nil, nil),
- a.refreshListeners(data.Management, nil, nil),
- a.refreshListeners(data.Public, nil, nil)}
-
- return errors.Join(errs...)
+ return nil
})
if err != nil {
return err
@@ -190,17 +174,7 @@ func (a *AutoTLS) registerEventListeners() error {
a.issuer.CA = ""
}
- if a.issuer.CA == "" || a.issuer.Email == "" {
- a.Config.Issuers = []certmagic.Issuer{}
- } else {
- a.Config.Issuers = []certmagic.Issuer{a.issuer}
- }
-
- errs := []error{a.refreshListeners(data.Tunnel, nil, nil),
- a.refreshListeners(data.Management, nil, nil),
- a.refreshListeners(data.Public, nil, nil)}
-
- return errors.Join(errs...)
+ return nil
})
if err != nil {
return err
@@ -273,10 +247,6 @@ func (a *AutoTLS) refreshListeners(forWhat data.Webserver, mux http.Handler, det
// if we have no domain, or tls is explicitly disabled ( or acme provider hasnt been configured )
// open an http only port on whatever the listen address is
if w.details.Domain == "" || !w.details.TLS || len(a.Issuers) == 0 {
- httpListener, err := net.Listen("tcp", w.details.ListenAddress)
- if err != nil {
- return err
- }
httpServer := &http.Server{
ReadHeaderTimeout: 10 * time.Second,
@@ -292,6 +262,11 @@ func (a *AutoTLS) refreshListeners(forWhat data.Webserver, mux http.Handler, det
}
w.listeners = []*http.Server{httpServer}
+ httpListener, err := net.Listen("tcp", w.details.ListenAddress)
+ if err != nil {
+ return err
+ }
+
go httpServer.Serve(httpListener)
} else {
err := a.Config.ManageSync(ctx, []string{w.details.Domain})
@@ -302,11 +277,6 @@ func (a *AutoTLS) refreshListeners(forWhat data.Webserver, mux http.Handler, det
tlsConfig := a.Config.TLSConfig()
tlsConfig.NextProtos = append([]string{"h2", "http/1.1"}, tlsConfig.NextProtos...)
- httpsLn, err := tls.Listen("tcp", fmt.Sprintf(w.details.ListenAddress), tlsConfig)
- if err != nil {
- return err
- }
-
httpsServer := &http.Server{
ReadHeaderTimeout: 10 * time.Second,
ReadTimeout: 30 * time.Second,
@@ -329,6 +299,10 @@ func (a *AutoTLS) refreshListeners(forWhat data.Webserver, mux http.Handler, det
w.listeners = append(w.listeners, httpsServer)
+ httpsLn, err := tls.Listen("tcp", fmt.Sprintf(w.details.ListenAddress), tlsConfig)
+ if err != nil {
+ return err
+ }
go httpsServer.Serve(httpsLn)
}
@@ -345,7 +319,9 @@ func (a *AutoTLS) autoRedirector(httpsServerListenAddr, domain string) (*http.Se
port = "443"
}
- httpRedirectListener, err := net.Listen("tcp", fmt.Sprintf("%s:80", host))
+ listenAddr := fmt.Sprintf("%s:80", host)
+
+ httpRedirectListener, err := net.Listen("tcp", listenAddr)
if err != nil {
return nil, err
}
diff --git a/internal/data/config.go b/internal/data/config.go
index d836423e..8b495461 100644
--- a/internal/data/config.go
+++ b/internal/data/config.go
@@ -79,6 +79,27 @@ type WebserverConfiguration struct {
TLS bool `json:"tls"`
}
+func GetAllWebserverConfigs() (details map[string]WebserverConfiguration, err error) {
+
+ details = make(map[string]WebserverConfiguration)
+ response, err := etcd.Get(context.Background(), WebServerConfigKey, clientv3.WithPrefix(), clientv3.WithSort(clientv3.SortByKey, clientv3.SortDescend))
+ if err != nil {
+ return nil, err
+ }
+
+ for _, res := range response.Kvs {
+ var conf WebserverConfiguration
+ err := json.Unmarshal(res.Value, &conf)
+ if err != nil {
+ return nil, err
+ }
+
+ details[strings.TrimPrefix(string(res.Key), WebServerConfigKey)] = conf
+ }
+
+ return details, nil
+}
+
func GetWebserverConfig(forWhat Webserver) (details WebserverConfiguration, err error) {
response, err := etcd.Get(context.Background(), WebServerConfigKey+string(forWhat))
@@ -96,6 +117,12 @@ func GetWebserverConfig(forWhat Webserver) (details WebserverConfiguration, err
func SetWebserverConfig(forWhat Webserver, details WebserverConfiguration) (err error) {
+ switch forWhat {
+ case Tunnel, Management, Public:
+ default:
+ return errors.New("unsupported webserver")
+ }
+
b, err := json.Marshal(details)
if err != nil {
return err
diff --git a/internal/data/events.go b/internal/data/events.go
index a958bfd5..c58f107b 100644
--- a/internal/data/events.go
+++ b/internal/data/events.go
@@ -164,14 +164,19 @@ func redact[T any](input T) (redacted []byte) {
}
values := reflect.ValueOf(current)
+ if values.Kind() == reflect.Pointer {
+ values = values.Elem()
+ }
if current.Kind() == reflect.Struct {
for i := 0; i < current.NumField(); i++ {
_, isSensitive := current.Field(i).Tag.Lookup("sensitive")
- if isSensitive && values.Field(i).CanSet() {
- values.Field(i).SetZero()
- } else {
- log.Println("cannot remove value for field, as cannot set")
+ if isSensitive {
+ if values.Field(i).CanSet() {
+ values.Field(i).SetZero()
+ } else {
+ log.Println("cannot remove value for field, as cannot set")
+ }
}
}
}
diff --git a/internal/data/init.go b/internal/data/init.go
index a511d5bf..7fb84bf9 100644
--- a/internal/data/init.go
+++ b/internal/data/init.go
@@ -324,7 +324,9 @@ func loadInitialSettings() error {
return err
}
- err = putIfNotFound(AcmeDNS01CloudflareAPIToken, config.Values.Acme.CloudflareDNSToken, "acme cloudflare dns api token")
+ var token CloudflareToken
+ token.APIToken = config.Values.Acme.CloudflareDNSToken
+ err = putIfNotFound(AcmeDNS01CloudflareAPIToken, token, "acme cloudflare dns api token")
if err != nil {
return err
}
diff --git a/internal/data/tls.go b/internal/data/tls.go
index 20247d94..e0eac78a 100644
--- a/internal/data/tls.go
+++ b/internal/data/tls.go
@@ -7,6 +7,7 @@ import (
"io/fs"
"path"
"strings"
+ "sync"
"github.com/caddyserver/certmagic"
clientv3 "go.etcd.io/etcd/client/v3"
@@ -39,6 +40,14 @@ func GetAcmeEmail() (string, error) {
return getString(AcmeEmailKey)
}
+func SetAcmeEmail(email string) error {
+ data, _ := json.Marshal(email)
+
+ _, err := etcd.Put(context.Background(), AcmeEmailKey, string(data))
+
+ return err
+}
+
func SetAcmeProvider(providerURL string) error {
if !strings.HasPrefix(providerURL, "https://") {
return errors.New("acme provider must start with https://")
@@ -56,11 +65,16 @@ func GetAcmeProvider() (string, error) {
type CertMagicStore struct {
basePath string
+
+ locks map[string]*concurrency.Mutex
+ mapMutex *sync.RWMutex
}
func NewCertStore(basePath string) *CertMagicStore {
return &CertMagicStore{
basePath: basePath,
+ locks: make(map[string]*concurrency.Mutex),
+ mapMutex: &sync.RWMutex{},
}
}
@@ -74,23 +88,56 @@ func (cms *CertMagicStore) Exists(ctx context.Context, key string) bool {
return res.Count > 1
}
+func (cms *CertMagicStore) lockPath(name string) string {
+ return path.Join(cms.basePath, "locks", certmagic.StorageKeys.Safe(name)+"-lock")
+}
+
func (cms *CertMagicStore) Lock(ctx context.Context, name string) error {
+
+ lockKey := cms.lockPath(name)
+ cms.mapMutex.RLock()
+ _, lockExists := cms.locks[lockKey]
+ cms.mapMutex.RUnlock()
+ if lockExists {
+ return nil
+ }
+
session, err := concurrency.NewSession(etcd, concurrency.WithContext(ctx))
if err != nil {
return err
}
- return concurrency.NewMutex(session, name).Lock(ctx)
+ mutex := concurrency.NewMutex(session, lockKey)
+ err = mutex.Lock(session.Client().Ctx())
+ if err != nil {
+ return err
+ }
+
+ cms.mapMutex.Lock()
+ cms.locks[lockKey] = mutex
+ cms.mapMutex.Unlock()
+
+ return nil
}
func (cms *CertMagicStore) Unlock(ctx context.Context, name string) error {
- session, err := concurrency.NewSession(etcd, concurrency.WithContext(ctx))
- if err != nil {
- return err
+
+ lockKey := cms.lockPath(name)
+
+ cms.mapMutex.RLock()
+ mutex, ok := cms.locks[lockKey]
+ cms.mapMutex.RUnlock()
+ if !ok {
+ return errors.New("mutex is not held")
}
- return concurrency.NewMutex(session, name).Unlock(ctx)
+ defer func() {
+ cms.mapMutex.Lock()
+ delete(cms.locks, lockKey)
+ cms.mapMutex.Unlock()
+ }()
+ return mutex.Unlock(ctx)
}
func (cms *CertMagicStore) Store(ctx context.Context, key string, value []byte) error {