Skip to content

Commit

Permalink
Make Server awaitable and cancellable (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
stevapple authored Aug 17, 2024
1 parent 9884483 commit ec668a7
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 36 deletions.
18 changes: 13 additions & 5 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,34 @@ import (
"os"
)

func sshmuxServer(configFile string) {
func sshmuxServer(configFile string) (*Server, error) {
var config Config
configFileBytes, err := os.ReadFile(configFile)
if err != nil {
log.Fatal(err)
return nil, err
}
err = json.Unmarshal(configFileBytes, &config)
if err != nil {
log.Fatal(err)
}
sshmux, err := makeServer(config)
if err != nil {
log.Fatal(err)
return nil, err
}
sshmux.ListenAddr(config.Address)
return sshmux, nil
}

func main() {
var configFile string
flag.StringVar(&configFile, "c", "/etc/sshmux/config.json", "config file")
flag.Parse()
sshmuxServer(configFile)
sshmux, err := sshmuxServer(configFile)
if err != nil {
log.Fatal(err)
}
err = sshmux.Start()
if err != nil {
log.Fatal(err)
}
sshmux.Wait()
}
112 changes: 82 additions & 30 deletions sshmux.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
package main

import (
"context"
"fmt"
"log"
"net"
"net/netip"
"os"
"slices"
"sync"
"time"

"github.com/pires/go-proxyproto"
"golang.org/x/crypto/ssh"
)

type Server struct {
listener net.Listener
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
Address string
Banner string
SSHConfig *ssh.ServerConfig
ProxyUpstreams []netip.Prefix
Expand Down Expand Up @@ -48,6 +55,7 @@ func makeServer(config Config) (*Server, error) {
proxyUpstreams = append(proxyUpstreams, network)
}
sshmux := &Server{
Address: config.Address,
Banner: config.Banner,
SSHConfig: sshConfig,
ProxyUpstreams: proxyUpstreams,
Expand All @@ -65,6 +73,57 @@ func makeServer(config Config) (*Server, error) {
return sshmux, nil
}

func (s *Server) serve() {
defer s.wg.Done()
for {
select {
case <-s.ctx.Done():
return
default:
conn, err := s.listener.Accept()
if err != nil {
if s.ctx.Err() != nil {
// Context cancelled, stop accepting connections
return
}
log.Printf("Error on Accept: %s\n", err)
continue
}
s.wg.Add(1)
go s.handler(conn)
}
}
}

func (s *Server) handler(conn net.Conn) {
defer s.wg.Done()
defer conn.Close()

session, err := ssh.NewPipeSession(conn, s.SSHConfig)
if err != nil {
return
}
defer session.Close()

logMessage := LogMessage{
ConnectTime: time.Now().Unix(),
ClientIp: conn.RemoteAddr().String(),
Username: "", // should be provided by API server
ClientType: "SSH",
Authenticated: true,
}
defer s.Logger.SendLog(&logMessage)

select {
case <-s.ctx.Done():
return
default:
if err := s.RunPipeSession(session, &logMessage); err != nil {
log.Println("runPipeSession:", err)
}
}
}

func (s *Server) Handshake(session *ssh.PipeSession) error {
hasSetUser := false
var user string
Expand Down Expand Up @@ -223,11 +282,11 @@ func (s *Server) Handshake(session *ssh.PipeSession) error {
}
}

func (s *Server) ListenAddr(address string) error {
func (s *Server) Start() error {
// set up TCP listener
listener, err := net.Listen("tcp", address)
listener, err := net.Listen("tcp", s.Address)
if err != nil {
log.Fatal(err)
return err
}
if len(s.ProxyUpstreams) > 0 {
listener = &proxyproto.Listener{
Expand All @@ -250,36 +309,29 @@ func (s *Server) ListenAddr(address string) error {
},
}
}
defer listener.Close()

// set up server context
s.ctx, s.cancel = context.WithCancel(context.Background())
s.listener = listener
s.wg.Add(1)

// main handler loop
for {
conn, err := listener.Accept()
if err != nil {
log.Printf("Error on Accept: %s\n", err)
continue
}
go func() {
session, err := ssh.NewPipeSession(conn, s.SSHConfig)
logMessage := LogMessage{
ConnectTime: time.Now().Unix(),
ClientIp: conn.RemoteAddr().String(),
Username: "", // should be provided by API server
ClientType: "SSH",
Authenticated: true,
}
if err != nil {
return
}
defer func() {
session.Close()
s.Logger.SendLog(&logMessage)
}()
if err := s.RunPipeSession(session, &logMessage); err != nil {
log.Println("runPipeSession:", err)
}
}()
go s.serve()
return nil
}

func (s *Server) Wait() {
s.wg.Wait()
}

func (s *Server) Shutdown() {
if s.cancel != nil {
s.cancel()
}
if s.listener != nil {
s.listener.Close()
}
s.wg.Wait()
}

func (s *Server) RunPipeSession(session *ssh.PipeSession, logMessage *LogMessage) error {
Expand Down
12 changes: 11 additions & 1 deletion sshmux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,17 @@ func TestSSHClientConnection(t *testing.T) {

initEnv(t, baseDir)
privateKeyPath := filepath.Join(baseDir, "example_rsa")
go sshmuxServer("etc/config.example.json")

// start sshmux server
sshmux, err := sshmuxServer("etc/config.example.json")
if err != nil {
t.Fatal(err)
}
err = sshmux.Start()
if err != nil {
t.Fatal(err)
}
defer sshmux.Shutdown()

// sanity check
testWithSSHClient(t, sshdServerAddr, "sanity check", false, baseDir, privateKeyPath)
Expand Down

0 comments on commit ec668a7

Please sign in to comment.