Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ build-test:
docker build --load -t $(TEST_IMAGE) -f e2e/Dockerfile .

build-ebpf:
docker build --load -t $(EBPF_IMAGE) -f socket/Dockerfile .
docker build --load -t $(EBPF_IMAGE) -f activator/Dockerfile .

push-dev: build-installer build-manager
docker push $(INSTALLER_IMAGE)
Expand All @@ -80,8 +80,7 @@ docker-bench: build-test
docker run --rm --privileged --network=host --rm -v $(DOCKER_SOCK):$(DOCKER_SOCK) -v $(PWD):/app $(TEST_IMAGE) make bench

# has to have SYS_ADMIN because the test tries to set netns and mount bpffs
# we use --pid=host to make the ebpf tracker work without a pid resolver
docker-test:
docker-test: build-test
docker run --rm --cap-add=SYS_ADMIN --cap-add=NET_ADMIN --pid=host --userns=host -v $(PWD):/app $(TEST_IMAGE) go test -v -short ./... $(testargs)

CLANG ?= clang
Expand Down Expand Up @@ -113,7 +112,7 @@ ttrpc:

update-vmlinux: ebpf-built-or-build-ebpf
docker run --rm -v $(PWD):/app:Z --entrypoint /bin/sh --user $(shell id -u):$(shell id -g) $(EBPF_IMAGE) \
-c "bpftool btf dump file /sys/kernel/btf/vmlinux format c" | gzip > socket/vmlinux.h.gz
-c "bpftool btf dump file /sys/kernel/btf/vmlinux format c" | gzip > activator/vmlinux.h.gz

ebpf-built-or-build-ebpf:
docker image inspect $(EBPF_IMAGE) || $(MAKE) build-ebpf
2 changes: 1 addition & 1 deletion socket/Dockerfile → activator/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ RUN dnf install -y llvm clang bpftool libbpf-devel golang

RUN mkdir /headers
RUN cp /usr/include/bpf/bpf_* /headers
COPY socket/vmlinux.h.gz /headers
COPY activator/vmlinux.h.gz /headers
RUN gunzip /headers/vmlinux.h.gz

COPY --from=gomod /go /tmp
Expand Down
155 changes: 112 additions & 43 deletions activator/activator.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/cilium/ebpf"
"github.com/containerd/log"
"github.com/containernetworking/plugins/pkg/ns"
"golang.org/x/sys/unix"
)

type Server struct {
Expand Down Expand Up @@ -52,8 +53,10 @@ func NewServer(ctx context.Context, nn ns.NetNS) (*Server, error) {
ns: nn,
sandboxPid: parsePidFromNetNS(nn),
}

return s, os.MkdirAll(PinPath(s.sandboxPid), os.ModePerm)
if err := os.MkdirAll(PinPath(s.sandboxPid), os.ModePerm); err != nil {
return nil, err
}
return s, nil
}

