diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2de50df..da9e044 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,7 +4,7 @@ on: [push] jobs: staticcheck: - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 @@ -19,7 +19,10 @@ jobs: version: "2024.1" test: - runs-on: ubuntu-latest + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-24.04, ubuntu-24.04-arm] steps: - uses: actions/checkout@v4 @@ -66,7 +69,10 @@ jobs: run: git diff --exit-code e2e: - runs-on: ubuntu-latest + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-24.04, ubuntu-24.04-arm] steps: - uses: actions/checkout@v4 diff --git a/Makefile b/Makefile index e8c81be..b1f3a3f 100644 --- a/Makefile +++ b/Makefile @@ -60,13 +60,13 @@ push-dev: build-installer build-manager docker push $(MANAGER_IMAGE) test-e2e: - go test -v ./e2e/ + go test -timeout=30m -v ./e2e/ $(testargs) bench: go test -bench=. -benchtime=10x -v -run=Bench ./e2e/ test: - go test -v -short ./... + go test -v -short ./... $(testargs) # docker-e2e runs the e2e test in a docker container. However, as running the # e2e test requires a docker socket (for kind), this mounts the docker socket @@ -81,7 +81,7 @@ docker-bench: build-test # has to have SYS_ADMIN because the test tries to set netns and mount bpffs # we use --pid=host to make the ebpf tracker work without a pid resolver docker-test: - docker run --rm --cap-add=SYS_ADMIN --cap-add=NET_ADMIN --pid=host --userns=host -v $(PWD):/app $(TEST_IMAGE) make test + docker run --rm --cap-add=SYS_ADMIN --cap-add=NET_ADMIN --pid=host --userns=host -v $(PWD):/app $(TEST_IMAGE) go test -v -short ./... $(testargs) CLANG ?= clang CFLAGS := -O2 -g -Wall -Werror @@ -91,9 +91,11 @@ CFLAGS := -O2 -g -Wall -Werror # dependencies installed. generate: export BPF_CLANG := $(CLANG) generate: export BPF_CFLAGS := $(CFLAGS) -generate: ttrpc +generate: ttrpc ebpf go generate ./api/... - docker run --rm -v $(PWD):/app:Z --user $(shell id -u):$(shell id -g) --env=BPF_CLANG="$(CLANG)" --env=BPF_CFLAGS="$(CFLAGS)" $(EBPF_IMAGE) + +ebpf: + docker run --rm -v $(PWD):/app:Z --user $(shell id -u):$(shell id -g) --userns=host --env=BPF_CLANG="$(CLANG)" --env=BPF_CFLAGS="$(CFLAGS)" $(EBPF_IMAGE) ttrpc: go mod download diff --git a/README.md b/README.md index 6b5f3aa..39357a8 100644 --- a/README.md +++ b/README.md @@ -204,8 +204,7 @@ kubectl delete -k https://github.com/ctrox/zeropod/config/production ## Configuration A pod can make use of zeropod only if the `runtimeClassName` is set to -`zeropod`. Apart from that there are two annotations that are currently -required. See this minimal example of a pod: +`zeropod`. See this minimal example of a pod: ```yaml apiVersion: v1 @@ -223,8 +222,41 @@ spec: - containerPort: 80 ``` -Then there are also a few optional annotations that can be set on the pod to -tweak the behaviour of zeropod. +### Probes + +Zeropod is able to intercept liveness probes while the container process is +scaled down to ensure the application is not restored for probes. This just +works for HTTP and TCP probes, GRPC and exec probes will wake the container up. + +```yaml +apiVersion: v1 +kind: Pod +metadata: + name: nginx + annotations: + zeropod.ctrox.dev/scaledown-duration: 10s +spec: + runtimeClassName: zeropod + containers: + - name: nginx + image: nginx + ports: + - containerPort: 80 + livenessProbe: + httpGet: + port: 80 +``` + +In this example, the container will be scaled down 10 seconds after starting +even though we have defined a probe. Zeropod will take care of replying to the +probe when the container is scaled down. Whenever the container is running, the +probe traffic will be forwarded to the app just like normal traffic. You can +also customize the path and the headers of the probe, just be mindful of the +size of those. To reduce memory usage, by default, zeropod will only read the +first `1024` bytes of each request to detect an HTTP probe. If the probe is +larger than that, traffic will just be passed through and the app will be +restored on each probe request. In that case, it can be increased with the +[probe buffer size](#zeropodctroxdevprobe-buffer-size) annotation. ### `zeropod.ctrox.dev/container-names` @@ -281,6 +313,20 @@ the application is stateless and super fast to startup. zeropod.ctrox.dev/disable-checkpointing: "true" ``` +### `zeropod.ctrox.dev/disable-probe-detection` + +Disables the probe detection mechanism. If there are probes defined on a +container, they will be forwarded to the container just like any traffic and +will wake it up. + +### `zeropod.ctrox.dev/probe-buffer-size` + +Configure the buffer size of the probe detector. To be able to detect an HTTP +liveness/readiness probe, zeropod needs to read a certain amount of bytes from +the TCP stream of incoming connections. This normally does not need to be +adjusted as the default should fit most probes and only needs to be increased in +case the probe contains lots of header data. Defaults to `1024` if unset. + ## Experimental features ### `zeropod.ctrox.dev/migrate` diff --git a/activator/activator.go b/activator/activator.go index b49129f..a310a4f 100644 --- a/activator/activator.go +++ b/activator/activator.go @@ -1,3 +1,8 @@ +// Package activator contains a userspace TCP proxy that listens on a random +// port and loads an eBPF program to intercept and redirect packets destined to +// the configured ports. The activator accepts the connection, calls onAccept, +// signals to disable the eBPF redirect and then proxies the initial data to the +// defined ports as soon as something is listening. package activator import ( @@ -19,14 +24,13 @@ import ( "github.com/containernetworking/plugins/pkg/ns" ) -const () - type Server struct { listeners []net.Listener ports []uint16 - quit chan interface{} + quit chan any wg sync.WaitGroup - onAccept OnAccept + connHook ConnHook + restoreHook RestoreHook connectTimeout time.Duration proxyTimeout time.Duration proxyCancel context.CancelFunc @@ -34,13 +38,15 @@ type Server struct { maps bpfMaps sandboxPid int started bool + peekBufferSize int } -type OnAccept func() error +type ConnHook func(net.Conn) (conn net.Conn, cont bool, err error) +type RestoreHook func() error func NewServer(ctx context.Context, nn ns.NetNS) (*Server, error) { s := &Server{ - quit: make(chan interface{}), + quit: make(chan any), connectTimeout: time.Second * 5, proxyTimeout: time.Second * 5, ns: nn, @@ -66,15 +72,17 @@ func parsePidFromNetNS(nn ns.NetNS) int { var ErrMapNotFound = errors.New("bpf map could not be found") -func (s *Server) Start(ctx context.Context, ports []uint16, onAccept OnAccept) error { +func (s *Server) Start(ctx context.Context, ports []uint16, connHook ConnHook, restoreHook RestoreHook) error { s.ports = ports + s.connHook = connHook + s.restoreHook = restoreHook if err := s.loadPinnedMaps(); err != nil { return err } for _, port := range s.ports { - proxyPort, err := s.listen(ctx, port, onAccept) + proxyPort, err := s.listen(ctx, port) if err != nil { return err } @@ -111,14 +119,17 @@ func (s *Server) DisableRedirects() error { return nil } -func (s *Server) listen(ctx context.Context, port uint16, onAccept OnAccept) (int, error) { +func (s *Server) SetPeekBufferSize(size int) { + s.peekBufferSize = size +} + +func (s *Server) listen(ctx context.Context, port uint16) (int, error) { // use a random free port for our proxy - addr := "0.0.0.0:0" cfg := net.ListenConfig{} var listener net.Listener if err := s.ns.Do(func(_ ns.NetNS) error { - l, err := cfg.Listen(ctx, "tcp4", addr) + l, err := cfg.Listen(ctx, "tcp", ":0") if err != nil { return fmt.Errorf("unable to listen: %w", err) } @@ -132,8 +143,6 @@ func (s *Server) listen(ctx context.Context, port uint16, onAccept OnAccept) (in log.G(ctx).Debugf("listening on %s in ns %s", listener.Addr(), s.ns.Path()) - s.onAccept = onAccept - s.wg.Add(1) go s.serve(ctx, listener, port) @@ -188,15 +197,23 @@ func (s *Server) serve(ctx context.Context, listener net.Listener, port uint16) wg.Add(1) go func() { log.G(ctx).Debug("accepting connection") - s.handleConection(ctx, conn, port) + s.handleConnection(ctx, conn, port) wg.Done() }() } } } -func (s *Server) handleConection(ctx context.Context, conn net.Conn, port uint16) { +func (s *Server) handleConnection(ctx context.Context, netConn net.Conn, port uint16) { + conn, cont, err := s.connHook(netConn) + if err != nil { + log.G(ctx).Errorf("connHook: %s", err) + return + } defer conn.Close() + if !cont { + return + } tcpAddr, ok := conn.RemoteAddr().(*net.TCPAddr) if !ok { @@ -204,14 +221,14 @@ func (s *Server) handleConection(ctx context.Context, conn net.Conn, port uint16 return } - log.G(ctx).Debugf("registering connection on remote port %d", tcpAddr.Port) + log.G(ctx).Debugf("registering connection on remote port %d from %s", tcpAddr.Port, tcpAddr.IP.String()) if err := s.registerConnection(uint16(tcpAddr.Port)); err != nil { log.G(ctx).Errorf("error registering connection: %s", err) return } - if err := s.onAccept(); err != nil { - log.G(ctx).Errorf("accept function: %s", err) + if err := s.restoreHook(); err != nil { + log.G(ctx).Errorf("restoreHook: %s", err) return } @@ -269,7 +286,7 @@ func (s *Server) connect(ctx context.Context, port uint16) (net.Conn, error) { return err } - addr, err := net.ResolveTCPAddr("tcp4", fmt.Sprintf("127.0.0.1:%d", backendConnPort)) + addr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("localhost:%d", backendConnPort)) if err != nil { return err } @@ -277,7 +294,7 @@ func (s *Server) connect(ctx context.Context, port uint16) (net.Conn, error) { LocalAddr: addr, Timeout: s.connectTimeout, } - backendConn, err = d.Dial("tcp4", fmt.Sprintf("localhost:%d", port)) + backendConn, err = d.Dial("tcp", fmt.Sprintf("localhost:%d", port)) return err }); err != nil { var serr syscall.Errno @@ -300,38 +317,38 @@ const ( ingressRedirectsMap = "ingress_redirects" ) -func (a *Server) loadPinnedMaps() error { +func (s *Server) loadPinnedMaps() error { // either all or none of the maps are pinned, so we want to return // ErrMapNotFound so it can be handled. - if _, err := os.Stat(filepath.Join(PinPath(a.sandboxPid), activeConnectionsMap)); os.IsNotExist(err) { + if _, err := os.Stat(filepath.Join(PinPath(s.sandboxPid), activeConnectionsMap)); os.IsNotExist(err) { return ErrMapNotFound } var err error opts := &ebpf.LoadPinOptions{} - if a.maps.ActiveConnections == nil { - a.maps.ActiveConnections, err = ebpf.LoadPinnedMap(a.mapPath(activeConnectionsMap), opts) + if s.maps.ActiveConnections == nil { + s.maps.ActiveConnections, err = ebpf.LoadPinnedMap(s.mapPath(activeConnectionsMap), opts) if err != nil { return err } } - if a.maps.DisableRedirect == nil { - a.maps.DisableRedirect, err = ebpf.LoadPinnedMap(a.mapPath(disableRedirectMap), opts) + if s.maps.DisableRedirect == nil { + s.maps.DisableRedirect, err = ebpf.LoadPinnedMap(s.mapPath(disableRedirectMap), opts) if err != nil { return err } } - if a.maps.EgressRedirects == nil { - a.maps.EgressRedirects, err = ebpf.LoadPinnedMap(a.mapPath(egressRedirectsMap), opts) + if s.maps.EgressRedirects == nil { + s.maps.EgressRedirects, err = ebpf.LoadPinnedMap(s.mapPath(egressRedirectsMap), opts) if err != nil { return err } } - if a.maps.IngressRedirects == nil { - a.maps.IngressRedirects, err = ebpf.LoadPinnedMap(a.mapPath(ingressRedirectsMap), opts) + if s.maps.IngressRedirects == nil { + s.maps.IngressRedirects, err = ebpf.LoadPinnedMap(s.mapPath(ingressRedirectsMap), opts) if err != nil { return err } @@ -340,44 +357,44 @@ func (a *Server) loadPinnedMaps() error { return nil } -func (a *Server) mapPath(name string) string { - return filepath.Join(PinPath(a.sandboxPid), name) +func (s *Server) mapPath(name string) string { + return filepath.Join(PinPath(s.sandboxPid), name) } // RedirectPort redirects the port from to on ingress and to from on egress. -func (a *Server) RedirectPort(from, to uint16) error { - if err := a.maps.IngressRedirects.Put(&from, &to); err != nil { +func (s *Server) RedirectPort(from, to uint16) error { + if err := s.maps.IngressRedirects.Put(&from, &to); err != nil { return fmt.Errorf("unable to put ports %d -> %d into bpf map: %w", from, to, err) } - if err := a.maps.EgressRedirects.Put(&to, &from); err != nil { + if err := s.maps.EgressRedirects.Put(&to, &from); err != nil { return fmt.Errorf("unable to put ports %d -> %d into bpf map: %w", to, from, err) } return nil } -func (a *Server) registerConnection(port uint16) error { - if err := a.maps.ActiveConnections.Put(&port, uint8(1)); err != nil { +func (s *Server) registerConnection(port uint16) error { + if err := s.maps.ActiveConnections.Put(&port, uint8(1)); err != nil { return fmt.Errorf("unable to put port %d into bpf map: %w", port, err) } return nil } -func (a *Server) removeConnection(port uint16) error { - if err := a.maps.ActiveConnections.Delete(&port); err != nil { +func (s *Server) removeConnection(port uint16) error { + if err := s.maps.ActiveConnections.Delete(&port); err != nil { return fmt.Errorf("unable to delete port %d in bpf map: %w", port, err) } return nil } -func (a *Server) disableRedirect(port uint16) error { - if err := a.maps.DisableRedirect.Put(&port, uint8(1)); err != nil { +func (s *Server) disableRedirect(port uint16) error { + if err := s.maps.DisableRedirect.Put(&port, uint8(1)); err != nil { return fmt.Errorf("unable to put %d into bpf map: %w", port, err) } return nil } -func (a *Server) enableRedirect(port uint16) error { - if err := a.maps.DisableRedirect.Delete(&port); err != nil { +func (s *Server) enableRedirect(port uint16) error { + if err := s.maps.DisableRedirect.Delete(&port); err != nil { if !errors.Is(err, ebpf.ErrKeyNotExist) { return err } @@ -415,7 +432,7 @@ func copy(done chan struct{}, errors chan error, dst io.Writer, src io.Reader) { } func freePort() (int, error) { - listener, err := net.Listen("tcp4", "127.0.0.1:0") + listener, err := net.Listen("tcp", ":0") if err != nil { return 0, err } diff --git a/activator/activator_test.go b/activator/activator_test.go index c4e7c9b..e3fa23f 100644 --- a/activator/activator_test.go +++ b/activator/activator_test.go @@ -25,10 +25,6 @@ func TestActivator(t *testing.T) { require.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) - - port, err := freePort() - require.NoError(t, err) - s, err := NewServer(ctx, nn) require.NoError(t, err) @@ -36,61 +32,122 @@ func TestActivator(t *testing.T) { require.NoError(t, err) require.NoError(t, bpf.AttachRedirector("lo")) - response := "ok" - ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, response) - })) - - once := sync.Once{} - err = s.Start(ctx, []uint16{uint16(port)}, func() error { - once.Do(func() { - // simulate a delay until our server is started - time.Sleep(time.Millisecond * 200) - l, err := net.Listen("tcp4", fmt.Sprintf(":%d", port)) - require.NoError(t, err) - - if err := s.DisableRedirects(); err != nil { - t.Errorf("could not disable redirects: %s", err) - } - - // replace listener of server - ts.Listener.Close() - ts.Listener = l - ts.Start() - t.Logf("listening on :%d", port) - - t.Cleanup(func() { - ts.Close() - }) - }) - return nil - }) + port, err := freePort() require.NoError(t, err) + t.Cleanup(func() { s.Stop(ctx) cancel() }) - c := &http.Client{Timeout: time.Second} + c := &http.Client{ + Timeout: time.Second, + Transport: &http.Transport{ + DisableKeepAlives: true, + }, + } - parallelReqs := 10 + tests := map[string]struct { + parallelReqs int + connHook ConnHook + expectedBody string + expectedCode int + }{ + "no probe": { + parallelReqs: 1, + expectedBody: "ok", + expectedCode: http.StatusOK, + }, + "10 in parallel": { + parallelReqs: 10, + expectedBody: "ok", + expectedCode: http.StatusOK, + }, + "conn hook": { + parallelReqs: 1, + expectedBody: "", + connHook: func(conn net.Conn) (net.Conn, bool, error) { + resp := http.Response{ + StatusCode: http.StatusForbidden, + } + return conn, false, resp.Write(conn) + }, + expectedCode: http.StatusForbidden, + }, + } wg := sync.WaitGroup{} - for _, port := range []int{port} { - port := port - for i := 0; i < parallelReqs; i++ { - wg.Add(1) - go func() { - defer wg.Done() - resp, err := c.Get(fmt.Sprintf("http://localhost:%d", port)) - require.NoError(t, err) - b, err := io.ReadAll(resp.Body) - require.NoError(t, err) + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + startServer(t, ctx, s, port, tc.connHook) + for i := 0; i < tc.parallelReqs; i++ { + wg.Add(1) + go func() { + defer wg.Done() + req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://localhost:%d", port), nil) + if !assert.NoError(t, err) { + return + } + + resp, err := c.Do(req) + if !assert.NoError(t, err) { + return + } + + b, err := io.ReadAll(resp.Body) + if !assert.NoError(t, err) { + return + } + + assert.Equal(t, tc.expectedCode, resp.StatusCode) + assert.Equal(t, tc.expectedBody, string(b)) + t.Log(string(b)) + }() + } + wg.Wait() + assert.NoError(t, s.Reset()) + }) + } +} - assert.Equal(t, http.StatusOK, resp.StatusCode) - assert.Equal(t, response, string(b)) - t.Log(string(b)) - }() +func startServer(t *testing.T, ctx context.Context, s *Server, port int, connHook ConnHook) { + response := "ok" + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, response) + })) + if connHook == nil { + connHook = func(c net.Conn) (net.Conn, bool, error) { + return c, true, nil } } - wg.Wait() + + once := sync.Once{} + err := s.Start( + ctx, []uint16{uint16(port)}, + connHook, + func() error { + once.Do(func() { + // simulate a delay until our server is started + time.Sleep(time.Millisecond * 200) + l, err := net.Listen("tcp4", fmt.Sprintf(":%d", port)) + require.NoError(t, err) + + if err := s.DisableRedirects(); err != nil { + t.Errorf("could not disable redirects: %s", err) + } + + // replace listener of server + ts.Listener.Close() + ts.Listener = l + ts.Start() + t.Logf("listening on :%d", port) + + t.Cleanup(func() { + l.Close() + ts.Close() + }) + }) + return nil + }, + ) + require.NoError(t, err) } diff --git a/activator/bpf_bpfeb.o b/activator/bpf_bpfeb.o index 9be16af..b2040c6 100644 Binary files a/activator/bpf_bpfeb.o and b/activator/bpf_bpfeb.o differ diff --git a/activator/bpf_bpfel.go b/activator/bpf_bpfel.go index f84f0e5..677eddc 100644 --- a/activator/bpf_bpfel.go +++ b/activator/bpf_bpfel.go @@ -1,5 +1,5 @@ // Code generated by bpf2go; DO NOT EDIT. -//go:build 386 || amd64 || arm || arm64 || loong64 || mips64le || mipsle || ppc64le || riscv64 +//go:build 386 || amd64 || arm || arm64 || loong64 || mips64le || mipsle || ppc64le || riscv64 || wasm package activator diff --git a/activator/bpf_bpfel.o b/activator/bpf_bpfel.o index cfe4f8c..14f01b9 100644 Binary files a/activator/bpf_bpfel.o and b/activator/bpf_bpfel.o differ diff --git a/activator/redirector.c b/activator/redirector.c index 75ff274..8fdb8cc 100644 --- a/activator/redirector.c +++ b/activator/redirector.c @@ -6,7 +6,10 @@ char __license[] SEC("license") = "Dual MIT/GPL"; -#define TC_ACT_OK 0 +#define TC_ACT_OK 0 +#define ETH_P_IP 0x0800 +#define ETH_P_IPV6 0x86DD +#define NEXTHDR_TCP 6 struct { __uint(type, BPF_MAP_TYPE_LRU_HASH); @@ -70,7 +73,7 @@ static __always_inline int ingress_redirect(struct tcphdr *tcp) { if (new_dest) { // check ports which should not be redirected if (disabled(sport_h, dport_h)) { - // if we can find an acive connection on the source port, we need + // if we can find an active connection on the source port, we need // to redirect regardless until the connection is closed. void *conn_sport = bpf_map_lookup_elem(active_connections_map, &sport_h); if (!conn_sport) { @@ -101,29 +104,69 @@ static __always_inline int egress_redirect(struct tcphdr *tcp) { return TC_ACT_OK; } -static __always_inline int parse_and_redirect(struct __sk_buff *ctx, bool ingress) { - void *data = (void *)(long)ctx->data; - void *data_end = (void *)(long)ctx->data_end; +static __always_inline struct ipv6hdr* ipv6_header(void *data, void *data_end) { struct ethhdr *eth = data; + struct ipv6hdr *ip6; + + if (data + sizeof(*eth) + sizeof(*ip6) > data_end) { + return NULL; + } + + if (bpf_ntohs(eth->h_proto) != ETH_P_IPV6) { + return NULL; + } + + ip6 = data + sizeof(*eth); + return ip6; +} + +static __always_inline struct iphdr* ipv4_header(void *data, void *data_end) { + struct ethhdr *eth = data; + struct iphdr *ip4; + + if (data + sizeof(*eth) + sizeof(*ip4) > data_end) { + return NULL; + } - if ((void*)eth + sizeof(*eth) <= data_end) { - struct iphdr *ip = data + sizeof(*eth); + if (bpf_ntohs(eth->h_proto) != ETH_P_IP) { + return NULL; + } - if ((void*)ip + sizeof(*ip) <= data_end) { - if (ip->protocol == IPPROTO_TCP) { - struct tcphdr *tcp = (void*)ip + sizeof(*ip); - if ((void*)tcp + sizeof(*tcp) <= data_end) { - if (ingress) { - return ingress_redirect(tcp); - } + ip4 = data + sizeof(*eth); + return ip4; +} - return egress_redirect(tcp); +static __always_inline int parse_and_redirect(struct __sk_buff *ctx, bool ingress) { + void *data_end = (void *)(long)ctx->data_end; + void *data = (void *)(long)ctx->data; + struct tcphdr *tcp = NULL; + + struct iphdr *ip4 = ipv4_header(data, data_end); + if (ip4) { + if ((void*)ip4 + sizeof(*ip4) <= data_end) { + if (ip4->protocol == IPPROTO_TCP) { + tcp = (void*)ip4 + sizeof(*ip4); + } + } + } else { + struct ipv6hdr *ip6 = ipv6_header(data, data_end); + if (ip6) { + if ((void*)ip6 + sizeof(*ip6) <= data_end) { + if (ip6->nexthdr == NEXTHDR_TCP) { + tcp = (void*)ip6 + sizeof(*ip6); } } } } - return 0; + if ((tcp != NULL) && ((void*)tcp + sizeof(*tcp) <= data_end)) { + if (ingress) { + return ingress_redirect(tcp); + } + return egress_redirect(tcp); + } + + return TC_ACT_OK; } diff --git a/cmd/freezer/Dockerfile b/cmd/freezer/Dockerfile index 6f668a5..46b1453 100644 --- a/cmd/freezer/Dockerfile +++ b/cmd/freezer/Dockerfile @@ -7,6 +7,7 @@ RUN go mod download COPY cmd/freezer cmd/freezer +ARG TARGETARCH RUN CGO_ENABLED=0 GOOS=linux GOARCH=$TARGETARCH GO111MODULE=on go build -ldflags "-s -w" -a -o freezer cmd/freezer/main.go FROM gcr.io/distroless/static-debian12 diff --git a/cmd/installer/main.go b/cmd/installer/main.go index 7fe10b9..bba81de 100644 --- a/cmd/installer/main.go +++ b/cmd/installer/main.go @@ -84,6 +84,8 @@ network-lock skip "zeropod.ctrox.dev/pre-dump", "zeropod.ctrox.dev/migrate", "zeropod.ctrox.dev/live-migrate", + "zeropod.ctrox.dev/disable-probe-detection", + "zeropod.ctrox.dev/probe-buffer-size", "io.containerd.runc.v2.group" ] @@ -104,6 +106,8 @@ network-lock skip "zeropod.ctrox.dev/pre-dump", "zeropod.ctrox.dev/migrate", "zeropod.ctrox.dev/live-migrate", + "zeropod.ctrox.dev/disable-probe-detection", + "zeropod.ctrox.dev/probe-buffer-size", "io.containerd.runc.v2.group" ] diff --git a/cmd/manager/main.go b/cmd/manager/main.go index 4e5135e..e570cb5 100644 --- a/cmd/manager/main.go +++ b/cmd/manager/main.go @@ -11,6 +11,7 @@ import ( "os/signal" "syscall" + nodev1 "github.com/ctrox/zeropod/api/node/v1" v1 "github.com/ctrox/zeropod/api/runtime/v1" "github.com/ctrox/zeropod/manager" "github.com/ctrox/zeropod/manager/node" @@ -18,7 +19,10 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/fields" "k8s.io/apimachinery/pkg/runtime" + "sigs.k8s.io/controller-runtime/pkg/cache" + "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/config" ctrlmanager "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/controller-runtime/pkg/metrics/server" @@ -30,7 +34,8 @@ var ( debug = flag.Bool("debug", false, "enable debug logs") inPlaceScaling = flag.Bool("in-place-scaling", false, "enable in-place resource scaling, requires InPlacePodVerticalScaling feature flag") - statusLabels = flag.Bool("status-labels", false, "update pod labels to reflect container status") + statusLabels = flag.Bool("status-labels", false, "update pod labels to reflect container status") + probeBinaryName = flag.String("probe-binary-name", "kubelet", "set the probe binary name for probe detection") ) func main() { @@ -46,16 +51,16 @@ func main() { ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() - if err := manager.AttachRedirectors(ctx, log); err != nil { - log.Warn("attaching redirectors failed: restoring containers on traffic is disabled", "err", err) - } - - cleanSocketTracker, err := socket.LoadEBPFTracker() + tracker, cleanSocketTracker, err := socket.LoadEBPFTracker(*probeBinaryName) if err != nil { log.Warn("loading socket tracker failed, scaling down with static duration", "err", err) cleanSocketTracker = func() error { return nil } } + if err := manager.AttachRedirectors(ctx, log, tracker); err != nil { + log.Warn("attaching redirectors failed: restoring containers on traffic is disabled", "err", err) + } + mgr, err := newControllerManager() if err != nil { log.Error("creating controller manager", "err", err) @@ -138,8 +143,24 @@ func newControllerManager() (ctrlmanager.Manager, error) { if err := v1.AddToScheme(scheme); err != nil { return nil, err } + nodeName, ok := os.LookupEnv(nodev1.NodeNameEnvKey) + if !ok { + return nil, fmt.Errorf("could not find node name, env %s is not set", nodev1.NodeNameEnvKey) + } mgr, err := ctrlmanager.New(cfg, ctrlmanager.Options{ Scheme: scheme, Metrics: server.Options{BindAddress: "0"}, + Cache: cache.Options{ + ByObject: map[client.Object]cache.ByObject{ + // for pods we're only interested in objects that are running on + // the same node as the manager. This will reduce memory usage + // as we only keep a subset of all pods in the cache. + &corev1.Pod{}: cache.ByObject{ + Field: fields.SelectorFromSet(fields.Set{ + "spec.nodeName": nodeName, + }), + }, + }, + }, }) if err != nil { return nil, err diff --git a/config/examples/nginx.yaml b/config/examples/nginx.yaml index 62e4804..21e223f 100644 --- a/config/examples/nginx.yaml +++ b/config/examples/nginx.yaml @@ -12,7 +12,6 @@ spec: labels: app: nginx annotations: - io.containerd.runc.v2.group: "zeropod" zeropod.ctrox.dev/scaledown-duration: 10s spec: runtimeClassName: zeropod @@ -21,6 +20,10 @@ spec: name: nginx ports: - containerPort: 80 + livenessProbe: + periodSeconds: 1 + httpGet: + port: 80 resources: requests: cpu: 100m diff --git a/config/k3s/kustomization.yaml b/config/k3s/kustomization.yaml index cc07187..4173476 100644 --- a/config/k3s/kustomization.yaml +++ b/config/k3s/kustomization.yaml @@ -8,3 +8,9 @@ patches: value: -runtime=k3s target: kind: DaemonSet + - patch: |- + - op: add + path: /spec/template/spec/containers/0/args/- + value: -probe-binary-name=k3s + target: + kind: DaemonSet diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index aa3e237..3bb5f87 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -12,12 +12,14 @@ import ( v1 "github.com/ctrox/zeropod/api/shim/v1" "github.com/ctrox/zeropod/manager" + "github.com/ctrox/zeropod/shim" "github.com/prometheus/client_golang/prometheus" dto "github.com/prometheus/client_model/go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" + "k8s.io/apimachinery/pkg/util/intstr" "k8s.io/utils/ptr" ) @@ -33,17 +35,18 @@ func TestE2E(t *testing.T) { } cases := map[string]struct { - pod *corev1.Pod - svc *corev1.Service - parallelReqs int - sequentialReqs int - sequentialWait time.Duration - maxReqDuration time.Duration - ignoreFirstReq bool - keepAlive bool - preDump bool - waitScaledDown bool - expectRunning bool + pod *corev1.Pod + svc *corev1.Service + parallelReqs int + sequentialReqs int + sequentialWait time.Duration + maxReqDuration time.Duration + ignoreFirstReq bool + keepAlive bool + preDump bool + waitScaledDown bool + expectRunning bool + expectScaledDown bool }{ // note: some of these max request durations are really // system-dependent. It has been tested on a few systems so far and @@ -130,6 +133,69 @@ func TestE2E(t *testing.T) { maxReqDuration: time.Second, waitScaledDown: true, }, + "pod with HTTP probe": { + pod: testPod( + scaleDownAfter(time.Second), + addContainer("nginx", "nginx", nil, 80), + livenessProbe(&corev1.Probe{ + InitialDelaySeconds: 5, + PeriodSeconds: 1, + ProbeHandler: corev1.ProbeHandler{ + HTTPGet: &corev1.HTTPGetAction{ + Port: intstr.FromInt(80), + }, + }, + }), + ), + parallelReqs: 0, + sequentialReqs: 0, + waitScaledDown: true, + expectRunning: false, + expectScaledDown: true, + }, + "pod with TCP probe": { + pod: testPod( + scaleDownAfter(time.Second), + addContainer("nginx", "nginx", nil, 80), + livenessProbe(&corev1.Probe{ + InitialDelaySeconds: 5, + PeriodSeconds: 1, + ProbeHandler: corev1.ProbeHandler{ + TCPSocket: &corev1.TCPSocketAction{ + Port: intstr.FromInt(80), + }, + }, + }), + ), + parallelReqs: 0, + sequentialReqs: 0, + waitScaledDown: true, + expectRunning: false, + expectScaledDown: true, + }, + "pod with large HTTP probe and increased buffer": { + pod: testPod( + scaleDownAfter(time.Second), + annotations(map[string]string{shim.ProbeBufferSizeAnnotationKey: "2048"}), + addContainer("nginx", "nginx", nil, 80), + livenessProbe(&corev1.Probe{ + InitialDelaySeconds: 3, + PeriodSeconds: 1, + ProbeHandler: corev1.ProbeHandler{ + HTTPGet: &corev1.HTTPGetAction{ + Port: intstr.FromInt(80), + // ensures probe request is bigger than 1024 bytes + Path: "/" + strings.Repeat("a", 1025), + }, + }, + }), + ), + parallelReqs: 0, + sequentialReqs: 0, + waitScaledDown: true, + expectRunning: false, + expectScaledDown: true, + }, } for name, tc := range cases { @@ -156,6 +222,10 @@ func TestE2E(t *testing.T) { alwaysRunningFor(t, ctx, e2e.client, tc.pod, time.Second*10) } + if tc.expectScaledDown { + alwaysScaledDownFor(t, ctx, e2e.client, tc.pod, time.Second*10) + } + wg := sync.WaitGroup{} wg.Add(tc.parallelReqs) for i := 0; i < tc.parallelReqs; i++ { @@ -269,6 +339,33 @@ func TestE2E(t *testing.T) { }, time.Minute, time.Second) }) + t.Run("socket tracker ignores probe", func(t *testing.T) { + pod := testPod( + scaleDownAfter(time.Second*5), + addContainer("nginx", "nginx", nil, 80), + livenessProbe(&corev1.Probe{ + PeriodSeconds: 1, + ProbeHandler: corev1.ProbeHandler{ + HTTPGet: &corev1.HTTPGetAction{ + Port: intstr.FromInt(80), + }, + }, + }), + ) + cleanupPod := createPodAndWait(t, ctx, e2e.client, pod) + cleanupService := createServiceAndWait(t, ctx, e2e.client, testService(defaultTargetPort), 1) + defer cleanupPod() + defer cleanupService() + // we expect it to scale down even though a constant livenessProbe is + // hitting it + waitUntilScaledDown(t, ctx, e2e.client, pod) + // make a real request and expect it to scale down again + resp, err := c.Get(fmt.Sprintf("http://localhost:%d", e2e.port)) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + waitUntilScaledDown(t, ctx, e2e.client, pod) + }) + t.Run("metrics", func(t *testing.T) { // create two pods to test metric merging runningPod := testPod(scaleDownAfter(time.Hour)) diff --git a/e2e/migration_test.go b/e2e/migration_test.go index e900003..b65efe7 100644 --- a/e2e/migration_test.go +++ b/e2e/migration_test.go @@ -108,7 +108,7 @@ func TestMigration(t *testing.T) { t.Logf("migration phase: %s", migration.Status.Containers[0].Condition.Phase) return pods[0].Status.Phase == corev1.PodRunning && migration.Status.Containers[0].Condition.Phase == v1.MigrationPhaseCompleted - }, time.Minute, time.Second) + }, time.Minute*2, time.Second) waitForService(t, ctx, e2e.client, tc.svc, 1) } @@ -154,7 +154,9 @@ func defaultBeforeMigration(t *testing.T) { return false } f, err := freezerRead(e2e.port) - require.NoError(t, err) + if err != nil { + return false + } return t.Name() == f.Data }, time.Second*10, time.Second) } diff --git a/e2e/setup_test.go b/e2e/setup_test.go index 75f20aa..50a13b9 100644 --- a/e2e/setup_test.go +++ b/e2e/setup_test.go @@ -430,6 +430,14 @@ func resources(res corev1.ResourceRequirements) podOption { } } +func livenessProbe(probe *corev1.Probe) podOption { + return func(p *pod) { + for i := range p.spec.Containers { + p.spec.Containers[i].LivenessProbe = probe + } + } +} + const agnHostImage = "registry.k8s.io/e2e-test-images/agnhost:2.39" func agnContainer(name string, port int) podOption { @@ -782,7 +790,17 @@ func alwaysRunningFor(t testing.TB, ctx context.Context, c client.Client, pod *c require.Never(t, func() bool { ok, err := isRunning(ctx, c, pod, container.Name) t.Logf("running: %v: %s", ok, pod.GetLabels()[path.Join(manager.StatusLabelKeyPrefix, container.Name)]) - return err != nil && !ok + return err != nil || !ok + }, dur, time.Second) + } +} + +func alwaysScaledDownFor(t testing.TB, ctx context.Context, c client.Client, pod *corev1.Pod, dur time.Duration) { + for _, container := range pod.Spec.Containers { + require.Never(t, func() bool { + ok, err := isScaledDown(ctx, c, pod, container.Name) + t.Logf("scaled down: %v: %s", ok, pod.GetLabels()[path.Join(manager.StatusLabelKeyPrefix, container.Name)]) + return err != nil || !ok }, dur, time.Second) } } diff --git a/go.mod b/go.mod index d4ddb3f..c0d3e11 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.23.0 require ( github.com/checkpoint-restore/go-criu/v7 v7.2.0 - github.com/cilium/ebpf v0.17.3 + github.com/cilium/ebpf v0.19.0 github.com/containerd/cgroups/v3 v3.0.3 github.com/containerd/containerd/api v1.8.0 github.com/containerd/containerd/v2 v2.0.4 diff --git a/go.sum b/go.sum index eadab13..188dd53 100644 --- a/go.sum +++ b/go.sum @@ -53,8 +53,8 @@ github.com/checkpoint-restore/go-criu/v7 v7.2.0/go.mod h1:u0LCWLg0w4yqqu14aXhiB4 github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= -github.com/cilium/ebpf v0.17.3 h1:FnP4r16PWYSE4ux6zN+//jMcW4nMVRvuTLVTvCjyyjg= -github.com/cilium/ebpf v0.17.3/go.mod h1:G5EDHij8yiLzaqn0WjyfJHvRa+3aDlReIaLVRMvOyJk= +github.com/cilium/ebpf v0.19.0 h1:Ro/rE64RmFBeA9FGjcTc+KmCeY6jXmryu6FfnzPRIao= +github.com/cilium/ebpf v0.19.0/go.mod h1:fLCgMo3l8tZmAdM3B2XqdFzXBpwkcSTroaVqN08OWVY= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/containerd/cgroups/v3 v3.0.3 h1:S5ByHZ/h9PMe5IOQoN7E+nMc2UcLEM/V48DGDJ9kip0= @@ -138,8 +138,8 @@ github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF github.com/go-openapi/jsonreference v0.21.0/go.mod h1:LmZmgsrTkVg9LG4EaHeY8cBDslNPMo06cago5JNLkm4= github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= -github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI= -github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow= +github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6 h1:teYtXy9B7y5lHTp8V9KPxpYRAVA7dozigQcMiBust1s= +github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6/go.mod h1:p4lGIVX+8Wa6ZPNDvqcxq36XpUDLh42FLetFU7odllI= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/gobuffalo/flect v1.0.3 h1:xeWBM2nui+qnVvNM4S3foBhCAL2XgPU+a7FdpelbTq4= diff --git a/manager/pod_controller.go b/manager/pod_controller.go index b7c382f..3661b0e 100644 --- a/manager/pod_controller.go +++ b/manager/pod_controller.go @@ -21,6 +21,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/controller" "sigs.k8s.io/controller-runtime/pkg/handler" "sigs.k8s.io/controller-runtime/pkg/manager" + "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/reconcile" "sigs.k8s.io/controller-runtime/pkg/source" ) @@ -40,6 +41,9 @@ func NewPodController(ctx context.Context, mgr manager.Manager, log *slog.Logger } return c.Watch(source.Kind( mgr.GetCache(), &corev1.Pod{}, &handler.TypedEnqueueRequestForObject[*corev1.Pod]{}, + predicate.NewTypedPredicateFuncs[*corev1.Pod](func(pod *corev1.Pod) bool { + return isZeropod(pod) + }), )) } @@ -112,11 +116,13 @@ func (r *podReconciler) Reconcile(ctx context.Context, request reconcile.Request } func (r podReconciler) isMigratable(pod *corev1.Pod) bool { - if pod.Spec.RuntimeClassName != nil && *pod.Spec.RuntimeClassName != v1.RuntimeClassName { + // some of these are already handled by the cache/predicate but there's no + // harm in being sure. + if pod.Spec.NodeName != r.nodeName { return false } - if pod.Spec.NodeName != r.nodeName { + if !isZeropod(pod) { return false } @@ -137,6 +143,10 @@ func (r podReconciler) isMigratable(pod *corev1.Pod) bool { return true } +func isZeropod(pod *corev1.Pod) bool { + return pod.Spec.RuntimeClassName != nil && *pod.Spec.RuntimeClassName == v1.RuntimeClassName +} + func newMigration(pod *corev1.Pod) (*v1.Migration, error) { containers := []v1.MigrationContainer{} for _, container := range pod.Status.ContainerStatuses { diff --git a/manager/redirector_attacher.go b/manager/redirector_attacher.go index 004878a..af2e36d 100644 --- a/manager/redirector_attacher.go +++ b/manager/redirector_attacher.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" "log/slog" + "net" + "net/netip" "os" "path/filepath" "strconv" @@ -18,10 +20,21 @@ import ( type Redirector struct { sync.Mutex - activators map[int]*activator.BPF - log *slog.Logger + sandboxes map[int]sandbox + log *slog.Logger + tracker socket.Tracker } +type sandbox struct { + ip netip.Addr + activator *activator.BPF +} + +const ( + ifaceETH0 = "eth0" + ifaceLoopback = "lo" +) + // AttachRedirectors scans the zeropod maps path in the bpf file system for // directories named after the pid of the sandbox container. It does an // initial iteration over all directories and then starts a goroutine which @@ -29,10 +42,11 @@ type Redirector struct { // can be found it attaches the redirector BPF programs to the network // interfaces of the sandbox. The directories are expected to be created by // the zeropod shim on startup. -func AttachRedirectors(ctx context.Context, log *slog.Logger) error { +func AttachRedirectors(ctx context.Context, log *slog.Logger, tracker socket.Tracker) error { r := &Redirector{ - activators: make(map[int]*activator.BPF), - log: log, + sandboxes: make(map[int]sandbox), + log: log, + tracker: tracker, } if _, err := os.Stat(activator.MapsPath()); os.IsNotExist(err) { @@ -51,6 +65,7 @@ func AttachRedirectors(ctx context.Context, log *slog.Logger) error { r.log.Info("no sandbox pids found") } + errs := []error{} for _, pid := range pids { if err := statNetNS(pid); os.IsNotExist(err) { r.log.Info("net ns not found, removing leftover pid", "path", netNSPath(pid)) @@ -58,14 +73,12 @@ func AttachRedirectors(ctx context.Context, log *slog.Logger) error { continue } - if err := r.attachRedirector(pid); err != nil { - return err - } + errs = append(errs, r.attachRedirector(pid)) } go r.watchForSandboxPids(ctx) - return nil + return errors.Join(errs...) } func (r *Redirector) watchForSandboxPids(ctx context.Context) error { @@ -83,7 +96,7 @@ func (r *Redirector) watchForSandboxPids(ctx context.Context) error { select { // watch for events case event := <-watcher.Events: - if filepath.Base(event.Name) == socket.TCPEventsMap { + if ignoredDir(filepath.Base(event.Name)) { continue } @@ -105,9 +118,9 @@ func (r *Redirector) watchForSandboxPids(ctx context.Context) error { } case fsnotify.Remove: r.Lock() - if act, ok := r.activators[pid]; ok { - r.log.Info("cleaning up activator", "pid", pid) - if err := act.Cleanup(); err != nil { + if sb, ok := r.sandboxes[pid]; ok { + r.log.Info("cleaning up redirector", "pid", pid) + if err := sb.Remove(r.tracker); err != nil { r.log.Error("error cleaning up redirector", "err", err) } } @@ -126,25 +139,35 @@ func (r *Redirector) attachRedirector(pid int) error { if err != nil { return fmt.Errorf("unable to initialize BPF: %w", err) } - r.Lock() - r.activators[pid] = bpf - r.Unlock() netNS, err := ns.GetNS(netNSPath(pid)) if err != nil { return err } + var sandboxIP netip.Addr if err := netNS.Do(func(nn ns.NetNS) error { - // TODO: is this really always eth0? + // TODO: is this really always eth0? // as for loopback, this is required for port-forwarding to work - ifaces := []string{"eth0", "lo"} + ifaces := []string{ifaceETH0, ifaceLoopback} r.log.Info("attaching redirector for sandbox", "pid", pid, "links", ifaces) - return bpf.AttachRedirector(ifaces...) - }); err != nil { + if err := bpf.AttachRedirector(ifaces...); err != nil { + return err + } + + sandboxIP, err = getSandboxIP(ifaceETH0) return err + }); err != nil { + return errors.Join(err, bpf.Cleanup()) } + r.Lock() + r.sandboxes[pid] = sandbox{activator: bpf, ip: sandboxIP} + r.Unlock() + + if err := r.trackSandboxIP(sandboxIP); err != nil { + return fmt.Errorf("tracking sandbox IP: %w", err) + } return nil } @@ -173,7 +196,7 @@ func (r *Redirector) getSandboxPids() ([]int, error) { intPids := make([]int, 0, len(dirs)) for _, dir := range dirs { - if dir == socket.TCPEventsMap { + if ignoredDir(dir) { continue } @@ -193,3 +216,55 @@ func (r *Redirector) getSandboxPids() ([]int, error) { return intPids, nil } + +func getSandboxIP(ifaceName string) (netip.Addr, error) { + ip := netip.Addr{} + iface, err := net.InterfaceByName(ifaceName) + if err != nil { + return ip, fmt.Errorf("could not get interface: %w", err) + } + addrs, err := iface.Addrs() + if err != nil { + return ip, fmt.Errorf("could not get interface addrs: %w", err) + } + for _, addr := range addrs { + if ipnet, ok := addr.(*net.IPNet); ok { + // no need to track link local addresses + if ipnet.IP.IsLinkLocalUnicast() { + continue + } + ip, ok = netip.AddrFromSlice(ipnet.IP) + if !ok { + return ip, fmt.Errorf("unable to convert net.IP to netip.Addr: %s", ipnet.IP) + } + // use Unmap as the ipv4 might be mapped in v6 + return ip.Unmap(), nil + } + } + return ip, fmt.Errorf("sandbox IP not found") +} + +// trackSandboxIP passes the pod/sandbox IP to the tracker so it can ignore +// kubelet probes going to this pod. +func (r *Redirector) trackSandboxIP(ip netip.Addr) error { + if r.tracker == nil { + return nil + } + r.log.Info("tracking sandbox IP", "addr", ip.String()) + if err := r.tracker.PutPodIP(ip); err != nil { + return fmt.Errorf("putting pod IP in tracker map: %w", err) + } + return nil +} + +func ignoredDir(dir string) bool { + return dir == socket.TCPEventsMap || dir == socket.PodKubeletAddrsMapv4 || dir == socket.PodKubeletAddrsMapv6 +} + +func (sb sandbox) Remove(tracker socket.Tracker) error { + errs := []error{sb.activator.Cleanup()} + if tracker != nil { + errs = append(errs, tracker.RemovePodIP(sb.ip)) + } + return errors.Join(errs...) +} diff --git a/shim/config.go b/shim/config.go index 743f693..f58609a 100644 --- a/shim/config.go +++ b/shim/config.go @@ -23,6 +23,8 @@ const ( PreDumpAnnotationKey = "zeropod.ctrox.dev/pre-dump" MigrateAnnotationKey = "zeropod.ctrox.dev/migrate" LiveMigrateAnnotationKey = "zeropod.ctrox.dev/live-migrate" + DisableProbeDetectAnnotationKey = "zeropod.ctrox.dev/disable-probe-detection" + ProbeBufferSizeAnnotationKey = "zeropod.ctrox.dev/probe-buffer-size" CRIContainerNameAnnotation = "io.kubernetes.cri.container-name" CRIContainerTypeAnnotation = "io.kubernetes.cri.container-type" CRIPodNameAnnotation = "io.kubernetes.cri.sandbox-name" @@ -51,6 +53,8 @@ type Config struct { PodNamespace string PodUID string ContainerdNamespace string + DisableProbeDetection bool + ProbeBufferSize int spec *specs.Spec } @@ -135,6 +139,24 @@ func NewConfig(ctx context.Context, spec *specs.Spec) (*Config, error) { ns = defaultContainerdNS } + disableProbeDetectionValue := spec.Annotations[DisableProbeDetectAnnotationKey] + disableProbeDetection := false + if disableProbeDetectionValue != "" { + disableProbeDetection, err = strconv.ParseBool(disableProbeDetectionValue) + if err != nil { + return nil, err + } + } + + probeBufferSize := defaultProbeBufferSize + probeBufferSizeValue := spec.Annotations[ProbeBufferSizeAnnotationKey] + if probeBufferSizeValue != "" { + probeBufferSize, err = strconv.Atoi(probeBufferSizeValue) + if err != nil { + return nil, err + } + } + return &Config{ Ports: containerPorts, ScaleDownDuration: dur, @@ -149,6 +171,8 @@ func NewConfig(ctx context.Context, spec *specs.Spec) (*Config, error) { PodNamespace: spec.Annotations[CRIPodNamespaceAnnotation], PodUID: spec.Annotations[CRIPodUIDAnnotation], ContainerdNamespace: ns, + DisableProbeDetection: disableProbeDetection, + ProbeBufferSize: probeBufferSize, spec: spec, }, nil } diff --git a/shim/container.go b/shim/container.go index 23face2..c12c44b 100644 --- a/shim/container.go +++ b/shim/container.go @@ -340,7 +340,7 @@ func (c *Container) startActivator(ctx context.Context) error { log.G(ctx).Infof("starting activator with config: %v", c.cfg) - if err := c.activator.Start(ctx, c.cfg.Ports, c.restoreHandler(ctx)); err != nil { + if err := c.activator.Start(ctx, c.cfg.Ports, c.detectProbe(ctx), c.restoreHandler(ctx)); err != nil { if errors.Is(err, activator.ErrMapNotFound) { return err } @@ -353,7 +353,7 @@ func (c *Container) startActivator(ctx context.Context) error { return nil } -func (c *Container) restoreHandler(ctx context.Context) activator.OnAccept { +func (c *Container) restoreHandler(ctx context.Context) activator.RestoreHook { return func() error { log.G(ctx).Printf("got a request") diff --git a/shim/probe.go b/shim/probe.go new file mode 100644 index 0000000..42a2f25 --- /dev/null +++ b/shim/probe.go @@ -0,0 +1,103 @@ +package shim + +import ( + "bufio" + "bytes" + "context" + "io" + "net" + "net/http" + "strings" + + "github.com/containerd/log" + "github.com/ctrox/zeropod/activator" +) + +// defaultProbeBufferSize should be able to fit kube-probe HTTP requests with +// reasonable path and header sizes but should still be small enough to not +// impact performance. +const defaultProbeBufferSize = 1024 + +func (c *Container) detectProbe(ctx context.Context) activator.ConnHook { + if c.cfg.DisableProbeDetection { + return func(conn net.Conn) (net.Conn, bool, error) { + return conn, true, nil + } + } + return func(netConn net.Conn) (net.Conn, bool, error) { + conn := newBufferedConn(netConn, c.cfg.ProbeBufferSize) + if isTCPProbe(ctx, conn) { + log.G(ctx).Debug("detected TCP kube-probe, ignoring connection") + return conn, false, nil + } + if isHTTPProbe(ctx, conn) { + log.G(ctx).Debug("detected HTTP kube-probe request, responding") + if err := probeResponse(conn); err != nil { + log.G(ctx).Errorf("responding to kube-probe: %s", err) + } + return conn, false, nil + } + return conn, true, nil + } +} + +// isTCPProbe detects a TCP probe. It peeks 1 byte into the connection and if it +// receives an immediate [io.EOF] we know the conn has already been closed +// without receiving a single byte. Even if it wasn't a kube-probe, it's +// probably fine to not restore the application in case a connection is closed +// without receiving a single byte. If kubernetes ever starts to send something, +// this would need to be reworked but should be caught by e2e tests. +// https://github.com/kubernetes/kubernetes/blob/7cc3faf39d89d11c910db9ad19adfd931250e01c/pkg/probe/tcp/tcp.go#L50 +func isTCPProbe(ctx context.Context, conn bufConn) bool { + _, err := conn.Peek(1) + return err == io.EOF +} + +// isHTTPProbe detects an HTTP probe by constructing a [http.Request] from the +// conn buffer. If it's a valid HTTP request, we simply read the user agent +// string and assume it's a probe if it has the prefix "kube-probe/". +func isHTTPProbe(ctx context.Context, conn bufConn) bool { + if _, err := conn.Peek(1); err != nil { + return false + } + b, err := conn.Peek(min(conn.r.Buffered(), conn.r.Size())) + if err != nil && err != io.EOF { + log.G(ctx).WithError(err).Error("peek") + return false + } + req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(b))) + if err != nil { + log.G(ctx).WithError(err).Error("req") + return false + } + return strings.HasPrefix(req.Header.Get("User-Agent"), "kube-probe/") +} + +// probeResponse writes an HTTP response to conn that satisfies the kubelet. It +// sets the status code to [http.StatusNoContent] as the response does not +// contain a body and it makes it simpler to detect a probe response in testing. +func probeResponse(conn net.Conn) error { + resp := http.Response{ + ProtoMajor: 1, + ProtoMinor: 1, + StatusCode: http.StatusNoContent, + } + return resp.Write(conn) +} + +type bufConn struct { + net.Conn + r *bufio.Reader +} + +func newBufferedConn(c net.Conn, size int) bufConn { + return bufConn{c, bufio.NewReaderSize(c, size)} +} + +func (b bufConn) Peek(n int) ([]byte, error) { + return b.r.Peek(n) +} + +func (b bufConn) Read(p []byte) (int, error) { + return b.r.Read(p) +} diff --git a/shim/probe_test.go b/shim/probe_test.go new file mode 100644 index 0000000..4cc4808 --- /dev/null +++ b/shim/probe_test.go @@ -0,0 +1,148 @@ +package shim + +import ( + "context" + "crypto/rand" + "encoding/base64" + "net" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDetectProbe(t *testing.T) { + ctx := context.Background() + + for name, tc := range map[string]struct { + probeDetected bool + clientFunc func(t *testing.T, addr string) + }{ + "http kube-probe/1.32": { + probeDetected: true, + clientFunc: httpRequest("kube-probe/1.32", http.StatusNoContent, nil), + }, + "http kube-probe/any": { + clientFunc: httpRequest("kube-probe/any", http.StatusNoContent, nil), + probeDetected: true, + }, + "tcp probe": { + clientFunc: kubeTCPProbe, + probeDetected: true, + }, + "http but not a probe": { + clientFunc: httpRequest("kube-notprobe/1.32", http.StatusOK, nil), + probeDetected: false, + }, + "probe request header bigger than buffer": { + clientFunc: httpRequest("kube-probe/1.32", http.StatusOK, func(req *http.Request) { + rnd, err := randomData(defaultProbeBufferSize * 10) + assert.NoError(t, err) + req.Header.Set("random-stuff", base64.URLEncoding.EncodeToString(rnd)) + }), + probeDetected: false, + }, + "probe request path bigger than buffer": { + clientFunc: httpRequest("kube-probe/1.32", http.StatusOK, func(req *http.Request) { + rnd, err := randomData(defaultProbeBufferSize * 10) + assert.NoError(t, err) + req.URL.Path = "/" + base64.URLEncoding.EncodeToString(rnd) + }), + probeDetected: false, + }, + "random TCP data": { + clientFunc: writeRandomTCPData(10), + probeDetected: false, + }, + "random TCP data bigger than buffer": { + clientFunc: writeRandomTCPData(defaultProbeBufferSize * 1024), + probeDetected: false, + }, + } { + t.Run(name, func(t *testing.T) { + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + clientDone := make(chan bool) + go func() { + tc.clientFunc(t, l.Addr().String()) + clientDone <- true + }() + + conn, err := l.Accept() + require.NoError(t, err) + c := &Container{cfg: &Config{ProbeBufferSize: defaultProbeBufferSize}} + newConn, cont, err := c.detectProbe(ctx)(conn) + require.NoError(t, err) + if cont { + resp := http.Response{ + StatusCode: http.StatusOK, + } + resp.Write(newConn) + } + assert.Equal(t, !tc.probeDetected, cont) + + <-clientDone + newConn.Close() + }) + } +} + +func httpRequest(userAgent string, expectedStatus int, modifyReq func(*http.Request)) func(t *testing.T, addr string) { + return func(t *testing.T, addr string) { + // emulate kubernetes HTTP probe: + // https://github.com/kubernetes/kubernetes/blob/7cc3faf39d89d11c910db9ad19adfd931250e01c/pkg/probe/http/http.go#L48 + req, err := http.NewRequest(http.MethodGet, "http://"+addr, nil) + if !assert.NoError(t, err) { + return + } + req.Header.Set("User-Agent", userAgent) + req.Header.Set("Accept", "*/*") + if modifyReq != nil { + modifyReq(req) + } + + c := http.Client{ + Transport: &http.Transport{ + DisableCompression: true, + DisableKeepAlives: true, + }, + } + resp, err := c.Do(req) + assert.NoError(t, err) + assert.Equal(t, expectedStatus, resp.StatusCode) + } +} + +// kubeTCPProbe emulates a TCP probe: +// https://github.com/kubernetes/kubernetes/blob/7cc3faf39d89d11c910db9ad19adfd931250e01c/pkg/probe/tcp/tcp.go#L50 +func kubeTCPProbe(t *testing.T, addr string) { + conn, err := net.Dial("tcp", addr) + if !assert.NoError(t, err) { + return + } + assert.NoError(t, conn.Close()) +} + +func writeRandomTCPData(size int) func(t *testing.T, addr string) { + return func(t *testing.T, addr string) { + conn, err := net.Dial("tcp", addr) + if !assert.NoError(t, err) { + return + } + randomData, err := randomData(size) + if !assert.NoError(t, err) { + return + } + conn.Write(randomData) + } +} + +func randomData(size int) ([]byte, error) { + randomData := make([]byte, size) + if _, err := rand.Read(randomData); err != nil { + return nil, err + } + return randomData, nil +} diff --git a/socket/Dockerfile b/socket/Dockerfile index cf90e8e..4e49522 100644 --- a/socket/Dockerfile +++ b/socket/Dockerfile @@ -7,7 +7,7 @@ ADD go.* /app RUN go mod download # we use fedora since it has a recent version of bpftool -FROM fedora:41 +FROM fedora:42 RUN dnf install -y llvm clang bpftool libbpf-devel golang RUN mkdir /headers diff --git a/socket/bpf_bpfeb.go b/socket/bpf_arm64_bpfel.go similarity index 78% rename from socket/bpf_bpfeb.go rename to socket/bpf_arm64_bpfel.go index 65171b0..4e1a282 100644 --- a/socket/bpf_bpfeb.go +++ b/socket/bpf_arm64_bpfel.go @@ -1,5 +1,5 @@ // Code generated by bpf2go; DO NOT EDIT. -//go:build mips || mips64 || ppc64 || s390x +//go:build arm64 package socket @@ -8,10 +8,16 @@ import ( _ "embed" "fmt" "io" + "structs" "github.com/cilium/ebpf" ) +type bpfIpv6Addr struct { + _ structs.HostLayout + U6Addr8 [16]uint8 +} + // loadBpf returns the embedded CollectionSpec for bpf. func loadBpf() (*ebpf.CollectionSpec, error) { reader := bytes.NewReader(_BpfBytes) @@ -55,19 +61,23 @@ type bpfSpecs struct { // It can be passed ebpf.CollectionSpec.Assign. type bpfProgramSpecs struct { KretprobeInetCskAccept *ebpf.ProgramSpec `ebpf:"kretprobe__inet_csk_accept"` + TcpRcvStateProcess *ebpf.ProgramSpec `ebpf:"tcp_rcv_state_process"` } // bpfMapSpecs contains maps before they are loaded into the kernel. // // It can be passed ebpf.CollectionSpec.Assign. type bpfMapSpecs struct { - TcpEvents *ebpf.MapSpec `ebpf:"tcp_events"` + PodKubeletAddrsV4 *ebpf.MapSpec `ebpf:"pod_kubelet_addrs_v4"` + PodKubeletAddrsV6 *ebpf.MapSpec `ebpf:"pod_kubelet_addrs_v6"` + TcpEvents *ebpf.MapSpec `ebpf:"tcp_events"` } // bpfVariableSpecs contains global variables before they are loaded into the kernel. // // It can be passed ebpf.CollectionSpec.Assign. type bpfVariableSpecs struct { + ProbeBinaryName *ebpf.VariableSpec `ebpf:"probe_binary_name"` } // bpfObjects contains all objects after they have been loaded into the kernel. @@ -90,11 +100,15 @@ func (o *bpfObjects) Close() error { // // It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign. type bpfMaps struct { - TcpEvents *ebpf.Map `ebpf:"tcp_events"` + PodKubeletAddrsV4 *ebpf.Map `ebpf:"pod_kubelet_addrs_v4"` + PodKubeletAddrsV6 *ebpf.Map `ebpf:"pod_kubelet_addrs_v6"` + TcpEvents *ebpf.Map `ebpf:"tcp_events"` } func (m *bpfMaps) Close() error { return _BpfClose( + m.PodKubeletAddrsV4, + m.PodKubeletAddrsV6, m.TcpEvents, ) } @@ -103,6 +117,7 @@ func (m *bpfMaps) Close() error { // // It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign. type bpfVariables struct { + ProbeBinaryName *ebpf.Variable `ebpf:"probe_binary_name"` } // bpfPrograms contains all programs after they have been loaded into the kernel. @@ -110,11 +125,13 @@ type bpfVariables struct { // It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign. type bpfPrograms struct { KretprobeInetCskAccept *ebpf.Program `ebpf:"kretprobe__inet_csk_accept"` + TcpRcvStateProcess *ebpf.Program `ebpf:"tcp_rcv_state_process"` } func (p *bpfPrograms) Close() error { return _BpfClose( p.KretprobeInetCskAccept, + p.TcpRcvStateProcess, ) } @@ -129,5 +146,5 @@ func _BpfClose(closers ...io.Closer) error { // Do not access this directly. // -//go:embed bpf_bpfeb.o +//go:embed bpf_arm64_bpfel.o var _BpfBytes []byte diff --git a/socket/bpf_arm64_bpfel.o b/socket/bpf_arm64_bpfel.o new file mode 100644 index 0000000..139ebdb Binary files /dev/null and b/socket/bpf_arm64_bpfel.o differ diff --git a/socket/bpf_bpfeb.o b/socket/bpf_bpfeb.o deleted file mode 100644 index eaefa0f..0000000 Binary files a/socket/bpf_bpfeb.o and /dev/null differ diff --git a/socket/bpf_bpfel.o b/socket/bpf_bpfel.o deleted file mode 100644 index a8c8c0f..0000000 Binary files a/socket/bpf_bpfel.o and /dev/null differ diff --git a/socket/bpf_bpfel.go b/socket/bpf_x86_bpfel.go similarity index 78% rename from socket/bpf_bpfel.go rename to socket/bpf_x86_bpfel.go index 796f843..84171b0 100644 --- a/socket/bpf_bpfel.go +++ b/socket/bpf_x86_bpfel.go @@ -1,5 +1,5 @@ // Code generated by bpf2go; DO NOT EDIT. -//go:build 386 || amd64 || arm || arm64 || loong64 || mips64le || mipsle || ppc64le || riscv64 +//go:build 386 || amd64 package socket @@ -8,10 +8,16 @@ import ( _ "embed" "fmt" "io" + "structs" "github.com/cilium/ebpf" ) +type bpfIpv6Addr struct { + _ structs.HostLayout + U6Addr8 [16]uint8 +} + // loadBpf returns the embedded CollectionSpec for bpf. func loadBpf() (*ebpf.CollectionSpec, error) { reader := bytes.NewReader(_BpfBytes) @@ -55,19 +61,23 @@ type bpfSpecs struct { // It can be passed ebpf.CollectionSpec.Assign. type bpfProgramSpecs struct { KretprobeInetCskAccept *ebpf.ProgramSpec `ebpf:"kretprobe__inet_csk_accept"` + TcpRcvStateProcess *ebpf.ProgramSpec `ebpf:"tcp_rcv_state_process"` } // bpfMapSpecs contains maps before they are loaded into the kernel. // // It can be passed ebpf.CollectionSpec.Assign. type bpfMapSpecs struct { - TcpEvents *ebpf.MapSpec `ebpf:"tcp_events"` + PodKubeletAddrsV4 *ebpf.MapSpec `ebpf:"pod_kubelet_addrs_v4"` + PodKubeletAddrsV6 *ebpf.MapSpec `ebpf:"pod_kubelet_addrs_v6"` + TcpEvents *ebpf.MapSpec `ebpf:"tcp_events"` } // bpfVariableSpecs contains global variables before they are loaded into the kernel. // // It can be passed ebpf.CollectionSpec.Assign. type bpfVariableSpecs struct { + ProbeBinaryName *ebpf.VariableSpec `ebpf:"probe_binary_name"` } // bpfObjects contains all objects after they have been loaded into the kernel. @@ -90,11 +100,15 @@ func (o *bpfObjects) Close() error { // // It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign. type bpfMaps struct { - TcpEvents *ebpf.Map `ebpf:"tcp_events"` + PodKubeletAddrsV4 *ebpf.Map `ebpf:"pod_kubelet_addrs_v4"` + PodKubeletAddrsV6 *ebpf.Map `ebpf:"pod_kubelet_addrs_v6"` + TcpEvents *ebpf.Map `ebpf:"tcp_events"` } func (m *bpfMaps) Close() error { return _BpfClose( + m.PodKubeletAddrsV4, + m.PodKubeletAddrsV6, m.TcpEvents, ) } @@ -103,6 +117,7 @@ func (m *bpfMaps) Close() error { // // It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign. type bpfVariables struct { + ProbeBinaryName *ebpf.Variable `ebpf:"probe_binary_name"` } // bpfPrograms contains all programs after they have been loaded into the kernel. @@ -110,11 +125,13 @@ type bpfVariables struct { // It can be passed to loadBpfObjects or ebpf.CollectionSpec.LoadAndAssign. type bpfPrograms struct { KretprobeInetCskAccept *ebpf.Program `ebpf:"kretprobe__inet_csk_accept"` + TcpRcvStateProcess *ebpf.Program `ebpf:"tcp_rcv_state_process"` } func (p *bpfPrograms) Close() error { return _BpfClose( p.KretprobeInetCskAccept, + p.TcpRcvStateProcess, ) } @@ -129,5 +146,5 @@ func _BpfClose(closers ...io.Closer) error { // Do not access this directly. // -//go:embed bpf_bpfel.o +//go:embed bpf_x86_bpfel.o var _BpfBytes []byte diff --git a/socket/bpf_x86_bpfel.o b/socket/bpf_x86_bpfel.o new file mode 100644 index 0000000..38e924e Binary files /dev/null and b/socket/bpf_x86_bpfel.o differ diff --git a/socket/ebpf.go b/socket/ebpf.go index b5f6c0a..4ed3747 100644 --- a/socket/ebpf.go +++ b/socket/ebpf.go @@ -1,7 +1,10 @@ package socket import ( + "encoding/binary" + "errors" "fmt" + "net/netip" "os" "path/filepath" "time" @@ -14,26 +17,41 @@ import ( ) // $BPF_CLANG and $BPF_CFLAGS are set by the Makefile. -//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc $BPF_CLANG -cflags $BPF_CFLAGS bpf kprobe.c -- -I/headers +//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc $BPF_CLANG -target amd64,arm64 -cflags $BPF_CFLAGS bpf kprobe.c -- -I/headers -const TCPEventsMap = "tcp_events" +const ( + TCPEventsMap = "tcp_events" + PodKubeletAddrsMapv4 = "pod_kubelet_addrs_v4" + PodKubeletAddrsMapv6 = "pod_kubelet_addrs_v6" +) + +var mapNames = []string{ + TCPEventsMap, + PodKubeletAddrsMapv4, + PodKubeletAddrsMapv6, +} + +const ( + probeBinaryNameVariable = "probe_binary_name" + probeBinaryNameMaxLength = 16 +) // LoadEBPFTracker loads the eBPF program and attaches the kretprobe to track // connections system-wide. -func LoadEBPFTracker() (func() error, error) { +func LoadEBPFTracker(probeBinaryName string) (Tracker, func() error, error) { // Allow the current process to lock memory for eBPF resources. if err := rlimit.RemoveMemlock(); err != nil { - return nil, err + return nil, nil, err } pinPath := activator.MapsPath() if err := os.MkdirAll(pinPath, os.ModePerm); err != nil { - return nil, fmt.Errorf("failed to create bpf fs subpath: %w", err) + return nil, nil, fmt.Errorf("failed to create bpf fs subpath: %w", err) } // Load pre-compiled programs and maps into the kernel. objs := bpfObjects{} - if err := loadBpfObjects(&objs, &ebpf.CollectionOptions{ + collectionOpts := &ebpf.CollectionOptions{ Maps: ebpf.MapOptions{ // Pin the map to the BPF filesystem and configure the // library to automatically re-write it in the BPF @@ -41,24 +59,63 @@ func LoadEBPFTracker() (func() error, error) { // create it if not. PinPath: pinPath, }, - }); err != nil { - return nil, fmt.Errorf("loading objects: %w", err) + } + + spec, err := loadBpf() + if err != nil { + return nil, nil, fmt.Errorf("loading bpf objects: %w", err) + } + + if len([]byte(probeBinaryName)) > probeBinaryNameMaxLength { + return nil, nil, fmt.Errorf( + "probe binary name %s is too long (%d bytes), max is %d bytes", + probeBinaryName, len([]byte(probeBinaryName)), probeBinaryNameMaxLength, + ) + } + binName := [probeBinaryNameMaxLength]byte{} + copy(binName[:], probeBinaryName[:]) + if err := spec.Variables[probeBinaryNameVariable].Set(binName); err != nil { + return nil, nil, fmt.Errorf("setting probe binary variable: %w", err) + } + + if err := spec.LoadAndAssign(&objs, collectionOpts); err != nil { + if !errors.Is(err, ebpf.ErrMapIncompatible) { + return nil, nil, fmt.Errorf("loading objects: %w", err) + } + // try to unpin the maps and load again + for _, mapName := range mapNames { + if err := os.Remove(filepath.Join(pinPath, mapName)); err != nil && !os.IsNotExist(err) { + return nil, nil, fmt.Errorf("removing map after incompatibility: %w", err) + } + } + if err := loadBpfObjects(&objs, collectionOpts); err != nil { + return nil, nil, fmt.Errorf("loading objects: %w", err) + } } // in the past we used inet_sock_set_state here but we now use a // kretprobe with inet_csk_accept as inet_sock_set_state is not giving us // reliable PIDs. https://github.com/iovisor/bcc/issues/2304 - kp, err := link.Kretprobe("inet_csk_accept", objs.KretprobeInetCskAccept, &link.KprobeOptions{}) + tracker, err := link.Kretprobe("inet_csk_accept", objs.KretprobeInetCskAccept, &link.KprobeOptions{}) if err != nil { - return nil, fmt.Errorf("linking kprobe: %w", err) + return nil, nil, fmt.Errorf("linking kprobe: %w", err) + } + kubeletDetector, err := link.Kprobe("tcp_rcv_state_process", objs.TcpRcvStateProcess, &link.KprobeOptions{}) + if err != nil { + return nil, nil, fmt.Errorf("linking tcp rcv kprobe: %w", err) } - return func() error { + t, err := NewEBPFTracker() + return t, func() error { + errs := []error{} if err := objs.Close(); err != nil { - return err + errs = append(errs, err) + } + if err := kubeletDetector.Close(); err != nil { + errs = append(errs, err) } - return kp.Close() - }, nil + return errors.Join(append(errs, tracker.Close())...) + }, err } // NewEBPFTracker returns a TCP connection tracker that will keep track of the @@ -74,11 +131,20 @@ func NewEBPFTracker() (Tracker, error) { resolver = hostResolver{} } + podKubeletAddrsv4, err := ebpf.LoadPinnedMap(filepath.Join(activator.MapsPath(), PodKubeletAddrsMapv4), &ebpf.LoadPinOptions{}) + if err != nil { + return nil, err + } + podKubeletAddrsv6, err := ebpf.LoadPinnedMap(filepath.Join(activator.MapsPath(), PodKubeletAddrsMapv6), &ebpf.LoadPinOptions{}) + if err != nil { + return nil, err + } tcpEvents, err := ebpf.LoadPinnedMap(filepath.Join(activator.MapsPath(), TCPEventsMap), &ebpf.LoadPinOptions{}) - return &EBPFTracker{ - PIDResolver: resolver, - tcpEvents: tcpEvents, + PIDResolver: resolver, + tcpEvents: tcpEvents, + podKubeletAddrsv4: podKubeletAddrsv4, + podKubeletAddrsv6: podKubeletAddrsv6, }, err } @@ -105,7 +171,9 @@ func (err NoActivityRecordedErr) Error() string { type EBPFTracker struct { PIDResolver - tcpEvents *ebpf.Map + tcpEvents *ebpf.Map + podKubeletAddrsv4 *ebpf.Map + podKubeletAddrsv6 *ebpf.Map } // TrackPid puts the pid into the TcpEvents map meaning tcp events of the @@ -147,6 +215,47 @@ func (c *EBPFTracker) Close() error { return c.tcpEvents.Close() } +// RemovePodIPv4 adds the pod IP to the tracker unless it already exists. +func (c *EBPFTracker) PutPodIP(ip netip.Addr) error { + if ip.Is4() { + val := uint32(0) + ipv4 := ip.As4() + uIP := binary.NativeEndian.Uint32(ipv4[:]) + if err := c.podKubeletAddrsv4.Update(&uIP, &val, ebpf.UpdateNoExist); err != nil && + !errors.Is(err, ebpf.ErrKeyExist) { + return fmt.Errorf("unable to put ipv4 %s into bpf map: %w", ip, err) + } + return nil + } + + val := bpfIpv6Addr{} + bpfIP := bpfIpv6Addr{U6Addr8: ip.As16()} + if err := c.podKubeletAddrsv6.Update(&bpfIP, &val, ebpf.UpdateNoExist); err != nil && + !errors.Is(err, ebpf.ErrKeyExist) { + return fmt.Errorf("unable to put ipv6 %s into bpf map: %w", ip, err) + } + + return nil +} + +// RemovePodIPv4 removes the pod IP from the tracker. +func (c *EBPFTracker) RemovePodIP(ip netip.Addr) error { + if ip.Is4() { + ipv4 := ip.As4() + uIP := binary.NativeEndian.Uint32(ipv4[:]) + if err := c.podKubeletAddrsv4.Delete(&uIP); err != nil && !errors.Is(err, ebpf.ErrKeyNotExist) { + return fmt.Errorf("unable to delete ipv4 %s from bpf map: %w", ip, err) + } + return nil + } + + bpfIP := bpfIpv6Addr{U6Addr8: ip.As16()} + if err := c.podKubeletAddrsv6.Delete(&bpfIP); err != nil && !errors.Is(err, ebpf.ErrKeyNotExist) { + return fmt.Errorf("unable to delete ipv6 %s from bpf map: %w", ip, err) + } + return nil +} + // convertBPFTime takes the value of bpf_ktime_get_ns and converts it to a // time.Time. func convertBPFTime(t uint64) (time.Time, error) { diff --git a/socket/kprobe.c b/socket/kprobe.c index 5c7bf94..4a5068f 100644 --- a/socket/kprobe.c +++ b/socket/kprobe.c @@ -1,46 +1,140 @@ //go:build ignore #include "vmlinux.h" +#include "ptregs.h" #include "bpf_helpers.h" +#include "bpf_tracing.h" +#include "bpf_core_read.h" +#include "bpf_endian.h" char __license[] SEC("license") = "Dual MIT/GPL"; +#define AF_INET 2 + struct { - __uint(type, BPF_MAP_TYPE_HASH); - __uint(max_entries, 1024); // should be enough pids? - __type(key, __u32); // pid - __type(value, __u64); // ktime ns of the last tracked event - __uint(pinning, LIBBPF_PIN_BY_NAME); + __uint(type, BPF_MAP_TYPE_LRU_HASH); + __uint(max_entries, 1024); // should be enough pids? + __type(key, __u32); // pid + __type(value, __u64); // ktime ns of the last tracked event + __uint(pinning, LIBBPF_PIN_BY_NAME); } tcp_events SEC(".maps"); -SEC("kretprobe/inet_csk_accept") -int kretprobe__inet_csk_accept(struct pt_regs *ctx) -{ - // TODO: we don't check if the protocol is actually TCP here as this seems quite messy: - // https://github.com/iovisor/bcc/blob/71b5141659aaaf4a7c2172c73a802bd86a256ecd/tools/tcpaccept.py#L118 - // does this matter? Which other protocols make use of inet_csk_accept? +struct { + __uint(type, BPF_MAP_TYPE_LRU_HASH); + __uint(max_entries, 1024); + __type(key, __be32); // pod addr + __type(value, __be32); // kubelet addr + __uint(pinning, LIBBPF_PIN_BY_NAME); +} pod_kubelet_addrs_v4 SEC(".maps"); - struct task_struct* task = (struct task_struct*)bpf_get_current_task_btf(); - // we use the tgid as our pid as it represents the pid from userspace - __u32 pid = task->tgid; +struct ipv6_addr { + __u8 u6_addr8[16]; +}; - void *tcp_event = &tcp_events; - void* found_pid = bpf_map_lookup_elem(tcp_event, &pid); +struct { + __uint(type, BPF_MAP_TYPE_LRU_HASH); + __uint(max_entries, 1024); + __type(key, struct ipv6_addr); // pod addr + __type(value, struct ipv6_addr); // kubelet addr + __uint(pinning, LIBBPF_PIN_BY_NAME); +} pod_kubelet_addrs_v6 SEC(".maps"); + +const volatile char probe_binary_name[TASK_COMM_LEN] = ""; - if (!found_pid) { - // try ppid, our process might have forks - pid = task->real_parent->tgid; +static __always_inline bool ipv6_addr_equal(const struct ipv6_addr *a1, const struct ipv6_addr *a2) { + #pragma unroll + for (int i = 0; i < 15; i++) { + if (a1->u6_addr8[i] != a2->u6_addr8[i]) { + return false; + } + } + return true; +} - void* found_ppid = bpf_map_lookup_elem(tcp_event, &pid); - if (!found_ppid) { - return 0; - } - } +SEC("kretprobe/inet_csk_accept") +int kretprobe__inet_csk_accept(struct pt_regs *ctx) { + struct task_struct* task = (struct task_struct*)bpf_get_current_task_btf(); + // we use the tgid as our pid as it represents the pid from userspace + __u32 pid = task->tgid; - __u64 time = bpf_ktime_get_ns(); + void *tcp_event = &tcp_events; + void *found_pid = bpf_map_lookup_elem(tcp_event, &pid); - // const char fmt_str[] = "%d: accept found on pid %d\n"; - // bpf_trace_printk(fmt_str, sizeof(fmt_str), time, pid); + if (!found_pid) { + // try ppid, our process might have forks + pid = task->real_parent->tgid; - return bpf_map_update_elem(tcp_event, &pid, &time, BPF_ANY); + void *found_ppid = bpf_map_lookup_elem(tcp_event, &pid); + if (!found_ppid) { + return 0; + } + } + + struct sock *sk = (struct sock *)PT_REGS_RC(ctx); + if (sk == NULL) { + return 0; + } + + if (BPF_CORE_READ(sk, __sk_common.skc_family) == AF_INET) { + __be32 pod_addr = 0; + __be32 daddr = 0; + BPF_CORE_READ_INTO(&pod_addr, sk, __sk_common.skc_rcv_saddr); + BPF_CORE_READ_INTO(&daddr, sk, __sk_common.skc_daddr); + void *addrs = &pod_kubelet_addrs_v4; + __be32 *kubelet_addr = bpf_map_lookup_elem(addrs, &pod_addr); + if (kubelet_addr && *kubelet_addr == daddr) { + return 0; + } + } else { + struct ipv6_addr pod_addr; + struct ipv6_addr daddr; + BPF_CORE_READ_INTO(&pod_addr, sk, __sk_common.skc_v6_rcv_saddr.in6_u); + BPF_CORE_READ_INTO(&daddr, sk, __sk_common.skc_v6_daddr.in6_u); + void *addrs = &pod_kubelet_addrs_v6; + struct ipv6_addr *kubelet_addr = bpf_map_lookup_elem(addrs, &pod_addr); + if (kubelet_addr && ipv6_addr_equal(kubelet_addr, &daddr)) { + return 0; + } + } + + __u64 time = bpf_ktime_get_ns(); + return bpf_map_update_elem(tcp_event, &pid, &time, BPF_ANY); }; + +static int find_potential_kubelet_ip(struct sock *sk) { + char comm[TASK_COMM_LEN]; + bpf_get_current_comm(&comm, sizeof(comm)); + if (bpf_strncmp(comm, TASK_COMM_LEN, (char *)probe_binary_name) == 0) { + if (BPF_CORE_READ(sk, __sk_common.skc_family) == AF_INET) { + __be32 saddr = 0; + __be32 daddr = 0; + BPF_CORE_READ_INTO(&saddr, sk, __sk_common.skc_rcv_saddr); + BPF_CORE_READ_INTO(&daddr, sk, __sk_common.skc_daddr); + void *addrs = &pod_kubelet_addrs_v4; + __be32 *kubelet_addr = bpf_map_lookup_elem(addrs, &daddr); + if (kubelet_addr) { + bpf_map_update_elem(addrs, &daddr, &saddr, 0); + } + } else { + struct ipv6_addr saddr; + struct ipv6_addr daddr; + BPF_CORE_READ_INTO(&saddr, sk, __sk_common.skc_v6_rcv_saddr.in6_u); + BPF_CORE_READ_INTO(&daddr, sk, __sk_common.skc_v6_daddr.in6_u); + void *addrs = &pod_kubelet_addrs_v6; + struct ipv6_addr *kubelet_addr = bpf_map_lookup_elem(addrs, &daddr); + if (kubelet_addr) { + bpf_map_update_elem(addrs, &daddr, &saddr, 0); + } + } + } + return 0; +} + +SEC("kprobe/tcp_rcv_state_process") +int BPF_KPROBE(tcp_rcv_state_process, struct sock *sk) { + if (BPF_CORE_READ(sk, __sk_common.skc_state) != TCP_SYN_SENT) { + return 0; + } + + return find_potential_kubelet_ip(sk); +} diff --git a/socket/noop.go b/socket/noop.go index f882791..de62a4c 100644 --- a/socket/noop.go +++ b/socket/noop.go @@ -1,6 +1,9 @@ package socket -import "time" +import ( + "net/netip" + "time" +) func NewNoopTracker(scaleDownDuration time.Duration) NoopTracker { return NoopTracker{ @@ -29,3 +32,11 @@ func (n NoopTracker) LastActivity(pid uint32) (time.Time, error) { func (n NoopTracker) Close() error { return nil } + +func (n NoopTracker) PutPodIP(ip netip.Addr) error { + return nil +} + +func (n NoopTracker) RemovePodIP(ip netip.Addr) error { + return nil +} diff --git a/socket/ptregs.h b/socket/ptregs.h new file mode 100644 index 0000000..69ac655 --- /dev/null +++ b/socket/ptregs.h @@ -0,0 +1,8 @@ +#if defined(__TARGET_ARCH_arm64) +struct user_pt_regs { + __u64 regs[31]; + __u64 sp; + __u64 pc; + __u64 pstate; +}; +#endif diff --git a/socket/tracker.go b/socket/tracker.go index a2a23fb..a4b2351 100644 --- a/socket/tracker.go +++ b/socket/tracker.go @@ -1,6 +1,9 @@ package socket -import "time" +import ( + "net/netip" + "time" +) type Tracker interface { PIDResolver @@ -13,4 +16,9 @@ type Tracker interface { LastActivity(pid uint32) (time.Time, error) // Close the activity tracker. Close() error + // PutPodIP inserts a pod IP into the pod-to-kubelet map, helping with + // ignoring probes coming from kubelet within the tracker. + PutPodIP(ip netip.Addr) error + // RemovePodIP removes a pod IP from the pod-to-kubelet map. + RemovePodIP(ip netip.Addr) error } diff --git a/socket/tracker_test.go b/socket/tracker_test.go index 504cedc..181abec 100644 --- a/socket/tracker_test.go +++ b/socket/tracker_test.go @@ -4,11 +4,14 @@ import ( "fmt" "net/http" "net/http/httptest" + "net/netip" "os" + "path/filepath" "testing" "time" "github.com/ctrox/zeropod/activator" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -18,12 +21,11 @@ import ( func TestEBPFTracker(t *testing.T) { require.NoError(t, activator.MountBPFFS(activator.BPFFSPath)) - clean, err := LoadEBPFTracker() + name, err := os.Executable() require.NoError(t, err) - defer func() { require.NoError(t, clean()) }() - - tracker, err := NewEBPFTracker() + tracker, clean, err := LoadEBPFTracker(filepath.Base(name)) require.NoError(t, err) + defer func() { require.NoError(t, clean()) }() pid := uint32(os.Getpid()) require.NoError(t, tracker.TrackPid(pid)) @@ -32,21 +34,44 @@ func TestEBPFTracker(t *testing.T) { fmt.Fprintln(w, "ok") })) - require.Eventually(t, func() bool { - _, err = http.Get(ts.URL) - return err == nil - }, time.Millisecond*100, time.Millisecond, "waiting for http server to reply") + for name, tc := range map[string]struct { + ip netip.Addr + expectLastActivity bool + }{ + "activity tracked": { + ip: netip.MustParseAddr("10.0.0.1"), + expectLastActivity: true, + }, + "activity ignored": { + // use 127.0.0.1 as that's where our test program connects from + ip: netip.MustParseAddr("127.0.0.1"), + expectLastActivity: false, + }, + } { + t.Run(name, func(t *testing.T) { + require.NoError(t, tracker.PutPodIP(tc.ip)) + defer func() { assert.NoError(t, tracker.RemovePodIP(tc.ip)) }() + + require.Eventually(t, func() bool { + _, err = http.Get(ts.URL) + return err == nil + }, time.Millisecond*100, time.Millisecond, "waiting for http server to reply") - require.Eventually(t, func() bool { - activity, err := tracker.LastActivity(pid) - if err != nil { - return false - } + require.Eventually(t, func() bool { + activity, err := tracker.LastActivity(pid) + if err != nil { + return !tc.expectLastActivity + } - if time.Since(activity) > time.Millisecond*100 { - t.Fatalf("last activity was %s ago, expected it to be within the last 100ms", time.Since(activity)) - } + if time.Since(activity) > time.Millisecond*100 { + if tc.expectLastActivity { + t.Fatalf("last activity was %s ago, expected it to be within the last 100ms", time.Since(activity)) + } + } - return true - }, time.Millisecond*100, time.Millisecond, "waiting for last tcp activity") + return true + }, time.Millisecond*100, time.Millisecond, "waiting for last tcp activity") + time.Sleep(time.Millisecond * 200) + }) + } } diff --git a/socket/vmlinux.h.gz b/socket/vmlinux.h.gz index 2ecf2fa..de90ebd 100644 Binary files a/socket/vmlinux.h.gz and b/socket/vmlinux.h.gz differ