diff --git a/internal/config/config.go b/internal/config/config.go index 64406fe3..db4d0a2f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -48,19 +48,6 @@ type Acls struct { Policies map[string]*acls.Acl } -func (a Acls) GetUserGroups(username string) (result []string) { - if values.Acls.rGroupLookup == nil { - return []string{} - } - - result = make([]string, 0, len(values.Acls.rGroupLookup[username])) - for group := range values.Acls.rGroupLookup[username] { - result = append(result, group) - } - - return -} - type Config struct { path string Socket string `json:",omitempty"` diff --git a/internal/data/config.go b/internal/data/config.go index dc782f5c..17945dfa 100644 --- a/internal/data/config.go +++ b/internal/data/config.go @@ -158,11 +158,11 @@ func GetAllSettings() (s Settings, err error) { } if response.Responses[6].GetResponseRange().Count == 1 { - s.HelpMail = string(response.Responses[6].GetResponseRange().Kvs[0].Value) + s.Issuer = string(response.Responses[6].GetResponseRange().Kvs[0].Value) } if response.Responses[7].GetResponseRange().Count == 1 { - s.ExternalAddress = string(response.Responses[7].GetResponseRange().Kvs[0].Value) + s.Domain = string(response.Responses[7].GetResponseRange().Kvs[0].Value) } return diff --git a/internal/data/devices.go b/internal/data/devices.go index d077063e..71f03cfa 100644 --- a/internal/data/devices.go +++ b/internal/data/devices.go @@ -6,12 +6,9 @@ import ( "errors" "fmt" "net" - "strconv" - "strings" "time" "github.com/NHAS/wag/internal/config" - "github.com/NHAS/wag/internal/utils" clientv3 "go.etcd.io/etcd/client/v3" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -28,25 +25,6 @@ type Device struct { Authorised time.Time } -func stringToUDPaddr(address string) (r *net.UDPAddr) { - parts := strings.Split(address, ":") - if len(parts) < 2 { - return nil - } - - port, err := strconv.Atoi(parts[len(parts)-1]) - if err != nil { - return nil - } - - r = &net.UDPAddr{ - IP: net.ParseIP(utils.GetIP(address)), - Port: port, - } - - return -} - func UpdateDeviceEndpoint(address string, endpoint *net.UDPAddr) error { realKey, err := etcd.Get(context.Background(), "deviceref-"+address) @@ -58,22 +36,22 @@ func UpdateDeviceEndpoint(address string, endpoint *net.UDPAddr) error { return errors.New("device was not found") } - return doSafeUpdate(context.Background(), string(realKey.Kvs[0].Value), func(gr *clientv3.GetResponse) (string, bool, error) { + return doSafeUpdate(context.Background(), string(realKey.Kvs[0].Value), func(gr *clientv3.GetResponse) (string, error) { if len(gr.Kvs) != 1 { - return "", false, errors.New("user device has multiple keys") + return "", errors.New("user device has multiple keys") } var device Device err := json.Unmarshal(gr.Kvs[0].Value, &device) if err != nil { - return "", false, err + return "", err } device.Endpoint = endpoint b, _ := json.Marshal(device) - return string(b), false, err + return string(b), err }) } @@ -99,25 +77,25 @@ func GetDevice(username, id string) (device Device, err error) { // Set device as authorized and clear authentication attempts func AuthoriseDevice(username, address string) error { - return doSafeUpdate(context.Background(), deviceKey(username, address), func(gr *clientv3.GetResponse) (string, bool, error) { + return doSafeUpdate(context.Background(), deviceKey(username, address), func(gr *clientv3.GetResponse) (string, error) { if len(gr.Kvs) != 1 { - return "", false, errors.New("user device has multiple keys") + return "", errors.New("user device has multiple keys") } var device Device err := json.Unmarshal(gr.Kvs[0].Value, &device) if err != nil { - return "", false, err + return "", err } u, err := GetUserData(device.Username) if err != nil { // We may want to make this lock the device if the user is not found. At the moment settle with doing nothing - return "", false, err + return "", err } if u.Locked { - return "", false, errors.New("account is locked") + return "", errors.New("account is locked") } device.Authorised = time.Now() @@ -125,7 +103,7 @@ func AuthoriseDevice(username, address string) error { b, _ := json.Marshal(device) - return string(b), false, err + return string(b), err }) } @@ -140,42 +118,42 @@ func DeauthenticateDevice(address string) error { return errors.New("device was not found") } - return doSafeUpdate(context.Background(), string(realKey.Kvs[0].Value), func(gr *clientv3.GetResponse) (string, bool, error) { + return doSafeUpdate(context.Background(), string(realKey.Kvs[0].Value), func(gr *clientv3.GetResponse) (string, error) { if len(gr.Kvs) != 1 { - return "", false, errors.New("user device has multiple keys") + return "", errors.New("user device has multiple keys") } var device Device err := json.Unmarshal(gr.Kvs[0].Value, &device) if err != nil { - return "", false, err + return "", err } device.Authorised = time.Time{} b, _ := json.Marshal(device) - return string(b), false, err + return string(b), err }) } func SetDeviceAuthenticationAttempts(username, address string, attempts int) error { - return doSafeUpdate(context.Background(), deviceKey(username, address), func(gr *clientv3.GetResponse) (string, bool, error) { + return doSafeUpdate(context.Background(), deviceKey(username, address), func(gr *clientv3.GetResponse) (string, error) { if len(gr.Kvs) != 1 { - return "", false, errors.New("user device has multiple keys") + return "", errors.New("user device has multiple keys") } var device Device err := json.Unmarshal(gr.Kvs[0].Value, &device) if err != nil { - return "", false, err + return "", err } device.Attempts = attempts b, _ := json.Marshal(device) - return string(b), false, err + return string(b), err }) } @@ -326,22 +304,22 @@ func UpdateDevicePublicKey(username, address string, publicKey wgtypes.Key) erro return err } - err = doSafeUpdate(context.Background(), deviceKey(username, address), func(gr *clientv3.GetResponse) (string, bool, error) { + err = doSafeUpdate(context.Background(), deviceKey(username, address), func(gr *clientv3.GetResponse) (string, error) { if len(gr.Kvs) != 1 { - return "", false, errors.New("user device has multiple keys") + return "", errors.New("user device has multiple keys") } var device Device err := json.Unmarshal(gr.Kvs[0].Value, &device) if err != nil { - return "", false, err + return "", err } device.Publickey = publicKey.String() b, _ := json.Marshal(device) - return string(b), false, err + return string(b), err }) if err != nil { diff --git a/internal/data/groups.go b/internal/data/groups.go index 61597194..98b17a5f 100644 --- a/internal/data/groups.go +++ b/internal/data/groups.go @@ -9,6 +9,7 @@ import ( "github.com/NHAS/wag/pkg/control" clientv3 "go.etcd.io/etcd/client/v3" + "golang.org/x/exp/maps" ) func SetGroup(group string, members []string, overwrite bool) error { @@ -36,16 +37,16 @@ func SetGroup(group string, members []string, overwrite bool) error { } } - err = doSafeUpdate(context.Background(), "wag-membership", func(gr *clientv3.GetResponse) (value string, onErrwrite bool, err error) { + err = doSafeUpdate(context.Background(), "wag-membership", func(gr *clientv3.GetResponse) (value string, err error) { if len(gr.Kvs) != 1 { - return "", false, errors.New("bad number of membership keys") + return "", errors.New("bad number of membership keys") } var rGroupLookup map[string]map[string]bool err = json.Unmarshal(gr.Kvs[0].Value, &rGroupLookup) if err != nil { - return "", false, err + return "", err } for _, member := range oldMembers { @@ -62,7 +63,7 @@ func SetGroup(group string, members []string, overwrite bool) error { reverseMappingJson, _ := json.Marshal(rGroupLookup) - return string(reverseMappingJson), false, nil + return string(reverseMappingJson), nil }) return err @@ -111,16 +112,16 @@ func RemoveGroup(groupName string) error { } } - err = doSafeUpdate(context.Background(), "wag-membership", func(gr *clientv3.GetResponse) (value string, onErrwrite bool, err error) { + err = doSafeUpdate(context.Background(), "wag-membership", func(gr *clientv3.GetResponse) (value string, err error) { if len(gr.Kvs) != 1 { - return "", false, errors.New("bad number of membership keys") + return "", errors.New("bad number of membership keys") } var rGroupLookup map[string]map[string]bool err = json.Unmarshal(gr.Kvs[0].Value, &rGroupLookup) if err != nil { - return "", false, err + return "", err } for _, member := range oldMembers { @@ -129,8 +130,29 @@ func RemoveGroup(groupName string) error { reverseMappingJson, _ := json.Marshal(rGroupLookup) - return string(reverseMappingJson), false, nil + return string(reverseMappingJson), nil }) return err } + +func GetUserGroupMembership(username string) ([]string, error) { + + response, err := etcd.Get(context.Background(), "wag-membership") + if err != nil { + return nil, err + } + + var rGroupLookup map[string]map[string]bool + + err = json.Unmarshal(response.Kvs[0].Value, &rGroupLookup) + if err != nil { + return nil, err + } + + if rGroupLookup[username] == nil { + return []string{}, nil + } + + return maps.Keys(rGroupLookup[username]), nil +} diff --git a/internal/data/init.go b/internal/data/init.go index 8ff49315..374b29a1 100644 --- a/internal/data/init.go +++ b/internal/data/init.go @@ -369,7 +369,7 @@ func TearDown() { } } -func doSafeUpdate(ctx context.Context, key string, mutateFunc func(*clientv3.GetResponse) (value string, onErrwrite bool, err error)) error { +func doSafeUpdate(ctx context.Context, key string, mutateFunc func(*clientv3.GetResponse) (value string, err error)) error { //https://github.com/kubernetes/kubernetes/blob/master/staging/src/k8s.io/apiserver/pkg/storage/etcd3/store.go#L382 opts := []clientv3.OpOption{} @@ -387,8 +387,8 @@ func doSafeUpdate(ctx context.Context, key string, mutateFunc func(*clientv3.Get return errors.New("no record found") } - newValue, onErrwrite, err := mutateFunc(origState) - if err != nil && !onErrwrite { + newValue, err := mutateFunc(origState) + if err != nil { return err } diff --git a/internal/data/registration.go b/internal/data/registration.go index 0128b09e..43f2b992 100644 --- a/internal/data/registration.go +++ b/internal/data/registration.go @@ -78,12 +78,12 @@ func DeleteRegistrationToken(identifier string) error { func FinaliseRegistration(token string) error { errVal := errors.New("registration token has expired") - err := doSafeUpdate(context.Background(), "tokens-"+token, func(gr *clientv3.GetResponse) (string, bool, error) { + err := doSafeUpdate(context.Background(), "tokens-"+token, func(gr *clientv3.GetResponse) (string, error) { var result control.RegistrationResult err := json.Unmarshal(gr.Kvs[0].Value, &result) if err != nil { - return "", false, err + return "", err } result.NumUses-- @@ -94,7 +94,7 @@ func FinaliseRegistration(token string) error { b, _ := json.Marshal(result) - return string(b), false, err + return string(b), err }) if err == errVal { diff --git a/internal/data/sql_compat.go b/internal/data/sql_compat.go index 60d7b2fa..0e21aaa3 100644 --- a/internal/data/sql_compat.go +++ b/internal/data/sql_compat.go @@ -3,7 +3,11 @@ package data import ( "database/sql" "encoding/json" + "net" + "strconv" + "strings" + "github.com/NHAS/wag/internal/utils" "github.com/NHAS/wag/pkg/control" ) @@ -119,3 +123,22 @@ func sqlGetAllDevices() (devices []Device, err error) { return devices, nil } + +func stringToUDPaddr(address string) (r *net.UDPAddr) { + parts := strings.Split(address, ":") + if len(parts) < 2 { + return nil + } + + port, err := strconv.Atoi(parts[len(parts)-1]) + if err != nil { + return nil + } + + r = &net.UDPAddr{ + IP: net.ParseIP(utils.GetIP(address)), + Port: port, + } + + return +} diff --git a/internal/data/ui.go b/internal/data/ui.go index f648cc03..7f2cdd14 100644 --- a/internal/data/ui.go +++ b/internal/data/ui.go @@ -41,6 +41,35 @@ func generateSalt() ([]byte, error) { return randomData, nil } +func IncrementAdminAuthenticationAttempt(username string) error { + return doSafeUpdate(context.Background(), "admin-users-"+username, func(gr *clientv3.GetResponse) (value string, err error) { + + if len(gr.Kvs) != 1 { + return "", errors.New("invalid number of admin keys") + } + + var admin admin + err = json.Unmarshal(gr.Kvs[0].Value, &admin) + if err != nil { + return "", err + } + + l, err := GetLockout() + if err != nil { + return "", err + } + + if admin.Attempts < l { + admin.Attempts++ + } + + b, _ := json.Marshal(admin) + + return string(b), nil + + }) +} + func CreateAdminUser(username, password string, changeOnFirstUse bool) error { if len(password) < minPasswordLength { return fmt.Errorf("password is too short for administrative console (must be greater than %d characters)", minPasswordLength) @@ -80,39 +109,38 @@ func CompareAdminKeys(username, password string) error { subtle.ConstantTimeCompare(hash, hash) } - err := doSafeUpdate(context.Background(), "admin-users-"+username, func(gr *clientv3.GetResponse) (string, bool, error) { + err := doSafeUpdate(context.Background(), "admin-users-"+username, func(gr *clientv3.GetResponse) (string, error) { var result admin err := json.Unmarshal(gr.Kvs[0].Value, &result) if err != nil { - return "", false, err + return "", err } - if result.Attempts >= 5 { + lockout, err := GetLockout() + if err != nil { + return "", err + } + if result.Attempts >= lockout { wasteTime() - return "", false, errors.New("account locked") + return "", errors.New("account locked") } rawHashSalt, err := base64.RawStdEncoding.DecodeString(result.Hash) if err != nil { - return "", false, err + return "", err } thisHash := argon2.IDKey([]byte(password), rawHashSalt[len(rawHashSalt)-16:], 1, 10*1024, 4, 32) if subtle.ConstantTimeCompare(thisHash, rawHashSalt[:len(rawHashSalt)-16]) != 1 { - result.Attempts++ - - b, _ := json.Marshal(result) - - // For this specific error we need to write the attempts to the entry - return string(b), true, errors.New("passwords did not match") + return "", errors.New("passwords did not match") } result.Attempts = 0 b, _ := json.Marshal(result) - return string(b), false, nil + return string(b), nil }) return err @@ -121,19 +149,20 @@ func CompareAdminKeys(username, password string) error { // Lock admin account and make them unable to login func SetAdminUserLock(username string) error { - return doSafeUpdate(context.Background(), "admin-users-"+username, func(gr *clientv3.GetResponse) (string, bool, error) { + return doSafeUpdate(context.Background(), "admin-users-"+username, func(gr *clientv3.GetResponse) (string, error) { var result admin err := json.Unmarshal(gr.Kvs[0].Value, &result) if err != nil { - return "", false, err + return "", err } - result.Attempts = 6 - - result.Attempts = 0 + result.Attempts, err = GetLockout() + if err != nil { + return "", err + } b, _ := json.Marshal(result) - return string(b), false, nil + return string(b), nil }) } @@ -141,19 +170,17 @@ func SetAdminUserLock(username string) error { // Unlock admin account func SetAdminUserUnlock(username string) error { - return doSafeUpdate(context.Background(), "admin-users-"+username, func(gr *clientv3.GetResponse) (string, bool, error) { + return doSafeUpdate(context.Background(), "admin-users-"+username, func(gr *clientv3.GetResponse) (string, error) { var result admin err := json.Unmarshal(gr.Kvs[0].Value, &result) if err != nil { - return "", false, err + return "", err } - result.Attempts = 0 - result.Attempts = 0 b, _ := json.Marshal(result) - return string(b), false, nil + return string(b), nil }) } @@ -215,16 +242,16 @@ func SetAdminPassword(username, password string) error { hash := argon2.IDKey([]byte(password), salt, 1, 10*1024, 4, 32) - return doSafeUpdate(context.Background(), "admin-users-"+username, func(gr *clientv3.GetResponse) (value string, onErrwrite bool, err error) { + return doSafeUpdate(context.Background(), "admin-users-"+username, func(gr *clientv3.GetResponse) (value string, err error) { if len(gr.Kvs) != 1 { - return "", false, errors.New("invalid number of admin users") + return "", errors.New("invalid number of admin users") } var admin admin err = json.Unmarshal(gr.Kvs[0].Value, &admin) if err != nil { - return "", false, err + return "", err } admin.Change = false @@ -232,23 +259,23 @@ func SetAdminPassword(username, password string) error { b, _ := json.Marshal(admin) - return string(b), false, nil + return string(b), nil }) } func setAdminHash(username, hash string) error { - return doSafeUpdate(context.Background(), "admin-users-"+username, func(gr *clientv3.GetResponse) (value string, onErrwrite bool, err error) { + return doSafeUpdate(context.Background(), "admin-users-"+username, func(gr *clientv3.GetResponse) (value string, err error) { if len(gr.Kvs) != 1 { - return "", false, errors.New("invalid number of admin users") + return "", errors.New("invalid number of admin users") } var admin admin err = json.Unmarshal(gr.Kvs[0].Value, &admin) if err != nil { - return "", false, err + return "", err } admin.Change = false @@ -256,22 +283,22 @@ func setAdminHash(username, hash string) error { b, _ := json.Marshal(admin) - return string(b), false, nil + return string(b), nil }) } func SetLastLoginInformation(username, ip string) error { - return doSafeUpdate(context.Background(), "admin-users-"+username, func(gr *clientv3.GetResponse) (value string, onErrwrite bool, err error) { + return doSafeUpdate(context.Background(), "admin-users-"+username, func(gr *clientv3.GetResponse) (value string, err error) { if len(gr.Kvs) != 1 { - return "", false, errors.New("invalid number of admin users") + return "", errors.New("invalid number of admin users") } var admin admin err = json.Unmarshal(gr.Kvs[0].Value, &admin) if err != nil { - return "", false, err + return "", err } admin.LastLogin = time.Now().Format(time.RFC3339) @@ -279,7 +306,7 @@ func SetLastLoginInformation(username, ip string) error { b, _ := json.Marshal(admin) - return string(b), false, nil + return string(b), nil }) diff --git a/internal/data/user.go b/internal/data/user.go index d71d48c1..6d8b19fc 100644 --- a/internal/data/user.go +++ b/internal/data/user.go @@ -24,21 +24,21 @@ func (um *UserModel) GetID() [20]byte { // Make sure that the attempts is always incremented first to stop race condition attacks func IncrementAuthenticationAttempt(username, device string) error { - return doSafeUpdate(context.Background(), deviceKey(username, device), func(gr *clientv3.GetResponse) (value string, onErrwrite bool, err error) { + return doSafeUpdate(context.Background(), deviceKey(username, device), func(gr *clientv3.GetResponse) (value string, err error) { if len(gr.Kvs) != 1 { - return "", false, errors.New("invalid number of users") + return "", errors.New("invalid number of users") } var userDevice Device err = json.Unmarshal(gr.Kvs[0].Value, &userDevice) if err != nil { - return "", false, err + return "", err } l, err := GetLockout() if err != nil { - return "", false, err + return "", err } if userDevice.Attempts < l { @@ -47,7 +47,7 @@ func IncrementAuthenticationAttempt(username, device string) error { b, _ := json.Marshal(userDevice) - return string(b), false, nil + return string(b), nil }) } @@ -97,18 +97,18 @@ func GetAuthenticationDetails(username, device string) (mfa, mfaType string, att // Disable authentication for user func SetUserLock(username string) error { - err := doSafeUpdate(context.Background(), "users-"+username+"-", func(gr *clientv3.GetResponse) (string, bool, error) { + err := doSafeUpdate(context.Background(), "users-"+username+"-", func(gr *clientv3.GetResponse) (string, error) { var result UserModel err := json.Unmarshal(gr.Kvs[0].Value, &result) if err != nil { - return "", false, err + return "", err } result.Locked = true b, _ := json.Marshal(result) - return string(b), false, nil + return string(b), nil }) if err != nil { @@ -119,18 +119,18 @@ func SetUserLock(username string) error { } func SetUserUnlock(username string) error { - err := doSafeUpdate(context.Background(), "users-"+username+"-", func(gr *clientv3.GetResponse) (string, bool, error) { + err := doSafeUpdate(context.Background(), "users-"+username+"-", func(gr *clientv3.GetResponse) (string, error) { var result UserModel err := json.Unmarshal(gr.Kvs[0].Value, &result) if err != nil { - return "", false, err + return "", err } result.Locked = false b, _ := json.Marshal(result) - return string(b), false, nil + return string(b), nil }) if err != nil { @@ -164,36 +164,36 @@ func IsEnforcingMFA(username string) bool { // Stop displaying MFA secrets for user func SetEnforceMFAOn(username string) error { - return doSafeUpdate(context.Background(), "users-"+username+"-", func(gr *clientv3.GetResponse) (string, bool, error) { + return doSafeUpdate(context.Background(), "users-"+username+"-", func(gr *clientv3.GetResponse) (string, error) { var result UserModel err := json.Unmarshal(gr.Kvs[0].Value, &result) if err != nil { - return "", false, err + return "", err } result.Enforcing = true b, _ := json.Marshal(result) - return string(b), false, nil + return string(b), nil }) } func SetEnforceMFAOff(username string) error { - return doSafeUpdate(context.Background(), "users-"+username+"-", func(gr *clientv3.GetResponse) (string, bool, error) { + return doSafeUpdate(context.Background(), "users-"+username+"-", func(gr *clientv3.GetResponse) (string, error) { var result UserModel err := json.Unmarshal(gr.Kvs[0].Value, &result) if err != nil { - return "", false, err + return "", err } result.Enforcing = false b, _ := json.Marshal(result) - return string(b), false, nil + return string(b), nil }) } @@ -302,11 +302,11 @@ func GetUserDataFromAddress(address string) (u UserModel, err error) { func SetUserMfa(username, value, mfaType string) error { - return doSafeUpdate(context.Background(), "users-"+username+"-", func(gr *clientv3.GetResponse) (string, bool, error) { + return doSafeUpdate(context.Background(), "users-"+username+"-", func(gr *clientv3.GetResponse) (string, error) { var result UserModel err := json.Unmarshal(gr.Kvs[0].Value, &result) if err != nil { - return "", false, err + return "", err } result.Mfa = value @@ -314,7 +314,7 @@ func SetUserMfa(username, value, mfaType string) error { b, _ := json.Marshal(result) - return string(b), false, nil + return string(b), nil }) } diff --git a/internal/router/bpf.go b/internal/router/bpf.go index 63c0b74a..28d5f85a 100644 --- a/internal/router/bpf.go +++ b/internal/router/bpf.go @@ -300,6 +300,7 @@ func SetLockAccount(username string, locked uint32) error { // Takes the LPM table and associates a route to a policy func xdpAddRoute(usersRouteTable *ebpf.Map, userAcls acls.Acl) error { + rules, err := routetypes.ParseRules(userAcls.Mfa, userAcls.Allow, userAcls.Deny) if err != nil { return err @@ -359,6 +360,7 @@ func AddUser(username string, acls acls.Acl) error { func setSingleUserMap(userid [20]byte, acls acls.Acl) error { // Adds LPM trie to existing map (hashmap to map) + // Or if we have an existing map, update it if _, ok := userPolicyMaps[userid]; !ok { policiesInnerTable, err := ebpf.NewMap(routesMapSpec) @@ -366,7 +368,7 @@ func setSingleUserMap(userid [20]byte, acls acls.Acl) error { return fmt.Errorf("%s creating new map: %s", xdpObjects.PoliciesTable.String(), err) } - err = xdpObjects.PoliciesTable.Put(userid, uint32(policiesInnerTable.FD())) + err = xdpObjects.PoliciesTable.Update(userid, uint32(policiesInnerTable.FD()), ebpf.UpdateNoExist) if err != nil { return fmt.Errorf("%s adding new map to table: %s", xdpObjects.PoliciesTable.String(), err) } @@ -374,13 +376,48 @@ func setSingleUserMap(userid [20]byte, acls acls.Acl) error { userPolicyMaps[userid] = policiesInnerTable } - if err := xdpAddRoute(userPolicyMaps[userid], acls); err != nil { + mapRef := userPolicyMaps[userid] + if err := clearPolicyMap(mapRef); err != nil { + return err + } + + if err := xdpAddRoute(mapRef, acls); err != nil { return err } return nil } +func clearPolicyMap(toClear *ebpf.Map) error { + var ( + lastKey []byte + err error + ) + + // Due to type inference we cant just set lastKey to nil to get the first key + lastKey, err = toClear.NextKeyBytes(nil) + if err != nil { + return err + } + + for { + + if lastKey == nil { + return nil + } + + err = toClear.Delete(lastKey) + if err != nil && err != ebpf.ErrKeyNotExist { + return err + } + + lastKey, err = toClear.NextKeyBytes(lastKey) + if err != nil { + return err + } + } +} + // I've tried my hardest not to make this stateful. But alas we must cache the user policy maps or things become unreasonbly slow // If someone has a better way of doing this. Please for the love of god pipe up // https://github.com/cilium/ebpf/discussions/1297 @@ -394,8 +431,6 @@ func bulkCreateUserMaps(users []data.UserModel) []error { maps = map[string]*ebpf.Map{} ) - x := 0 - for _, user := range users { userid := sha1.Sum([]byte(user.Username)) @@ -403,13 +438,15 @@ func bulkCreateUserMaps(users []data.UserModel) []error { // This speeds up things like refresh acls, but not wag start up if policiesInnerTable, ok := userPolicyMaps[userid]; ok { + if err := clearPolicyMap(policiesInnerTable); err != nil { + errors = append(errors, err) + continue + } + err := xdpAddRoute(policiesInnerTable, data.GetEffectiveAcl(user.Username)) if err != nil { errors = append(errors, err) - } else { - x++ - continue } } @@ -437,7 +474,7 @@ func bulkCreateUserMaps(users []data.UserModel) []error { } n, err := xdpObjects.PoliciesTable.BatchUpdate(keys, values, &ebpf.BatchOptions{ - Flags: uint64(ebpf.UpdateAny), + Flags: uint64(ebpf.UpdateNoExist), }) if err != nil { @@ -448,6 +485,7 @@ func bulkCreateUserMaps(users []data.UserModel) []error { return []error{fmt.Errorf("batch update could not write all keys to map: expected %d got %d", len(keys), n)} } + // As we created maps for this, we dont need to clear things for username, m := range maps { err := xdpAddRoute(m, data.GetEffectiveAcl(username)) if err != nil { diff --git a/ui/ui_webserver.go b/ui/ui_webserver.go index b0a04849..6583d7b5 100644 --- a/ui/ui_webserver.go +++ b/ui/ui_webserver.go @@ -91,6 +91,14 @@ func doLogin(w http.ResponseWriter, r *http.Request) { return } + err = data.IncrementAdminAuthenticationAttempt(r.Form.Get("username")) + if err != nil { + log.Println("admin login failed for user", r.Form.Get("username"), ": ", err) + + render(w, r, Login{ErrorMessage: "Unable to login"}, "templates/login.html") + return + } + err = data.CompareAdminKeys(r.Form.Get("username"), r.Form.Get("password")) if err != nil { log.Println("admin login failed for user", r.Form.Get("username"), ": ", err) @@ -199,6 +207,15 @@ func populateDashboard(w http.ResponseWriter, r *http.Request) { return } + s, err := data.GetAllSettings() + if err != nil { + log.Println("error getting server settings: ", err) + + w.WriteHeader(http.StatusInternalServerError) + renderDefaults(w, r, nil, "error.html") + return + } + d := Dashboard{ Page: Page{ Update: getUpdate(), @@ -210,7 +227,7 @@ func populateDashboard(w http.ResponseWriter, r *http.Request) { Port: port, PublicKey: pubkey.String(), - ExternalAddress: config.Values().ExternalAddress, + ExternalAddress: s.ExternalAddress, Subnet: config.Values().Wireguard.Range.String(), NumUsers: len(allUsers), @@ -1197,14 +1214,19 @@ func manageUsers(w http.ResponseWriter, r *http.Request) { return } - data := []UsersData{} + usersData := []UsersData{} for _, u := range users { devices, _ := ctrl.ListDevice(u.Username) - groups := append([]string{"*"}, config.Values().Acls.GetUserGroups(u.Username)...) + groups, err := data.GetUserGroupMembership(u.Username) + if err != nil { + log.Println("unable to get users groups: ", err) + http.Error(w, "Server error", 500) + return + } - data = append(data, UsersData{ + usersData = append(usersData, UsersData{ Username: u.Username, Locked: u.Locked, Devices: len(devices), @@ -1213,7 +1235,7 @@ func manageUsers(w http.ResponseWriter, r *http.Request) { }) } - b, err := json.Marshal(data) + b, err := json.Marshal(usersData) if err != nil { log.Println("unable to marshal users data: ", err) http.Error(w, "Server error", 500)