func parsePidFromNetNS(nn ns.NetNS) int {
Expand All @@ -72,15 +75,21 @@ func parsePidFromNetNS(nn ns.NetNS) int {

var ErrMapNotFound = errors.New("bpf map could not be found")

func (s *Server) Start(ctx context.Context, ports []uint16, connHook ConnHook, restoreHook RestoreHook) error {
s.ports = ports
func (s *Server) Start(ctx context.Context, connHook ConnHook, restoreHook RestoreHook, ports ...uint16) error {
s.connHook = connHook
s.restoreHook = restoreHook
s.ports = ports

if err := s.loadPinnedMaps(); err != nil {
return err
}

if err := s.initActivityTracker(); err != nil {
return err
}
// start with disabled redirects
if err := s.DisableRedirects(); err != nil {
return err
}
for _, port := range s.ports {
proxyPort, err := s.listen(ctx, port)
if err != nil {
Expand All @@ -102,9 +111,6 @@ func (s *Server) Started() bool {
}

func (s *Server) Reset() error {
if !s.Started() {
return nil
}
for _, port := range s.ports {
if err := s.enableRedirect(port); err != nil {
return err
Expand All @@ -114,9 +120,6 @@ func (s *Server) Reset() error {
}

func (s *Server) DisableRedirects() error {
if !s.Started() {
return nil
}
for _, port := range s.ports {
if err := s.disableRedirect(port); err != nil {
return err
Expand Down Expand Up @@ -161,7 +164,7 @@ func (s *Server) listen(ctx context.Context, port uint16) (int, error) {
}

func (s *Server) Stop(ctx context.Context) {
log.G(ctx).Debugf("stopping activator")
log.G(ctx).Debug("stopping activator")

if s.proxyCancel != nil {
s.proxyCancel()
Expand All @@ -176,7 +179,7 @@ func (s *Server) Stop(ctx context.Context) {
_ = os.RemoveAll(PinPath(s.sandboxPid))

s.wg.Wait()
log.G(ctx).Debugf("activator stopped")
log.G(ctx).Debug("activator stopped")
}

func (s *Server) serve(ctx context.Context, listener net.Listener, port uint16) {
Expand Down Expand Up @@ -238,7 +241,7 @@ func (s *Server) handleConnection(ctx context.Context, netConn net.Conn, port ui
return
}

backendConn, err := s.connect(ctx, port)
backendConn, err := s.connect(ctx, port, tcpAddr)
if err != nil {
log.G(ctx).Errorf("error establishing connection: %s", err)
return
Expand All @@ -261,13 +264,23 @@ func (s *Server) handleConnection(ctx context.Context, netConn net.Conn, port ui
log.G(ctx).Println("connection closed", conn.RemoteAddr().String())
}

func (s *Server) connect(ctx context.Context, port uint16) (net.Conn, error) {
func (s *Server) connect(ctx context.Context, port uint16, remoteAddr *net.TCPAddr) (net.Conn, error) {
var backendConn net.Conn
// use v4/v6 local and backend addr depending on remoteAddr type
addr := loopbackV4(0)
backendAddr := loopbackV4(port)
if remoteAddr.IP.To4() == nil {
addr = loopbackV6(0)
backendAddr = loopbackV6(port)
}
dialer := net.Dialer{
LocalAddr: addr,
Timeout: s.connectTimeout,
}

ticker := time.NewTicker(time.Millisecond)
defer ticker.Stop()
start := time.Now()

for {
select {
case <-ctx.Done():
Expand All @@ -276,31 +289,9 @@ func (s *Server) connect(ctx context.Context, port uint16) (net.Conn, error) {
if time.Since(start) > s.connectTimeout {
return nil, fmt.Errorf("timeout dialing process")
}

if err := s.ns.Do(func(_ ns.NetNS) error {
// to ensure we don't create a redirect loop we need to know
// the local port of our connection to the activated process.
// We reserve a free port, store it in the disable bpf map and
// then use it to make the connection.
backendConnPort, err := freePort()
if err != nil {
return fmt.Errorf("unable to get free port: %w", err)
}

log.G(ctx).Debugf("registering backend connection port %d in bpf map", backendConnPort)
if err := s.disableRedirect(uint16(backendConnPort)); err != nil {
return err
}

addr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("localhost:%d", backendConnPort))
if err != nil {
return err
}
d := net.Dialer{
LocalAddr: addr,
Timeout: s.connectTimeout,
}
backendConn, err = d.Dial("tcp", fmt.Sprintf("localhost:%d", port))
var err error
backendConn, err = dialer.Dial("tcp", backendAddr.String())
return err
}); err != nil {
var serr syscall.Errno
Expand All @@ -316,11 +307,26 @@ func (s *Server) connect(ctx context.Context, port uint16) (net.Conn, error) {
}
}

func loopbackV4(port uint16) *net.TCPAddr {
return &net.TCPAddr{
IP: net.IP{127, 0, 0, 1},
Port: int(port),
}
}

func loopbackV6(port uint16) *net.TCPAddr {
return &net.TCPAddr{
IP: net.IPv6loopback,
Port: int(port),
}
}

const (
activeConnectionsMap = "active_connections"
disableRedirectMap = "disable_redirect"
egressRedirectsMap = "egress_redirects"
ingressRedirectsMap = "ingress_redirects"
socketTrackerMap = "socket_tracker"
)

func (s *Server) loadPinnedMaps() error {
Expand Down Expand Up @@ -360,6 +366,13 @@ func (s *Server) loadPinnedMaps() error {
}
}

if s.maps.SocketTracker == nil {
s.maps.SocketTracker, err = ebpf.LoadPinnedMap(s.mapPath(socketTrackerMap), opts)
if err != nil {
return err
}
}

return nil
}

Expand All @@ -378,6 +391,62 @@ func (s *Server) RedirectPort(from, to uint16) error {
return nil
}

type NoActivityRecordedErr struct{}

func (err NoActivityRecordedErr) Error() string {
return "no activity recorded"
}

func (s *Server) LastActivity(port uint16) (time.Time, error) {
if !s.started {
return time.Time{}, nil
}
var val uint64
if err := s.maps.SocketTracker.Lookup(&port, &val); err != nil {
return time.Time{}, fmt.Errorf("looking up %d: %w", port, err)
}

if val == 0 {
return time.Time{}, NoActivityRecordedErr{}
}

return convertBPFTime(val)
}

func (s *Server) initActivityTracker() error {
for _, port := range s.ports {
val := uint64(0)
if err := s.maps.SocketTracker.Put(&port, &val); err != nil {
return fmt.Errorf("unable to init activity tracker for port %d: %w", port, err)
}
}
return nil
}

// convertBPFTime takes the value of bpf_ktime_get_ns and converts it to a
// time.Time.
func convertBPFTime(t uint64) (time.Time, error) {
b, err := getBootTimeNS()
if err != nil {
return time.Time{}, err
}

return time.Now().Add(-time.Duration(b - int64(t))), nil
}

// getKtimeNS returns the time elapsed since system boot, in nanoseconds. Does
// not include time the system was suspended. Basically the equivalent of
// bpf_ktime_get_ns.
func getBootTimeNS() (int64, error) {
var ts unix.Timespec
err := unix.ClockGettime(unix.CLOCK_MONOTONIC, &ts)
if err != nil {
return 0, fmt.Errorf("could not get time: %s", err)
}

return unix.TimespecToNsec(ts), nil
}

func (s *Server) registerConnection(port uint16) error {
if err := s.maps.ActiveConnections.Put(&port, uint8(1)); err != nil {
return fmt.Errorf("unable to put port %d into bpf map: %w", port, err)
Expand Down Expand Up @@ -415,8 +484,8 @@ func proxy(ctx context.Context, conn1, conn2 net.Conn) error {

errors := make(chan error, 2)
done := make(chan struct{}, 2)
go copy(done, errors, conn2, conn1)
go copy(done, errors, conn1, conn2)
go cp(done, errors, conn2, conn1)
go cp(done, errors, conn1, conn2)

select {
case <-ctx.Done():
Expand All @@ -429,7 +498,7 @@ func proxy(ctx context.Context, conn1, conn2 net.Conn) error {
}
}

func copy(done chan struct{}, errors chan error, dst io.Writer, src io.Reader) {
func cp(done chan struct{}, errors chan error, dst io.Writer, src io.Reader) {
_, err := io.Copy(dst, src)
done <- struct{}{}
if err != nil {
Expand Down
Loading