Skip to content

Commit

Permalink
First working and clustering version
Browse files Browse the repository at this point in the history
  • Loading branch information
NHAS committed Jan 24, 2024
1 parent 41e035e commit f863db9
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 67 deletions.
5 changes: 5 additions & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ type Config struct {
DatabaseLocation string
ETCDLogLevel string
Witness bool
ClusterState string
}

Authenticators struct {
Expand Down Expand Up @@ -223,6 +224,10 @@ func load(path string) (c Config, err error) {
c.Clustering.ListenAddresses = []string{"http://localhost:2380"}
}

if c.Clustering.ClusterState == "" {
c.Clustering.ClusterState = "new"
}

if c.NAT == nil {
c.NAT = new(bool)
*c.NAT = true
Expand Down
23 changes: 23 additions & 0 deletions internal/data/clustering.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package data

import "go.etcd.io/etcd/server/v3/etcdserver/api/membership"

func GetServerID() string {
return etcdServer.Server.ID().String()
}

func HasLeader() bool {
return etcdServer.Server.Leader() != 0
}

func IsLearner() bool {
return etcdServer.Server.IsLearner()
}

func IsLeader() bool {
return etcdServer.Server.Leader() == etcdServer.Server.ID()
}

func GetMembers() []*membership.Member {
return etcdServer.Server.Cluster().Members()
}
14 changes: 7 additions & 7 deletions internal/data/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,19 +210,19 @@ func checkClusterHealth() {

select {
case <-etcdServer.Server.LeaderChangedNotify():
notfyHealthy()

execWatchers(clusterHealthWatchers, "changed", 0)
case <-time.After(1 * time.Second):
leader := etcdServer.Server.Leader()
if leader == 0 {
execWatchers(clusterHealthWatchers, "electing", 0)
<-time.After(etcdServer.Server.Cfg.ElectionTimeout() * 2)
leader = etcdServer.Server.Leader()
}

if leader != 0 {
notfyHealthy()
} else {
execWatchers(clusterHealthWatchers, "dead", 0)
if leader == 0 {
execWatchers(clusterHealthWatchers, "dead", 0)
} else {
notfyHealthy()
}
}

}
Expand Down
14 changes: 9 additions & 5 deletions internal/data/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"log"
"net/url"
"os"
"path/filepath"
"regexp"
"strconv"
Expand Down Expand Up @@ -49,12 +50,13 @@ func parseUrls(values ...string) []url.URL {
func Load(path string) error {

doMigration := true
db, err := sql.Open("sqlite3", path)
if err != nil {
if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) {
doMigration = false
}

if err == nil {
var db *sql.DB
if doMigration {
db, _ = sql.Open("sqlite3", path)

defer db.Close()

Expand All @@ -77,7 +79,6 @@ func Load(path string) error {
if err != nil {
return err
}

}

part, err := generateRandomBytes(10)
Expand All @@ -88,6 +89,7 @@ func Load(path string) error {

cfg := embed.NewConfig()
cfg.Name = config.Values.Clustering.Name
cfg.ClusterState = config.Values.Clustering.ClusterState
cfg.InitialClusterToken = "wag-test"
cfg.LogLevel = config.Values.Clustering.ETCDLogLevel
cfg.ListenPeerUrls = parseUrls(config.Values.Clustering.ListenAddresses...)
Expand All @@ -103,9 +105,11 @@ func Load(path string) error {

cfg.InitialCluster = ""
for tag, addresses := range peers {
cfg.InitialCluster += fmt.Sprintf("%s=%s", tag, strings.Join(addresses, ","))
cfg.InitialCluster += fmt.Sprintf("%s=%s", tag, strings.Join(addresses, ",")) + ","
}

cfg.InitialCluster = cfg.InitialCluster[:len(cfg.InitialCluster)-1]

cfg.Dir = filepath.Join(config.Values.Clustering.DatabaseLocation, config.Values.Clustering.Name+".wag-node.etcd")
etcdServer, err = embed.StartEtcd(cfg)
if err != nil {
Expand Down
102 changes: 58 additions & 44 deletions internal/router/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ import (
"golang.org/x/sys/unix"
)

var lock sync.RWMutex
var (
lock sync.RWMutex
cancel = make(chan bool)
)

func Setup(errorChan chan<- error, iptables bool) (err error) {

Expand Down Expand Up @@ -64,52 +67,56 @@ func Setup(errorChan chan<- error, iptables bool) (err error) {

for {

dev, err := ctrl.Device(config.Values.Wireguard.DevName)
if err != nil {
errorChan <- fmt.Errorf("endpoint watcher: %s", err)
select {
case <-cancel:
return
}
case <-time.After(100 * time.Millisecond):
dev, err := ctrl.Device(config.Values.Wireguard.DevName)
if err != nil {
errorChan <- fmt.Errorf("endpoint watcher: %s", err)
return
}

for _, p := range dev.Peers {
for _, p := range dev.Peers {

if len(p.AllowedIPs) != 1 {
log.Println("Warning, peer ", p.PublicKey.String(), " len(p.AllowedIPs) != 1, which is not supported")
continue
}
if len(p.AllowedIPs) != 1 {
log.Println("Warning, peer ", p.PublicKey.String(), " len(p.AllowedIPs) != 1, which is not supported")
continue
}

ip := p.AllowedIPs[0].IP.String()
ip := p.AllowedIPs[0].IP.String()

if cache[ip] != p.Endpoint.String() {
cache[ip] = p.Endpoint.String()
if cache[ip] != p.Endpoint.String() {
cache[ip] = p.Endpoint.String()

d, err := data.GetDeviceByAddress(ip)
if err != nil {
log.Println("unable to get previous device endpoint for ", ip, err)
if err := Deauthenticate(ip); err != nil {
log.Println(ip, "unable to remove forwards for device: ", err)
d, err := data.GetDeviceByAddress(ip)
if err != nil {
log.Println("unable to get previous device endpoint for ", ip, err)
if err := Deauthenticate(ip); err != nil {
log.Println(ip, "unable to remove forwards for device: ", err)
}
continue
}
continue
}

err = data.UpdateDeviceEndpoint(p.AllowedIPs[0].IP.String(), p.Endpoint)
if err != nil {
log.Println(ip, "unable to update device endpoint: ", err)
}
err = data.UpdateDeviceEndpoint(p.AllowedIPs[0].IP.String(), p.Endpoint)
if err != nil {
log.Println(ip, "unable to update device endpoint: ", err)
}

//Dont try and remove rules, if we've just started
if !startup {
log.Println(ip, "endpoint changed", d.Endpoint.String(), "->", p.Endpoint.String())
if err := Deauthenticate(ip); err != nil {
log.Println(ip, "unable to remove forwards for device: ", err)
//Dont try and remove rules, if we've just started
if !startup {
log.Println(ip, "endpoint changed", d.Endpoint.String(), "->", p.Endpoint.String())
if err := Deauthenticate(ip); err != nil {
log.Println(ip, "unable to remove forwards for device: ", err)
}
}
}

}

startup = false
}

startup = false

time.Sleep(100 * time.Millisecond)
}
}()

Expand All @@ -133,6 +140,24 @@ func Setup(errorChan chan<- error, iptables bool) (err error) {

func TearDown() {

cancel <- true

log.Println("Removing wireguard device")
conn, err := netlink.Dial(unix.NETLINK_ROUTE, nil)
if err != nil {
log.Println("Unable to remove wireguard device, netlink connection failed: ", err.Error())
return
}
defer conn.Close()

err = delWg(conn, config.Values.Wireguard.DevName)
if err != nil {
log.Println("Unable to remove wireguard device, delete failed: ", err.Error())
return
}

log.Println("Wireguard device removed")

log.Println("Removing Firewall rules...")

ipt, err := iptables.New()
Expand Down Expand Up @@ -200,17 +225,6 @@ func TearDown() {
log.Println("Unable to clean up firewall rules: ", err)
}

conn, err := netlink.Dial(unix.NETLINK_ROUTE, nil)
if err != nil {
log.Println("Unable to remove wireguard device, netlink connection failed: ", err.Error())
return
}
defer conn.Close()

err = delWg(conn, config.Values.Wireguard.DevName)
if err != nil {
log.Println("Unable to remove wireguard device, delete failed: ", err.Error())
return
}
log.Println("Firewall rules removed.")

}
28 changes: 17 additions & 11 deletions internal/router/statemachine.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,7 @@ func deviceChanges(device data.BasicEvent[data.Device], state int) {
}

if device.CurrentValue.Authorised != device.Previous.Authorised {
log.Println("authorisation state changed on device")

if !device.CurrentValue.Authorised.IsZero() && device.CurrentValue.Attempts <= lockout {

log.Println("authorising device")
err := SetAuthorized(device.CurrentValue.Address, device.CurrentValue.Username)
if err != nil {
log.Println(err)
Expand Down Expand Up @@ -142,17 +138,27 @@ func groupChanges(groupChange data.TargettedEvent[[]string], state int) {

func clusterState(errorsChan chan<- error) data.ClusterHealthFunc {

hasDied := false
return func(stateText string, state int) {
log.Println("entered state: ", stateText)

switch stateText {
case "dead":
log.Println("Cluster has entered dead state, tearing down")
TearDown()
if !hasDied {
hasDied = true
log.Println("Cluster has entered dead state, tearing down: ", hasDied)
TearDown()
}
case "healthy":
err := Setup(errorsChan, true)
if err != nil {
errorsChan <- err
log.Println("was unable to return wag member to healthy state, dying: ", err)
return
if hasDied {
err := Setup(errorsChan, true)
if err != nil {
log.Println("was unable to return wag member to healthy state, dying: ", err)
errorsChan <- err
return
}

hasDied = false
}
}
}
Expand Down

0 comments on commit f863db9

Please sign in to comment.