diff --git a/internal/cmd/bruteforce.go b/internal/cmd/bruteforce.go index 066c5209..7d776e24 100644 --- a/internal/cmd/bruteforce.go +++ b/internal/cmd/bruteforce.go @@ -2,6 +2,7 @@ package cmd import ( "bufio" + "context" "fmt" "os" "strings" @@ -43,7 +44,7 @@ func (opts BruteforceOpts) Validate() error { return nil } -func BruteForce(opts BruteforceOpts) error { +func BruteForce(ctx context.Context, opts BruteforceOpts) error { if err := opts.Validate(); err != nil { return err } @@ -56,7 +57,7 @@ func BruteForce(opts BruteforceOpts) error { scanner := bufio.NewScanner(pfile) for scanner.Scan() { - if err := testPassword(opts, scanner.Text()); err != nil { + if err := testPassword(ctx, opts, scanner.Text()); err != nil { return err } } @@ -67,15 +68,15 @@ func BruteForce(opts BruteforceOpts) error { return nil } -func testPassword(opts BruteforceOpts, password string) error { - remote, err := internal.Connect(opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout) +func testPassword(ctx context.Context, opts BruteforceOpts, password string) error { + remote, err := internal.Connect(ctx, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout) if err != nil { return err } addressFamily := internal.AllocateProtocolIgnore allocateRequest := internal.AllocateRequest(internal.RequestedTransportUDP, addressFamily) - allocateResponse, err := allocateRequest.SendAndReceive(opts.Log, remote, opts.Timeout) + allocateResponse, err := allocateRequest.SendAndReceive(ctx, opts.Log, remote, opts.Timeout) if err != nil { return fmt.Errorf("error on sending AllocateRequest: %w", err) } @@ -87,7 +88,7 @@ func testPassword(opts BruteforceOpts, password string) error { nonce := string(allocateResponse.GetAttribute(internal.AttrNonce).Value) allocateRequest = internal.AllocateRequestAuth(opts.Username, password, nonce, realm, internal.RequestedTransportUDP, addressFamily) - allocateResponse, err = allocateRequest.SendAndReceive(opts.Log, remote, opts.Timeout) + allocateResponse, err = allocateRequest.SendAndReceive(ctx, opts.Log, remote, opts.Timeout) if err != nil { return fmt.Errorf("error on sending AllocateRequest Auth: %w", err) } diff --git a/internal/cmd/brutetransports.go b/internal/cmd/brutetransports.go index 600459b8..24b1d522 100644 --- a/internal/cmd/brutetransports.go +++ b/internal/cmd/brutetransports.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "strings" "time" @@ -42,20 +43,20 @@ func (opts BruteTransportOpts) Validate() error { return nil } -func BruteTransports(opts BruteTransportOpts) error { +func BruteTransports(ctx context.Context, opts BruteTransportOpts) error { if err := opts.Validate(); err != nil { return err } for i := 0; i <= 255; i++ { - conn, err := internal.Connect(opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout) + conn, err := internal.Connect(ctx, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout) if err != nil { return err } x := internal.RequestedTransport(uint32(i)) allocateRequest := internal.AllocateRequest(x, internal.AllocateProtocolIgnore) - allocateResponse, err := allocateRequest.SendAndReceive(opts.Log, conn, opts.Timeout) + allocateResponse, err := allocateRequest.SendAndReceive(ctx, opts.Log, conn, opts.Timeout) if err != nil { return fmt.Errorf("error on sending allocate request: %w", err) } @@ -64,7 +65,7 @@ func BruteTransports(opts BruteTransportOpts) error { nonce := string(allocateResponse.GetAttribute(internal.AttrNonce).Value) allocateRequest = internal.AllocateRequestAuth(opts.Username, opts.Password, nonce, realm, x, internal.AllocateProtocolIgnore) - allocateResponse, err = allocateRequest.SendAndReceive(opts.Log, conn, opts.Timeout) + allocateResponse, err = allocateRequest.SendAndReceive(ctx, opts.Log, conn, opts.Timeout) if err != nil { return fmt.Errorf("error on sending allocate request auth: %w", err) } diff --git a/internal/cmd/info.go b/internal/cmd/info.go index cc867a4d..0470409b 100644 --- a/internal/cmd/info.go +++ b/internal/cmd/info.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "strings" "time" @@ -35,12 +36,12 @@ func (opts InfoOpts) Validate() error { return nil } -func Info(opts InfoOpts) error { +func Info(ctx context.Context, opts InfoOpts) error { if err := opts.Validate(); err != nil { return err } - if attr, err := testStun(opts); err != nil { + if attr, err := testStun(ctx, opts); err != nil { opts.Log.Debugf("STUN error: %v", err) opts.Log.Error("this server does not support the STUN protocol") } else { @@ -48,7 +49,7 @@ func Info(opts InfoOpts) error { printAttributes(opts, attr) } - if attr, err := testTurn(opts, internal.RequestedTransportUDP); err != nil { + if attr, err := testTurn(ctx, opts, internal.RequestedTransportUDP); err != nil { opts.Log.Debugf("TURN UDP error: %v", err) opts.Log.Error("this server does not support the TURN UDP protocol") } else { @@ -56,7 +57,7 @@ func Info(opts InfoOpts) error { printAttributes(opts, attr) } - if attr, err := testTurn(opts, internal.RequestedTransportTCP); err != nil { + if attr, err := testTurn(ctx, opts, internal.RequestedTransportTCP); err != nil { opts.Log.Debugf("TURN TCP error: %v", err) opts.Log.Error("this server does not support the TURN TCP protocol") } else { @@ -67,15 +68,15 @@ func Info(opts InfoOpts) error { return nil } -func testStun(opts InfoOpts) ([]internal.Attribute, error) { - conn, err := internal.Connect(opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout) +func testStun(ctx context.Context, opts InfoOpts) ([]internal.Attribute, error) { + conn, err := internal.Connect(ctx, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout) if err != nil { return nil, err } defer conn.Close() bindingRequest := internal.BindingRequest() - bindingResponse, err := bindingRequest.SendAndReceive(opts.Log, conn, opts.Timeout) + bindingResponse, err := bindingRequest.SendAndReceive(ctx, opts.Log, conn, opts.Timeout) if err != nil { return nil, fmt.Errorf("error on sending binding request: %w", err) } @@ -86,15 +87,15 @@ func testStun(opts InfoOpts) ([]internal.Attribute, error) { return bindingResponse.Attributes, nil } -func testTurn(opts InfoOpts, proto internal.RequestedTransport) ([]internal.Attribute, error) { - conn, err := internal.Connect(opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout) +func testTurn(ctx context.Context, opts InfoOpts, proto internal.RequestedTransport) ([]internal.Attribute, error) { + conn, err := internal.Connect(ctx, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout) if err != nil { return nil, err } defer conn.Close() allocateRequest := internal.AllocateRequest(proto, internal.AllocateProtocolIgnore) - allocateResponse, err := allocateRequest.SendAndReceive(opts.Log, conn, opts.Timeout) + allocateResponse, err := allocateRequest.SendAndReceive(ctx, opts.Log, conn, opts.Timeout) if err != nil { return nil, fmt.Errorf("error on sending allocate request: %w", err) } diff --git a/internal/cmd/memoryleak.go b/internal/cmd/memoryleak.go index c67a3040..5fed900f 100644 --- a/internal/cmd/memoryleak.go +++ b/internal/cmd/memoryleak.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "net/netip" "strings" @@ -56,12 +57,12 @@ func (opts MemoryleakOpts) Validate() error { return nil } -func MemoryLeak(opts MemoryleakOpts) error { +func MemoryLeak(ctx context.Context, opts MemoryleakOpts) error { if err := opts.Validate(); err != nil { return err } - remote, realm, nonce, err := internal.SetupTurnConnection(opts.Log, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout, opts.TargetHost, opts.TargetPort, opts.Username, opts.Password) + remote, realm, nonce, err := internal.SetupTurnConnection(ctx, opts.Log, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout, opts.TargetHost, opts.TargetPort, opts.Username, opts.Password) if err != nil { return err } @@ -76,7 +77,7 @@ func MemoryLeak(opts MemoryleakOpts) error { return fmt.Errorf("error on generating ChannelBind request: %w", err) } opts.Log.Debugf("ChannelBind Request:\n%s", channelBindRequest.String()) - channelBindResponse, err := channelBindRequest.SendAndReceive(opts.Log, remote, opts.Timeout) + channelBindResponse, err := channelBindRequest.SendAndReceive(ctx, opts.Log, remote, opts.Timeout) if err != nil { return fmt.Errorf("error on sending ChannelBind request: %w", err) } @@ -91,7 +92,7 @@ func MemoryLeak(opts MemoryleakOpts) error { toSend = append(toSend, helper.PutUint16(opts.Size)...) toSend = append(toSend, []byte("xxx")...) toSend = internal.Padding(toSend) - err := helper.ConnectionWrite(remote, toSend, opts.Timeout) + err := helper.ConnectionWrite(ctx, remote, toSend, opts.Timeout) if err != nil { return fmt.Errorf("error on sending data: %w", err) } diff --git a/internal/cmd/rangescan.go b/internal/cmd/rangescan.go index a2b21450..914a7aa6 100644 --- a/internal/cmd/rangescan.go +++ b/internal/cmd/rangescan.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "errors" "fmt" "net/netip" @@ -45,7 +46,7 @@ func (opts RangeScanOpts) Validate() error { return nil } -func RangeScan(opts RangeScanOpts) error { +func RangeScan(ctx context.Context, opts RangeScanOpts) error { if err := opts.Validate(); err != nil { return err } @@ -105,7 +106,7 @@ func RangeScan(opts RangeScanOpts) error { return fmt.Errorf("target is no valid ip address: %w", err) } - suc, err := scanUDP(opts, ip, 80) + suc, err := scanUDP(ctx, opts, ip, 80) if err != nil { opts.Log.Errorf("UDP %s: %v", ip, err) } @@ -121,7 +122,7 @@ func RangeScan(opts RangeScanOpts) error { return fmt.Errorf("target is no valid ip address: %w", err) } - suc, err := scanTCP(opts, ip, 80) + suc, err := scanTCP(ctx, opts, ip, 80) if err != nil { opts.Log.Errorf("TCP %s: %v", ip, err) } @@ -132,8 +133,8 @@ func RangeScan(opts RangeScanOpts) error { return nil } -func scanTCP(opts RangeScanOpts, targetHost netip.Addr, targetPort uint16) (bool, error) { - conn, err := internal.Connect(opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout) +func scanTCP(ctx context.Context, opts RangeScanOpts, targetHost netip.Addr, targetPort uint16) (bool, error) { + conn, err := internal.Connect(ctx, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout) if err != nil { return false, err } @@ -145,7 +146,7 @@ func scanTCP(opts RangeScanOpts, targetHost netip.Addr, targetPort uint16) (bool } allocateRequest := internal.AllocateRequest(internal.RequestedTransportTCP, addressFamily) - allocateResponse, err := allocateRequest.SendAndReceive(opts.Log, conn, opts.Timeout) + allocateResponse, err := allocateRequest.SendAndReceive(ctx, opts.Log, conn, opts.Timeout) if err != nil { return false, fmt.Errorf("error on sending allocate request 1: %w", err) } @@ -157,7 +158,7 @@ func scanTCP(opts RangeScanOpts, targetHost netip.Addr, targetPort uint16) (bool nonce := string(allocateResponse.GetAttribute(internal.AttrNonce).Value) allocateRequest = internal.AllocateRequestAuth(opts.Username, opts.Password, nonce, realm, internal.RequestedTransportTCP, addressFamily) - allocateResponse, err = allocateRequest.SendAndReceive(opts.Log, conn, opts.Timeout) + allocateResponse, err = allocateRequest.SendAndReceive(ctx, opts.Log, conn, opts.Timeout) if err != nil { return false, fmt.Errorf("error on sending allocate request 2: %w", err) } @@ -169,7 +170,7 @@ func scanTCP(opts RangeScanOpts, targetHost netip.Addr, targetPort uint16) (bool if err != nil { return false, fmt.Errorf("error on generating Connect request: %w", err) } - connectResponse, err := connectRequest.SendAndReceive(opts.Log, conn, opts.Timeout) + connectResponse, err := connectRequest.SendAndReceive(ctx, opts.Log, conn, opts.Timeout) if err != nil { // ignore timeouts, a timeout means open port if errors.Is(err, helper.ErrTimeout) { @@ -184,8 +185,8 @@ func scanTCP(opts RangeScanOpts, targetHost netip.Addr, targetPort uint16) (bool return true, nil } -func scanUDP(opts RangeScanOpts, targetHost netip.Addr, targetPort uint16) (bool, error) { - remote, _, _, err := internal.SetupTurnConnection(opts.Log, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout, targetHost, targetPort, opts.Username, opts.Password) +func scanUDP(ctx context.Context, opts RangeScanOpts, targetHost netip.Addr, targetPort uint16) (bool, error) { + remote, _, _, err := internal.SetupTurnConnection(ctx, opts.Log, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout, targetHost, targetPort, opts.Username, opts.Password) if err != nil { return false, err } diff --git a/internal/cmd/socks.go b/internal/cmd/socks.go index e4df00bf..aeaaba29 100644 --- a/internal/cmd/socks.go +++ b/internal/cmd/socks.go @@ -52,13 +52,12 @@ func (opts SocksOpts) Validate() error { return nil } -func Socks(opts SocksOpts) error { +func Socks(ctx context.Context, opts SocksOpts) error { if err := opts.Validate(); err != nil { return err } handler := &socksimplementations.SocksTurnTCPHandler{ - Ctx: context.Background(), Server: opts.TurnServer, TURNUsername: opts.Username, TURNPassword: opts.Password, @@ -74,7 +73,7 @@ func Socks(opts SocksOpts) error { Log: opts.Log, } opts.Log.Infof("starting SOCKS server on %s", opts.Listen) - if err := p.Start(); err != nil { + if err := p.Start(ctx); err != nil { return err } <-p.Done diff --git a/internal/cmd/tcpscanner.go b/internal/cmd/tcpscanner.go index bf417767..e7ba4849 100644 --- a/internal/cmd/tcpscanner.go +++ b/internal/cmd/tcpscanner.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "crypto/tls" "encoding/hex" "fmt" @@ -55,7 +56,7 @@ func (opts TCPScannerOpts) Validate() error { return nil } -func TCPScanner(opts TCPScannerOpts) error { +func TCPScanner(ctx context.Context, opts TCPScannerOpts) error { if err := opts.Validate(); err != nil { return err } @@ -79,7 +80,7 @@ func TCPScanner(opts TCPScannerOpts) error { return fmt.Errorf("Invalid port %s: %w", port, err) } opts.Log.Debugf("Scanning %s:%d", ip.IP.String(), portI) - if err := httpScan(opts, ip.IP, uint16(portI)); err != nil { + if err := httpScan(ctx, opts, ip.IP, uint16(portI)); err != nil { opts.Log.Errorf("error on running HTTP Scan for %s:%d: %v", ip.IP.String(), portI, err) } } @@ -88,8 +89,8 @@ func TCPScanner(opts TCPScannerOpts) error { return nil } -func httpScan(opts TCPScannerOpts, ip netip.Addr, port uint16) error { - controlConnection, dataConnection, err := internal.SetupTurnTCPConnection(opts.Log, opts.TurnServer, opts.UseTLS, opts.Timeout, ip, port, opts.Username, opts.Password) +func httpScan(ctx context.Context, opts TCPScannerOpts, ip netip.Addr, port uint16) error { + _, _, controlConnection, dataConnection, err := internal.SetupTurnTCPConnection(ctx, opts.Log, opts.TurnServer, opts.UseTLS, opts.Timeout, ip, port, opts.Username, opts.Password) if err != nil { return err } @@ -103,10 +104,10 @@ func httpScan(opts TCPScannerOpts, ip netip.Addr, port uint16) error { if useTLS { tlsConn := tls.Client(dataConnection, &tls.Config{InsecureSkipVerify: true}) - if err := helper.ConnectionWrite(tlsConn, []byte(httpRequest), opts.Timeout); err != nil { + if err := helper.ConnectionWrite(ctx, tlsConn, []byte(httpRequest), opts.Timeout); err != nil { return fmt.Errorf("error on sending TLS data: %w", err) } - data, err := helper.ConnectionRead(tlsConn, opts.Timeout) + data, err := helper.ConnectionRead(ctx, tlsConn, opts.Timeout) if err != nil { return fmt.Errorf("error on reading after sending TLS data: %w", err) } @@ -116,10 +117,10 @@ func httpScan(opts TCPScannerOpts, ip netip.Addr, port uint16) error { } // plain text connection - if err := helper.ConnectionWrite(dataConnection, []byte(httpRequest), opts.Timeout); err != nil { + if err := helper.ConnectionWrite(ctx, dataConnection, []byte(httpRequest), opts.Timeout); err != nil { return fmt.Errorf("error on sending data: %w", err) } - data, err := helper.ConnectionRead(dataConnection, opts.Timeout) + data, err := helper.ConnectionRead(ctx, dataConnection, opts.Timeout) if err != nil { return fmt.Errorf("error on reading after sending data: %w", err) } diff --git a/internal/cmd/udpscanner.go b/internal/cmd/udpscanner.go index 607b251b..289c06b1 100644 --- a/internal/cmd/udpscanner.go +++ b/internal/cmd/udpscanner.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "errors" "fmt" "math/rand" @@ -56,7 +57,7 @@ func (opts UDPScannerOpts) Validate() error { return nil } -func UDPScanner(opts UDPScannerOpts) error { +func UDPScanner(ctx context.Context, opts UDPScannerOpts) error { if err := opts.Validate(); err != nil { return err } @@ -74,10 +75,10 @@ func UDPScanner(opts UDPScannerOpts) error { continue } opts.Log.Debugf("Scanning %s", ip.IP.String()) - if err := snmpScan(opts, ip.IP, 161, opts.CommunityString); err != nil { + if err := snmpScan(ctx, opts, ip.IP, 161, opts.CommunityString); err != nil { opts.Log.Errorf("error on running SNMP Scan for ip %s: %v", ip.IP.String(), err) } - if err := dnsScan(opts, ip.IP, 53, opts.DomainName); err != nil { + if err := dnsScan(ctx, opts, ip.IP, 53, opts.DomainName); err != nil { opts.Log.Errorf("error on running DNS Scan for ip %s: %v", ip.IP.String(), err) } } @@ -85,8 +86,8 @@ func UDPScanner(opts UDPScannerOpts) error { return nil } -func snmpScan(opts UDPScannerOpts, ip netip.Addr, port uint16, community string) error { - remote, realm, nonce, err := internal.SetupTurnConnection(opts.Log, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout, ip, port, opts.Username, opts.Password) +func snmpScan(ctx context.Context, opts UDPScannerOpts, ip netip.Addr, port uint16, community string) error { + remote, realm, nonce, err := internal.SetupTurnConnection(ctx, opts.Log, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout, ip, port, opts.Username, opts.Password) if err != nil { // ignore timeouts if errors.Is(err, helper.ErrTimeout) { @@ -105,7 +106,7 @@ func snmpScan(opts UDPScannerOpts, ip netip.Addr, port uint16, community string) return fmt.Errorf("error on generating ChannelBindRequest: %w", err) } - channelBindResponse, err := channelBindRequest.SendAndReceive(opts.Log, remote, opts.Timeout) + channelBindResponse, err := channelBindRequest.SendAndReceive(ctx, opts.Log, remote, opts.Timeout) if err != nil { return fmt.Errorf("error on sending ChannelBindRequest: %w", err) } @@ -147,12 +148,12 @@ func snmpScan(opts UDPScannerOpts, ip netip.Addr, port uint16, community string) buf = append(buf, helper.PutUint16(uint16(snmpLen))...) buf = append(buf, snmp...) - err = helper.ConnectionWrite(remote, buf, opts.Timeout) + err = helper.ConnectionWrite(ctx, remote, buf, opts.Timeout) if err != nil { return fmt.Errorf("error on sending SNMP request: %w", err) } - resp, err := helper.ConnectionRead(remote, opts.Timeout) + resp, err := helper.ConnectionRead(ctx, remote, opts.Timeout) if err != nil { // ignore timeouts if errors.Is(err, helper.ErrTimeout) { @@ -172,8 +173,8 @@ func snmpScan(opts UDPScannerOpts, ip netip.Addr, port uint16, community string) return nil } -func dnsScan(opts UDPScannerOpts, ip netip.Addr, port uint16, dnsName string) error { - remote, realm, nonce, err := internal.SetupTurnConnection(opts.Log, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout, ip, port, opts.Username, opts.Password) +func dnsScan(ctx context.Context, opts UDPScannerOpts, ip netip.Addr, port uint16, dnsName string) error { + remote, realm, nonce, err := internal.SetupTurnConnection(ctx, opts.Log, opts.Protocol, opts.TurnServer, opts.UseTLS, opts.Timeout, ip, port, opts.Username, opts.Password) if err != nil { // ignore timeouts if errors.Is(err, helper.ErrTimeout) { @@ -192,7 +193,7 @@ func dnsScan(opts UDPScannerOpts, ip netip.Addr, port uint16, dnsName string) er return fmt.Errorf("error on generating ChannelBindRequest: %w", err) } - channelBindResponse, err := channelBindRequest.SendAndReceive(opts.Log, remote, opts.Timeout) + channelBindResponse, err := channelBindRequest.SendAndReceive(ctx, opts.Log, remote, opts.Timeout) if err != nil { return fmt.Errorf("error on sending ChannelBindRequest: %w", err) } @@ -239,12 +240,12 @@ func dnsScan(opts UDPScannerOpts, ip netip.Addr, port uint16, dnsName string) er buf = append(buf, helper.PutUint16(uint16(dnsLen))...) buf = append(buf, dns...) - err = helper.ConnectionWrite(remote, buf, opts.Timeout) + err = helper.ConnectionWrite(ctx, remote, buf, opts.Timeout) if err != nil { return fmt.Errorf("error on sending DNS request: %w", err) } - resp, err := helper.ConnectionRead(remote, opts.Timeout) + resp, err := helper.ConnectionRead(ctx, remote, opts.Timeout) if err != nil { // ignore timeouts if errors.Is(err, helper.ErrTimeout) { diff --git a/internal/connection.go b/internal/connection.go index 7da902e6..01d4575f 100644 --- a/internal/connection.go +++ b/internal/connection.go @@ -11,7 +11,7 @@ import ( "github.com/pion/dtls/v2" ) -func Connect(protocol string, turnServer string, useTLS bool, timeout time.Duration) (net.Conn, error) { +func Connect(ctx context.Context, protocol string, turnServer string, useTLS bool, timeout time.Duration) (net.Conn, error) { if !useTLS { // non TLS connection conn, err := net.DialTimeout(protocol, turnServer, timeout) @@ -39,7 +39,7 @@ func Connect(protocol string, turnServer string, useTLS bool, timeout time.Durat if err != nil { return nil, fmt.Errorf("error on establishing a connection to the server: %w", err) } - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() dtlsConn, err := dtls.ClientWithContext(ctx, conn, &dtls.Config{ InsecureSkipVerify: true, @@ -54,12 +54,12 @@ func Connect(protocol string, turnServer string, useTLS bool, timeout time.Durat } // send serializes a STUN object and sends it on the provided connection -func (s *Stun) send(conn net.Conn, timeout time.Duration) error { +func (s *Stun) send(ctx context.Context, conn net.Conn, timeout time.Duration) error { data, err := s.Serialize() if err != nil { return fmt.Errorf("Serialize: %w", err) } - if err := helper.ConnectionWrite(conn, data, timeout); err != nil { + if err := helper.ConnectionWrite(ctx, conn, data, timeout); err != nil { return fmt.Errorf("ConnectionWrite: %w", err) } @@ -67,13 +67,13 @@ func (s *Stun) send(conn net.Conn, timeout time.Duration) error { } // SendAndReceive sends a TURN request on a connection and gets a response -func (s *Stun) SendAndReceive(logger DebugLogger, conn net.Conn, timeout time.Duration) (*Stun, error) { +func (s *Stun) SendAndReceive(ctx context.Context, logger DebugLogger, conn net.Conn, timeout time.Duration) (*Stun, error) { logger.Debugf("Sending\n%s", s.String()) - err := s.send(conn, timeout) + err := s.send(ctx, conn, timeout) if err != nil { return nil, fmt.Errorf("Send: %w", err) } - buffer, err := helper.ConnectionRead(conn, timeout) + buffer, err := helper.ConnectionRead(ctx, conn, timeout) if err != nil { return nil, fmt.Errorf("ConnectionRead: %w", err) } diff --git a/internal/helper/connection.go b/internal/helper/connection.go index bf138aad..9e9bc799 100644 --- a/internal/helper/connection.go +++ b/internal/helper/connection.go @@ -1,8 +1,8 @@ package helper import ( + "context" "errors" - "fmt" "io" "net" "time" @@ -11,56 +11,65 @@ import ( var ErrTimeout = errors.New("timeout occurred. you can try to increase the timeout if the server responds too slowly") // ConnectionRead reads all data from a connection -func ConnectionRead(conn net.Conn, timeout time.Duration) ([]byte, error) { +func ConnectionRead(ctx context.Context, conn net.Conn, timeout time.Duration) ([]byte, error) { var ret []byte - if err := conn.SetReadDeadline(time.Now().Add(timeout)); err != nil { - return nil, fmt.Errorf("could not set read deadline: %w", err) - } + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() bufLen := 1024 for { - buf := make([]byte, bufLen) - i, err := conn.Read(buf) - if err != nil { - if err != io.EOF { - // also return read data on timeout so caller can use it - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - return ret, ErrTimeout + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + buf := make([]byte, bufLen) + i, err := conn.Read(buf) + if err != nil { + if err != io.EOF { + // also return read data on timeout so caller can use it + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return ret, ErrTimeout + } + return nil, err } - return nil, err + return ret, nil + } + ret = append(ret, buf[:i]...) + // we've read all data, bail out + if i < bufLen { + return ret, nil } - return ret, nil - } - ret = append(ret, buf[:i]...) - // we've read all data, bail out - if i < bufLen { - return ret, nil } } } // ConnectionWrite makes sure to write all data to a connection -func ConnectionWrite(conn net.Conn, data []byte, timeout time.Duration) error { +func ConnectionWrite(ctx context.Context, conn net.Conn, data []byte, timeout time.Duration) error { toWriteLeft := len(data) written := 0 - err := conn.SetWriteDeadline(time.Now().Add(timeout)) - if err != nil { - return fmt.Errorf("could not set write deadline: %w", err) - } + var err error + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() for { - written, err = conn.Write(data[written:toWriteLeft]) - if err != nil { - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - return ErrTimeout - } else { - return err + select { + case <-ctx.Done(): + return ctx.Err() + default: + written, err = conn.Write(data[written:toWriteLeft]) + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return ErrTimeout + } else { + return err + } } + if written == toWriteLeft { + return nil + } + toWriteLeft -= written } - if written == toWriteLeft { - return nil - } - toWriteLeft -= written } } diff --git a/internal/helpers_turn.go b/internal/helpers_turn.go index b58bb3d4..922ffc82 100644 --- a/internal/helpers_turn.go +++ b/internal/helpers_turn.go @@ -2,6 +2,7 @@ package internal import ( "bytes" + "context" "encoding/binary" "fmt" "net" @@ -127,8 +128,8 @@ func ConvertXORAddr(input []byte, transactionID string) (string, uint16, error) // CreatePermission // // it returns the connection, the realm, the nonce and an error -func SetupTurnConnection(logger DebugLogger, connectProtocol string, turnServer string, useTLS bool, timeout time.Duration, targetHost netip.Addr, targetPort uint16, username, password string) (net.Conn, string, string, error) { - remote, err := Connect(connectProtocol, turnServer, useTLS, timeout) +func SetupTurnConnection(ctx context.Context, logger DebugLogger, connectProtocol string, turnServer string, useTLS bool, timeout time.Duration, targetHost netip.Addr, targetPort uint16, username, password string) (net.Conn, string, string, error) { + remote, err := Connect(ctx, connectProtocol, turnServer, useTLS, timeout) if err != nil { return nil, "", "", err } @@ -139,7 +140,7 @@ func SetupTurnConnection(logger DebugLogger, connectProtocol string, turnServer } allocateRequest := AllocateRequest(RequestedTransportUDP, addressFamily) - allocateResponse, err := allocateRequest.SendAndReceive(logger, remote, timeout) + allocateResponse, err := allocateRequest.SendAndReceive(ctx, logger, remote, timeout) if err != nil { return nil, "", "", fmt.Errorf("error on sending AllocateRequest: %w", err) } @@ -151,7 +152,7 @@ func SetupTurnConnection(logger DebugLogger, connectProtocol string, turnServer nonce := string(allocateResponse.GetAttribute(AttrNonce).Value) allocateRequest = AllocateRequestAuth(username, password, nonce, realm, RequestedTransportUDP, addressFamily) - allocateResponse, err = allocateRequest.SendAndReceive(logger, remote, timeout) + allocateResponse, err = allocateRequest.SendAndReceive(ctx, logger, remote, timeout) if err != nil { return nil, "", "", fmt.Errorf("error on sending AllocateRequest Auth: %w", err) } @@ -162,7 +163,7 @@ func SetupTurnConnection(logger DebugLogger, connectProtocol string, turnServer if err != nil { return nil, "", "", fmt.Errorf("error on generating CreatePermissionRequest: %w", err) } - permissionResponse, err := permissionRequest.SendAndReceive(logger, remote, timeout) + permissionResponse, err := permissionRequest.SendAndReceive(ctx, logger, remote, timeout) if err != nil { return nil, "", "", fmt.Errorf("error on sending CreatePermissionRequest: %w", err) } diff --git a/internal/helpers_turntcp.go b/internal/helpers_turntcp.go index e3b403f6..c4225320 100644 --- a/internal/helpers_turntcp.go +++ b/internal/helpers_turntcp.go @@ -1,13 +1,17 @@ package internal import ( - "crypto/tls" + "context" "fmt" "net" "net/netip" "time" ) +type keepAlive interface { + SetKeepAlive(bool) +} + // SetupTurnTCPConnection executes the following: // // Allocate Unauth (to get realm and nonce) @@ -17,24 +21,16 @@ import ( // ConnectionBind // // it returns the controlConnection, the dataConnection and an error -func SetupTurnTCPConnection(logger DebugLogger, turnServer string, useTLS bool, timeout time.Duration, targetHost netip.Addr, targetPort uint16, username, password string) (net.Conn, net.Conn, error) { +func SetupTurnTCPConnection(ctx context.Context, logger DebugLogger, turnServer string, useTLS bool, timeout time.Duration, targetHost netip.Addr, targetPort uint16, username, password string) (string, string, net.Conn, net.Conn, error) { // protocol needs to be tcp - controlConnectionRaw, err := Connect("tcp", turnServer, useTLS, timeout) + controlConnection, err := Connect(ctx, "tcp", turnServer, useTLS, timeout) if err != nil { - return nil, nil, fmt.Errorf("error on establishing control connection: %w", err) + return "", "", nil, nil, fmt.Errorf("error on establishing control connection: %w", err) } - var controlConnection net.Conn - switch t := controlConnectionRaw.(type) { - case *net.TCPConn: - if err := t.SetKeepAlive(true); err != nil { - return nil, nil, fmt.Errorf("could not set KeepAlive on control connection: %w", err) - } - controlConnection = t - case *tls.Conn: - controlConnection = t - default: - return nil, nil, fmt.Errorf("could not determine control connection type (%T)", t) + if x, ok := controlConnection.(keepAlive); ok { + logger.Debug("controlconnection: set keepalive to true") + x.SetKeepAlive(true) } logger.Debugf("opened turn tcp control connection from %s to %s", controlConnection.LocalAddr().String(), controlConnection.RemoteAddr().String()) @@ -45,68 +41,60 @@ func SetupTurnTCPConnection(logger DebugLogger, turnServer string, useTLS bool, } allocateRequest := AllocateRequest(RequestedTransportTCP, addressFamily) - allocateResponse, err := allocateRequest.SendAndReceive(logger, controlConnection, timeout) + allocateResponse, err := allocateRequest.SendAndReceive(ctx, logger, controlConnection, timeout) if err != nil { - return nil, nil, fmt.Errorf("error on sending allocate request 1: %w", err) + return "", "", nil, nil, fmt.Errorf("error on sending allocate request 1: %w", err) } if allocateResponse.Header.MessageType.Class != MsgTypeClassError { - return nil, nil, fmt.Errorf("MessageClass is not Error (should be not authenticated)") + return "", "", nil, nil, fmt.Errorf("MessageClass is not Error (should be not authenticated)") } realm := string(allocateResponse.GetAttribute(AttrRealm).Value) nonce := string(allocateResponse.GetAttribute(AttrNonce).Value) allocateRequest = AllocateRequestAuth(username, password, nonce, realm, RequestedTransportTCP, addressFamily) - allocateResponse, err = allocateRequest.SendAndReceive(logger, controlConnection, timeout) + allocateResponse, err = allocateRequest.SendAndReceive(ctx, logger, controlConnection, timeout) if err != nil { - return nil, nil, fmt.Errorf("error on sending allocate request 2: %w", err) + return "", "", nil, nil, fmt.Errorf("error on sending allocate request 2: %w", err) } if allocateResponse.Header.MessageType.Class == MsgTypeClassError { - return nil, nil, fmt.Errorf("error on allocate response: %s", allocateResponse.GetErrorString()) + return "", "", nil, nil, fmt.Errorf("error on allocate response: %s", allocateResponse.GetErrorString()) } connectRequest, err := ConnectRequestAuth(username, password, nonce, realm, targetHost, targetPort) if err != nil { - return nil, nil, fmt.Errorf("error on generating Connect request: %w", err) + return "", "", nil, nil, fmt.Errorf("error on generating Connect request: %w", err) } - connectResponse, err := connectRequest.SendAndReceive(logger, controlConnection, timeout) + connectResponse, err := connectRequest.SendAndReceive(ctx, logger, controlConnection, timeout) if err != nil { - return nil, nil, fmt.Errorf("error on sending Connect request: %w", err) + return "", "", nil, nil, fmt.Errorf("error on sending Connect request: %w", err) } if connectResponse.Header.MessageType.Class == MsgTypeClassError { - return nil, nil, fmt.Errorf("error on Connect response: %s", connectResponse.GetErrorString()) + return "", "", nil, nil, fmt.Errorf("error on Connect response: %s", connectResponse.GetErrorString()) } connectionID := connectResponse.GetAttribute(AttrConnectionID).Value - dataConnectionRaw, err := Connect("tcp", turnServer, useTLS, timeout) + dataConnection, err := Connect(ctx, "tcp", turnServer, useTLS, timeout) if err != nil { - return nil, nil, fmt.Errorf("error on establishing data connection: %w", err) + return "", "", nil, nil, fmt.Errorf("error on establishing data connection: %w", err) } - var dataConnection net.Conn - switch t := dataConnectionRaw.(type) { - case *net.TCPConn: - if err := t.SetKeepAlive(true); err != nil { - return nil, nil, fmt.Errorf("could not set KeepAlive on data connection: %w", err) - } - dataConnection = t - case *tls.Conn: - dataConnection = t - default: - return nil, nil, fmt.Errorf("could not determine data connection type (%T)", t) + if x, ok := dataConnection.(keepAlive); ok { + logger.Debug("dataconnection: set keepalive to true") + x.SetKeepAlive(true) } logger.Debugf("opened turn tcp data connection from %s to %s", dataConnection.LocalAddr().String(), dataConnection.RemoteAddr().String()) connectionBindRequest := ConnectionBindRequest(connectionID, username, password, nonce, realm) - connectionBindResponse, err := connectionBindRequest.SendAndReceive(logger, dataConnection, timeout) + connectionBindResponse, err := connectionBindRequest.SendAndReceive(ctx, logger, dataConnection, timeout) if err != nil { - return nil, nil, fmt.Errorf("error on sending ConnectionBind request: %w", err) + return "", "", nil, nil, fmt.Errorf("error on sending ConnectionBind request: %w", err) } if connectionBindResponse.Header.MessageType.Class == MsgTypeClassError { - return nil, nil, fmt.Errorf("error on ConnectionBind reposnse: %s", connectionBindResponse.GetErrorString()) + return "", "", nil, nil, fmt.Errorf("error on ConnectionBind reposnse: %s", connectionBindResponse.GetErrorString()) } - return controlConnection, dataConnection, nil + return realm, nonce, controlConnection, dataConnection, nil } diff --git a/internal/logger.go b/internal/logger.go index a010651d..843b298f 100644 --- a/internal/logger.go +++ b/internal/logger.go @@ -1,5 +1,6 @@ package internal type DebugLogger interface { + Debug(...interface{}) Debugf(format string, args ...interface{}) } diff --git a/internal/socksimplementations/socksturntcphandler.go b/internal/socksimplementations/socksturntcphandler.go index 68c2a691..782b8501 100644 --- a/internal/socksimplementations/socksturntcphandler.go +++ b/internal/socksimplementations/socksturntcphandler.go @@ -2,6 +2,7 @@ package socksimplementations import ( "context" + "errors" "fmt" "io" "net" @@ -17,7 +18,6 @@ import ( // SocksTurnTCPHandler is the implementation of a TCP TURN server type SocksTurnTCPHandler struct { - Ctx context.Context ControlConnection net.Conn TURNUsername string TURNPassword string @@ -26,10 +26,12 @@ type SocksTurnTCPHandler struct { UseTLS bool DropNonPrivateRequests bool Log *logrus.Logger + realm string + nonce string } // PreHandler connects to the STUN server, sets the connection up and returns the data connections -func (s *SocksTurnTCPHandler) Init(request socks.Request) (io.ReadWriteCloser, *socks.Error) { +func (s *SocksTurnTCPHandler) Init(ctx context.Context, request socks.Request) (io.ReadWriteCloser, *socks.Error) { var target netip.Addr var err error switch request.AddressType { @@ -45,7 +47,7 @@ func (s *SocksTurnTCPHandler) Init(request socks.Request) (io.ReadWriteCloser, * target = ip } else { // input is a hostname - names, err := helper.ResolveName(s.Ctx, string(request.DestinationAddress)) + names, err := helper.ResolveName(ctx, string(request.DestinationAddress)) if err != nil { return nil, socks.NewError(socks.RequestReplyHostUnreachable, err) } @@ -63,10 +65,12 @@ func (s *SocksTurnTCPHandler) Init(request socks.Request) (io.ReadWriteCloser, * return nil, socks.NewError(socks.RequestReplyHostUnreachable, fmt.Errorf("dropping non private connection to %s:%d", target.String(), request.DestinationPort)) } - controlConnection, dataConnection, err := internal.SetupTurnTCPConnection(s.Log, s.Server, s.UseTLS, s.Timeout, target, request.DestinationPort, s.TURNUsername, s.TURNPassword) + realm, nonce, controlConnection, dataConnection, err := internal.SetupTurnTCPConnection(ctx, s.Log, s.Server, s.UseTLS, s.Timeout, target, request.DestinationPort, s.TURNUsername, s.TURNPassword) if err != nil { return nil, socks.NewError(socks.RequestReplyHostUnreachable, err) } + s.realm = realm + s.nonce = nonce // we need to keep this connection open s.ControlConnection = controlConnection @@ -75,60 +79,102 @@ func (s *SocksTurnTCPHandler) Init(request socks.Request) (io.ReadWriteCloser, * // Refresh is used to refresh an active connection every 2 minutes func (s *SocksTurnTCPHandler) Refresh(ctx context.Context) { - nonce := "" - realm := "" - tick := time.NewTicker(2 * time.Minute) - select { - case <-ctx.Done(): - return - case <-tick.C: - s.Log.Debug("[socks] refreshing connection") - refresh := internal.RefreshRequest(s.TURNUsername, s.TURNPassword, nonce, realm) - response, err := refresh.SendAndReceive(s.Log, s.ControlConnection, s.Timeout) - if err != nil { - s.Log.Error(err) + nonce := s.nonce + realm := s.realm + tick := time.NewTicker(5 * time.Minute) // default timeout on coturn is 600 seconds (10 minutes) + for { + select { + case <-ctx.Done(): return - } - // should happen on a stale nonce - if response.Header.MessageType.Class == internal.MsgTypeClassError { - realm := string(response.GetAttribute(internal.AttrRealm).Value) - nonce := string(response.GetAttribute(internal.AttrNonce).Value) - refresh = internal.RefreshRequest(s.TURNUsername, s.TURNPassword, nonce, realm) - response, err = refresh.SendAndReceive(s.Log, s.ControlConnection, s.Timeout) + case <-tick.C: + s.Log.Debug("[socks] refreshing connection") + refresh := internal.RefreshRequest(s.TURNUsername, s.TURNPassword, nonce, realm) + response, err := refresh.SendAndReceive(ctx, s.Log, s.ControlConnection, s.Timeout) if err != nil { s.Log.Error(err) return } + // should happen on a stale nonce if response.Header.MessageType.Class == internal.MsgTypeClassError { - s.Log.Error(response.GetErrorString()) - return + realm := string(response.GetAttribute(internal.AttrRealm).Value) + nonce := string(response.GetAttribute(internal.AttrNonce).Value) + s.nonce = nonce + s.realm = realm + refresh = internal.RefreshRequest(s.TURNUsername, s.TURNPassword, nonce, realm) + response, err = refresh.SendAndReceive(ctx, s.Log, s.ControlConnection, s.Timeout) + if err != nil { + s.Log.Error(err) + return + } + if response.Header.MessageType.Class == internal.MsgTypeClassError { + s.Log.Error(response.GetErrorString()) + return + } } } } } +const bufferLength = 1024 * 100 + // ReadFromClient is used to copy data func (s *SocksTurnTCPHandler) ReadFromClient(ctx context.Context, client io.ReadCloser, remote io.WriteCloser) error { - i, err := io.Copy(remote, client) - if err != nil { - return fmt.Errorf("CopyFromRemoteToClient: %w", err) + for { + // anonymous func for defer + // this might not be the fastest, but it does the trick + err := func() error { + ctx, cancel := context.WithTimeout(ctx, s.Timeout) + defer cancel() + select { + case <-ctx.Done(): + return ctx.Err() + default: + i, err := io.CopyN(remote, client, bufferLength) + if errors.Is(err, io.EOF) { + return nil + } else if err != nil { + return fmt.Errorf("ReadFromClient: %w", err) + } + s.Log.Debugf("[socks] wrote %d bytes to client", i) + } + return nil + }() + if err != nil { + return err + } } - s.Log.Debugf("[socks] wrote %d bytes to client", i) - return nil } // ReadFromRemote is used to copy data func (s *SocksTurnTCPHandler) ReadFromRemote(ctx context.Context, remote io.ReadCloser, client io.WriteCloser) error { - i, err := io.Copy(client, remote) - if err != nil { - return fmt.Errorf("CopyFromClientToRemote: %w", err) + for { + // anonymous func for defer + // this might not be the fastest, but it does the trick + err := func() error { + ctx, cancel := context.WithTimeout(ctx, s.Timeout) + defer cancel() + select { + case <-ctx.Done(): + return ctx.Err() + default: + i, err := io.CopyN(client, remote, bufferLength) + if errors.Is(err, io.EOF) { + return nil + } else if err != nil { + return fmt.Errorf("ReadFromRemote: %w", err) + } + s.Log.Debugf("[socks] wrote %d bytes to remote", i) + } + return nil + }() + if err != nil { + return err + } } - s.Log.Debugf("[socks] wrote %d bytes to remote", i) - return nil } // Cleanup closes the stored control connection -func (s *SocksTurnTCPHandler) Close() error { +func (s *SocksTurnTCPHandler) Close(ctx context.Context) error { if s.ControlConnection != nil { return s.ControlConnection.Close() } diff --git a/internal/socksimplementations/socksturnudphandler.go b/internal/socksimplementations/socksturnudphandler.go index 5cc039cf..c411471d 100644 --- a/internal/socksimplementations/socksturnudphandler.go +++ b/internal/socksimplementations/socksturnudphandler.go @@ -17,7 +17,6 @@ import ( // SocksTurnUDPHandler is the implementation of a UDP TURN server type SocksTurnUDPHandler struct { - Ctx context.Context TURNUsername string TURNPassword string Server string @@ -30,7 +29,7 @@ type SocksTurnUDPHandler struct { } // PreHandler creates a connection to the target server and returns a connection to send data -func (s *SocksTurnUDPHandler) PreHandler(request socks.Request) (io.ReadWriteCloser, *socks.Error) { +func (s *SocksTurnUDPHandler) Init(ctx context.Context, request socks.Request) (io.ReadWriteCloser, *socks.Error) { var target netip.Addr var err error switch request.AddressType { @@ -41,7 +40,7 @@ func (s *SocksTurnUDPHandler) PreHandler(request socks.Request) (io.ReadWriteClo } target = tmp case socks.RequestAddressTypeDomainname: - names, err := helper.ResolveName(s.Ctx, string(request.DestinationAddress)) + names, err := helper.ResolveName(ctx, string(request.DestinationAddress)) if err != nil { return nil, socks.NewError(socks.RequestReplyHostUnreachable, err) } @@ -58,7 +57,7 @@ func (s *SocksTurnUDPHandler) PreHandler(request socks.Request) (io.ReadWriteClo return nil, socks.NewError(socks.RequestReplyHostUnreachable, fmt.Errorf("dropping non private connection to %s:%d", target.String(), request.DestinationPort)) } - remote, realm, nonce, err := internal.SetupTurnConnection(s.Log, s.ConnectProtocol, s.Server, s.UseTLS, s.Timeout, target, request.DestinationPort, s.TURNUsername, s.TURNPassword) + remote, realm, nonce, err := internal.SetupTurnConnection(ctx, s.Log, s.ConnectProtocol, s.Server, s.UseTLS, s.Timeout, target, request.DestinationPort, s.TURNUsername, s.TURNPassword) if err != nil { return nil, socks.NewError(socks.RequestReplyHostUnreachable, err) } @@ -73,7 +72,7 @@ func (s *SocksTurnUDPHandler) PreHandler(request socks.Request) (io.ReadWriteClo return nil, socks.NewError(socks.RequestReplyHostUnreachable, fmt.Errorf("error on generating ChannelBindRequest: %w", err)) } s.Log.Debugf("ChannelBind Request:\n%s", channelBindRequest.String()) - channelBindResponse, err := channelBindRequest.SendAndReceive(s.Log, remote, s.Timeout) + channelBindResponse, err := channelBindRequest.SendAndReceive(ctx, s.Log, remote, s.Timeout) if err != nil { return nil, socks.NewError(socks.RequestReplyHostUnreachable, fmt.Errorf("error on sending ChannelBindRequest: %w", err)) } @@ -85,7 +84,7 @@ func (s *SocksTurnUDPHandler) PreHandler(request socks.Request) (io.ReadWriteClo } // CopyFromRemoteToClient is used to send data and remove the extra channel data header -func (s *SocksTurnUDPHandler) CopyFromRemoteToClient(ctx context.Context, remote io.ReadCloser, client io.WriteCloser) error { +func (s *SocksTurnUDPHandler) ReadFromRemote(ctx context.Context, remote io.ReadCloser, client io.WriteCloser) error { clientConn, ok := client.(net.Conn) if !ok { return fmt.Errorf("could not cast client to net.Conn") @@ -95,7 +94,7 @@ func (s *SocksTurnUDPHandler) CopyFromRemoteToClient(ctx context.Context, remote return fmt.Errorf("could not cast remote to net.Conn") } - recv, err := helper.ConnectionRead(remoteConn, s.Timeout) + recv, err := helper.ConnectionRead(ctx, remoteConn, s.Timeout) if err != nil { return err } @@ -106,7 +105,7 @@ func (s *SocksTurnUDPHandler) CopyFromRemoteToClient(ctx context.Context, remote } s.Log.Debugf("received %d bytes on channel %02x", len(data), channel) - err = helper.ConnectionWrite(clientConn, data, s.Timeout) + err = helper.ConnectionWrite(ctx, clientConn, data, s.Timeout) if err != nil { return err } @@ -114,7 +113,7 @@ func (s *SocksTurnUDPHandler) CopyFromRemoteToClient(ctx context.Context, remote } // CopyFromClientToRemote is used to send data and add the extra channel data header -func (s *SocksTurnUDPHandler) CopyFromClientToRemote(ctx context.Context, client io.ReadCloser, remote io.WriteCloser) error { +func (s *SocksTurnUDPHandler) ReadFromClient(ctx context.Context, client io.ReadCloser, remote io.WriteCloser) error { clientConn, ok := client.(net.Conn) if !ok { return fmt.Errorf("could not cast client to net.Conn") @@ -124,7 +123,7 @@ func (s *SocksTurnUDPHandler) CopyFromClientToRemote(ctx context.Context, client return fmt.Errorf("could not cast remote to net.Conn") } - toSend, err := helper.ConnectionRead(clientConn, s.Timeout) + toSend, err := helper.ConnectionRead(ctx, clientConn, s.Timeout) if err != nil { return err } @@ -136,7 +135,7 @@ func (s *SocksTurnUDPHandler) CopyFromClientToRemote(ctx context.Context, client buf = append(buf, helper.PutUint16(uint16(toSendLen))...) buf = append(buf, toSend...) - err = helper.ConnectionWrite(remoteConn, buf, s.Timeout) + err = helper.ConnectionWrite(ctx, remoteConn, buf, s.Timeout) if err != nil { return err } @@ -148,6 +147,6 @@ func (s *SocksTurnUDPHandler) Refresh(_ context.Context) { } // Cleanup is not used in this implementation -func (s *SocksTurnUDPHandler) Cleanup() error { +func (s *SocksTurnUDPHandler) Close(ctx context.Context) error { return nil } diff --git a/main.go b/main.go index 75f9bb55..8655dffd 100644 --- a/main.go +++ b/main.go @@ -52,7 +52,7 @@ func main() { &cli.StringFlag{Name: "turnserver", Aliases: []string{"s"}, Required: true, Usage: "turn server to connect to in the format host:port"}, &cli.BoolFlag{Name: "tls", Value: false, Usage: "Use TLS/DTLS on connecting to the STUN or TURN server"}, &cli.StringFlag{Name: "protocol", Value: "udp", Usage: "protocol to use when connecting to the TURN server. Supported values: tcp and udp"}, - &cli.DurationFlag{Name: "timeout", Value: 1 * time.Second, Usage: "connect timeout to turn server"}, + &cli.DurationFlag{Name: "timeout", Value: 5 * time.Second, Usage: "connect timeout to turn server"}, }, Before: func(ctx *cli.Context) error { if ctx.Bool("debug") { @@ -65,7 +65,7 @@ func main() { useTLS := c.Bool("tls") protocol := c.String("protocol") timeout := c.Duration("timeout") - return cmd.Info(cmd.InfoOpts{ + return cmd.Info(c.Context, cmd.InfoOpts{ TurnServer: turnServer, UseTLS: useTLS, Protocol: protocol, @@ -86,7 +86,7 @@ func main() { &cli.StringFlag{Name: "turnserver", Aliases: []string{"s"}, Required: true, Usage: "turn server to connect to in the format host:port"}, &cli.BoolFlag{Name: "tls", Value: false, Usage: "Use TLS/DTLS on connecting to the STUN or TURN server"}, &cli.StringFlag{Name: "protocol", Value: "udp", Usage: "protocol to use when connecting to the TURN server. Supported values: tcp and udp"}, - &cli.DurationFlag{Name: "timeout", Value: 1 * time.Second, Usage: "connect timeout to turn server"}, + &cli.DurationFlag{Name: "timeout", Value: 5 * time.Second, Usage: "connect timeout to turn server"}, &cli.StringFlag{Name: "username", Aliases: []string{"u"}, Required: true, Usage: "username for the turn server"}, &cli.StringFlag{Name: "password", Aliases: []string{"p"}, Required: true, Usage: "password for the turn server"}, }, @@ -103,7 +103,7 @@ func main() { timeout := c.Duration("timeout") username := c.String("username") password := c.String("password") - return cmd.BruteTransports(cmd.BruteTransportOpts{ + return cmd.BruteTransports(c.Context, cmd.BruteTransportOpts{ TurnServer: turnServer, UseTLS: useTLS, Protocol: protocol, @@ -125,7 +125,7 @@ func main() { &cli.StringFlag{Name: "turnserver", Aliases: []string{"s"}, Required: true, Usage: "turn server to connect to in the format host:port"}, &cli.BoolFlag{Name: "tls", Value: false, Usage: "Use TLS/DTLS on connecting to the STUN or TURN server"}, &cli.StringFlag{Name: "protocol", Value: "udp", Usage: "protocol to use when connecting to the TURN server. Supported values: tcp and udp"}, - &cli.DurationFlag{Name: "timeout", Value: 1 * time.Second, Usage: "connect timeout to turn server"}, + &cli.DurationFlag{Name: "timeout", Value: 5 * time.Second, Usage: "connect timeout to turn server"}, &cli.StringFlag{Name: "username", Aliases: []string{"u"}, Required: true, Usage: "username for the turn server"}, &cli.StringFlag{Name: "passfile", Aliases: []string{"p"}, Required: true, Usage: "passwordfile to use for bruteforce"}, }, @@ -142,7 +142,7 @@ func main() { timeout := c.Duration("timeout") username := c.String("username") passwordFile := c.String("passfile") - return cmd.BruteForce(cmd.BruteforceOpts{ + return cmd.BruteForce(c.Context, cmd.BruteforceOpts{ TurnServer: turnServer, UseTLS: useTLS, Protocol: protocol, @@ -168,7 +168,7 @@ func main() { &cli.StringFlag{Name: "turnserver", Aliases: []string{"s"}, Required: true, Usage: "turn server to connect to in the format host:port"}, &cli.BoolFlag{Name: "tls", Value: false, Usage: "Use TLS/DTLS on connecting to the STUN or TURN server"}, &cli.StringFlag{Name: "protocol", Value: "udp", Usage: "protocol to use when connecting to the TURN server. Supported values: tcp and udp"}, - &cli.DurationFlag{Name: "timeout", Value: 1 * time.Second, Usage: "connect timeout to turn server"}, + &cli.DurationFlag{Name: "timeout", Value: 5 * time.Second, Usage: "connect timeout to turn server"}, &cli.StringFlag{Name: "username", Aliases: []string{"u"}, Required: true, Usage: "username for the turn server"}, &cli.StringFlag{Name: "password", Aliases: []string{"p"}, Required: true, Usage: "password for the turn server"}, &cli.StringFlag{Name: "target", Aliases: []string{"t"}, Required: true, Usage: "Target to leak memory to in the form host:port. Should be a public server under your control"}, @@ -206,7 +206,7 @@ func main() { } size := c.Uint("size") - return cmd.MemoryLeak(cmd.MemoryleakOpts{ + return cmd.MemoryLeak(c.Context, cmd.MemoryleakOpts{ TurnServer: turnServer, UseTLS: useTLS, Protocol: protocol, @@ -231,7 +231,7 @@ func main() { &cli.StringFlag{Name: "turnserver", Aliases: []string{"s"}, Required: true, Usage: "turn server to connect to in the format host:port"}, &cli.BoolFlag{Name: "tls", Value: false, Usage: "Use TLS/DTLS on connecting to the STUN or TURN server"}, &cli.StringFlag{Name: "protocol", Value: "udp", Usage: "protocol to use when connecting to the TURN server. Supported values: tcp and udp"}, - &cli.DurationFlag{Name: "timeout", Value: 1 * time.Second, Usage: "connect timeout to turn server"}, + &cli.DurationFlag{Name: "timeout", Value: 5 * time.Second, Usage: "connect timeout to turn server"}, &cli.StringFlag{Name: "username", Aliases: []string{"u"}, Required: true, Usage: "username for the turn server"}, &cli.StringFlag{Name: "password", Aliases: []string{"p"}, Required: true, Usage: "password for the turn server"}, }, @@ -248,7 +248,7 @@ func main() { timeout := c.Duration("timeout") username := c.String("username") password := c.String("password") - return cmd.RangeScan(cmd.RangeScanOpts{ + return cmd.RangeScan(c.Context, cmd.RangeScanOpts{ TurnServer: turnServer, UseTLS: useTLS, Protocol: protocol, @@ -269,7 +269,7 @@ func main() { &cli.StringFlag{Name: "turnserver", Aliases: []string{"s"}, Required: true, Usage: "turn server to connect to in the format host:port"}, &cli.BoolFlag{Name: "tls", Value: false, Usage: "Use TLS/DTLS on connecting to the STUN or TURN server"}, &cli.StringFlag{Name: "protocol", Value: "udp", Usage: "protocol to use when connecting to the TURN server. Supported values: tcp and udp"}, - &cli.DurationFlag{Name: "timeout", Value: 1 * time.Second, Usage: "connect timeout to turn server"}, + &cli.DurationFlag{Name: "timeout", Value: 5 * time.Second, Usage: "connect timeout to turn server"}, &cli.StringFlag{Name: "username", Aliases: []string{"u"}, Required: true, Usage: "username for the turn server"}, &cli.StringFlag{Name: "password", Aliases: []string{"p"}, Required: true, Usage: "password for the turn server"}, &cli.StringFlag{Name: "listen", Aliases: []string{"l"}, Value: "127.0.0.1:1080", Usage: "Address and port to listen on"}, @@ -290,7 +290,7 @@ func main() { password := c.String("password") listen := c.String("listen") dropPublic := c.Bool("drop-public") - return cmd.Socks(cmd.SocksOpts{ + return cmd.Socks(c.Context, cmd.SocksOpts{ TurnServer: turnServer, UseTLS: useTLS, Protocol: protocol, @@ -312,7 +312,7 @@ func main() { &cli.StringFlag{Name: "turnserver", Aliases: []string{"s"}, Required: true, Usage: "turn server to connect to in the format host:port"}, &cli.BoolFlag{Name: "tls", Value: false, Usage: "Use TLS/DTLS on connecting to the STUN or TURN server"}, &cli.StringFlag{Name: "protocol", Value: "udp", Usage: "protocol to use when connecting to the TURN server. Supported values: tcp and udp"}, - &cli.DurationFlag{Name: "timeout", Value: 1 * time.Second, Usage: "connect timeout to turn server"}, + &cli.DurationFlag{Name: "timeout", Value: 5 * time.Second, Usage: "connect timeout to turn server"}, &cli.StringFlag{Name: "username", Aliases: []string{"u"}, Required: true, Usage: "username for the turn server"}, &cli.StringFlag{Name: "password", Aliases: []string{"p"}, Required: true, Usage: "password for the turn server"}, &cli.StringFlag{Name: "ports", Value: "80,443,8080,8081", Usage: "Ports to check"}, @@ -337,7 +337,7 @@ func main() { ips := c.StringSlice("ip") - return cmd.TCPScanner(cmd.TCPScannerOpts{ + return cmd.TCPScanner(c.Context, cmd.TCPScannerOpts{ TurnServer: turnServer, UseTLS: useTLS, Protocol: protocol, @@ -360,7 +360,7 @@ func main() { &cli.StringFlag{Name: "turnserver", Aliases: []string{"s"}, Required: true, Usage: "turn server to connect to in the format host:port"}, &cli.BoolFlag{Name: "tls", Value: false, Usage: "Use TLS/DTLS on connecting to the STUN or TURN server"}, &cli.StringFlag{Name: "protocol", Value: "udp", Usage: "protocol to use when connecting to the TURN server. Supported values: tcp and udp"}, - &cli.DurationFlag{Name: "timeout", Value: 1 * time.Second, Usage: "connect timeout to turn server"}, + &cli.DurationFlag{Name: "timeout", Value: 5 * time.Second, Usage: "connect timeout to turn server"}, &cli.StringFlag{Name: "username", Aliases: []string{"u"}, Required: true, Usage: "username for the turn server"}, &cli.StringFlag{Name: "password", Aliases: []string{"p"}, Required: true, Usage: "password for the turn server"}, &cli.StringFlag{Name: "community-string", Value: "public", Usage: "SNMP community string to use for scanning"}, @@ -383,7 +383,7 @@ func main() { communityString := c.String("community-string") domain := c.String("domain") ips := c.StringSlice("ip") - return cmd.UDPScanner(cmd.UDPScannerOpts{ + return cmd.UDPScanner(c.Context, cmd.UDPScannerOpts{ TurnServer: turnServer, UseTLS: useTLS, Protocol: protocol,