diff --git a/Makefile b/Makefile index 3176730..02a5d51 100644 --- a/Makefile +++ b/Makefile @@ -54,7 +54,7 @@ build-test: docker build --load -t $(TEST_IMAGE) -f e2e/Dockerfile . build-ebpf: - docker build --load -t $(EBPF_IMAGE) -f socket/Dockerfile . + docker build --load -t $(EBPF_IMAGE) -f activator/Dockerfile . push-dev: build-installer build-manager docker push $(INSTALLER_IMAGE) @@ -80,8 +80,7 @@ docker-bench: build-test docker run --rm --privileged --network=host --rm -v $(DOCKER_SOCK):$(DOCKER_SOCK) -v $(PWD):/app $(TEST_IMAGE) make bench # has to have SYS_ADMIN because the test tries to set netns and mount bpffs -# we use --pid=host to make the ebpf tracker work without a pid resolver -docker-test: +docker-test: build-test docker run --rm --cap-add=SYS_ADMIN --cap-add=NET_ADMIN --pid=host --userns=host -v $(PWD):/app $(TEST_IMAGE) go test -v -short ./... $(testargs) CLANG ?= clang @@ -113,7 +112,7 @@ ttrpc: update-vmlinux: ebpf-built-or-build-ebpf docker run --rm -v $(PWD):/app:Z --entrypoint /bin/sh --user $(shell id -u):$(shell id -g) $(EBPF_IMAGE) \ - -c "bpftool btf dump file /sys/kernel/btf/vmlinux format c" | gzip > socket/vmlinux.h.gz + -c "bpftool btf dump file /sys/kernel/btf/vmlinux format c" | gzip > activator/vmlinux.h.gz ebpf-built-or-build-ebpf: docker image inspect $(EBPF_IMAGE) || $(MAKE) build-ebpf diff --git a/socket/Dockerfile b/activator/Dockerfile similarity index 94% rename from socket/Dockerfile rename to activator/Dockerfile index 07add51..06c9296 100644 --- a/socket/Dockerfile +++ b/activator/Dockerfile @@ -12,7 +12,7 @@ RUN dnf install -y llvm clang bpftool libbpf-devel golang RUN mkdir /headers RUN cp /usr/include/bpf/bpf_* /headers -COPY socket/vmlinux.h.gz /headers +COPY activator/vmlinux.h.gz /headers RUN gunzip /headers/vmlinux.h.gz COPY --from=gomod /go /tmp diff --git a/activator/activator.go b/activator/activator.go index 3733ef2..cc6fde8 100644 --- a/activator/activator.go +++ b/activator/activator.go @@ -22,6 +22,7 @@ import ( "github.com/cilium/ebpf" "github.com/containerd/log" "github.com/containernetworking/plugins/pkg/ns" + "golang.org/x/sys/unix" ) type Server struct { @@ -52,8 +53,10 @@ func NewServer(ctx context.Context, nn ns.NetNS) (*Server, error) { ns: nn, sandboxPid: parsePidFromNetNS(nn), } - - return s, os.MkdirAll(PinPath(s.sandboxPid), os.ModePerm) + if err := os.MkdirAll(PinPath(s.sandboxPid), os.ModePerm); err != nil { + return nil, err + } + return s, nil } func parsePidFromNetNS(nn ns.NetNS) int { @@ -72,15 +75,21 @@ func parsePidFromNetNS(nn ns.NetNS) int { var ErrMapNotFound = errors.New("bpf map could not be found") -func (s *Server) Start(ctx context.Context, ports []uint16, connHook ConnHook, restoreHook RestoreHook) error { - s.ports = ports +func (s *Server) Start(ctx context.Context, connHook ConnHook, restoreHook RestoreHook, ports ...uint16) error { s.connHook = connHook s.restoreHook = restoreHook + s.ports = ports if err := s.loadPinnedMaps(); err != nil { return err } - + if err := s.initActivityTracker(); err != nil { + return err + } + // start with disabled redirects + if err := s.DisableRedirects(); err != nil { + return err + } for _, port := range s.ports { proxyPort, err := s.listen(ctx, port) if err != nil { @@ -102,9 +111,6 @@ func (s *Server) Started() bool { } func (s *Server) Reset() error { - if !s.Started() { - return nil - } for _, port := range s.ports { if err := s.enableRedirect(port); err != nil { return err @@ -114,9 +120,6 @@ func (s *Server) Reset() error { } func (s *Server) DisableRedirects() error { - if !s.Started() { - return nil - } for _, port := range s.ports { if err := s.disableRedirect(port); err != nil { return err @@ -161,7 +164,7 @@ func (s *Server) listen(ctx context.Context, port uint16) (int, error) { } func (s *Server) Stop(ctx context.Context) { - log.G(ctx).Debugf("stopping activator") + log.G(ctx).Debug("stopping activator") if s.proxyCancel != nil { s.proxyCancel() @@ -176,7 +179,7 @@ func (s *Server) Stop(ctx context.Context) { _ = os.RemoveAll(PinPath(s.sandboxPid)) s.wg.Wait() - log.G(ctx).Debugf("activator stopped") + log.G(ctx).Debug("activator stopped") } func (s *Server) serve(ctx context.Context, listener net.Listener, port uint16) { @@ -238,7 +241,7 @@ func (s *Server) handleConnection(ctx context.Context, netConn net.Conn, port ui return } - backendConn, err := s.connect(ctx, port) + backendConn, err := s.connect(ctx, port, tcpAddr) if err != nil { log.G(ctx).Errorf("error establishing connection: %s", err) return @@ -261,13 +264,23 @@ func (s *Server) handleConnection(ctx context.Context, netConn net.Conn, port ui log.G(ctx).Println("connection closed", conn.RemoteAddr().String()) } -func (s *Server) connect(ctx context.Context, port uint16) (net.Conn, error) { +func (s *Server) connect(ctx context.Context, port uint16, remoteAddr *net.TCPAddr) (net.Conn, error) { var backendConn net.Conn + // use v4/v6 local and backend addr depending on remoteAddr type + addr := loopbackV4(0) + backendAddr := loopbackV4(port) + if remoteAddr.IP.To4() == nil { + addr = loopbackV6(0) + backendAddr = loopbackV6(port) + } + dialer := net.Dialer{ + LocalAddr: addr, + Timeout: s.connectTimeout, + } ticker := time.NewTicker(time.Millisecond) defer ticker.Stop() start := time.Now() - for { select { case <-ctx.Done(): @@ -276,31 +289,9 @@ func (s *Server) connect(ctx context.Context, port uint16) (net.Conn, error) { if time.Since(start) > s.connectTimeout { return nil, fmt.Errorf("timeout dialing process") } - if err := s.ns.Do(func(_ ns.NetNS) error { - // to ensure we don't create a redirect loop we need to know - // the local port of our connection to the activated process. - // We reserve a free port, store it in the disable bpf map and - // then use it to make the connection. - backendConnPort, err := freePort() - if err != nil { - return fmt.Errorf("unable to get free port: %w", err) - } - - log.G(ctx).Debugf("registering backend connection port %d in bpf map", backendConnPort) - if err := s.disableRedirect(uint16(backendConnPort)); err != nil { - return err - } - - addr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("localhost:%d", backendConnPort)) - if err != nil { - return err - } - d := net.Dialer{ - LocalAddr: addr, - Timeout: s.connectTimeout, - } - backendConn, err = d.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + var err error + backendConn, err = dialer.Dial("tcp", backendAddr.String()) return err }); err != nil { var serr syscall.Errno @@ -316,11 +307,26 @@ func (s *Server) connect(ctx context.Context, port uint16) (net.Conn, error) { } } +func loopbackV4(port uint16) *net.TCPAddr { + return &net.TCPAddr{ + IP: net.IP{127, 0, 0, 1}, + Port: int(port), + } +} + +func loopbackV6(port uint16) *net.TCPAddr { + return &net.TCPAddr{ + IP: net.IPv6loopback, + Port: int(port), + } +} + const ( activeConnectionsMap = "active_connections" disableRedirectMap = "disable_redirect" egressRedirectsMap = "egress_redirects" ingressRedirectsMap = "ingress_redirects" + socketTrackerMap = "socket_tracker" ) func (s *Server) loadPinnedMaps() error { @@ -360,6 +366,13 @@ func (s *Server) loadPinnedMaps() error { } } + if s.maps.SocketTracker == nil { + s.maps.SocketTracker, err = ebpf.LoadPinnedMap(s.mapPath(socketTrackerMap), opts) + if err != nil { + return err + } + } + return nil } @@ -378,6 +391,62 @@ func (s *Server) RedirectPort(from, to uint16) error { return nil } +type NoActivityRecordedErr struct{} + +func (err NoActivityRecordedErr) Error() string { + return "no activity recorded" +} + +func (s *Server) LastActivity(port uint16) (time.Time, error) { + if !s.started { + return time.Time{}, nil + } + var val uint64 + if err := s.maps.SocketTracker.Lookup(&port, &val); err != nil { + return time.Time{}, fmt.Errorf("looking up %d: %w", port, err) + } + + if val == 0 { + return time.Time{}, NoActivityRecordedErr{} + } + + return convertBPFTime(val) +} + +func (s *Server) initActivityTracker() error { + for _, port := range s.ports { + val := uint64(0) + if err := s.maps.SocketTracker.Put(&port, &val); err != nil { + return fmt.Errorf("unable to init activity tracker for port %d: %w", port, err) + } + } + return nil +} + +// convertBPFTime takes the value of bpf_ktime_get_ns and converts it to a +// time.Time. +func convertBPFTime(t uint64) (time.Time, error) { + b, err := getBootTimeNS() + if err != nil { + return time.Time{}, err + } + + return time.Now().Add(-time.Duration(b - int64(t))), nil +} + +// getKtimeNS returns the time elapsed since system boot, in nanoseconds. Does +// not include time the system was suspended. Basically the equivalent of +// bpf_ktime_get_ns. +func getBootTimeNS() (int64, error) { + var ts unix.Timespec + err := unix.ClockGettime(unix.CLOCK_MONOTONIC, &ts) + if err != nil { + return 0, fmt.Errorf("could not get time: %s", err) + } + + return unix.TimespecToNsec(ts), nil +} + func (s *Server) registerConnection(port uint16) error { if err := s.maps.ActiveConnections.Put(&port, uint8(1)); err != nil { return fmt.Errorf("unable to put port %d into bpf map: %w", port, err) @@ -415,8 +484,8 @@ func proxy(ctx context.Context, conn1, conn2 net.Conn) error { errors := make(chan error, 2) done := make(chan struct{}, 2) - go copy(done, errors, conn2, conn1) - go copy(done, errors, conn1, conn2) + go cp(done, errors, conn2, conn1) + go cp(done, errors, conn1, conn2) select { case <-ctx.Done(): @@ -429,7 +498,7 @@ func proxy(ctx context.Context, conn1, conn2 net.Conn) error { } } -func copy(done chan struct{}, errors chan error, dst io.Writer, src io.Reader) { +func cp(done chan struct{}, errors chan error, dst io.Writer, src io.Reader) { _, err := io.Copy(dst, src) done <- struct{}{} if err != nil { diff --git a/activator/activator_test.go b/activator/activator_test.go index e3fa23f..eec8e69 100644 --- a/activator/activator_test.go +++ b/activator/activator_test.go @@ -9,6 +9,7 @@ import ( "net/http" "net/http/httptest" "os" + "path/filepath" "sync" "testing" "time" @@ -18,28 +19,21 @@ import ( "github.com/stretchr/testify/require" ) +type testCase struct { + parallelReqs int + connHook ConnHook + expectedBody string + expectedCode int + expectLastActivity bool + ipv6 bool + setBinaryName bool +} + func TestActivator(t *testing.T) { require.NoError(t, MountBPFFS(BPFFSPath)) - nn, err := ns.GetCurrentNS() require.NoError(t, err) - ctx, cancel := context.WithCancel(context.Background()) - s, err := NewServer(ctx, nn) - require.NoError(t, err) - - bpf, err := InitBPF(os.Getpid(), slog.Default()) - require.NoError(t, err) - require.NoError(t, bpf.AttachRedirector("lo")) - - port, err := freePort() - require.NoError(t, err) - - t.Cleanup(func() { - s.Stop(ctx) - cancel() - }) - c := &http.Client{ Timeout: time.Second, Transport: &http.Transport{ @@ -47,21 +41,25 @@ func TestActivator(t *testing.T) { }, } - tests := map[string]struct { - parallelReqs int - connHook ConnHook - expectedBody string - expectedCode int - }{ - "no probe": { - parallelReqs: 1, - expectedBody: "ok", - expectedCode: http.StatusOK, + tests := map[string]testCase{ + "no hook": { + parallelReqs: 1, + expectedBody: "ok", + expectedCode: http.StatusOK, + expectLastActivity: true, + }, + "no hook ipv6": { + parallelReqs: 1, + expectedBody: "ok", + expectedCode: http.StatusOK, + ipv6: true, + expectLastActivity: true, }, "10 in parallel": { - parallelReqs: 10, - expectedBody: "ok", - expectedCode: http.StatusOK, + parallelReqs: 10, + expectedBody: "ok", + expectedCode: http.StatusOK, + expectLastActivity: true, }, "conn hook": { parallelReqs: 1, @@ -72,18 +70,78 @@ func TestActivator(t *testing.T) { } return conn, false, resp.Write(conn) }, - expectedCode: http.StatusForbidden, + expectedCode: http.StatusForbidden, + expectLastActivity: true, + }, + "conn hook ipv6": { + parallelReqs: 1, + expectedBody: "", + connHook: func(conn net.Conn) (net.Conn, bool, error) { + resp := http.Response{ + StatusCode: http.StatusForbidden, + } + return conn, false, resp.Write(conn) + }, + expectedCode: http.StatusForbidden, + expectLastActivity: true, + ipv6: true, + }, + "ignore activity with binary name set": { + parallelReqs: 1, + expectedBody: "ok", + expectedCode: http.StatusOK, + setBinaryName: true, + expectLastActivity: false, + }, + "ignore activity with binary name set ipv6": { + parallelReqs: 1, + expectedBody: "ok", + expectedCode: http.StatusOK, + ipv6: true, + setBinaryName: true, + expectLastActivity: false, }, } wg := sync.WaitGroup{} for name, tc := range tests { t.Run(name, func(t *testing.T) { - startServer(t, ctx, s, port, tc.connHook) + ctx, cancel := context.WithCancel(context.Background()) + s, err := NewServer(ctx, nn) + require.NoError(t, err) + + port, err := freePort() + require.NoError(t, err) + + t.Cleanup(func() { + s.Stop(ctx) + cancel() + }) + + exeName := "" + if tc.setBinaryName { + currentExe, err := os.Executable() + require.NoError(t, err) + exeName = filepath.Base(currentExe) + } + bpf, err := InitBPF(os.Getpid(), slog.Default(), exeName, OverrideMapSize( + // not completely sure why this happens but when testing in + // github actions, the default map size of 128 makes the test + // very flaky so we increase it here. + map[string]uint32{SocketTrackerMap: 1024}, + )) + require.NoError(t, err) + require.NoError(t, bpf.AttachRedirector("lo")) + + startServer(t, ctx, s, uint16(port), &tc) for i := 0; i < tc.parallelReqs; i++ { wg.Add(1) go func() { defer wg.Done() - req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://localhost:%d", port), nil) + host := "127.0.0.1" + if tc.ipv6 { + host = "[::1]" + } + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s:%d", host, port), nil) if !assert.NoError(t, err) { return } @@ -104,31 +162,53 @@ func TestActivator(t *testing.T) { }() } wg.Wait() + var key uint16 + var val uint64 + count := 0 + iter := s.maps.SocketTracker.Iterate() + for iter.Next(&key, &val) { + t.Logf("found %d: %d", key, val) + count++ + } + assert.Equal(t, 1, count, "one element in socket tracker map") + last, err := s.LastActivity(uint16(port)) + if tc.expectLastActivity { + assert.NoError(t, err) + assert.Less(t, time.Since(last), time.Second) + } else { + assert.Error(t, err) + assert.ErrorIs(t, err, NoActivityRecordedErr{}) + } assert.NoError(t, s.Reset()) + s.Stop(t.Context()) }) } } -func startServer(t *testing.T, ctx context.Context, s *Server, port int, connHook ConnHook) { +func startServer(t *testing.T, ctx context.Context, s *Server, port uint16, tc *testCase) { response := "ok" ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, response) })) - if connHook == nil { - connHook = func(c net.Conn) (net.Conn, bool, error) { + if tc.connHook == nil { + tc.connHook = func(c net.Conn) (net.Conn, bool, error) { return c, true, nil } } once := sync.Once{} err := s.Start( - ctx, []uint16{uint16(port)}, - connHook, + ctx, + tc.connHook, func() error { once.Do(func() { // simulate a delay until our server is started time.Sleep(time.Millisecond * 200) - l, err := net.Listen("tcp4", fmt.Sprintf(":%d", port)) + network := "tcp4" + if tc.ipv6 { + network = "tcp6" + } + l, err := net.Listen(network, fmt.Sprintf(":%d", port)) require.NoError(t, err) if err := s.DisableRedirects(); err != nil { @@ -139,7 +219,7 @@ func startServer(t *testing.T, ctx context.Context, s *Server, port int, connHoo ts.Listener.Close() ts.Listener = l ts.Start() - t.Logf("listening on :%d", port) + t.Logf("listening on %s", l.Addr().String()) t.Cleanup(func() { l.Close() @@ -148,6 +228,8 @@ func startServer(t *testing.T, ctx context.Context, s *Server, port int, connHoo }) return nil }, + port, ) require.NoError(t, err) + s.enableRedirect(port) } diff --git a/activator/bpf.go b/activator/bpf.go index 94cbe06..8f097d8 100644 --- a/activator/bpf.go +++ b/activator/bpf.go @@ -3,6 +3,7 @@ package activator import ( "fmt" "log/slog" + "maps" "net" "os" "path/filepath" @@ -17,7 +18,14 @@ import ( // $BPF_CLANG and $BPF_CFLAGS are set by the Makefile. //go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc $BPF_CLANG -cflags $BPF_CFLAGS bpf redirector.c -- -I/headers -const BPFFSPath = "/sys/fs/bpf" +const ( + BPFFSPath = "/sys/fs/bpf" + probeBinaryNameVariable = "probe_binary_name" + probeBinaryNameMaxLength = 16 + SocketTrackerMap = "socket_tracker" + PodKubeletAddrsMapv4 = "kubelet_addrs_v4" + PodKubeletAddrsMapv6 = "kubelet_addrs_v6" +) type BPF struct { pid int @@ -27,7 +35,27 @@ type BPF struct { log *slog.Logger } -func InitBPF(pid int, log *slog.Logger) (*BPF, error) { +type BPFConfig struct { + mapSizes map[string]uint32 +} + +type BPFOpts func(cfg *BPFConfig) + +func OverrideMapSize(mapSizes map[string]uint32) BPFOpts { + return func(cfg *BPFConfig) { + maps.Copy(cfg.mapSizes, mapSizes) + } +} + +func InitBPF(pid int, log *slog.Logger, probeBinaryName string, opts ...BPFOpts) (*BPF, error) { + cfg := &BPFConfig{ + mapSizes: map[string]uint32{ + SocketTrackerMap: 128, + }, + } + for _, opt := range opts { + opt(cfg) + } // Allow the current process to lock memory for eBPF resources. if err := rlimit.RemoveMemlock(); err != nil { return nil, err @@ -40,8 +68,28 @@ func InitBPF(pid int, log *slog.Logger) (*BPF, error) { return nil, fmt.Errorf("failed to create bpf fs subpath: %w", err) } + spec, err := loadBpf() + if err != nil { + return nil, fmt.Errorf("loading bpf objects: %w", err) + } + + if len([]byte(probeBinaryName)) > probeBinaryNameMaxLength { + return nil, fmt.Errorf( + "probe binary name %s is too long (%d bytes), max is %d bytes", + probeBinaryName, len([]byte(probeBinaryName)), probeBinaryNameMaxLength, + ) + } + binName := [probeBinaryNameMaxLength]byte{} + copy(binName[:], probeBinaryName[:]) + if err := spec.Variables[probeBinaryNameVariable].Set(binName); err != nil { + return nil, fmt.Errorf("setting probe binary variable: %w", err) + } + + for mapName, size := range cfg.mapSizes { + spec.Maps[mapName].MaxEntries = size + } objs := bpfObjects{} - if err := loadBpfObjects(&objs, &ebpf.CollectionOptions{ + if err := spec.LoadAndAssign(&objs, &ebpf.CollectionOptions{ Maps: ebpf.MapOptions{ PinPath: path, }, diff --git a/activator/bpf_bpfeb.go b/activator/bpf_bpfeb.go index ef2cf0e..89ebb73 100644 --- a/activator/bpf_bpfeb.go +++ b/activator/bpf_bpfeb.go @@ -8,10 +8,16 @@ import ( _ "embed" "fmt" "io" + "structs" "github.com/cilium/ebpf" ) +type bpfIpv6Addr struct { + _ structs.HostLayout + U6Addr8 [16]uint8 +} + // loadBpf returns the embedded CollectionSpec for bpf. func loadBpf() (*ebpf.CollectionSpec, error) { reader := bytes.NewReader(_BpfBytes) @@ -66,12 +72,16 @@ type bpfMapSpecs struct { DisableRedirect *ebpf.MapSpec `ebpf:"disable_redirect"` EgressRedirects *ebpf.MapSpec `ebpf:"egress_redirects"` IngressRedirects *ebpf.MapSpec `ebpf:"ingress_redirects"` + KubeletAddrsV4 *ebpf.MapSpec `ebpf:"kubelet_addrs_v4"` + KubeletAddrsV6 *ebpf.MapSpec `ebpf:"kubelet_addrs_v6"` + SocketTracker *ebpf.MapSpec `ebpf:"socket_tracker"` } // bpfVariableSpecs contains global variables before they are loaded into the kernel. // // It can be passed ebpf.CollectionSpec.Assign. type bpfVariableSpecs struct { + ProbeBinaryName *ebpf.VariableSpec `ebpf:"probe_binary_name"` } // bpfObjects contains all objects after they have been loaded into the kernel. @@ -98,6 +108,9 @@ type bpfMaps struct { DisableRedirect *ebpf.Map `ebpf:"disable_redirect"` EgressRedirects *ebpf.Map `ebpf:"egress_redirects"` IngressRedirects *ebpf.Map `ebpf:"ingress_redirects"` + KubeletAddrsV4 *ebpf.Map `ebpf:"kubelet_addrs_v4"` + KubeletAddrsV6 *ebpf.Map `ebpf:"kubelet_addrs_v6"` + SocketTracker *ebpf.Map `ebpf:"socket_tracker"` } func (m *bpfMaps) Close() error { @@ -106,6 +119,9 @@ func (m *bpfMaps) Close() error { m.DisableRedirect, m.EgressRedirects, m.IngressRedirects, + m.KubeletAddrsV4, + m.KubeletAddrsV6, + m.SocketTracker, ) } @@ -113,6 +129,7 @@ func (m *bpfMaps) Close() error { // // It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign. type bpfVariables struct { + ProbeBinaryName *ebpf.Variable `ebpf:"probe_binary_name"` } // bpfPrograms contains all programs after they have been loaded into the kernel. diff --git a/activator/bpf_bpfeb.o b/activator/bpf_bpfeb.o index b2040c6..99a6f5a 100644 Binary files a/activator/bpf_bpfeb.o and b/activator/bpf_bpfeb.o differ diff --git a/activator/bpf_bpfel.go b/activator/bpf_bpfel.go index 677eddc..d10387a 100644 --- a/activator/bpf_bpfel.go +++ b/activator/bpf_bpfel.go @@ -8,10 +8,16 @@ import ( _ "embed" "fmt" "io" + "structs" "github.com/cilium/ebpf" ) +type bpfIpv6Addr struct { + _ structs.HostLayout + U6Addr8 [16]uint8 +} + // loadBpf returns the embedded CollectionSpec for bpf. func loadBpf() (*ebpf.CollectionSpec, error) { reader := bytes.NewReader(_BpfBytes) @@ -66,12 +72,16 @@ type bpfMapSpecs struct { DisableRedirect *ebpf.MapSpec `ebpf:"disable_redirect"` EgressRedirects *ebpf.MapSpec `ebpf:"egress_redirects"` IngressRedirects *ebpf.MapSpec `ebpf:"ingress_redirects"` + KubeletAddrsV4 *ebpf.MapSpec `ebpf:"kubelet_addrs_v4"` + KubeletAddrsV6 *ebpf.MapSpec `ebpf:"kubelet_addrs_v6"` + SocketTracker *ebpf.MapSpec `ebpf:"socket_tracker"` } // bpfVariableSpecs contains global variables before they are loaded into the kernel. // // It can be passed ebpf.CollectionSpec.Assign. type bpfVariableSpecs struct { + ProbeBinaryName *ebpf.VariableSpec `ebpf:"probe_binary_name"` } // bpfObjects contains all objects after they have been loaded into the kernel. @@ -98,6 +108,9 @@ type bpfMaps struct { DisableRedirect *ebpf.Map `ebpf:"disable_redirect"` EgressRedirects *ebpf.Map `ebpf:"egress_redirects"` IngressRedirects *ebpf.Map `ebpf:"ingress_redirects"` + KubeletAddrsV4 *ebpf.Map `ebpf:"kubelet_addrs_v4"` + KubeletAddrsV6 *ebpf.Map `ebpf:"kubelet_addrs_v6"` + SocketTracker *ebpf.Map `ebpf:"socket_tracker"` } func (m *bpfMaps) Close() error { @@ -106,6 +119,9 @@ func (m *bpfMaps) Close() error { m.DisableRedirect, m.EgressRedirects, m.IngressRedirects, + m.KubeletAddrsV4, + m.KubeletAddrsV6, + m.SocketTracker, ) } @@ -113,6 +129,7 @@ func (m *bpfMaps) Close() error { // // It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign. type bpfVariables struct { + ProbeBinaryName *ebpf.Variable `ebpf:"probe_binary_name"` } // bpfPrograms contains all programs after they have been loaded into the kernel. diff --git a/activator/bpf_bpfel.o b/activator/bpf_bpfel.o index 14f01b9..1232e42 100644 Binary files a/activator/bpf_bpfel.o and b/activator/bpf_bpfel.o differ diff --git a/activator/redirector.c b/activator/redirector.c index 8fdb8cc..39bd5bd 100644 --- a/activator/redirector.c +++ b/activator/redirector.c @@ -3,6 +3,7 @@ #include "vmlinux.h" #include "bpf_helpers.h" #include "bpf_endian.h" +#include "bpf_core_read.h" char __license[] SEC("license") = "Dual MIT/GPL"; @@ -13,7 +14,7 @@ char __license[] SEC("license") = "Dual MIT/GPL"; struct { __uint(type, BPF_MAP_TYPE_LRU_HASH); - __uint(max_entries, 128); + __uint(max_entries, 128); // allows for 128 different ports in a single pod __type(key, __be16); // sport __type(value, __be16); // dport __uint(pinning, LIBBPF_PIN_BY_NAME); @@ -21,7 +22,7 @@ struct { struct { __uint(type, BPF_MAP_TYPE_LRU_HASH); - __uint(max_entries, 128); + __uint(max_entries, 128); // allows for 128 different ports in a single pod __type(key, __be16); // sport __type(value, __be16); // dport __uint(pinning, LIBBPF_PIN_BY_NAME); @@ -29,7 +30,7 @@ struct { struct { __uint(type, BPF_MAP_TYPE_LRU_HASH); - __uint(max_entries, 512); + __uint(max_entries, 128); // allows for 128 different ports in a single pod __type(key, __be16); // proxy port __type(value, u8); // unused __uint(pinning, LIBBPF_PIN_BY_NAME); @@ -37,23 +38,52 @@ struct { struct { __uint(type, BPF_MAP_TYPE_LRU_HASH); - __uint(max_entries, 512); // TBD but should probably be enough + __uint(max_entries, 512); // 512 max connections while application is being restored __type(key, __be16); // remote_port __type(value, u8); // unused __uint(pinning, LIBBPF_PIN_BY_NAME); } active_connections SEC(".maps"); -static __always_inline int disabled(__be16 sport_h, __be16 dport_h) { - void *disable_redirect_map = &disable_redirect; +struct { + __uint(type, BPF_MAP_TYPE_LRU_HASH); + __uint(max_entries, 128); // allows for 128 different ports in a single pod + __type(key, __be16); // dport + __type(value, __u64); // ktime ns of the last tracked event + __uint(pinning, LIBBPF_PIN_BY_NAME); +} socket_tracker SEC(".maps"); - void *disabled_s = bpf_map_lookup_elem(disable_redirect_map, &sport_h); +struct { + __uint(type, BPF_MAP_TYPE_LRU_HASH); + __uint(max_entries, 1); + __type(key, u8); // fixed identifier (0) + __type(value, __be32); // kubelet addr + __uint(pinning, LIBBPF_PIN_BY_NAME); +} kubelet_addrs_v4 SEC(".maps"); - if (disabled_s) { - return 1; - } +struct ipv6_addr { + __u8 u6_addr8[16]; +}; - void *disabled_d = bpf_map_lookup_elem(disable_redirect_map, &dport_h); +struct { + __uint(type, BPF_MAP_TYPE_LRU_HASH); + __uint(max_entries, 1); + __type(key, u8); // fixed identifier (0) + __type(value, struct ipv6_addr); // kubelet addr + __uint(pinning, LIBBPF_PIN_BY_NAME); +} kubelet_addrs_v6 SEC(".maps"); + +const volatile char probe_binary_name[TASK_COMM_LEN] = ""; + +static __always_inline int track_activity(__be16 dport) { + __u64 time = bpf_ktime_get_ns(); + void *tracker = &socket_tracker; + // bpf_printk("tracking activity to %d", dport); + return bpf_map_update_elem(tracker, &dport, &time, BPF_EXIST); +} +static __always_inline int disabled(__be16 dport_h) { + void *disable_redirect_map = &disable_redirect; + void *disabled_d = bpf_map_lookup_elem(disable_redirect_map, &dport_h); if (disabled_d) { return 1; } @@ -72,7 +102,7 @@ static __always_inline int ingress_redirect(struct tcphdr *tcp) { if (new_dest) { // check ports which should not be redirected - if (disabled(sport_h, dport_h)) { + if (disabled(dport_h)) { // if we can find an active connection on the source port, we need // to redirect regardless until the connection is closed. void *conn_sport = bpf_map_lookup_elem(active_connections_map, &sport_h); @@ -136,6 +166,61 @@ static __always_inline struct iphdr* ipv4_header(void *data, void *data_end) { return ip4; } +static __always_inline __be32 lookup_kubelet_ip_v4(struct iphdr *ip) { + u8 key = 0; + char comm[TASK_COMM_LEN]; + void *kubelet_addrs = &kubelet_addrs_v4; + __be32 *existing_addr = bpf_map_lookup_elem(kubelet_addrs, &key); + if (existing_addr) { + // bpf_printk("returning existing kubelet addr: %pI4", existing_addr); + return *existing_addr; + } + // bpf_get_current_comm is not available in a tc program on arm64, so we use + // bpf_get_current_task to get the comm. + struct task_struct *task = (void *)bpf_get_current_task(); + BPF_CORE_READ_STR_INTO(&comm, task, comm); + if (bpf_strncmp(comm, TASK_COMM_LEN, (char *)probe_binary_name) == 0) { + // bpf_printk("found kubelet addr v4: %pI4", &ip->saddr); + bpf_map_update_elem(kubelet_addrs, &key, &ip->saddr, BPF_ANY); + return ip->saddr; + } + return 0; +} + +static __always_inline bool ipv6_addr_equal(const struct in6_addr *a1, const struct in6_addr *a2) { + if (a1 == NULL || a2 == NULL) { + return false; + } + #pragma unroll + for (int i = 0; i < 15; i++) { + if (a1->in6_u.u6_addr8[i] != a2->in6_u.u6_addr8[i]) { + return false; + } + } + return true; +} + +static __always_inline struct in6_addr* lookup_kubelet_ip_v6(struct ipv6hdr *ip) { + char comm[TASK_COMM_LEN]; + u8 key = 0; + void *kubelet_addrs = &kubelet_addrs_v6; + struct in6_addr *existing_addr = bpf_map_lookup_elem(kubelet_addrs, &key); + if (existing_addr) { + // bpf_printk("returning existing kubelet v6 addr: %pI6", existing_addr); + return existing_addr; + } + // bpf_get_current_comm is not available in a tc program on arm64, so we use + // bpf_get_current_task to get the comm. + struct task_struct *task = (void *)bpf_get_current_task(); + BPF_CORE_READ_STR_INTO(&comm, task, comm); + if (bpf_strncmp(comm, TASK_COMM_LEN, (char *)probe_binary_name) == 0) { + // bpf_printk("found kubelet addr v6: %pI6", &ip->saddr); + bpf_map_update_elem(kubelet_addrs, &key, &ip->saddr, BPF_ANY); + return &ip->saddr; + } + return NULL; +} + static __always_inline int parse_and_redirect(struct __sk_buff *ctx, bool ingress) { void *data_end = (void *)(long)ctx->data_end; void *data = (void *)(long)ctx->data; @@ -146,6 +231,13 @@ static __always_inline int parse_and_redirect(struct __sk_buff *ctx, bool ingres if ((void*)ip4 + sizeof(*ip4) <= data_end) { if (ip4->protocol == IPPROTO_TCP) { tcp = (void*)ip4 + sizeof(*ip4); + if ((tcp != NULL) && ((void*)tcp + sizeof(*tcp) <= data_end)) { + if (ingress) { + if (ip4->saddr != lookup_kubelet_ip_v4(ip4)) { + track_activity(bpf_ntohs(tcp->dest)); + } + } + } } } } else { @@ -154,6 +246,13 @@ static __always_inline int parse_and_redirect(struct __sk_buff *ctx, bool ingres if ((void*)ip6 + sizeof(*ip6) <= data_end) { if (ip6->nexthdr == NEXTHDR_TCP) { tcp = (void*)ip6 + sizeof(*ip6); + if ((tcp != NULL) && ((void*)tcp + sizeof(*tcp) <= data_end)) { + if (ingress) { + if (!ipv6_addr_equal(&ip6->saddr, lookup_kubelet_ip_v6(ip6))) { + track_activity(bpf_ntohs(tcp->dest)); + } + } + } } } } diff --git a/socket/vmlinux.h.gz b/activator/vmlinux.h.gz similarity index 100% rename from socket/vmlinux.h.gz rename to activator/vmlinux.h.gz diff --git a/cmd/manager/main.go b/cmd/manager/main.go index 4ee521a..9941fd3 100644 --- a/cmd/manager/main.go +++ b/cmd/manager/main.go @@ -18,7 +18,6 @@ import ( v1 "github.com/ctrox/zeropod/api/runtime/v1" "github.com/ctrox/zeropod/manager" "github.com/ctrox/zeropod/manager/node" - "github.com/ctrox/zeropod/socket" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" corev1 "k8s.io/api/core/v1" @@ -89,13 +88,7 @@ func main() { ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() - tracker, cleanSocketTracker, err := socket.LoadEBPFTracker(*probeBinaryName) - if err != nil { - log.Warn("loading socket tracker failed, scaling down with static duration", "err", err) - cleanSocketTracker = func() error { return nil } - } - - if err := manager.AttachRedirectors(ctx, log, tracker); err != nil { + if err := manager.AttachRedirectors(ctx, log, *probeBinaryName); err != nil { log.Warn("attaching redirectors failed: restoring containers on traffic is disabled", "err", err) } @@ -167,7 +160,6 @@ func main() { <-ctx.Done() log.Info("stopping manager") - cleanSocketTracker() if err := server.Shutdown(ctx); err != nil { log.Error("shutting down server", "err", err) } diff --git a/e2e/kind.yaml b/e2e/kind.yaml index 4d258ae..3d6333f 100644 --- a/e2e/kind.yaml +++ b/e2e/kind.yaml @@ -9,15 +9,9 @@ nodes: image: kindest/node:v1.34.0@sha256:7416a61b42b1662ca6ca89f02028ac133a309a2a30ba309614e8ec94d976dc5a - role: worker image: kindest/node:v1.34.0@sha256:7416a61b42b1662ca6ca89f02028ac133a309a2a30ba309614e8ec94d976dc5a - extraMounts: - - hostPath: /proc - containerPath: /host/proc labels: zeropod.ctrox.dev/node: "true" - role: worker image: kindest/node:v1.34.0@sha256:7416a61b42b1662ca6ca89f02028ac133a309a2a30ba309614e8ec94d976dc5a - extraMounts: - - hostPath: /proc - containerPath: /host/proc labels: zeropod.ctrox.dev/node: "true" diff --git a/e2e/migration_test.go b/e2e/migration_test.go index e94d855..e0c09f0 100644 --- a/e2e/migration_test.go +++ b/e2e/migration_test.go @@ -217,15 +217,12 @@ func readPodDataEventually(t testing.TB, pod *corev1.Pod) (string, error) { } func defaultBeforeMigration(t *testing.T) { - assert.Eventually(t, func() bool { - if err := freezerWrite(t.Name(), e2e.port); err != nil { - return false - } + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.NoError(c, freezerWrite(t.Name(), e2e.port)) f, err := freezerRead(e2e.port) - if err != nil { - return false + if assert.NoError(c, err) { + assert.Equal(c, t.Name(), f.Data) } - return t.Name() == f.Data }, time.Second*10, time.Second) } diff --git a/e2e/setup_test.go b/e2e/setup_test.go index e9347ef..a0cdce5 100644 --- a/e2e/setup_test.go +++ b/e2e/setup_test.go @@ -165,10 +165,6 @@ func startKind(t testing.TB, name, kubeconfig string, port int) (c *rest.Config, // setup mounts for ebpf tcp tracking (it needs to map host pids to // container pids) extraMounts := []v1alpha4.Mount{ - { - HostPath: "/proc", - ContainerPath: "/host/proc", - }, { HostPath: activator.BPFFSPath, ContainerPath: activator.BPFFSPath, diff --git a/manager/redirector_attacher.go b/manager/redirector_attacher.go index f4a1c9b..e59eacc 100644 --- a/manager/redirector_attacher.go +++ b/manager/redirector_attacher.go @@ -14,15 +14,14 @@ import ( "github.com/containernetworking/plugins/pkg/ns" "github.com/ctrox/zeropod/activator" - "github.com/ctrox/zeropod/socket" "github.com/fsnotify/fsnotify" ) type Redirector struct { sync.Mutex - sandboxes map[int]sandbox - log *slog.Logger - tracker socket.Tracker + sandboxes map[int]sandbox + log *slog.Logger + probeBinaryName string } type sandbox struct { @@ -42,11 +41,11 @@ const ( // can be found it attaches the redirector BPF programs to the network // interfaces of the sandbox. The directories are expected to be created by // the zeropod shim on startup. -func AttachRedirectors(ctx context.Context, log *slog.Logger, tracker socket.Tracker) error { +func AttachRedirectors(ctx context.Context, log *slog.Logger, probeBinaryName string) error { r := &Redirector{ - sandboxes: make(map[int]sandbox), - log: log, - tracker: tracker, + sandboxes: make(map[int]sandbox), + log: log, + probeBinaryName: probeBinaryName, } if _, err := os.Stat(activator.MapsPath()); os.IsNotExist(err) { @@ -120,7 +119,7 @@ func (r *Redirector) watchForSandboxPids(ctx context.Context) error { r.Lock() if sb, ok := r.sandboxes[pid]; ok { r.log.Info("cleaning up redirector", "pid", pid) - if err := sb.Remove(r.tracker); err != nil { + if err := sb.Remove(); err != nil { r.log.Error("error cleaning up redirector", "err", err) } } @@ -135,7 +134,7 @@ func (r *Redirector) watchForSandboxPids(ctx context.Context) error { } func (r *Redirector) attachRedirector(pid int) error { - bpf, err := activator.InitBPF(pid, r.log) + bpf, err := activator.InitBPF(pid, r.log, r.probeBinaryName) if err != nil { return fmt.Errorf("unable to initialize BPF: %w", err) } @@ -165,11 +164,6 @@ func (r *Redirector) attachRedirector(pid int) error { r.sandboxes[pid] = sandbox{activator: bpf, ips: sandboxIPs} r.Unlock() - for _, ip := range sandboxIPs { - if err := r.trackSandboxIP(ip); err != nil { - return fmt.Errorf("tracking sandbox IP: %w", err) - } - } return nil } @@ -249,29 +243,13 @@ func getSandboxIPs(ifaceName string) ([]netip.Addr, error) { return ips, nil } -// trackSandboxIP passes the pod/sandbox IP to the tracker so it can ignore -// kubelet probes going to this pod. -func (r *Redirector) trackSandboxIP(ip netip.Addr) error { - if r.tracker == nil { - return nil - } - r.log.Info("tracking sandbox IP", "addr", ip.String()) - if err := r.tracker.PutPodIP(ip); err != nil { - return fmt.Errorf("putting pod IP in tracker map: %w", err) - } - return nil -} - func ignoredDir(dir string) bool { - return dir == socket.TCPEventsMap || dir == socket.PodKubeletAddrsMapv4 || dir == socket.PodKubeletAddrsMapv6 + return dir == activator.SocketTrackerMap || + dir == activator.PodKubeletAddrsMapv4 || + dir == activator.PodKubeletAddrsMapv6 } -func (sb sandbox) Remove(tracker socket.Tracker) error { +func (sb sandbox) Remove() error { errs := []error{sb.activator.Cleanup()} - if tracker != nil { - for _, ip := range sb.ips { - errs = append(errs, tracker.RemovePodIP(ip)) - } - } return errors.Join(errs...) } diff --git a/shim/checkpoint.go b/shim/checkpoint.go index 50164e8..ba6c15d 100644 --- a/shim/checkpoint.go +++ b/shim/checkpoint.go @@ -2,7 +2,6 @@ package shim import ( "context" - "errors" "fmt" "os" "path" @@ -13,41 +12,19 @@ import ( "github.com/containerd/containerd/v2/cmd/containerd-shim-runc-v2/process" runcC "github.com/containerd/go-runc" "github.com/containerd/log" - "github.com/ctrox/zeropod/activator" nodev1 "github.com/ctrox/zeropod/api/node/v1" v1 "github.com/ctrox/zeropod/api/shim/v1" ) func (c *Container) scaleDown(ctx context.Context) error { - if err := c.startActivator(ctx); err != nil { - if errors.Is(err, errNoPortsDetected) { - retryIn := c.scaleDownRetry() - log.G(ctx).Infof("no ports detected, rescheduling scale down in %s", retryIn) - return c.scheduleScaleDownIn(retryIn) - } - - if errors.Is(err, activator.ErrMapNotFound) { - retryIn := c.scaleDownRetry() - log.G(ctx).Infof("activator is not ready, rescheduling scale down in %s", retryIn) - return c.scheduleScaleDownIn(retryIn) - } - - return err + if c.ScaledDown() { + return nil } if err := c.activator.Reset(); err != nil { return err } - if err := c.tracker.RemovePid(uint32(c.process.Pid())); err != nil { - // key could not exist, just log the error for now - log.G(ctx).Warnf("unable to remove pid %d: %s", c.process.Pid(), err) - } - - if c.ScaledDown() { - return nil - } - if c.cfg.DisableCheckpointing { if err := c.kill(ctx); err != nil { return err @@ -62,22 +39,6 @@ func (c *Container) scaleDown(ctx context.Context) error { return nil } -// scaleDownRetry returns the duration in which the next scaledown should be -// retried. It backs off exponentially with an initial wait of 1 second. -func (c *Container) scaleDownRetry() time.Duration { - const initial, max = time.Second, time.Minute * 5 - c.scaleDownBackoff = c.scaleDownBackoff * 2 - if c.scaleDownBackoff >= max { - c.scaleDownBackoff = max - } - - if c.scaleDownBackoff == 0 { - c.scaleDownBackoff = initial - } - - return c.scaleDownBackoff -} - func (c *Container) kill(ctx context.Context) error { c.checkpointRestore.Lock() defer c.checkpointRestore.Unlock() diff --git a/shim/container.go b/shim/container.go index a7cc034..95e0d54 100644 --- a/shim/container.go +++ b/shim/container.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "os" + "slices" "sync" "time" @@ -18,7 +19,6 @@ import ( "github.com/containernetworking/plugins/pkg/ns" "github.com/ctrox/zeropod/activator" v1 "github.com/ctrox/zeropod/api/shim/v1" - "github.com/ctrox/zeropod/socket" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/durationpb" @@ -43,9 +43,9 @@ type Container struct { skipStart bool netNS ns.NetNS scaleDownTimer *time.Timer - scaleDownBackoff time.Duration + initTimer *time.Timer + initBackoff time.Duration platform stdio.Platform - tracker socket.Tracker preRestore func() HandleStartedFunc postRestore func(*runc.Container, HandleStartedFunc) events chan *v1.ContainerStatus @@ -111,17 +111,7 @@ func (c *Container) Register(ctx context.Context, container *runc.Container) err c.process = p c.initialProcess = p - tracker, err := socket.NewEBPFTracker() - if err != nil { - log.G(ctx).Warnf("creating ebpf tracker failed, falling back to noop tracker: %s", err) - tracker = socket.NewNoopTracker(c.cfg.ScaleDownDuration) - } - c.tracker = tracker - - if err := tracker.TrackPid(uint32(p.Pid())); err != nil { - log.G(ctx).Warnf("tracking pid failed: %s", err) - } - if err := c.initActivator(ctx); err != nil { + if err := c.initActivator(ctx, c.SkipStart()); err != nil { log.G(ctx).Warnf("activator init failed, disabling scale down: %s", err) c.cfg.ScaleDownDuration = 0 } @@ -155,8 +145,13 @@ func (c *Container) scheduleScaleDownIn(in time.Duration) error { log.G(c.context).Infof("scheduling scale down in %s", in) timer := time.AfterFunc(in, func() { - last, err := c.tracker.LastActivity(uint32(c.process.Pid())) - if errors.Is(err, socket.NoActivityRecordedErr{}) { + if !c.activator.Started() { + log.G(c.context).Infof("activator not ready, delaying scale down by %s", c.initBackoff) + c.scaleDownTimer.Reset(c.initBackoff) + return + } + last, err := c.lastActivity() + if errors.Is(err, activator.NoActivityRecordedErr{}) { log.G(c.context).Info(err) } else if err != nil { log.G(c.context).Warnf("unable to get last TCP activity from tracker: %s", err) @@ -300,15 +295,8 @@ func (c *Container) DeleteCheckpointedPID(pid int) { } func (c *Container) Stop(ctx context.Context) { + c.cancelInit() c.CancelScaleDown() - if c.tracker != nil { - if err := c.tracker.RemovePid(uint32(c.process.Pid())); err != nil { - log.G(ctx).Warnf("unable to remove pid from tracker: %s", err) - } - if err := c.tracker.Close(); err != nil { - log.G(ctx).Warnf("unable to close tracker: %s", err) - } - } status := c.Status() status.Phase = v1.ContainerPhase_STOPPING c.sendEvent(status) @@ -327,27 +315,15 @@ func (c *Container) RegisterPostRestore(f func(*runc.Container, HandleStartedFun c.postRestore = f } -var errNoPortsDetected = errors.New("no listening ports detected") - -func (c *Container) initActivator(ctx context.Context) error { - // we already have an activator - if c.activator != nil { - return nil - } - - srv, err := activator.NewServer(ctx, c.netNS) - if err != nil { - return err - } - c.activator = srv - - return nil -} +func (c *Container) initActivator(ctx context.Context, enableRedirects bool) error { + c.cancelInit() -// startActivator starts the activator -func (c *Container) startActivator(ctx context.Context) error { - if c.activator.Started() { - return nil + if c.activator == nil { + act, err := activator.NewServer(ctx, c.netNS) + if err != nil { + return err + } + c.activator = act } if len(c.cfg.Ports) == 0 { @@ -356,22 +332,67 @@ func (c *Container) startActivator(ctx context.Context) error { ports, err := listeningPortsDeep(c.initialProcess.Pid()) if err != nil || len(ports) == 0 { // our initialProcess might not even be running yet, so finding the listening - // ports might fail in various ways. We return errNoPortsDetected so the - // caller can retry later. - return errNoPortsDetected + // ports might fail in various ways. We schedule a retry. + retryIn := c.initRetry() + log.G(ctx).Infof("no ports detected, retrying init in %s", retryIn) + c.retryInitIn(retryIn, enableRedirects) + return nil } c.cfg.Ports = ports } log.G(ctx).Infof("starting activator with ports: %v", c.cfg.Ports) + if err := c.startActivator(ctx, c.cfg.Ports...); err != nil { + if errors.Is(err, activator.ErrMapNotFound) { + c.retryInitIn(c.initRetry(), enableRedirects) + return nil + } + return err + } - // create a new context in order to not run into deadline of parent context - ctx = log.WithLogger(context.Background(), log.G(ctx).WithField("runtime", RuntimeName)) + if enableRedirects { + return c.activator.Reset() + } + return nil +} + +// initRetry returns the duration in which the next init should be retried. It +// backs off exponentially with an initial wait of 100 milliseconds. +func (c *Container) initRetry() time.Duration { + const initial, max = time.Millisecond * 100, time.Minute * 5 + c.initBackoff = min(max, c.initBackoff*2) + + if c.initBackoff == 0 { + c.initBackoff = initial + } + + return c.initBackoff +} - log.G(ctx).Infof("starting activator with config: %v", c.cfg) +func (c *Container) retryInitIn(in time.Duration, enableRedirects bool) { + log.G(c.context).Infof("scheduling init in %s", in) + timer := time.AfterFunc(in, func() { + if err := c.initActivator(c.context, enableRedirects); err != nil { + log.G(c.context).Warnf("error initializing activator: %s", err) + } + }) + c.initTimer = timer +} + +func (c *Container) cancelInit() { + if c.initTimer == nil { + return + } + c.initTimer.Stop() +} - if err := c.activator.Start(ctx, c.cfg.Ports, c.detectProbe(ctx), c.restoreHandler(ctx)); err != nil { +// startActivator starts the activator +func (c *Container) startActivator(ctx context.Context, ports ...uint16) error { + if c.activator.Started() { + return nil + } + if err := c.activator.Start(c.context, c.detectProbe(c.context), c.restoreHandler(c.context), ports...); err != nil { if errors.Is(err, activator.ErrMapNotFound) { return err } @@ -379,7 +400,6 @@ func (c *Container) startActivator(ctx context.Context) error { log.G(ctx).Errorf("failed to start activator: %s", err) return err } - log.G(ctx).Printf("activator started") return nil } @@ -388,7 +408,7 @@ func (c *Container) restoreHandler(ctx context.Context) activator.RestoreHook { return func() error { log.G(ctx).Printf("got a request") - restoredContainer, p, err := c.Restore(ctx) + restoredContainer, _, err := c.Restore(ctx) if err != nil { if errors.Is(err, ErrAlreadyRestored) { log.G(ctx).Info("container is already restored, ignoring request") @@ -401,14 +421,31 @@ func (c *Container) restoreHandler(ctx context.Context) activator.RestoreHook { } c.Container = restoredContainer - if err := c.tracker.TrackPid(uint32(p.Pid())); err != nil { - log.G(ctx).Warnf("unable to track pid %d: %s", p.Pid(), err) - } - return c.ScheduleScaleDown() } } +// lastActivity returns a [time.Time] of the last recorded network activity on +// any port of the container. +func (c *Container) lastActivity() (time.Time, error) { + if c.activator == nil { + return time.Time{}, activator.NoActivityRecordedErr{} + } + act := []time.Time{} + for _, port := range c.cfg.Ports { + last, err := c.activator.LastActivity(port) + if err != nil { + return time.Time{}, err + } + act = append(act, last) + } + if len(act) == 0 { + return time.Time{}, activator.NoActivityRecordedErr{} + } + slices.SortFunc(act, func(a, b time.Time) int { return a.Compare(b) }) + return act[len(act)-1], nil +} + func (c *Container) GetMetrics() *v1.ContainerMetrics { m := proto.Clone(c.metrics) c.clearMetrics() diff --git a/socket/bpf_arm64_bpfel.go b/socket/bpf_arm64_bpfel.go deleted file mode 100644 index 4e1a282..0000000 --- a/socket/bpf_arm64_bpfel.go +++ /dev/null @@ -1,150 +0,0 @@ -// Code generated by bpf2go; DO NOT EDIT. -//go:build arm64 - -package socket - -import ( - "bytes" - _ "embed" - "fmt" - "io" - "structs" - - "github.com/cilium/ebpf" -) - -type bpfIpv6Addr struct { - _ structs.HostLayout - U6Addr8 [16]uint8 -} - -// loadBpf returns the embedded CollectionSpec for bpf. -func loadBpf() (*ebpf.CollectionSpec, error) { - reader := bytes.NewReader(_BpfBytes) - spec, err := ebpf.LoadCollectionSpecFromReader(reader) - if err != nil { - return nil, fmt.Errorf("can't load bpf: %w", err) - } - - return spec, err -} - -// loadBpfObjects loads bpf and converts it into a struct. -// -// The following types are suitable as obj argument: -// -// *bpfObjects -// *bpfPrograms -// *bpfMaps -// -// See ebpf.CollectionSpec.LoadAndAssign documentation for details. -func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error { - spec, err := loadBpf() - if err != nil { - return err - } - - return spec.LoadAndAssign(obj, opts) -} - -// bpfSpecs contains maps and programs before they are loaded into the kernel. -// -// It can be passed ebpf.CollectionSpec.Assign. -type bpfSpecs struct { - bpfProgramSpecs - bpfMapSpecs - bpfVariableSpecs -} - -// bpfProgramSpecs contains programs before they are loaded into the kernel. -// -// It can be passed ebpf.CollectionSpec.Assign. -type bpfProgramSpecs struct { - KretprobeInetCskAccept *ebpf.ProgramSpec `ebpf:"kretprobe__inet_csk_accept"` - TcpRcvStateProcess *ebpf.ProgramSpec `ebpf:"tcp_rcv_state_process"` -} - -// bpfMapSpecs contains maps before they are loaded into the kernel. -// -// It can be passed ebpf.CollectionSpec.Assign. -type bpfMapSpecs struct { - PodKubeletAddrsV4 *ebpf.MapSpec `ebpf:"pod_kubelet_addrs_v4"` - PodKubeletAddrsV6 *ebpf.MapSpec `ebpf:"pod_kubelet_addrs_v6"` - TcpEvents *ebpf.MapSpec `ebpf:"tcp_events"` -} - -// bpfVariableSpecs contains global variables before they are loaded into the kernel. -// -// It can be passed ebpf.CollectionSpec.Assign. -type bpfVariableSpecs struct { - ProbeBinaryName *ebpf.VariableSpec `ebpf:"probe_binary_name"` -} - -// bpfObjects contains all objects after they have been loaded into the kernel. -// -// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign. -type bpfObjects struct { - bpfPrograms - bpfMaps - bpfVariables -} - -func (o *bpfObjects) Close() error { - return _BpfClose( - &o.bpfPrograms, - &o.bpfMaps, - ) -} - -// bpfMaps contains all maps after they have been loaded into the kernel. -// -// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign. -type bpfMaps struct { - PodKubeletAddrsV4 *ebpf.Map `ebpf:"pod_kubelet_addrs_v4"` - PodKubeletAddrsV6 *ebpf.Map `ebpf:"pod_kubelet_addrs_v6"` - TcpEvents *ebpf.Map `ebpf:"tcp_events"` -} - -func (m *bpfMaps) Close() error { - return _BpfClose( - m.PodKubeletAddrsV4, - m.PodKubeletAddrsV6, - m.TcpEvents, - ) -} - -// bpfVariables contains all global variables after they have been loaded into the kernel. -// -// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign. -type bpfVariables struct { - ProbeBinaryName *ebpf.Variable `ebpf:"probe_binary_name"` -} - -// bpfPrograms contains all programs after they have been loaded into the kernel. -// -// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign. -type bpfPrograms struct { - KretprobeInetCskAccept *ebpf.Program `ebpf:"kretprobe__inet_csk_accept"` - TcpRcvStateProcess *ebpf.Program `ebpf:"tcp_rcv_state_process"` -} - -func (p *bpfPrograms) Close() error { - return _BpfClose( - p.KretprobeInetCskAccept, - p.TcpRcvStateProcess, - ) -} - -func _BpfClose(closers ...io.Closer) error { - for _, closer := range closers { - if err := closer.Close(); err != nil { - return err - } - } - return nil -} - -// Do not access this directly. -// -//go:embed bpf_arm64_bpfel.o -var _BpfBytes []byte diff --git a/socket/bpf_arm64_bpfel.o b/socket/bpf_arm64_bpfel.o deleted file mode 100644 index 363e07a..0000000 Binary files a/socket/bpf_arm64_bpfel.o and /dev/null differ diff --git a/socket/bpf_x86_bpfel.go b/socket/bpf_x86_bpfel.go deleted file mode 100644 index 84171b0..0000000 --- a/socket/bpf_x86_bpfel.go +++ /dev/null @@ -1,150 +0,0 @@ -// Code generated by bpf2go; DO NOT EDIT. -//go:build 386 || amd64 - -package socket - -import ( - "bytes" - _ "embed" - "fmt" - "io" - "structs" - - "github.com/cilium/ebpf" -) - -type bpfIpv6Addr struct { - _ structs.HostLayout - U6Addr8 [16]uint8 -} - -// loadBpf returns the embedded CollectionSpec for bpf. -func loadBpf() (*ebpf.CollectionSpec, error) { - reader := bytes.NewReader(_BpfBytes) - spec, err := ebpf.LoadCollectionSpecFromReader(reader) - if err != nil { - return nil, fmt.Errorf("can't load bpf: %w", err) - } - - return spec, err -} - -// loadBpfObjects loads bpf and converts it into a struct. -// -// The following types are suitable as obj argument: -// -// *bpfObjects -// *bpfPrograms -// *bpfMaps -// -// See ebpf.CollectionSpec.LoadAndAssign documentation for details. -func loadBpfObjects(obj interface{}, opts *ebpf.CollectionOptions) error { - spec, err := loadBpf() - if err != nil { - return err - } - - return spec.LoadAndAssign(obj, opts) -} - -// bpfSpecs contains maps and programs before they are loaded into the kernel. -// -// It can be passed ebpf.CollectionSpec.Assign. -type bpfSpecs struct { - bpfProgramSpecs - bpfMapSpecs - bpfVariableSpecs -} - -// bpfProgramSpecs contains programs before they are loaded into the kernel. -// -// It can be passed ebpf.CollectionSpec.Assign. -type bpfProgramSpecs struct { - KretprobeInetCskAccept *ebpf.ProgramSpec `ebpf:"kretprobe__inet_csk_accept"` - TcpRcvStateProcess *ebpf.ProgramSpec `ebpf:"tcp_rcv_state_process"` -} - -// bpfMapSpecs contains maps before they are loaded into the kernel. -// -// It can be passed ebpf.CollectionSpec.Assign. -type bpfMapSpecs struct { - PodKubeletAddrsV4 *ebpf.MapSpec `ebpf:"pod_kubelet_addrs_v4"` - PodKubeletAddrsV6 *ebpf.MapSpec `ebpf:"pod_kubelet_addrs_v6"` - TcpEvents *ebpf.MapSpec `ebpf:"tcp_events"` -} - -// bpfVariableSpecs contains global variables before they are loaded into the kernel. -// -// It can be passed ebpf.CollectionSpec.Assign. -type bpfVariableSpecs struct { - ProbeBinaryName *ebpf.VariableSpec `ebpf:"probe_binary_name"` -} - -// bpfObjects contains all objects after they have been loaded into the kernel. -// -// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign. -type bpfObjects struct { - bpfPrograms - bpfMaps - bpfVariables -} - -func (o *bpfObjects) Close() error { - return _BpfClose( - &o.bpfPrograms, - &o.bpfMaps, - ) -} - -// bpfMaps contains all maps after they have been loaded into the kernel. -// -// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign. -type bpfMaps struct { - PodKubeletAddrsV4 *ebpf.Map `ebpf:"pod_kubelet_addrs_v4"` - PodKubeletAddrsV6 *ebpf.Map `ebpf:"pod_kubelet_addrs_v6"` - TcpEvents *ebpf.Map `ebpf:"tcp_events"` -} - -func (m *bpfMaps) Close() error { - return _BpfClose( - m.PodKubeletAddrsV4, - m.PodKubeletAddrsV6, - m.TcpEvents, - ) -} - -// bpfVariables contains all global variables after they have been loaded into the kernel. -// -// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign. -type bpfVariables struct { - ProbeBinaryName *ebpf.Variable `ebpf:"probe_binary_name"` -} - -// bpfPrograms contains all programs after they have been loaded into the kernel. -// -// It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign. -type bpfPrograms struct { - KretprobeInetCskAccept *ebpf.Program `ebpf:"kretprobe__inet_csk_accept"` - TcpRcvStateProcess *ebpf.Program `ebpf:"tcp_rcv_state_process"` -} - -func (p *bpfPrograms) Close() error { - return _BpfClose( - p.KretprobeInetCskAccept, - p.TcpRcvStateProcess, - ) -} - -func _BpfClose(closers ...io.Closer) error { - for _, closer := range closers { - if err := closer.Close(); err != nil { - return err - } - } - return nil -} - -// Do not access this directly. -// -//go:embed bpf_x86_bpfel.o -var _BpfBytes []byte diff --git a/socket/bpf_x86_bpfel.o b/socket/bpf_x86_bpfel.o deleted file mode 100644 index 01e6ed6..0000000 Binary files a/socket/bpf_x86_bpfel.o and /dev/null differ diff --git a/socket/ebpf.go b/socket/ebpf.go deleted file mode 100644 index 4ed3747..0000000 --- a/socket/ebpf.go +++ /dev/null @@ -1,281 +0,0 @@ -package socket - -import ( - "encoding/binary" - "errors" - "fmt" - "net/netip" - "os" - "path/filepath" - "time" - - "github.com/cilium/ebpf" - "github.com/cilium/ebpf/link" - "github.com/cilium/ebpf/rlimit" - "github.com/ctrox/zeropod/activator" - "golang.org/x/sys/unix" -) - -// $BPF_CLANG and $BPF_CFLAGS are set by the Makefile. -//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc $BPF_CLANG -target amd64,arm64 -cflags $BPF_CFLAGS bpf kprobe.c -- -I/headers - -const ( - TCPEventsMap = "tcp_events" - PodKubeletAddrsMapv4 = "pod_kubelet_addrs_v4" - PodKubeletAddrsMapv6 = "pod_kubelet_addrs_v6" -) - -var mapNames = []string{ - TCPEventsMap, - PodKubeletAddrsMapv4, - PodKubeletAddrsMapv6, -} - -const ( - probeBinaryNameVariable = "probe_binary_name" - probeBinaryNameMaxLength = 16 -) - -// LoadEBPFTracker loads the eBPF program and attaches the kretprobe to track -// connections system-wide. -func LoadEBPFTracker(probeBinaryName string) (Tracker, func() error, error) { - // Allow the current process to lock memory for eBPF resources. - if err := rlimit.RemoveMemlock(); err != nil { - return nil, nil, err - } - - pinPath := activator.MapsPath() - if err := os.MkdirAll(pinPath, os.ModePerm); err != nil { - return nil, nil, fmt.Errorf("failed to create bpf fs subpath: %w", err) - } - - // Load pre-compiled programs and maps into the kernel. - objs := bpfObjects{} - collectionOpts := &ebpf.CollectionOptions{ - Maps: ebpf.MapOptions{ - // Pin the map to the BPF filesystem and configure the - // library to automatically re-write it in the BPF - // program so it can be re-used if it already exists or - // create it if not. - PinPath: pinPath, - }, - } - - spec, err := loadBpf() - if err != nil { - return nil, nil, fmt.Errorf("loading bpf objects: %w", err) - } - - if len([]byte(probeBinaryName)) > probeBinaryNameMaxLength { - return nil, nil, fmt.Errorf( - "probe binary name %s is too long (%d bytes), max is %d bytes", - probeBinaryName, len([]byte(probeBinaryName)), probeBinaryNameMaxLength, - ) - } - binName := [probeBinaryNameMaxLength]byte{} - copy(binName[:], probeBinaryName[:]) - if err := spec.Variables[probeBinaryNameVariable].Set(binName); err != nil { - return nil, nil, fmt.Errorf("setting probe binary variable: %w", err) - } - - if err := spec.LoadAndAssign(&objs, collectionOpts); err != nil { - if !errors.Is(err, ebpf.ErrMapIncompatible) { - return nil, nil, fmt.Errorf("loading objects: %w", err) - } - // try to unpin the maps and load again - for _, mapName := range mapNames { - if err := os.Remove(filepath.Join(pinPath, mapName)); err != nil && !os.IsNotExist(err) { - return nil, nil, fmt.Errorf("removing map after incompatibility: %w", err) - } - } - if err := loadBpfObjects(&objs, collectionOpts); err != nil { - return nil, nil, fmt.Errorf("loading objects: %w", err) - } - } - - // in the past we used inet_sock_set_state here but we now use a - // kretprobe with inet_csk_accept as inet_sock_set_state is not giving us - // reliable PIDs. https://github.com/iovisor/bcc/issues/2304 - tracker, err := link.Kretprobe("inet_csk_accept", objs.KretprobeInetCskAccept, &link.KprobeOptions{}) - if err != nil { - return nil, nil, fmt.Errorf("linking kprobe: %w", err) - } - kubeletDetector, err := link.Kprobe("tcp_rcv_state_process", objs.TcpRcvStateProcess, &link.KprobeOptions{}) - if err != nil { - return nil, nil, fmt.Errorf("linking tcp rcv kprobe: %w", err) - } - - t, err := NewEBPFTracker() - return t, func() error { - errs := []error{} - if err := objs.Close(); err != nil { - errs = append(errs, err) - } - if err := kubeletDetector.Close(); err != nil { - errs = append(errs, err) - } - return errors.Join(append(errs, tracker.Close())...) - }, err -} - -// NewEBPFTracker returns a TCP connection tracker that will keep track of the -// last TCP accept of specific processes. It writes the results to an ebpf map -// keyed with the PID and the value contains the timestamp of the last -// observed accept. -func NewEBPFTracker() (Tracker, error) { - var resolver PIDResolver - resolver = noopResolver{} - // if hostProcPath exists, we're probably running in a test container. We - // will use the hostResolver instead of using the actual pids. - if _, err := os.Stat(hostProcPath); err == nil { - resolver = hostResolver{} - } - - podKubeletAddrsv4, err := ebpf.LoadPinnedMap(filepath.Join(activator.MapsPath(), PodKubeletAddrsMapv4), &ebpf.LoadPinOptions{}) - if err != nil { - return nil, err - } - podKubeletAddrsv6, err := ebpf.LoadPinnedMap(filepath.Join(activator.MapsPath(), PodKubeletAddrsMapv6), &ebpf.LoadPinOptions{}) - if err != nil { - return nil, err - } - tcpEvents, err := ebpf.LoadPinnedMap(filepath.Join(activator.MapsPath(), TCPEventsMap), &ebpf.LoadPinOptions{}) - return &EBPFTracker{ - PIDResolver: resolver, - tcpEvents: tcpEvents, - podKubeletAddrsv4: podKubeletAddrsv4, - podKubeletAddrsv6: podKubeletAddrsv6, - }, err -} - -// PIDResolver allows to customize how the PIDs of the connection tracker are -// resolved. This can be useful if the shim is already running in a container -// (e.g. when using Kind), so it can resolve the PID of the container to the -// ones of the host that ebpf sees. -type PIDResolver interface { - Resolve(pid uint32) uint32 -} - -// noopResolver does not resolve anything and just returns the actual pid. -type noopResolver struct{} - -func (p noopResolver) Resolve(pid uint32) uint32 { - return pid -} - -type NoActivityRecordedErr struct{} - -func (err NoActivityRecordedErr) Error() string { - return "no activity recorded" -} - -type EBPFTracker struct { - PIDResolver - tcpEvents *ebpf.Map - podKubeletAddrsv4 *ebpf.Map - podKubeletAddrsv6 *ebpf.Map -} - -// TrackPid puts the pid into the TcpEvents map meaning tcp events of the -// process belonging to that pid will be tracked. -func (c *EBPFTracker) TrackPid(pid uint32) error { - val := uint64(0) - pid = c.PIDResolver.Resolve(pid) - if err := c.tcpEvents.Put(&pid, &val); err != nil { - return fmt.Errorf("unable to put pid %d into bpf map: %w", pid, err) - } - - return nil -} - -// RemovePid removes the pid from the TcpEvents map. -func (c *EBPFTracker) RemovePid(pid uint32) error { - pid = c.PIDResolver.Resolve(pid) - return c.tcpEvents.Delete(&pid) -} - -// LastActivity returns a time.Time of the last tcp activity recorded of the -// process belonging to the pid (or a child-process of the pid). -func (c *EBPFTracker) LastActivity(pid uint32) (time.Time, error) { - var val uint64 - - pid = c.PIDResolver.Resolve(pid) - if err := c.tcpEvents.Lookup(&pid, &val); err != nil { - return time.Time{}, fmt.Errorf("looking up %d: %w", pid, err) - } - - if val == 0 { - return time.Time{}, NoActivityRecordedErr{} - } - - return convertBPFTime(val) -} - -func (c *EBPFTracker) Close() error { - return c.tcpEvents.Close() -} - -// RemovePodIPv4 adds the pod IP to the tracker unless it already exists. -func (c *EBPFTracker) PutPodIP(ip netip.Addr) error { - if ip.Is4() { - val := uint32(0) - ipv4 := ip.As4() - uIP := binary.NativeEndian.Uint32(ipv4[:]) - if err := c.podKubeletAddrsv4.Update(&uIP, &val, ebpf.UpdateNoExist); err != nil && - !errors.Is(err, ebpf.ErrKeyExist) { - return fmt.Errorf("unable to put ipv4 %s into bpf map: %w", ip, err) - } - return nil - } - - val := bpfIpv6Addr{} - bpfIP := bpfIpv6Addr{U6Addr8: ip.As16()} - if err := c.podKubeletAddrsv6.Update(&bpfIP, &val, ebpf.UpdateNoExist); err != nil && - !errors.Is(err, ebpf.ErrKeyExist) { - return fmt.Errorf("unable to put ipv6 %s into bpf map: %w", ip, err) - } - - return nil -} - -// RemovePodIPv4 removes the pod IP from the tracker. -func (c *EBPFTracker) RemovePodIP(ip netip.Addr) error { - if ip.Is4() { - ipv4 := ip.As4() - uIP := binary.NativeEndian.Uint32(ipv4[:]) - if err := c.podKubeletAddrsv4.Delete(&uIP); err != nil && !errors.Is(err, ebpf.ErrKeyNotExist) { - return fmt.Errorf("unable to delete ipv4 %s from bpf map: %w", ip, err) - } - return nil - } - - bpfIP := bpfIpv6Addr{U6Addr8: ip.As16()} - if err := c.podKubeletAddrsv6.Delete(&bpfIP); err != nil && !errors.Is(err, ebpf.ErrKeyNotExist) { - return fmt.Errorf("unable to delete ipv6 %s from bpf map: %w", ip, err) - } - return nil -} - -// convertBPFTime takes the value of bpf_ktime_get_ns and converts it to a -// time.Time. -func convertBPFTime(t uint64) (time.Time, error) { - b, err := getBootTimeNS() - if err != nil { - return time.Time{}, err - } - - return time.Now().Add(-time.Duration(b - int64(t))), nil -} - -// getKtimeNS returns the time elapsed since system boot, in nanoseconds. Does -// not include time the system was suspended. Basically the equivalent of -// bpf_ktime_get_ns. -func getBootTimeNS() (int64, error) { - var ts unix.Timespec - err := unix.ClockGettime(unix.CLOCK_MONOTONIC, &ts) - if err != nil { - return 0, fmt.Errorf("could not get time: %s", err) - } - - return unix.TimespecToNsec(ts), nil -} diff --git a/socket/host_resolver.go b/socket/host_resolver.go deleted file mode 100644 index 6c687e2..0000000 --- a/socket/host_resolver.go +++ /dev/null @@ -1,40 +0,0 @@ -package socket - -import ( - "fmt" - "os/exec" - "strconv" - "strings" -) - -const hostProcPath = "/host/proc/" - -// hostResolver uses the procfs of the host to resolve PIDs. With this the -// connection tracker can work when running in a container. As the ebpf -// program is not aware of the PID namespace that the processes running in, we -// need to find the PIDs of the host processes from the ones in the container. -type hostResolver struct{} - -func (h hostResolver) Resolve(pid uint32) uint32 { - p, err := findHostPid(hostProcPath, pid) - if err != nil { - return pid - } - - return p -} - -// findHostPid greps through the procfs to find the host pid of the supplied -// namespaced pid. It's very ugly but it works well enough for testing with -// Kind. It would be better to use the procfs package here but NSpid is always -// empty. -func findHostPid(procPath string, nsPid uint32) (uint32, error) { - out, err := exec.Command("bash", "-c", fmt.Sprintf(`grep -P 'NSpid:.*\t%d\t' -ril %s*/status | head -n 1`, nsPid, procPath)).Output() - if err != nil { - return 0, err - } - - strPid := strings.TrimSuffix(strings.TrimPrefix(string(out), procPath), "/status\n") - pid, err := strconv.ParseUint(strPid, 10, 32) - return uint32(pid), err -} diff --git a/socket/kprobe.c b/socket/kprobe.c deleted file mode 100644 index 52c1231..0000000 --- a/socket/kprobe.c +++ /dev/null @@ -1,167 +0,0 @@ -//go:build ignore - -#include "vmlinux.h" -#include "ptregs.h" -#include "bpf_helpers.h" -#include "bpf_tracing.h" -#include "bpf_core_read.h" -#include "bpf_endian.h" - -char __license[] SEC("license") = "Dual MIT/GPL"; - -#define AF_INET 2 - -struct { - __uint(type, BPF_MAP_TYPE_LRU_HASH); - __uint(max_entries, 1024); // should be enough pids? - __type(key, __u32); // pid - __type(value, __u64); // ktime ns of the last tracked event - __uint(pinning, LIBBPF_PIN_BY_NAME); -} tcp_events SEC(".maps"); - -struct { - __uint(type, BPF_MAP_TYPE_LRU_HASH); - __uint(max_entries, 1024); - __type(key, __be32); // pod addr - __type(value, __be32); // kubelet addr - __uint(pinning, LIBBPF_PIN_BY_NAME); -} pod_kubelet_addrs_v4 SEC(".maps"); - -struct ipv6_addr { - __u8 u6_addr8[16]; -}; - -struct { - __uint(type, BPF_MAP_TYPE_LRU_HASH); - __uint(max_entries, 1024); - __type(key, struct ipv6_addr); // pod addr - __type(value, struct ipv6_addr); // kubelet addr - __uint(pinning, LIBBPF_PIN_BY_NAME); -} pod_kubelet_addrs_v6 SEC(".maps"); - -const volatile char probe_binary_name[TASK_COMM_LEN] = ""; - -static __always_inline bool ipv6_addr_equal(const struct ipv6_addr *a1, const struct ipv6_addr *a2) { - #pragma unroll - for (int i = 0; i < 15; i++) { - if (a1->u6_addr8[i] != a2->u6_addr8[i]) { - return false; - } - } - return true; -} - -static __always_inline bool is_v4_mapped_v6(__u8 addr[16]) { - return (addr[0] == 0 && addr[1] == 0 && - addr[2] == 0 && addr[3] == 0 && - addr[4] == 0 && addr[5] == 0 && - addr[6] == 0 && addr[7] == 0 && - addr[8] == 0 && addr[9] == 0 && - addr[10] == 0xff && addr[11] == 0xff); -} - -static __always_inline __be32 extract_v4_from_v6(__u8 addr[16]) { - __be32 ipv4; - ipv4 = addr[15]; - ipv4 = (ipv4 << 8) + addr[14]; - ipv4 = (ipv4 << 8) + addr[13]; - ipv4 = (ipv4 << 8) + addr[12]; - return ipv4; -} - -SEC("kretprobe/inet_csk_accept") -int kretprobe__inet_csk_accept(struct pt_regs *ctx) { - struct task_struct* task = (struct task_struct*)bpf_get_current_task_btf(); - // we use the tgid as our pid as it represents the pid from userspace - __u32 pid = task->tgid; - - void *tcp_event = &tcp_events; - void *found_pid = bpf_map_lookup_elem(tcp_event, &pid); - - if (!found_pid) { - // try ppid, our process might have forks - pid = task->real_parent->tgid; - - void *found_ppid = bpf_map_lookup_elem(tcp_event, &pid); - if (!found_ppid) { - return 0; - } - } - - struct sock *sk = (struct sock *)PT_REGS_RC(ctx); - if (sk == NULL) { - return 0; - } - - if (BPF_CORE_READ(sk, __sk_common.skc_family) == AF_INET) { - __be32 pod_addr = 0; - __be32 daddr = 0; - BPF_CORE_READ_INTO(&pod_addr, sk, __sk_common.skc_rcv_saddr); - BPF_CORE_READ_INTO(&daddr, sk, __sk_common.skc_daddr); - void *addrs = &pod_kubelet_addrs_v4; - __be32 *kubelet_addr = bpf_map_lookup_elem(addrs, &pod_addr); - if (kubelet_addr && *kubelet_addr == daddr) { - return 0; - } - } else { - struct ipv6_addr pod_addr; - struct ipv6_addr daddr; - BPF_CORE_READ_INTO(&pod_addr, sk, __sk_common.skc_v6_rcv_saddr.in6_u); - BPF_CORE_READ_INTO(&daddr, sk, __sk_common.skc_v6_daddr.in6_u); - if (is_v4_mapped_v6(pod_addr.u6_addr8)) { - void *addrs = &pod_kubelet_addrs_v4; - __be32 pod_addr_v4 = extract_v4_from_v6(pod_addr.u6_addr8); - __be32 daddr_v4 = extract_v4_from_v6(daddr.u6_addr8); - __be32 *kubelet_addr = bpf_map_lookup_elem(addrs, &pod_addr_v4); - if (kubelet_addr && *kubelet_addr == daddr_v4) { - return 0; - } - } - void *addrs = &pod_kubelet_addrs_v6; - struct ipv6_addr *kubelet_addr = bpf_map_lookup_elem(addrs, &pod_addr); - if (kubelet_addr && ipv6_addr_equal(kubelet_addr, &daddr)) { - return 0; - } - } - - __u64 time = bpf_ktime_get_ns(); - return bpf_map_update_elem(tcp_event, &pid, &time, BPF_ANY); -}; - -static int find_potential_kubelet_ip(struct sock *sk) { - char comm[TASK_COMM_LEN]; - bpf_get_current_comm(&comm, sizeof(comm)); - if (bpf_strncmp(comm, TASK_COMM_LEN, (char *)probe_binary_name) == 0) { - if (BPF_CORE_READ(sk, __sk_common.skc_family) == AF_INET) { - __be32 saddr = 0; - __be32 daddr = 0; - BPF_CORE_READ_INTO(&saddr, sk, __sk_common.skc_rcv_saddr); - BPF_CORE_READ_INTO(&daddr, sk, __sk_common.skc_daddr); - void *addrs = &pod_kubelet_addrs_v4; - __be32 *kubelet_addr = bpf_map_lookup_elem(addrs, &daddr); - if (kubelet_addr) { - bpf_map_update_elem(addrs, &daddr, &saddr, 0); - } - } else { - struct ipv6_addr saddr; - struct ipv6_addr daddr; - BPF_CORE_READ_INTO(&saddr, sk, __sk_common.skc_v6_rcv_saddr.in6_u); - BPF_CORE_READ_INTO(&daddr, sk, __sk_common.skc_v6_daddr.in6_u); - void *addrs = &pod_kubelet_addrs_v6; - struct ipv6_addr *kubelet_addr = bpf_map_lookup_elem(addrs, &daddr); - if (kubelet_addr) { - bpf_map_update_elem(addrs, &daddr, &saddr, 0); - } - } - } - return 0; -} - -SEC("kprobe/tcp_rcv_state_process") -int BPF_KPROBE(tcp_rcv_state_process, struct sock *sk) { - if (BPF_CORE_READ(sk, __sk_common.skc_state) != TCP_SYN_SENT) { - return 0; - } - - return find_potential_kubelet_ip(sk); -} diff --git a/socket/noop.go b/socket/noop.go deleted file mode 100644 index de62a4c..0000000 --- a/socket/noop.go +++ /dev/null @@ -1,42 +0,0 @@ -package socket - -import ( - "net/netip" - "time" -) - -func NewNoopTracker(scaleDownDuration time.Duration) NoopTracker { - return NoopTracker{ - PIDResolver: noopResolver{}, - scaleDownDuration: scaleDownDuration, - } -} - -type NoopTracker struct { - PIDResolver - scaleDownDuration time.Duration -} - -func (n NoopTracker) TrackPid(pid uint32) error { - return nil -} - -func (n NoopTracker) RemovePid(pid uint32) error { - return nil -} - -func (n NoopTracker) LastActivity(pid uint32) (time.Time, error) { - return time.Now().Add(-n.scaleDownDuration), nil -} - -func (n NoopTracker) Close() error { - return nil -} - -func (n NoopTracker) PutPodIP(ip netip.Addr) error { - return nil -} - -func (n NoopTracker) RemovePodIP(ip netip.Addr) error { - return nil -} diff --git a/socket/ptregs.h b/socket/ptregs.h deleted file mode 100644 index 69ac655..0000000 --- a/socket/ptregs.h +++ /dev/null @@ -1,8 +0,0 @@ -#if defined(__TARGET_ARCH_arm64) -struct user_pt_regs { - __u64 regs[31]; - __u64 sp; - __u64 pc; - __u64 pstate; -}; -#endif diff --git a/socket/tracker.go b/socket/tracker.go deleted file mode 100644 index a4b2351..0000000 --- a/socket/tracker.go +++ /dev/null @@ -1,24 +0,0 @@ -package socket - -import ( - "net/netip" - "time" -) - -type Tracker interface { - PIDResolver - - // TrackPid starts connection tracking of the specified process. - TrackPid(pid uint32) error - // TrackPid stops connection tracking of the specified process. - RemovePid(pid uint32) error - // LastActivity returns the time of the last TCP activity of the specified process. - LastActivity(pid uint32) (time.Time, error) - // Close the activity tracker. - Close() error - // PutPodIP inserts a pod IP into the pod-to-kubelet map, helping with - // ignoring probes coming from kubelet within the tracker. - PutPodIP(ip netip.Addr) error - // RemovePodIP removes a pod IP from the pod-to-kubelet map. - RemovePodIP(ip netip.Addr) error -} diff --git a/socket/tracker_test.go b/socket/tracker_test.go deleted file mode 100644 index 181abec..0000000 --- a/socket/tracker_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package socket - -import ( - "fmt" - "net/http" - "net/http/httptest" - "net/netip" - "os" - "path/filepath" - "testing" - "time" - - "github.com/ctrox/zeropod/activator" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// TestEBPFTracker tests the ebpf tcp tracker by getting our own pid, starting -// an HTTP server and doing a request against it. This test requires elevated -// privileges to run. -func TestEBPFTracker(t *testing.T) { - require.NoError(t, activator.MountBPFFS(activator.BPFFSPath)) - - name, err := os.Executable() - require.NoError(t, err) - tracker, clean, err := LoadEBPFTracker(filepath.Base(name)) - require.NoError(t, err) - defer func() { require.NoError(t, clean()) }() - - pid := uint32(os.Getpid()) - require.NoError(t, tracker.TrackPid(pid)) - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, "ok") - })) - - for name, tc := range map[string]struct { - ip netip.Addr - expectLastActivity bool - }{ - "activity tracked": { - ip: netip.MustParseAddr("10.0.0.1"), - expectLastActivity: true, - }, - "activity ignored": { - // use 127.0.0.1 as that's where our test program connects from - ip: netip.MustParseAddr("127.0.0.1"), - expectLastActivity: false, - }, - } { - t.Run(name, func(t *testing.T) { - require.NoError(t, tracker.PutPodIP(tc.ip)) - defer func() { assert.NoError(t, tracker.RemovePodIP(tc.ip)) }() - - require.Eventually(t, func() bool { - _, err = http.Get(ts.URL) - return err == nil - }, time.Millisecond*100, time.Millisecond, "waiting for http server to reply") - - require.Eventually(t, func() bool { - activity, err := tracker.LastActivity(pid) - if err != nil { - return !tc.expectLastActivity - } - - if time.Since(activity) > time.Millisecond*100 { - if tc.expectLastActivity { - t.Fatalf("last activity was %s ago, expected it to be within the last 100ms", time.Since(activity)) - } - } - - return true - }, time.Millisecond*100, time.Millisecond, "waiting for last tcp activity") - time.Sleep(time.Millisecond * 200) - }) - } -}