diff --git a/api/types/maintenance.go b/api/types/maintenance.go index 31a48472e6aa8..5d98a3ec311cb 100644 --- a/api/types/maintenance.go +++ b/api/types/maintenance.go @@ -26,13 +26,17 @@ import ( ) const ( - // UpgraderKindKuberController is a short name used to identify the kube-controller-based + // UpgraderKindKubeController is a short name used to identify the kube-controller-based // external upgrader variant. UpgraderKindKubeController = "kube" // UpgraderKindSystemdUnit is a short name used to identify the systemd-unit-based // external upgrader variant. UpgraderKindSystemdUnit = "unit" + + // UpgraderKindTeleportUpdate is a short name used to identify the teleport-update + // external upgrader variant. + UpgraderKindTeleportUpdate = "binary" ) var validWeekdays = [7]time.Weekday{ diff --git a/constants.go b/constants.go index 30adf42be55ca..42955b6f5f138 100644 --- a/constants.go +++ b/constants.go @@ -286,7 +286,7 @@ const ( // ComponentProxySecureGRPC represents a secure gRPC server running on Proxy (used for Kube). ComponentProxySecureGRPC = "proxy:secure-grpc" - // ComponentUpdater represents the agent updater. + // ComponentUpdater represents the teleport-update binary. ComponentUpdater = "updater" // ComponentGit represents git proxy related services. diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 0f6c85be8a5e7..96552aeae7778 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -6858,7 +6858,7 @@ func (a *Server) ExportUpgradeWindows(ctx context.Context, req proto.ExportUpgra } switch req.UpgraderKind { - case "": + case "", types.UpgraderKindTeleportUpdate: rsp.CanonicalSchedule = cached.CanonicalSchedule.Clone() case types.UpgraderKindKubeController: rsp.KubeControllerSchedule = cached.KubeControllerSchedule diff --git a/lib/autoupdate/agent/config.go b/lib/autoupdate/agent/config.go new file mode 100644 index 0000000000000..fe10d84808d11 --- /dev/null +++ b/lib/autoupdate/agent/config.go @@ -0,0 +1,251 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package agent + +import ( + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" + "time" + + "github.com/google/renameio/v2" + "github.com/gravitational/trace" + "gopkg.in/yaml.v3" + + "github.com/gravitational/teleport/lib/autoupdate" +) + +const ( + // updateConfigName specifies the name of the file inside versionsDirName containing configuration for the teleport update. + updateConfigName = "update.yaml" + + // UpdateConfig metadata + updateConfigVersion = "v1" + updateConfigKind = "update_config" +) + +// UpdateConfig describes the update.yaml file schema. +type UpdateConfig struct { + // Version of the configuration file + Version string `yaml:"version"` + // Kind of configuration file (always "update_config") + Kind string `yaml:"kind"` + // Spec contains user-specified configuration. + Spec UpdateSpec `yaml:"spec"` + // Status contains state configuration. + Status UpdateStatus `yaml:"status"` +} + +// UpdateSpec describes the spec field in update.yaml. +type UpdateSpec struct { + // Proxy address + Proxy string `yaml:"proxy"` + // Path is the location the Teleport binaries are linked into. + Path string `yaml:"path"` + // Group specifies the update group identifier for the agent. + Group string `yaml:"group,omitempty"` + // BaseURL is CDN base URL used for the Teleport tgz download URL. + BaseURL string `yaml:"base_url,omitempty"` + // Enabled controls whether auto-updates are enabled. + Enabled bool `yaml:"enabled"` + // Pinned controls whether the active_version is pinned. + Pinned bool `yaml:"pinned"` +} + +// UpdateStatus describes the status field in update.yaml. +type UpdateStatus struct { + // Active is the currently active revision of Teleport. + Active Revision `yaml:"active"` + // Backup is the last working revision of Teleport. + Backup *Revision `yaml:"backup,omitempty"` + // Skip is the skipped revision of Teleport. + // Skipped revisions are not applied because they + // are known to crash. + Skip *Revision `yaml:"skip,omitempty"` +} + +// Revision is a version and edition of Teleport. +type Revision struct { + // Version is the version of Teleport. + Version string `yaml:"version" json:"version"` + // Flags describe the edition of Teleport. + Flags autoupdate.InstallFlags `yaml:"flags,flow,omitempty" json:"flags,omitempty"` +} + +// NewRevision create a Revision. +// If version is not set, no flags are returned. +// This ensures that all Revisions without versions are zero-valued. +func NewRevision(version string, flags autoupdate.InstallFlags) Revision { + if version != "" { + return Revision{ + Version: version, + Flags: flags, + } + } + return Revision{} +} + +// NewRevisionFromDir translates a directory path containing Teleport into a Revision. +func NewRevisionFromDir(dir string) (Revision, error) { + parts := strings.Split(dir, "_") + var out Revision + if len(parts) == 0 { + return out, trace.Errorf("dir name empty") + } + out.Version = parts[0] + if out.Version == "" { + return out, trace.Errorf("version missing in dir %s", dir) + } + switch flags := parts[1:]; len(flags) { + case 2: + if flags[1] != autoupdate.FlagFIPS.DirFlag() { + break + } + out.Flags |= autoupdate.FlagFIPS + fallthrough + case 1: + if flags[0] != autoupdate.FlagEnterprise.DirFlag() { + break + } + out.Flags |= autoupdate.FlagEnterprise + fallthrough + case 0: + return out, nil + } + return out, trace.Errorf("invalid flag in %s", dir) +} + +// Dir returns the directory path name of a Revision. +func (r Revision) Dir() string { + // Do not change the order of these statements. + // Otherwise, installed versions will no longer match update.yaml. + var suffix string + if r.Flags&(autoupdate.FlagEnterprise|autoupdate.FlagFIPS) != 0 { + suffix += "_" + autoupdate.FlagEnterprise.DirFlag() + } + if r.Flags&autoupdate.FlagFIPS != 0 { + suffix += "_" + autoupdate.FlagFIPS.DirFlag() + } + return r.Version + suffix +} + +// String returns a human-readable description of a Teleport revision. +func (r Revision) String() string { + if flags := r.Flags.Strings(); len(flags) > 0 { + return fmt.Sprintf("%s+%s", r.Version, strings.Join(flags, "+")) + } + return r.Version +} + +// readConfig reads UpdateConfig from a file. +func readConfig(path string) (*UpdateConfig, error) { + f, err := os.Open(path) + if errors.Is(err, fs.ErrNotExist) { + return &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + }, nil + } + if err != nil { + return nil, trace.Wrap(err, "failed to open") + } + defer f.Close() + var cfg UpdateConfig + if err := yaml.NewDecoder(f).Decode(&cfg); err != nil { + return nil, trace.Wrap(err, "failed to parse") + } + if k := cfg.Kind; k != updateConfigKind { + return nil, trace.Errorf("invalid kind %s", k) + } + if v := cfg.Version; v != updateConfigVersion { + return nil, trace.Errorf("invalid version %s", v) + } + return &cfg, nil +} + +// writeConfig writes UpdateConfig to a file atomically, ensuring the file cannot be corrupted. +func writeConfig(filename string, cfg *UpdateConfig) error { + opts := []renameio.Option{ + renameio.WithPermissions(configFileMode), + renameio.WithExistingPermissions(), + renameio.WithTempDir(filepath.Dir(filename)), + } + t, err := renameio.NewPendingFile(filename, opts...) + if err != nil { + return trace.Wrap(err) + } + defer t.Cleanup() + err = yaml.NewEncoder(t).Encode(cfg) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(t.CloseAtomicallyReplace()) +} + +func validateConfigSpec(spec *UpdateSpec, override OverrideConfig) error { + if override.Proxy != "" { + spec.Proxy = override.Proxy + } + if override.Path != "" { + spec.Path = override.Path + } + if override.Group != "" { + spec.Group = override.Group + } + switch override.BaseURL { + case "": + case "default": + spec.BaseURL = "" + default: + spec.BaseURL = override.BaseURL + } + if spec.BaseURL != "" && + !strings.HasPrefix(strings.ToLower(spec.BaseURL), "https://") { + return trace.Errorf("Teleport download base URL %s must use TLS (https://)", spec.BaseURL) + } + if override.Enabled { + spec.Enabled = true + } + if override.Pinned { + spec.Pinned = true + } + return nil +} + +// Status of the agent auto-updates system. +type Status struct { + UpdateSpec `yaml:",inline"` + UpdateStatus `yaml:",inline"` + FindResp `yaml:",inline"` +} + +// FindResp summarizes the auto-update status response from cluster. +type FindResp struct { + // Target revision of Teleport to install + Target Revision `yaml:"target"` + // InWindow is true when the install should happen now. + InWindow bool `yaml:"in_window"` + // Jitter duration before an automated install + Jitter time.Duration `yaml:"jitter"` + // AGPL installations cannot use the official CDN. + AGPL bool `yaml:"agpl,omitempty"` +} diff --git a/lib/autoupdate/agent/config_test.go b/lib/autoupdate/agent/config_test.go new file mode 100644 index 0000000000000..39d318cd6ee4c --- /dev/null +++ b/lib/autoupdate/agent/config_test.go @@ -0,0 +1,127 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package agent + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/lib/autoupdate" +) + +func TestNewRevisionFromDir(t *testing.T) { + t.Parallel() + + for _, tt := range []struct { + name string + dir string + rev Revision + errMatch string + }{ + { + name: "version", + dir: "1.2.3", + rev: Revision{ + Version: "1.2.3", + }, + }, + { + name: "full", + dir: "1.2.3_ent_fips", + rev: Revision{ + Version: "1.2.3", + Flags: autoupdate.FlagEnterprise | autoupdate.FlagFIPS, + }, + }, + { + name: "ent", + dir: "1.2.3_ent", + rev: Revision{ + Version: "1.2.3", + Flags: autoupdate.FlagEnterprise, + }, + }, + { + name: "empty", + errMatch: "missing", + }, + { + name: "trailing", + dir: "1.2.3_", + errMatch: "invalid", + }, + { + name: "more trailing", + dir: "1.2.3___", + errMatch: "invalid", + }, + { + name: "no version", + dir: "_fips", + errMatch: "missing", + }, + { + name: "fips no ent", + dir: "1.2.3_fips", + errMatch: "invalid", + }, + { + name: "unknown start fips", + dir: "1.2.3_test_fips", + errMatch: "invalid", + }, + { + name: "unknown start ent", + dir: "1.2.3_test_ent", + errMatch: "invalid", + }, + { + name: "unknown end fips", + dir: "1.2.3_fips_test", + errMatch: "invalid", + }, + { + name: "unknown end ent", + dir: "1.2.3_ent_test", + errMatch: "invalid", + }, + { + name: "bad order", + dir: "1.2.3_fips_ent", + errMatch: "invalid", + }, + { + name: "underscore", + dir: "_", + errMatch: "missing", + }, + } { + t.Run(tt.name, func(t *testing.T) { + rev, err := NewRevisionFromDir(tt.dir) + if tt.errMatch != "" { + require.ErrorContains(t, err, tt.errMatch) + return + } + require.NoError(t, err) + require.Equal(t, tt.rev, rev) + require.Equal(t, tt.dir, rev.Dir()) + }) + } +} diff --git a/lib/autoupdate/agent/installer.go b/lib/autoupdate/agent/installer.go index e31813866eacf..29afb6d9da7aa 100644 --- a/lib/autoupdate/agent/installer.go +++ b/lib/autoupdate/agent/installer.go @@ -29,19 +29,37 @@ import ( "log/slog" "net/http" "os" + "path" "path/filepath" - "runtime" - "text/template" + "syscall" "time" + "github.com/google/renameio/v2" "github.com/gravitational/trace" + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/lib/autoupdate" "github.com/gravitational/teleport/lib/utils" ) const ( - checksumType = "sha256" + // checksumType for Teleport tgzs + checksumType = "sha256" + // checksumHexLen is the length of the Teleport checksum. checksumHexLen = sha256.Size * 2 // bytes to hex + // maxServiceFileSize is the maximum size allowed for a systemd service file. + maxServiceFileSize = 1_000_000 // 1 MB + // configFileMode is the mode used for new configuration files. + configFileMode = 0644 + // systemDirMode is the mode used for new directories. + systemDirMode = 0755 +) + +const ( + // serviceDir contains the relative path to the Teleport SystemD service dir. + serviceDir = "lib/systemd/system" + // serviceName contains the upstream name of the Teleport SystemD service file. + serviceName = "teleport.service" ) // LocalInstaller manages the creation and removal of installations @@ -49,6 +67,12 @@ const ( type LocalInstaller struct { // InstallDir contains each installation, named by version. InstallDir string + // TargetServiceFile contains a copy of the linked installation's systemd service. + TargetServiceFile string + // SystemBinDir contains binaries for the system (packaged) install of Teleport. + SystemBinDir string + // SystemServiceFile contains the systemd service file for the system (packaged) install of Teleport. + SystemServiceFile string // HTTP is an HTTP client for downloading Teleport. HTTP *http.Client // Log contains a logger. @@ -57,17 +81,31 @@ type LocalInstaller struct { ReservedFreeTmpDisk uint64 // ReservedFreeInstallDisk is the amount of disk that must remain free in the install directory. ReservedFreeInstallDisk uint64 + // TransformService transforms the systemd service during copying. + TransformService func([]byte) []byte + // ValidateBinary returns true if a file is a linkable binary, or + // false if a file should not be linked. + ValidateBinary func(ctx context.Context, path string) (bool, error) + // Template is download URI Template of Teleport packages. + Template string } // Remove a Teleport version directory from InstallDir. // This function is idempotent. -func (li *LocalInstaller) Remove(ctx context.Context, version string) error { - versionDir := filepath.Join(li.InstallDir, version) - sumPath := filepath.Join(versionDir, checksumType) +// See Installer interface for additional specs. +func (li *LocalInstaller) Remove(ctx context.Context, rev Revision) error { + // os.RemoveAll is dangerous because it can remove an entire directory tree. + // We must validate the version to ensure that we remove only a single path + // element under the InstallDir, and not InstallDir or its parents. + // revisionDir performs these validations. + versionDir, err := li.revisionDir(rev) + if err != nil { + return trace.Wrap(err) + } // invalidate checksum first, to protect against partially-removed // directory with valid checksum. - err := os.Remove(sumPath) + err = os.Remove(filepath.Join(versionDir, checksumType)) if err != nil && !errors.Is(err, os.ErrNotExist) { return trace.Wrap(err) } @@ -79,12 +117,16 @@ func (li *LocalInstaller) Remove(ctx context.Context, version string) error { // Install a Teleport version directory in InstallDir. // This function is idempotent. -func (li *LocalInstaller) Install(ctx context.Context, version, template string, flags InstallFlags) error { - versionDir := filepath.Join(li.InstallDir, version) +// See Installer interface for additional specs. +func (li *LocalInstaller) Install(ctx context.Context, rev Revision, baseURL string, force bool) (err error) { + versionDir, err := li.revisionDir(rev) + if err != nil { + return trace.Wrap(err) + } sumPath := filepath.Join(versionDir, checksumType) - // generate download URI from template - uri, err := makeURL(template, version, flags) + // generate download URI from Template + uri, err := autoupdate.MakeURL(li.Template, baseURL, autoupdate.DefaultPackage, rev.Version, rev.Flags) if err != nil { return trace.Wrap(err) } @@ -94,33 +136,37 @@ func (li *LocalInstaller) Install(ctx context.Context, version, template string, checksumURI := uri + "." + checksumType newSum, err := li.getChecksum(ctx, checksumURI) if err != nil { - return trace.Errorf("failed to download checksum from %s: %w", checksumURI, err) + return trace.Wrap(err, "failed to download checksum from %s", checksumURI) } oldSum, err := readChecksum(sumPath) - if err == nil { - if bytes.Equal(oldSum, newSum) { - li.Log.InfoContext(ctx, "Version already present.", "version", version) - return nil - } - li.Log.WarnContext(ctx, "Removing version that does not match checksum.", "version", version) - if err := li.Remove(ctx, version); err != nil { - return trace.Wrap(err) - } + versionPresent := err == nil + if versionPresent && bytes.Equal(oldSum, newSum) { + li.Log.InfoContext(ctx, "Version already present.", "version", rev) + return nil + } + if versionPresent { + li.Log.WarnContext(ctx, "Removing version that does not match checksum.", "version", rev) } else if !errors.Is(err, os.ErrNotExist) { - li.Log.WarnContext(ctx, "Removing version with unreadable checksum.", "version", version, "error", err) - if err := li.Remove(ctx, version); err != nil { - return trace.Wrap(err) + li.Log.WarnContext(ctx, "Removing version with unreadable checksum.", "version", rev, "error", err) + } + if versionPresent || !errors.Is(err, os.ErrNotExist) { + if force { + if err := li.Remove(ctx, rev); err != nil { + return trace.Wrap(err) + } + } else { + return trace.Errorf("refusing to remove linked installation of Teleport") } } // Verify that we have enough free temp space, then download tgz freeTmp, err := utils.FreeDiskWithReserve(os.TempDir(), li.ReservedFreeTmpDisk) if err != nil { - return trace.Errorf("failed to calculate free disk: %w", err) + return trace.Wrap(err, "failed to calculate free disk") } f, err := os.CreateTemp("", "teleport-update-") if err != nil { - return trace.Errorf("failed to create temporary file: %w", err) + return trace.Wrap(err, "failed to create temporary file") } defer func() { _ = f.Close() // data never read after close @@ -130,13 +176,19 @@ func (li *LocalInstaller) Install(ctx context.Context, version, template string, }() pathSum, err := li.download(ctx, f, int64(freeTmp), uri) if err != nil { - return trace.Errorf("failed to download teleport: %w", err) + return trace.Wrap(err, "failed to download teleport") } - // Seek to the start of the tgz file after writing if _, err := f.Seek(0, io.SeekStart); err != nil { - return trace.Errorf("failed seek to start of download: %w", err) + return trace.Wrap(err, "failed seek to start of download") } + + // If interrupted, close the file immediately to stop extracting. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + context.AfterFunc(ctx, func() { + _ = f.Close() // safe to close file multiple times + }) // Check integrity before decompression if !bytes.Equal(newSum, pathSum) { return trace.Errorf("mismatched checksum, download possibly corrupt") @@ -144,47 +196,33 @@ func (li *LocalInstaller) Install(ctx context.Context, version, template string, // Get uncompressed size of the tgz n, err := uncompressedSize(f) if err != nil { - return trace.Errorf("failed to determine uncompressed size: %w", err) + return trace.Wrap(err, "failed to determine uncompressed size") } // Seek to start of tgz after reading size if _, err := f.Seek(0, io.SeekStart); err != nil { - return trace.Errorf("failed seek to start: %w", err) + return trace.Wrap(err, "failed seek to start") } - if err := li.extract(ctx, versionDir, f, n); err != nil { - return trace.Errorf("failed to extract teleport: %w", err) + + // If there's an error after we start extracting, delete the version dir. + defer func() { + if err != nil { + if err := os.RemoveAll(versionDir); err != nil { + li.Log.WarnContext(ctx, "Failed to cleanup broken version extraction.", "error", err, "dir", versionDir) + } + } + }() + + // Extract tgz into version directory. + if err := li.extract(ctx, versionDir, f, n, rev.Flags); err != nil { + return trace.Wrap(err, "failed to extract teleport") } // Write the checksum last. This marks the version directory as valid. - err = os.WriteFile(sumPath, []byte(hex.EncodeToString(newSum)), 0755) - if err != nil { - return trace.Errorf("failed to write checksum: %w", err) + if err := os.WriteFile(sumPath, []byte(hex.EncodeToString(newSum)), configFileMode); err != nil { + return trace.Wrap(err, "failed to write checksum") } return nil } -// makeURL to download the Teleport tgz. -func makeURL(uriTmpl, version string, flags InstallFlags) (string, error) { - tmpl, err := template.New("uri").Parse(uriTmpl) - if err != nil { - return "", trace.Wrap(err) - } - var uriBuf bytes.Buffer - params := struct { - OS, Version, Arch string - FIPS, Enterprise bool - }{ - OS: runtime.GOOS, - Version: version, - Arch: runtime.GOARCH, - FIPS: flags&FlagFIPS != 0, - Enterprise: flags&(FlagEnterprise|FlagFIPS) != 0, - } - err = tmpl.Execute(&uriBuf, params) - if err != nil { - return "", trace.Wrap(err) - } - return uriBuf.String(), nil -} - // readChecksum from the version directory. func readChecksum(path string) ([]byte, error) { f, err := os.Open(path) @@ -208,6 +246,7 @@ func readChecksum(path string) ([]byte, error) { func (li *LocalInstaller) getChecksum(ctx context.Context, url string) ([]byte, error) { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, trace.Wrap(err) @@ -241,6 +280,7 @@ func (li *LocalInstaller) download(ctx context.Context, w io.Writer, max int64, if err != nil { return nil, trace.Wrap(err) } + startTime := time.Now() resp, err := li.HTTP.Do(req) if err != nil { return nil, trace.Wrap(err) @@ -264,42 +304,71 @@ func (li *LocalInstaller) download(ctx context.Context, w io.Writer, max int64, } // Calculate checksum concurrently with download. shaReader := sha256.New() - n, err := io.CopyN(w, io.TeeReader(resp.Body, shaReader), size) + tee := io.TeeReader(resp.Body, shaReader) + tee = io.TeeReader(tee, &progressLogger{ + ctx: ctx, + log: li.Log, + level: slog.LevelInfo, + name: path.Base(resp.Request.URL.Path), + max: int(resp.ContentLength), + lines: 5, + }) + n, err := io.CopyN(w, tee, size) if err != nil { return nil, trace.Wrap(err) } if resp.ContentLength >= 0 && n != resp.ContentLength { return nil, trace.Errorf("mismatch in Teleport download size") } + li.Log.InfoContext(ctx, "Download complete.", "duration", time.Since(startTime), "size", n) return shaReader.Sum(nil), nil } -func (li *LocalInstaller) extract(ctx context.Context, dstDir string, src io.Reader, max int64) error { - if err := os.MkdirAll(dstDir, 0755); err != nil { +func (li *LocalInstaller) extract(ctx context.Context, dstDir string, src io.Reader, max int64, flags autoupdate.InstallFlags) error { + if err := os.MkdirAll(dstDir, systemDirMode); err != nil { return trace.Wrap(err) } free, err := utils.FreeDiskWithReserve(dstDir, li.ReservedFreeInstallDisk) if err != nil { - return trace.Errorf("failed to calculate free disk in %q: %w", dstDir, err) + return trace.Wrap(err, "failed to calculate free disk in %s", dstDir) } // Bail if there's not enough free disk space at the target if d := int64(free) - max; d < 0 { - return trace.Errorf("%q needs %d additional bytes of disk space for decompression", dstDir, -d) + return trace.Errorf("%s needs %d additional bytes of disk space for decompression", dstDir, -d) } zr, err := gzip.NewReader(src) if err != nil { - return trace.Errorf("requires gzip-compressed body: %v", err) + return trace.Wrap(err, "requires gzip-compressed body") } li.Log.InfoContext(ctx, "Extracting Teleport tarball.", "path", dstDir, "size", max) - // TODO(sclevine): add variadic arg to Extract to extract teleport/ subdir into bin/. - err = utils.Extract(zr, dstDir) + err = utils.Extract(zr, dstDir, tgzExtractPaths(flags&(autoupdate.FlagEnterprise|autoupdate.FlagFIPS) != 0)...) if err != nil { return trace.Wrap(err) } return nil } +// tgzExtractPaths describes how to extract the Teleport tgz. +// See utils.Extract for more details on how this list is parsed. +// Paths must use tarball-style / separators (not filepath). +func tgzExtractPaths(ent bool) []utils.ExtractPath { + prefix := "teleport" + if ent { + prefix += "-ent" + } + return []utils.ExtractPath{ + {Src: path.Join(prefix, "examples/systemd/teleport.service"), Dst: filepath.Join(serviceDir, serviceName), DirMode: systemDirMode}, + {Src: path.Join(prefix, "examples"), Skip: true, DirMode: systemDirMode}, + {Src: path.Join(prefix, "install"), Skip: true, DirMode: systemDirMode}, + {Src: path.Join(prefix, "README.md"), Dst: "share/README.md", DirMode: systemDirMode}, + {Src: path.Join(prefix, "CHANGELOG.md"), Dst: "share/CHANGELOG.md", DirMode: systemDirMode}, + {Src: path.Join(prefix, "VERSION"), Dst: "share/VERSION", DirMode: systemDirMode}, + {Src: path.Join(prefix, "LICENSE-community"), Dst: "share/LICENSE-community", DirMode: systemDirMode}, + {Src: prefix, Dst: "bin", DirMode: systemDirMode}, + } +} + func uncompressedSize(f io.Reader) (int64, error) { // NOTE: The gzip length trailer is very unreliable, // but we could optimize this in the future if @@ -315,3 +384,500 @@ func uncompressedSize(f io.Reader) (int64, error) { } return n, nil } + +// List installed versions of Teleport. +func (li *LocalInstaller) List(ctx context.Context) (revs []Revision, err error) { + entries, err := os.ReadDir(li.InstallDir) + if err != nil { + return nil, trace.Wrap(err) + } + for _, entry := range entries { + if !entry.IsDir() { + continue + } + rev, err := NewRevisionFromDir(entry.Name()) + if err != nil { + return nil, trace.Wrap(err) + } + revs = append(revs, rev) + } + return revs, nil +} + +// Link the specified version into pathDir and TargetServiceFile. +// The revert function restores the previous linking. +// If force is true, Link will overwrite files that are not symlinks. +// See Installer interface for additional specs. +func (li *LocalInstaller) Link(ctx context.Context, rev Revision, pathDir string, force bool) (revert func(context.Context) bool, err error) { + revert = func(context.Context) bool { return true } + versionDir, err := li.revisionDir(rev) + if err != nil { + return revert, trace.Wrap(err) + } + revert, err = li.forceLinks(ctx, + filepath.Join(versionDir, "bin"), + filepath.Join(versionDir, serviceDir, serviceName), + pathDir, + force, + ) + if err != nil { + return revert, trace.Wrap(err) + } + return revert, nil +} + +// LinkSystem links the system (package) version into defaultPathDir and TargetServiceFile. +// This prevents namespaced installations in /opt/teleport from linking to the system package. +// The revert function restores the previous linking. +// See Installer interface for additional specs. +func (li *LocalInstaller) LinkSystem(ctx context.Context) (revert func(context.Context) bool, err error) { + revert, err = li.forceLinks(ctx, li.SystemBinDir, li.SystemServiceFile, defaultPathDir, false) + return revert, trace.Wrap(err) +} + +// TryLink links the specified version into pathDir, but only in the case that +// no installation of Teleport is already linked or partially linked. +// See Installer interface for additional specs. +func (li *LocalInstaller) TryLink(ctx context.Context, revision Revision, pathDir string) error { + versionDir, err := li.revisionDir(revision) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(li.tryLinks(ctx, + filepath.Join(versionDir, "bin"), + filepath.Join(versionDir, serviceDir, serviceName), + pathDir, + )) +} + +// TryLinkSystem links the system installation to defaultPathDir, but only in the case that +// no installation of Teleport is already linked or partially linked. +// See Installer interface for additional specs. +func (li *LocalInstaller) TryLinkSystem(ctx context.Context) error { + return trace.Wrap(li.tryLinks(ctx, li.SystemBinDir, li.SystemServiceFile, defaultPathDir)) +} + +// Unlink unlinks a version from pathDir and TargetServiceFile. +// See Installer interface for additional specs. +func (li *LocalInstaller) Unlink(ctx context.Context, rev Revision, pathDir string) error { + versionDir, err := li.revisionDir(rev) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(li.removeLinks(ctx, + filepath.Join(versionDir, "bin"), + filepath.Join(versionDir, serviceDir, serviceName), + pathDir, + )) +} + +// UnlinkSystem unlinks the system (package) version from defaultPathDir and TargetServiceFile. +// See Installer interface for additional specs. +func (li *LocalInstaller) UnlinkSystem(ctx context.Context) error { + return trace.Wrap(li.removeLinks(ctx, li.SystemBinDir, li.SystemServiceFile, defaultPathDir)) +} + +// symlink from oldname to newname +type symlink struct { + oldname, newname string +} + +// smallFile is a file small enough to be stored in memory. +type smallFile struct { + name string + data []byte + mode os.FileMode +} + +// forceLinks replaces binary links and service files using files in binDir and svcDir. +// Existing links and files are replaced, but mismatched links and files will result in error. +// forceLinks will revert any overridden links or files if it hits an error. +// If successful, forceLinks may also be reverted after it returns by calling revert. +// The revert function returns true if reverting succeeds. +// If force is true, non-link files will be overwritten. +func (li *LocalInstaller) forceLinks(ctx context.Context, srcBinDir, srcSvcFile, dstBinDir string, force bool) (revert func(context.Context) bool, err error) { + // setup revert function + var ( + revertLinks []symlink + revertFiles []smallFile + ) + revert = func(ctx context.Context) bool { + // This function is safe to call repeatedly. + // Returns true only when all changes are successfully reverted. + var ( + keepLinks []symlink + keepFiles []smallFile + ) + for _, l := range revertLinks { + err := renameio.Symlink(l.oldname, l.newname) + if err != nil { + keepLinks = append(keepLinks, l) + li.Log.ErrorContext(ctx, "Failed to revert symlink", "oldname", l.oldname, "newname", l.newname, errorKey, err) + } + } + for _, f := range revertFiles { + err := writeFileAtomicWithinDir(f.name, f.data, f.mode) + if err != nil { + keepFiles = append(keepFiles, f) + li.Log.ErrorContext(ctx, "Failed to revert files", "name", f.name, errorKey, err) + } + } + revertLinks = keepLinks + revertFiles = keepFiles + return len(revertLinks) == 0 && len(revertFiles) == 0 + } + // revert immediately on error, so caller can ignore revert arg + defer func() { + if err != nil { + revert(ctx) + } + }() + + // ensure source directory exists + entries, err := os.ReadDir(srcBinDir) + if errors.Is(err, os.ErrNotExist) { + return revert, trace.Wrap(ErrNoBinaries) + } + if err != nil { + return revert, trace.Wrap(err, "failed to read Teleport binary directory") + } + + // ensure target directories exist before trying to create links + err = os.MkdirAll(dstBinDir, systemDirMode) + if err != nil { + return revert, trace.Wrap(err) + } + err = os.MkdirAll(filepath.Dir(li.TargetServiceFile), systemDirMode) + if err != nil { + return revert, trace.Wrap(err) + } + + // create binary links + var linked int + for _, entry := range entries { + if entry.IsDir() { + continue + } + oldname := filepath.Join(srcBinDir, entry.Name()) + newname := filepath.Join(dstBinDir, entry.Name()) + exec, err := li.ValidateBinary(ctx, oldname) + if err != nil { + return revert, trace.Wrap(err) + } + if !exec { + continue + } + orig, err := forceLink(oldname, newname, force) + if err != nil && !errors.Is(err, os.ErrExist) { + return revert, trace.Wrap(err, "failed to create symlink for %s", entry.Name()) + } + if orig != "" { + revertLinks = append(revertLinks, symlink{ + oldname: orig, + newname: newname, + }) + } + linked++ + } + if linked == 0 { + return revert, trace.Wrap(ErrNoBinaries) + } + + // create systemd service file + + orig, err := li.forceCopyService(li.TargetServiceFile, srcSvcFile, maxServiceFileSize) + if err != nil && !errors.Is(err, os.ErrExist) { + return revert, trace.Wrap(err, "failed to copy service") + } + if orig != nil { + revertFiles = append(revertFiles, *orig) + } + return revert, nil +} + +// forceCopyService uses forceCopy to copy a systemd service file from src to dst. +// The contents of both src and dst must be smaller than n. +// See forceCopy for more details. +func (li *LocalInstaller) forceCopyService(dst, src string, n int64) (orig *smallFile, err error) { + srcData, err := readFileAtMost(src, n) + if err != nil { + return nil, trace.Wrap(err) + } + return forceCopy(dst, li.TransformService(srcData), n) +} + +// forceLink attempts to create a symlink, atomically replacing an existing link if already present. +// If a non-symlink file or directory exists in newname already and force is false, forceLink errors with ErrFilePresent. +// If the link is already present with the desired oldname, forceLink returns os.ErrExist. +func forceLink(oldname, newname string, force bool) (orig string, err error) { + orig, err = os.Readlink(newname) + if errors.Is(err, os.ErrInvalid) || + errors.Is(err, syscall.EINVAL) { // workaround missing ErrInvalid wrapper + if force { + return "", trace.Wrap(renameio.Symlink(oldname, newname)) + } + // important: do not attempt to replace a non-linked install of Teleport without force + return "", trace.Wrap(ErrFilePresent, "refusing to replace file at %s", newname) + } else if err != nil && !errors.Is(err, os.ErrNotExist) { + return "", trace.Wrap(err) + } + if orig == oldname { + return "", trace.Wrap(os.ErrExist) + } + err = renameio.Symlink(oldname, newname) + if err != nil { + return "", trace.Wrap(err) + } + return orig, nil +} + +// forceCopy atomically copies a file from srcData to dst, replacing an existing file at dst if needed. +// The contents of dst must be smaller than n. +// forceCopy returns the original file path, mode, and contents as orig. +// If an irregular file, too large file, or directory exists in dst already, forceCopy errors. +// If the file is already present with the desired contents, forceCopy returns os.ErrExist. +func forceCopy(dst string, srcData []byte, n int64) (orig *smallFile, err error) { + fi, err := os.Lstat(dst) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return nil, trace.Wrap(err) + } + if err == nil { + orig = &smallFile{ + name: dst, + mode: fi.Mode(), + } + if !orig.mode.IsRegular() { + return nil, trace.Errorf("refusing to replace irregular file at %s", dst) + } + orig.data, err = readFileAtMost(dst, n) + if err != nil { + return nil, trace.Wrap(err) + } + if bytes.Equal(srcData, orig.data) { + return nil, trace.Wrap(os.ErrExist) + } + } + err = writeFileAtomicWithinDir(dst, srcData, configFileMode) + if err != nil { + return nil, trace.Wrap(err) + } + return orig, nil +} + +// writeFileAtomicWithinDir atomically creates a new file with renameio, while ensuring that temporary +// files use the same directory as the target file (with format: .[base][randints]). +// This ensures that SELinux contexts for important files are set correctly. +func writeFileAtomicWithinDir(filename string, data []byte, perm os.FileMode) error { + dir := filepath.Dir(filename) + err := renameio.WriteFile(filename, data, perm, renameio.WithTempDir(dir)) + return trace.Wrap(err) +} + +// readFileAtMost reads a file up to n, or errors if it is too large. +func readFileAtMost(name string, n int64) ([]byte, error) { + f, err := os.Open(name) + if err != nil { + return nil, err + } + defer f.Close() + data, err := utils.ReadAtMost(f, n) + return data, trace.Wrap(err) +} + +func (li *LocalInstaller) removeLinks(ctx context.Context, srcBinDir, srcSvcFile, dstBinDir string) error { + removeService := false + entries, err := os.ReadDir(srcBinDir) + if err != nil { + return trace.Wrap(err, "failed to find Teleport binary directory") + } + for _, entry := range entries { + if entry.IsDir() { + continue + } + oldname := filepath.Join(srcBinDir, entry.Name()) + newname := filepath.Join(dstBinDir, entry.Name()) + v, err := os.Readlink(newname) + if errors.Is(err, os.ErrNotExist) || + errors.Is(err, os.ErrInvalid) || + errors.Is(err, syscall.EINVAL) { + li.Log.DebugContext(ctx, "Link not present.", "oldname", oldname, "newname", newname) + continue + } + if err != nil { + return trace.Wrap(err, "error reading link for %s", filepath.Base(newname)) + } + if v != oldname { + li.Log.DebugContext(ctx, "Skipping link to different binary.", "oldname", oldname, "newname", newname) + continue + } + if err := os.Remove(newname); err != nil { + li.Log.ErrorContext(ctx, "Unable to remove link.", "oldname", oldname, "newname", newname, errorKey, err) + continue + } + if filepath.Base(newname) == teleport.ComponentTeleport { + removeService = true + } + } + // only remove service if teleport was removed + if !removeService { + li.Log.DebugContext(ctx, "Teleport binary not unlinked. Skipping removal of teleport.service.") + return nil + } + srcBytes, err := readFileAtMost(srcSvcFile, maxServiceFileSize) + if err != nil { + return trace.Wrap(err) + } + dstBytes, err := readFileAtMost(li.TargetServiceFile, maxServiceFileSize) + if errors.Is(err, os.ErrNotExist) { + li.Log.DebugContext(ctx, "Service not present.", "path", li.TargetServiceFile) + return nil + } + if err != nil { + return trace.Wrap(err) + } + if !bytes.Equal(li.TransformService(srcBytes), dstBytes) { + li.Log.WarnContext(ctx, "Removed teleport binary link, but skipping removal of custom teleport.service: the service file does not match the reference file for this version. The file might have been manually edited.") + return nil + } + if err := os.Remove(li.TargetServiceFile); err != nil { + return trace.Wrap(err, "error removing copy of %s", filepath.Base(li.TargetServiceFile)) + } + return nil +} + +// tryLinks create binary and service links for files in binDir and svcDir if links are not already present. +// Existing links that point to files outside binDir or svcDir, as well as existing non-link files, will error. +// tryLinks will not attempt to create any links if linking could result in an error. +// However, concurrent changes to links may result in an error with partially-complete linking. +func (li *LocalInstaller) tryLinks(ctx context.Context, srcBinDir, srcSvcFile, dstBinDir string) error { + // ensure source directory exists + entries, err := os.ReadDir(srcBinDir) + if errors.Is(err, os.ErrNotExist) { + return trace.Wrap(ErrNoBinaries) + } + if err != nil { + return trace.Wrap(err, "failed to read Teleport binary directory") + } + + // ensure target directories exist before trying to create links + err = os.MkdirAll(dstBinDir, systemDirMode) + if err != nil { + return trace.Wrap(err) + } + err = os.MkdirAll(filepath.Dir(li.TargetServiceFile), systemDirMode) + if err != nil { + return trace.Wrap(err) + } + + // validate that we can link all system binaries before attempting linking + var links []symlink + var linked int + for _, entry := range entries { + if entry.IsDir() { + continue + } + oldname := filepath.Join(srcBinDir, entry.Name()) + newname := filepath.Join(dstBinDir, entry.Name()) + exec, err := li.ValidateBinary(ctx, oldname) + if err != nil { + return trace.Wrap(err) + } + if !exec { + continue + } + ok, err := needsLink(oldname, newname) + if err != nil { + return trace.Wrap(err, "error evaluating link for %s", filepath.Base(oldname)) + } + if ok { + links = append(links, symlink{oldname, newname}) + } + linked++ + } + // bail if no binaries can be linked + if linked == 0 { + return trace.Wrap(ErrNoBinaries) + } + + // link binaries that are missing links + for _, link := range links { + if err := os.Symlink(link.oldname, link.newname); err != nil { + return trace.Wrap(err, "failed to create symlink for %s", filepath.Base(link.oldname)) + } + } + + // if any binaries are linked from srcBinDir, always link the service from svcDir + _, err = li.forceCopyService(li.TargetServiceFile, srcSvcFile, maxServiceFileSize) + if err != nil && !errors.Is(err, os.ErrExist) { + return trace.Wrap(err, "failed to copy service") + } + + return nil +} + +// needsLink returns true when a symlink from oldname to newname needs to be created, or false if it exists. +// If a non-symlink file or directory exists at newname, needsLink errors with ErrFilePresent. +// If a symlink to a different location exists, needsLink errors with ErrLinked. +func needsLink(oldname, newname string) (ok bool, err error) { + orig, err := os.Readlink(newname) + if errors.Is(err, os.ErrInvalid) || + errors.Is(err, syscall.EINVAL) { // workaround missing ErrInvalid wrapper + // important: do not attempt to replace a non-linked install of Teleport + return false, trace.Wrap(ErrFilePresent, "refusing to replace file at %s", newname) + } + if errors.Is(err, os.ErrNotExist) { + return true, nil + } + if err != nil { + return false, trace.Wrap(err) + } + if orig != oldname { + return false, trace.Wrap(ErrLinked, "refusing to replace link at %s", newname) + } + return false, nil +} + +// revisionDir returns the storage directory for a Teleport revision. +// revisionDir will fail if the revision cannot be used to construct the directory name. +// For example, it ensures that ".." cannot be provided to return a system directory. +func (li *LocalInstaller) revisionDir(rev Revision) (string, error) { + installDir, err := filepath.Abs(li.InstallDir) + if err != nil { + return "", trace.Wrap(err) + } + versionDir := filepath.Join(installDir, rev.Dir()) + if filepath.Dir(versionDir) != filepath.Clean(installDir) { + return "", trace.Errorf("refusing to link directory outside of version directory") + } + return versionDir, nil +} + +// IsLinked returns true if any binaries for Revision rev are linked to pathDir. +// Returns os.ErrNotExist error if the revision does not exist. +func (li *LocalInstaller) IsLinked(ctx context.Context, rev Revision, pathDir string) (bool, error) { + versionDir, err := li.revisionDir(rev) + if err != nil { + return false, trace.Wrap(err) + } + binDir := filepath.Join(versionDir, "bin") + entries, err := os.ReadDir(binDir) + if errors.Is(err, os.ErrNotExist) { + return false, nil + } + if err != nil { + return false, trace.Wrap(err) + } + for _, entry := range entries { + if entry.IsDir() { + continue + } + v, err := os.Readlink(filepath.Join(pathDir, entry.Name())) + if err != nil { + continue + } + if filepath.Clean(v) == filepath.Join(binDir, entry.Name()) { + return true, nil + } + } + return false, nil +} diff --git a/lib/autoupdate/agent/installer_test.go b/lib/autoupdate/agent/installer_test.go index be778f7bcf16a..7fbee63c8aa73 100644 --- a/lib/autoupdate/agent/installer_test.go +++ b/lib/autoupdate/agent/installer_test.go @@ -32,15 +32,18 @@ import ( "os" "path/filepath" "runtime" + "slices" "strconv" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/lib/autoupdate" ) -func TestTeleportInstaller_Install(t *testing.T) { +func TestLocalInstaller_Install(t *testing.T) { t.Parallel() const version = "new-version" @@ -51,7 +54,8 @@ func TestTeleportInstaller_Install(t *testing.T) { reservedTmp uint64 reservedInstall uint64 existingSum string - flags InstallFlags + flags autoupdate.InstallFlags + force bool errMatch string }{ @@ -65,10 +69,18 @@ func TestTeleportInstaller_Install(t *testing.T) { { name: "mismatched checksum", existingSum: hex.EncodeToString(sha256.New().Sum(nil)), + force: true, }, { name: "unreadable checksum", existingSum: "bad", + force: true, + }, + { + name: "unreadable checksum, force false", + existingSum: "bad", + force: false, + errMatch: "refusing", }, { name: "out of space in /tmp", @@ -84,12 +96,13 @@ func TestTeleportInstaller_Install(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { dir := t.TempDir() + err := os.MkdirAll(filepath.Join(dir, version), os.ModePerm) + require.NoError(t, err) if tt.existingSum != "" { - err := os.WriteFile(filepath.Join(dir, checksumType), []byte(tt.existingSum), os.ModePerm) + err := os.WriteFile(filepath.Join(dir, version, checksumType), []byte(tt.existingSum), os.ModePerm) require.NoError(t, err) } @@ -122,9 +135,10 @@ func TestTeleportInstaller_Install(t *testing.T) { Log: slog.Default(), ReservedFreeTmpDisk: tt.reservedTmp, ReservedFreeInstallDisk: tt.reservedInstall, + Template: "{{.BaseURL}}/{{.Package}}-{{.OS}}/{{.Arch}}/{{.Version}}", } ctx := context.Background() - err := installer.Install(ctx, version, server.URL+"/{{.OS}}/{{.Arch}}/{{.Version}}", tt.flags) + err = installer.Install(ctx, NewRevision(version, tt.flags), server.URL, tt.force) if tt.errMatch != "" { require.Error(t, err) assert.Contains(t, err.Error(), tt.errMatch) @@ -132,17 +146,24 @@ func TestTeleportInstaller_Install(t *testing.T) { } require.NoError(t, err) - const expectedPath = "/" + runtime.GOOS + "/" + runtime.GOARCH + "/" + version - require.Equal(t, expectedPath, dlPath) + const expectedPath = "/teleport-" + runtime.GOOS + "/" + runtime.GOARCH + "/" + version require.Equal(t, expectedPath+"."+checksumType, shaPath) - teleportVersion, err := os.ReadFile(filepath.Join(dir, version, "teleport")) - require.NoError(t, err) - require.Equal(t, version, string(teleportVersion)) + if tt.existingSum == testSum { + return + } - tshVersion, err := os.ReadFile(filepath.Join(dir, version, "tsh")) - require.NoError(t, err) - require.Equal(t, version, string(tshVersion)) + require.Equal(t, expectedPath, dlPath) + + for _, p := range []string{ + filepath.Join(dir, version, "lib", "systemd", "system", "teleport.service"), + filepath.Join(dir, version, "bin", "teleport"), + filepath.Join(dir, version, "bin", "tsh"), + } { + v, err := os.ReadFile(p) + require.NoError(t, err) + require.Equal(t, version, string(v)) + } sum, err := os.ReadFile(filepath.Join(dir, version, checksumType)) require.NoError(t, err) @@ -163,8 +184,9 @@ func testTGZ(t *testing.T, version string) (tgz *bytes.Buffer, shasum string) { var files = []struct { Name, Body string }{ - {"teleport", version}, - {"tsh", version}, + {"teleport/examples/systemd/teleport.service", version}, + {"teleport/teleport", version}, + {"teleport/tsh", version}, } for _, file := range files { hdr := &tar.Header{ @@ -187,3 +209,955 @@ func testTGZ(t *testing.T, version string) (tgz *bytes.Buffer, shasum string) { } return &buf, hex.EncodeToString(sha.Sum(nil)) } + +func TestLocalInstaller_Link(t *testing.T) { + t.Parallel() + const version = "new-version" + servicePath := filepath.Join(serviceDir, serviceName) + + tests := []struct { + name string + installDirs []string + installFiles []string + installFileMode os.FileMode + existingLinks []string + existingFiles []string + force bool + + resultLinks []string + resultServices []string + errMatch string + }{ + { + name: "present with new links", + installDirs: []string{ + "bin", + "bin/somedir", + "lib", + "lib/systemd", + "lib/systemd/system", + "somedir", + }, + installFiles: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + servicePath, + "README", + }, + installFileMode: os.ModePerm, + + resultLinks: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + }, + resultServices: []string{ + "lib/systemd/system/teleport.service", + }, + }, + { + name: "present with non-executable files", + installDirs: []string{ + "bin", + "bin/somedir", + "lib", + "lib/systemd", + "lib/systemd/system", + "somedir", + }, + installFiles: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + servicePath, + "README", + }, + installFileMode: 0644, + + errMatch: ErrNoBinaries.Error(), + }, + { + name: "present with existing links", + installDirs: []string{ + "bin", + "bin/somedir", + "lib", + "lib/systemd", + "lib/systemd/system", + "somedir", + }, + installFiles: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + servicePath, + "README", + }, + installFileMode: os.ModePerm, + existingLinks: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + }, + existingFiles: []string{ + "lib/systemd/system/teleport.service", + }, + + resultLinks: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + }, + resultServices: []string{ + "lib/systemd/system/teleport.service", + }, + }, + { + name: "conflicting systemd files", + installDirs: []string{ + "bin", + "bin/somedir", + "lib", + "lib/systemd", + "lib/systemd/system", + "somedir", + }, + installFiles: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + servicePath, + "README", + }, + installFileMode: os.ModePerm, + existingLinks: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + "lib/systemd/system/teleport.service", + }, + + errMatch: "refusing", + }, + { + name: "conflicting bin files", + installDirs: []string{ + "bin", + "bin/somedir", + "lib", + "lib/systemd", + "lib/systemd/system", + "somedir", + }, + installFiles: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + servicePath, + "README", + }, + installFileMode: os.ModePerm, + existingLinks: []string{ + "bin/teleport", + "bin/tbot", + }, + existingFiles: []string{ + "lib/systemd/system/teleport.service", + "bin/tsh", + }, + + errMatch: ErrFilePresent.Error(), + }, + { + name: "overwriting bin files", + installDirs: []string{ + "bin", + "bin/somedir", + "lib", + "lib/systemd", + "lib/systemd/system", + "somedir", + }, + installFiles: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + servicePath, + "README", + }, + installFileMode: os.ModePerm, + existingLinks: []string{ + "bin/teleport", + "bin/tbot", + }, + existingFiles: []string{ + "lib/systemd/system/teleport.service", + "bin/tsh", + }, + force: true, + + resultLinks: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + }, + resultServices: []string{ + "lib/systemd/system/teleport.service", + }, + }, + { + name: "no links", + installFiles: []string{"README"}, + installDirs: []string{"bin"}, + + errMatch: ErrNoBinaries.Error(), + }, + { + name: "no bin directory", + installFiles: []string{"README"}, + + errMatch: ErrNoBinaries.Error(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + versionsDir := t.TempDir() + versionDir := filepath.Join(versionsDir, version) + err := os.MkdirAll(versionDir, 0o755) + require.NoError(t, err) + + // setup files in version directory + for _, d := range tt.installDirs { + err := os.Mkdir(filepath.Join(versionDir, d), os.ModePerm) + require.NoError(t, err) + } + for _, n := range tt.installFiles { + err := os.WriteFile(filepath.Join(versionDir, n), []byte(filepath.Base(n)), tt.installFileMode) + require.NoError(t, err) + } + + // setup files in system links directory + linkDir := t.TempDir() + for _, n := range tt.existingLinks { + err := os.MkdirAll(filepath.Dir(filepath.Join(linkDir, n)), os.ModePerm) + require.NoError(t, err) + err = os.Symlink(filepath.Base(n)+".old", filepath.Join(linkDir, n)) + require.NoError(t, err) + } + for _, n := range tt.existingFiles { + err := os.MkdirAll(filepath.Dir(filepath.Join(linkDir, n)), os.ModePerm) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(linkDir, n), []byte(filepath.Base(n)), os.ModePerm) + require.NoError(t, err) + } + + validator := Validator{Log: slog.Default()} + installer := &LocalInstaller{ + InstallDir: versionsDir, + TargetServiceFile: filepath.Join(linkDir, serviceDir, serviceName), + Log: slog.Default(), + TransformService: func(b []byte) []byte { + return []byte("[transform]" + string(b)) + }, + ValidateBinary: validator.IsExecutable, + Template: autoupdate.DefaultCDNURITemplate, + } + ctx := context.Background() + revert, err := installer.Link(ctx, NewRevision(version, 0), filepath.Join(linkDir, "bin"), tt.force) + if tt.errMatch != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMatch) + + // verify automatic revert + for _, link := range tt.existingLinks { + v, err := os.Readlink(filepath.Join(linkDir, link)) + require.NoError(t, err) + require.Equal(t, filepath.Base(link)+".old", v) + } + for _, n := range tt.existingFiles { + v, err := os.ReadFile(filepath.Join(linkDir, n)) + require.NoError(t, err) + require.Equal(t, filepath.Base(n), string(v)) + } + + // ensure revert still succeeds + ok := revert(ctx) + require.True(t, ok) + return + } + require.NoError(t, err) + + // verify links + for _, link := range tt.resultLinks { + v, err := os.ReadFile(filepath.Join(linkDir, link)) + require.NoError(t, err) + require.Equal(t, filepath.Base(link), string(v)) + } + for _, svc := range tt.resultServices { + v, err := os.ReadFile(filepath.Join(linkDir, svc)) + require.NoError(t, err) + require.Equal(t, "[transform]"+filepath.Base(svc), string(v)) + } + + // verify manual revert + ok := revert(ctx) + require.True(t, ok) + for _, link := range tt.existingLinks { + v, err := os.Readlink(filepath.Join(linkDir, link)) + require.NoError(t, err) + require.Equal(t, filepath.Base(link)+".old", v) + } + for _, n := range tt.existingFiles { + v, err := os.ReadFile(filepath.Join(linkDir, n)) + require.NoError(t, err) + require.Equal(t, filepath.Base(n), string(v)) + } + }) + } +} + +func TestLocalInstaller_TryLink(t *testing.T) { + t.Parallel() + const version = "new-version" + servicePath := filepath.Join(serviceDir, serviceName) + + tests := []struct { + name string + installDirs []string + installFiles []string + installFileMode os.FileMode + existingLinks []string + existingFiles []string + + resultLinks []string + resultServices []string + errMatch string + }{ + { + name: "present with new links", + installDirs: []string{ + "bin", + "bin/somedir", + "lib", + "lib/systemd", + "lib/systemd/system", + "somedir", + }, + installFiles: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + servicePath, + "README", + }, + installFileMode: os.ModePerm, + + resultLinks: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + }, + resultServices: []string{ + "lib/systemd/system/teleport.service", + }, + }, + { + name: "present with non-executable files", + installDirs: []string{ + "bin", + "bin/somedir", + "lib", + "lib/systemd", + "lib/systemd/system", + "somedir", + }, + installFiles: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + servicePath, + "README", + }, + installFileMode: 0644, + + errMatch: ErrNoBinaries.Error(), + }, + { + name: "present with existing links", + installDirs: []string{ + "bin", + "bin/somedir", + "lib", + "lib/systemd", + "lib/systemd/system", + "somedir", + }, + installFiles: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + servicePath, + "README", + }, + installFileMode: os.ModePerm, + existingLinks: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + }, + existingFiles: []string{ + "lib/systemd/system/teleport.service", + }, + + errMatch: "refusing", + }, + { + name: "conflicting systemd files", + installDirs: []string{ + "bin", + "bin/somedir", + "lib", + "lib/systemd", + "lib/systemd/system", + "somedir", + }, + installFiles: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + servicePath, + "README", + }, + installFileMode: os.ModePerm, + existingLinks: []string{ + "lib/systemd/system/teleport.service", + }, + + errMatch: "replace irregular file", + }, + { + name: "conflicting bin files", + installDirs: []string{ + "bin", + "bin/somedir", + "lib", + "lib/systemd", + "lib/systemd/system", + "somedir", + }, + installFiles: []string{ + "bin/teleport", + "bin/tsh", + "bin/tbot", + servicePath, + "README", + }, + installFileMode: os.ModePerm, + existingFiles: []string{ + "bin/tsh", + }, + + errMatch: ErrFilePresent.Error(), + }, + { + name: "no links", + installFiles: []string{"README"}, + installDirs: []string{"bin"}, + + errMatch: ErrNoBinaries.Error(), + }, + { + name: "no bin directory", + installFiles: []string{"README"}, + + errMatch: ErrNoBinaries.Error(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + versionsDir := t.TempDir() + versionDir := filepath.Join(versionsDir, version) + err := os.MkdirAll(versionDir, 0o755) + require.NoError(t, err) + + // setup files in version directory + for _, d := range tt.installDirs { + err := os.Mkdir(filepath.Join(versionDir, d), os.ModePerm) + require.NoError(t, err) + } + for _, n := range tt.installFiles { + err := os.WriteFile(filepath.Join(versionDir, n), []byte(filepath.Base(n)), tt.installFileMode) + require.NoError(t, err) + } + + // setup files in system links directory + linkDir := t.TempDir() + for _, n := range tt.existingLinks { + err := os.MkdirAll(filepath.Dir(filepath.Join(linkDir, n)), os.ModePerm) + require.NoError(t, err) + err = os.Symlink(filepath.Base(n)+".old", filepath.Join(linkDir, n)) + require.NoError(t, err) + } + for _, n := range tt.existingFiles { + err := os.MkdirAll(filepath.Dir(filepath.Join(linkDir, n)), os.ModePerm) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(linkDir, n), []byte(filepath.Base(n)), os.ModePerm) + require.NoError(t, err) + } + + validator := Validator{Log: slog.Default()} + installer := &LocalInstaller{ + InstallDir: versionsDir, + TargetServiceFile: filepath.Join(linkDir, serviceDir, serviceName), + Log: slog.Default(), + TransformService: func(b []byte) []byte { + return []byte("[transform]" + string(b)) + }, + ValidateBinary: validator.IsExecutable, + } + ctx := context.Background() + err = installer.TryLink(ctx, NewRevision(version, 0), filepath.Join(linkDir, "bin")) + if tt.errMatch != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMatch) + + // verify no changes + for _, link := range tt.existingLinks { + v, err := os.Readlink(filepath.Join(linkDir, link)) + require.NoError(t, err) + require.Equal(t, filepath.Base(link)+".old", v) + } + for _, n := range tt.existingFiles { + v, err := os.ReadFile(filepath.Join(linkDir, n)) + require.NoError(t, err) + require.Equal(t, filepath.Base(n), string(v)) + } + return + } + require.NoError(t, err) + + // verify links + for _, link := range tt.resultLinks { + v, err := os.ReadFile(filepath.Join(linkDir, link)) + require.NoError(t, err) + require.Equal(t, filepath.Base(link), string(v)) + } + for _, svc := range tt.resultServices { + v, err := os.ReadFile(filepath.Join(linkDir, svc)) + require.NoError(t, err) + require.Equal(t, "[transform]"+filepath.Base(svc), string(v)) + } + + }) + } +} + +func TestLocalInstaller_Remove(t *testing.T) { + t.Parallel() + const version = "existing-version" + + tests := []struct { + name string + dirs []string + files []string + createVersion string + removeVersion string + + errMatch string + }{ + { + name: "present", + dirs: []string{"bin/somedir", "somedir"}, + files: []string{checksumType, "bin/teleport", "bin/tsh", "bin/tbot", "README"}, + createVersion: version, + removeVersion: version, + }, + { + name: "present missing checksum", + dirs: []string{"bin/somedir", "somedir"}, + files: []string{"bin/teleport", "bin/tsh", "bin/tbot", "README"}, + createVersion: version, + removeVersion: version, + }, + { + name: "not present", + dirs: []string{"bin/somedir", "somedir"}, + files: []string{checksumType, "bin/teleport", "bin/tsh", "bin/tbot", "README"}, + createVersion: version, + removeVersion: "missing-version", + }, + { + name: "version empty", + dirs: []string{"bin/somedir", "somedir"}, + files: []string{checksumType, "bin/teleport", "bin/tsh", "bin/tbot", "README"}, + createVersion: version, + removeVersion: "", + + errMatch: "outside", + }, + { + name: "version has path", + dirs: []string{"bin/somedir", "somedir"}, + files: []string{checksumType, "bin/teleport", "bin/tsh", "bin/tbot", "README"}, + createVersion: version, + removeVersion: "one/two", + + errMatch: "outside", + }, + { + name: "version is ..", + dirs: []string{"bin/somedir", "somedir"}, + files: []string{checksumType, "bin/teleport", "bin/tsh", "bin/tbot", "README"}, + createVersion: version, + removeVersion: "..", + + errMatch: "outside", + }, + { + name: "version is .", + dirs: []string{"bin/somedir", "somedir"}, + files: []string{checksumType, "bin/teleport", "bin/tsh", "bin/tbot", "README"}, + createVersion: version, + removeVersion: ".", + + errMatch: "outside", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + versionsDir := t.TempDir() + versionDir := filepath.Join(versionsDir, tt.createVersion) + err := os.MkdirAll(versionDir, 0o755) + require.NoError(t, err) + + for _, d := range tt.dirs { + err := os.MkdirAll(filepath.Join(versionDir, d), os.ModePerm) + require.NoError(t, err) + } + for _, n := range tt.files { + err := os.WriteFile(filepath.Join(versionDir, n), []byte(filepath.Base(n)), os.ModePerm) + require.NoError(t, err) + } + + linkDir := t.TempDir() + + validator := Validator{Log: slog.Default()} + installer := &LocalInstaller{ + InstallDir: versionsDir, + TargetServiceFile: filepath.Join(linkDir, serviceDir, serviceName), + Log: slog.Default(), + TransformService: func(b []byte) []byte { + return []byte("[transform]" + string(b)) + }, + ValidateBinary: validator.IsExecutable, + } + ctx := context.Background() + err = installer.Remove(ctx, NewRevision(tt.removeVersion, 0)) + if tt.errMatch != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMatch) + return + } + require.NoError(t, err) + _, err = os.Stat(filepath.Join(versionDir, "bin", tt.removeVersion)) + require.ErrorIs(t, err, os.ErrNotExist) + }) + } +} + +func TestLocalInstaller_IsLinked(t *testing.T) { + t.Parallel() + const version = "existing-version" + servicePath := filepath.Join(serviceDir, serviceName) + + tests := []struct { + name string + dirs []string + files []string + createVersion string + linkVersion string + checkVersion string + + result bool + errMatch string + }{ + { + name: "linked", + dirs: []string{"bin/somedir", "somedir", serviceDir}, + files: []string{checksumType, "bin/teleport", "bin/tsh", "bin/tbot", "README", servicePath}, + createVersion: version, + linkVersion: version, + checkVersion: version, + result: true, + }, + { + name: "other linked", + dirs: []string{"bin/somedir", "somedir", serviceDir}, + files: []string{checksumType, "bin/teleport", "bin/tsh", "bin/tbot", "README", servicePath}, + createVersion: version, + linkVersion: version, + checkVersion: "other", + result: false, + }, + { + name: "not linked", + checkVersion: version, + result: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + versionsDir := t.TempDir() + linkDir := t.TempDir() + + validator := Validator{Log: slog.Default()} + installer := &LocalInstaller{ + InstallDir: versionsDir, + TargetServiceFile: filepath.Join(linkDir, serviceDir, serviceName), + Log: slog.Default(), + TransformService: func(b []byte) []byte { + return []byte("[transform]" + string(b)) + }, + ValidateBinary: validator.IsExecutable, + } + ctx := context.Background() + if tt.createVersion != "" { + versionDir := filepath.Join(versionsDir, tt.createVersion) + err := os.MkdirAll(versionDir, 0o755) + require.NoError(t, err) + + for _, d := range tt.dirs { + err := os.MkdirAll(filepath.Join(versionDir, d), os.ModePerm) + require.NoError(t, err) + } + for _, n := range tt.files { + err := os.WriteFile(filepath.Join(versionDir, n), []byte(filepath.Base(n)), os.ModePerm) + require.NoError(t, err) + } + } + if tt.linkVersion != "" { + _, err := installer.Link(ctx, NewRevision(tt.linkVersion, 0), linkDir, false) + require.NoError(t, err) + } + result, err := installer.IsLinked(ctx, NewRevision(tt.checkVersion, 0), linkDir) + if tt.errMatch != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMatch) + return + } + require.NoError(t, err) + require.Equal(t, tt.result, result) + }) + } +} + +func TestLocalInstaller_Unlink(t *testing.T) { + t.Parallel() + const version = "existing-version" + servicePath := filepath.Join(serviceDir, serviceName) + + tests := []struct { + name string + bins []string + svcOrig []byte + + links []symlink + svcCopy []byte + + remaining []string + errMatch string + }{ + { + name: "normal", + bins: []string{"teleport", "tsh"}, + svcOrig: []byte("orig"), + links: []symlink{ + {oldname: "bin/teleport", newname: "bin/teleport"}, + {oldname: "bin/tsh", newname: "bin/tsh"}, + }, + svcCopy: []byte("[transform]orig"), + }, + { + name: "different services", + bins: []string{"teleport", "tsh"}, + svcOrig: []byte("orig"), + links: []symlink{ + {oldname: "bin/teleport", newname: "bin/teleport"}, + {oldname: "bin/tsh", newname: "bin/tsh"}, + }, + svcCopy: []byte("custom"), + remaining: []string{servicePath}, + }, + { + name: "missing target service", + bins: []string{"teleport", "tsh"}, + svcOrig: []byte("orig"), + links: []symlink{ + {oldname: "bin/teleport", newname: "bin/teleport"}, + {oldname: "bin/tsh", newname: "bin/tsh"}, + }, + }, + { + name: "missing source service", + bins: []string{"teleport", "tsh"}, + links: []symlink{ + {oldname: "bin/teleport", newname: "bin/teleport"}, + {oldname: "bin/tsh", newname: "bin/tsh"}, + }, + svcCopy: []byte("custom"), + remaining: []string{servicePath}, + errMatch: "no such", + }, + { + name: "missing teleport link", + bins: []string{"teleport", "tsh"}, + svcOrig: []byte("orig"), + links: []symlink{ + {oldname: "bin/tsh", newname: "bin/tsh"}, + }, + svcCopy: []byte("[transform]orig"), + remaining: []string{servicePath}, + }, + { + name: "missing other link", + bins: []string{"teleport", "tsh"}, + svcOrig: []byte("orig"), + links: []symlink{ + {oldname: "bin/teleport", newname: "bin/teleport"}, + }, + svcCopy: []byte("[transform]orig"), + }, + { + name: "wrong teleport link", + bins: []string{"teleport", "tsh"}, + svcOrig: []byte("orig"), + links: []symlink{ + {oldname: "other", newname: "bin/teleport"}, + {oldname: "bin/tsh", newname: "bin/tsh"}, + }, + svcCopy: []byte("[transform]orig"), + remaining: []string{servicePath, "bin/teleport"}, + }, + { + name: "wrong other link", + bins: []string{"teleport", "tsh"}, + svcOrig: []byte("orig"), + links: []symlink{ + {oldname: "bin/teleport", newname: "bin/teleport"}, + {oldname: "wrong", newname: "bin/tsh"}, + }, + svcCopy: []byte("[transform]orig"), + remaining: []string{"bin/tsh"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + versionsDir := t.TempDir() + versionDir := filepath.Join(versionsDir, version) + err := os.MkdirAll(versionDir, 0o755) + require.NoError(t, err) + linkDir := t.TempDir() + + var files []smallFile + for _, n := range tt.bins { + files = append(files, smallFile{ + name: filepath.Join(versionDir, "bin", n), + data: []byte("binary"), + mode: os.ModePerm, + }) + } + if tt.svcOrig != nil { + files = append(files, smallFile{ + name: filepath.Join(versionDir, servicePath), + data: tt.svcOrig, + mode: os.ModePerm, + }) + } + if tt.svcCopy != nil { + files = append(files, smallFile{ + name: filepath.Join(linkDir, servicePath), + data: tt.svcCopy, + mode: os.ModePerm, + }) + } + + for _, n := range files { + err = os.MkdirAll(filepath.Dir(n.name), os.ModePerm) + require.NoError(t, err) + err = os.WriteFile(n.name, n.data, n.mode) + require.NoError(t, err) + } + for _, n := range tt.links { + newname := filepath.Join(linkDir, n.newname) + oldname := filepath.Join(versionDir, n.oldname) + err = os.MkdirAll(filepath.Dir(newname), os.ModePerm) + require.NoError(t, err) + err = os.Symlink(oldname, newname) + require.NoError(t, err) + } + + installer := &LocalInstaller{ + InstallDir: versionsDir, + TargetServiceFile: filepath.Join(linkDir, serviceDir, serviceName), + Log: slog.Default(), + TransformService: func(b []byte) []byte { + return []byte("[transform]" + string(b)) + }, + } + ctx := context.Background() + err = installer.Unlink(ctx, NewRevision(version, 0), filepath.Join(linkDir, "bin")) + if tt.errMatch != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMatch) + } else { + require.NoError(t, err) + } + for _, n := range tt.remaining { + _, err = os.Lstat(filepath.Join(linkDir, n)) + require.NoError(t, err) + } + for _, n := range tt.links { + if slices.Contains(tt.remaining, n.newname) { + continue + } + _, err = os.Lstat(filepath.Join(linkDir, n.newname)) + require.ErrorIs(t, err, os.ErrNotExist) + } + if !slices.Contains(tt.remaining, servicePath) { + _, err = os.Lstat(filepath.Join(linkDir, servicePath)) + require.ErrorIs(t, err, os.ErrNotExist) + } + }) + } +} + +func TestLocalInstaller_List(t *testing.T) { + installDir := t.TempDir() + versions := []string{"v1", "v2"} + + for _, d := range versions { + err := os.Mkdir(filepath.Join(installDir, d), os.ModePerm) + require.NoError(t, err) + } + for _, n := range []string{"file1", "file2"} { + err := os.WriteFile(filepath.Join(installDir, n), []byte(filepath.Base(n)), os.ModePerm) + require.NoError(t, err) + } + installer := &LocalInstaller{ + InstallDir: installDir, + Log: slog.Default(), + } + ctx := context.Background() + revisions, err := installer.List(ctx) + require.NoError(t, err) + require.Equal(t, []Revision{ + NewRevision("v1", 0), + NewRevision("v2", 0), + }, revisions) +} diff --git a/lib/autoupdate/agent/logger.go b/lib/autoupdate/agent/logger.go new file mode 100644 index 0000000000000..bd50ee50859a4 --- /dev/null +++ b/lib/autoupdate/agent/logger.go @@ -0,0 +1,108 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package agent + +import ( + "bytes" + "context" + "fmt" + "log/slog" + + "github.com/gravitational/trace" +) + +// progressLogger logs progress of any data written as it approaches max. +// max(lines, call_count(Write)) lines are written for each multiple of max. +// progressLogger uses the variability of chunk size as a proxy for speed, and avoids +// logging extraneous lines that do not improve UX for waiting humans. +type progressLogger struct { + ctx context.Context + log *slog.Logger + level slog.Level + name string + max int + lines int + + l int + n int +} + +func (w *progressLogger) Write(p []byte) (n int, err error) { + w.n += len(p) + if w.n >= w.max*(w.l+1)/w.lines { + w.log.Log(w.ctx, w.level, "Downloading", + "file", w.name, + "progress", fmt.Sprintf("%d%%", w.n*100/w.max), + ) + w.l++ + } + return len(p), nil +} + +// lineLogger logs each line written to it. +type lineLogger struct { + ctx context.Context + log *slog.Logger + level slog.Level + prefix string + + last bytes.Buffer +} + +func (w *lineLogger) out(s string) { + w.log.Log(w.ctx, w.level, w.prefix+s) //nolint:sloglint // msg cannot be constant +} + +func (w *lineLogger) Write(p []byte) (n int, err error) { + lines := bytes.Split(p, []byte("\n")) + // Finish writing line + if len(lines) > 0 { + n, err = w.last.Write(lines[0]) + lines = lines[1:] + } + // Quit if no newline + if len(lines) == 0 || err != nil { + return n, trace.Wrap(err) + } + + // Newline found, log line + w.out(w.last.String()) + n += 1 + w.last.Reset() + + // Log lines that are already newline-terminated + for _, line := range lines[:len(lines)-1] { + w.out(string(line)) + n += len(line) + 1 + } + + // Store remaining line non-newline-terminated line. + n2, err := w.last.Write(lines[len(lines)-1]) + n += n2 + return n, trace.Wrap(err) +} + +// Flush logs any trailing bytes that were never terminated with a newline. +func (w *lineLogger) Flush() { + if w.last.Len() == 0 { + return + } + w.out(w.last.String()) + w.last.Reset() +} diff --git a/lib/autoupdate/agent/logger_test.go b/lib/autoupdate/agent/logger_test.go new file mode 100644 index 0000000000000..2a8430ef8cf44 --- /dev/null +++ b/lib/autoupdate/agent/logger_test.go @@ -0,0 +1,186 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package agent + +import ( + "bytes" + "context" + "fmt" + "io" + "log/slog" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestLineLogger(t *testing.T) { + t.Parallel() + + out := &bytes.Buffer{} + ll := lineLogger{ + ctx: context.Background(), + log: slog.New(slog.NewTextHandler(out, + &slog.HandlerOptions{ReplaceAttr: msgOnly}, + )), + } + + for _, e := range []struct { + v string + n int + }{ + {v: "", n: 0}, + {v: "a", n: 1}, + {v: "b\n", n: 2}, + {v: "c\nd", n: 3}, + {v: "e\nf\ng", n: 5}, + {v: "h", n: 1}, + {v: "", n: 0}, + {v: "\n", n: 1}, + {v: "i\n", n: 2}, + {v: "j", n: 1}, + } { + n, err := ll.Write([]byte(e.v)) + require.NoError(t, err) + require.Equal(t, e.n, n) + } + require.Equal(t, "msg=ab\nmsg=c\nmsg=de\nmsg=f\nmsg=gh\nmsg=i\n", out.String()) + ll.Flush() + require.Equal(t, "msg=ab\nmsg=c\nmsg=de\nmsg=f\nmsg=gh\nmsg=i\nmsg=j\n", out.String()) +} + +func msgOnly(_ []string, a slog.Attr) slog.Attr { + switch a.Key { + case "time", "level": + return slog.Attr{} + } + return slog.Attr{Key: a.Key, Value: a.Value} +} + +func TestProgressLogger(t *testing.T) { + t.Parallel() + + type write struct { + n int + out string + } + for _, tt := range []struct { + name string + max, lines int + writes []write + }{ + { + name: "even", + max: 100, + lines: 5, + writes: []write{ + {n: 10}, + {n: 10, out: "20%"}, + {n: 10}, + {n: 10, out: "40%"}, + {n: 10}, + {n: 10, out: "60%"}, + {n: 10}, + {n: 10, out: "80%"}, + {n: 10}, + {n: 10, out: "100%"}, + {n: 10}, + {n: 10, out: "120%"}, + }, + }, + { + name: "fast", + max: 100, + lines: 5, + writes: []write{ + {n: 100, out: "100%"}, + {n: 100, out: "200%"}, + }, + }, + { + name: "over fast", + max: 100, + lines: 5, + writes: []write{ + {n: 200, out: "200%"}, + }, + }, + { + name: "slow down when uneven", + max: 100, + lines: 5, + writes: []write{ + {n: 50, out: "50%"}, + {n: 10, out: "60%"}, + {n: 10, out: "70%"}, + {n: 10, out: "80%"}, + {n: 10}, + {n: 10, out: "100%"}, + {n: 10}, + {n: 10, out: "120%"}, + }, + }, + { + name: "slow down when very uneven", + max: 100, + lines: 5, + writes: []write{ + {n: 50, out: "50%"}, + {n: 1, out: "51%"}, + {n: 1}, + {n: 20, out: "72%"}, + {n: 10, out: "82%"}, + {n: 10}, + {n: 10, out: "102%"}, + }, + }, + { + name: "close", + max: 1000, + lines: 5, + writes: []write{ + {n: 999, out: "99%"}, + {n: 1, out: "100%"}, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + out := &bytes.Buffer{} + ll := progressLogger{ + ctx: context.Background(), + log: slog.New(slog.NewTextHandler(out, + &slog.HandlerOptions{ReplaceAttr: msgOnly}, + )), + name: "test", + max: tt.max, + lines: tt.lines, + } + for _, e := range tt.writes { + n, err := ll.Write(make([]byte, e.n)) + require.NoError(t, err) + require.Equal(t, e.n, n) + v, err := io.ReadAll(out) + require.NoError(t, err) + if len(v) > 0 { + e.out = fmt.Sprintf(`msg=Downloading file=test progress=%s`+"\n", e.out) + } + require.Equal(t, e.out, string(v)) + } + }) + } +} diff --git a/lib/autoupdate/agent/process.go b/lib/autoupdate/agent/process.go new file mode 100644 index 0000000000000..a0b937c6b0b6e --- /dev/null +++ b/lib/autoupdate/agent/process.go @@ -0,0 +1,516 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package agent + +import ( + "bytes" + "context" + "errors" + "log/slog" + "net" + "os" + "os/exec" + "strconv" + "syscall" + "time" + + "github.com/gravitational/trace" + "golang.org/x/sync/errgroup" + + "github.com/gravitational/teleport/lib/client/debug" +) + +// process monitoring consts +const ( + // monitorTimeout is the timeout for determining whether the process has started. + monitorTimeout = 1 * time.Minute + // monitorInterval is the polling interval for determining whether the process has started. + monitorInterval = 2 * time.Second + // minRunningIntervalsBeforeStable is the number of consecutive intervals with the same running PID detected + // before the service is determined stable. + minRunningIntervalsBeforeStable = 6 + // maxCrashesBeforeFailure is the number of total crashes detected before the service is marked as crash-looping. + maxCrashesBeforeFailure = 2 +) + +// log keys +const ( + unitKey = "unit" +) + +// SystemdService manages a systemd service (e.g., teleport or teleport-update). +type SystemdService struct { + // ServiceName specifies the systemd service name. + ServiceName string + // PIDFile is a path to a file containing the service's PID. + PIDFile string + // Ready is a readiness checker. + Ready ReadyChecker + // Log contains a logger. + Log *slog.Logger +} + +// ReadyChecker returns the systemd service readiness status. +type ReadyChecker interface { + GetReadiness(ctx context.Context) (debug.Readiness, error) +} + +// Reload the systemd service. +// Attempts a graceful reload before a hard restart. +// See Process interface for more details. +func (s SystemdService) Reload(ctx context.Context) error { + // TODO(sclevine): allow server to force restart instead of reload + + if err := s.checkSystem(ctx); err != nil { + return trace.Wrap(err) + } + + // Command error codes < 0 indicate that we are unable to run the command. + // Errors from s.systemctl are logged along with stderr and stdout (debug only). + + // If the service is not running, return ErrNotNeeded. + // Note systemctl reload returns an error if the unit is not active, and + // try-reload-or-restart is too recent of an addition for centos7. + code := s.systemctl(ctx, slog.LevelDebug, "is-active", "--quiet", s.ServiceName) + switch { + case code < 0: + return trace.Errorf("unable to determine if systemd service is active") + case code > 0: + s.Log.WarnContext(ctx, "Systemd service not running.", unitKey, s.ServiceName) + return trace.Wrap(ErrNotNeeded) + } + + // Get initial PID for crash monitoring. + + initPID, err := readInt(s.PIDFile) + if errors.Is(err, os.ErrNotExist) { + s.Log.InfoContext(ctx, "No existing process detected. Skipping crash monitoring.", unitKey, s.ServiceName) + } else if err != nil { + s.Log.ErrorContext(ctx, "Error reading initial PID value. Skipping crash monitoring.", unitKey, s.ServiceName, errorKey, err) + } + + // Attempt graceful reload of running service. + code = s.systemctl(ctx, slog.LevelError, "reload", s.ServiceName) + switch { + case code < 0: + return trace.Errorf("unable to reload systemd service") + case code > 0: + // Graceful reload fails, try hard restart. + code = s.systemctl(ctx, slog.LevelError, "try-restart", s.ServiceName) + if code != 0 { + return trace.Errorf("hard restart of systemd service failed") + } + s.Log.WarnContext(ctx, "Service ungracefully restarted. Connections potentially dropped.", unitKey, s.ServiceName) + default: + s.Log.InfoContext(ctx, "Gracefully reloaded.", unitKey, s.ServiceName) + } + // monitor logs all relevant errors, so we filter for a few outcomes + err = s.monitor(ctx, initPID) + if errors.Is(err, context.DeadlineExceeded) || + errors.Is(err, context.Canceled) { + return trace.Wrap(err) + } + if err != nil { + return trace.Errorf("failed to monitor process") + } + return nil +} + +// monitor for a started, healthy process. +// monitor logs all errors that should be displayed to the user. +func (s SystemdService) monitor(ctx context.Context, initPID int) error { + ctx, cancel := context.WithTimeout(ctx, monitorTimeout) + defer cancel() + ticker := time.NewTicker(monitorInterval) + defer ticker.Stop() + + newPID := 0 + if initPID != 0 { + s.Log.InfoContext(ctx, "Monitoring PID file to detect crashes.", unitKey, s.ServiceName) + var err error + newPID, err = s.monitorPID(ctx, initPID, ticker.C) + if errors.Is(err, context.DeadlineExceeded) { + s.Log.ErrorContext(ctx, "Timed out monitoring for crashing PID.", unitKey, s.ServiceName) + return trace.Wrap(err) + } + if err != nil { + s.Log.ErrorContext(ctx, "Error monitoring for crashing PID.", errorKey, err, unitKey, s.ServiceName) + return trace.Wrap(err) + } + } + + s.Log.InfoContext(ctx, "Monitoring diagnostic socket to detect readiness.", unitKey, s.ServiceName) + ticker = time.NewTicker(monitorInterval) + defer ticker.Stop() + err := s.waitForReady(ctx, newPID, ticker.C) + if errors.Is(err, context.DeadlineExceeded) { + s.Log.ErrorContext(ctx, "Timed out monitoring for process readiness.", unitKey, s.ServiceName) + return trace.Wrap(err) + } + if err != nil { + s.Log.ErrorContext(ctx, "Error monitoring for process readiness.", errorKey, err, unitKey, s.ServiceName) + return trace.Wrap(err) + } + return nil +} + +// monitorPID for the started process to ensure it's running by polling PIDFile. +// This function detects several types of crashes while minimizing its own runtime during updates. +// For example, the process may crash by failing to fork (non-running PID), or looping (repeatedly changing PID), +// or getting stuck on quit (no change in PID). +// initPID is the PID before the restart operation has been issued. +// The final PID is returned. +func (s SystemdService) monitorPID(ctx context.Context, initPID int, tickC <-chan time.Time) (int, error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + pidC := make(chan int) + var g errgroup.Group + g.Go(func() error { + return tickFile(ctx, s.PIDFile, pidC, tickC) + }) + stablePID, err := s.waitForStablePID(ctx, minRunningIntervalsBeforeStable, maxCrashesBeforeFailure, + initPID, pidC, func(pid int) error { + p, err := os.FindProcess(pid) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(p.Signal(syscall.Signal(0))) + }) + cancel() + if err := g.Wait(); err != nil { + s.Log.ErrorContext(ctx, "Error monitoring for crashing process.", errorKey, err, unitKey, s.ServiceName) + } + return stablePID, trace.Wrap(err) +} + +// waitForStablePID monitors a service's PID via pidC and determines whether the service is crashing. +// verifyPID must be passed so that waitForStablePID can determine whether the process is running. +// verifyPID must return os.ErrProcessDone in the case that the PID cannot be found, or nil otherwise. +// baselinePID is the initial PID before any operation that might cause the process to start crashing. +// minStable is the number of times pidC must return the same running PID before waitForStablePID returns nil. +// minCrashes is the number of times pidC conveys a process crash or bad state before waitForStablePID returns an error. +// The last reported PID is returned. +func (s SystemdService) waitForStablePID(ctx context.Context, minStable, maxCrashes, baselinePID int, pidC <-chan int, verifyPID func(pid int) error) (int, error) { + pid := baselinePID + var last, stale int + var crashes int + for stable := 0; stable < minStable; stable++ { + select { + case <-ctx.Done(): + return pid, ctx.Err() + case p := <-pidC: + last = pid + pid = p + } + // A "crash" is defined as a transition away from a new (non-baseline) PID, or + // an interval where the current PID remains non-running (stale) since the last check. + if (last != 0 && pid != last && last != baselinePID) || + (stale != 0 && pid == stale && last == stale) { + crashes++ + } + if crashes > maxCrashes { + return pid, trace.Errorf("detected crashing process") + } + + // PID can only be stable if it is a real PID that is not new, + // has changed at least once, and hasn't been observed as missing. + if pid == 0 || + pid == baselinePID || + pid == stale || + pid != last { + stable = -1 + continue + } + err := verifyPID(pid) + // A stale PID most likely indicates that the process forked and crashed without systemd noticing. + // There is a small chance that we read the PID file before systemd removed it. + // Note: we only perform this check on PIDs that survive one iteration. + if errors.Is(err, os.ErrProcessDone) || + errors.Is(err, syscall.ESRCH) { + if pid != stale && + pid != baselinePID { + stale = pid + s.Log.WarnContext(ctx, "Detected stale PID.", unitKey, s.ServiceName, "pid", stale) + } + stable = -1 + continue + } + if err != nil { + return pid, trace.Wrap(err) + } + } + return pid, nil +} + +// readInt reads an integer from a file. +func readInt(path string) (int, error) { + p, err := readFileAtMost(path, 32) + if err != nil { + return 0, trace.Wrap(err) + } + i, err := strconv.ParseInt(string(bytes.TrimSpace(p)), 10, 64) + if err != nil { + return 0, trace.Wrap(err) + } + return int(i), nil +} + +// tickFile reads the current time on tickC, and outputs the last read int from path on ch for each received tick. +// If the path cannot be read, tickFile sends 0 on ch. +// Any error from the last attempt to read path is returned when ctx is canceled, unless the error is os.ErrNotExist. +func tickFile(ctx context.Context, path string, ch chan<- int, tickC <-chan time.Time) error { + var err error + for { + // two select statements -> never skip reads + select { + case <-tickC: + case <-ctx.Done(): + return err + } + var t int + t, err = readInt(path) + if errors.Is(err, os.ErrNotExist) { + err = nil + } + select { + case ch <- t: + case <-ctx.Done(): + return err + } + } +} + +// waitForReady polls the SocketPath unix domain socket with HTTP requests. +// If one request returns 200 before the timeout, the service is considered ready. +func (s SystemdService) waitForReady(ctx context.Context, pid int, tickC <-chan time.Time) error { + var lastErr error + var readiness debug.Readiness + for { + resp, err := s.Ready.GetReadiness(ctx) + if err == nil && + resp.Ready && + equalOrZero(resp.PID, pid) { + return nil + } + // If the Readiness check fails to due to intervention, we must not interpret + // the error as a disabled socket, which results in a passing check. + if !errors.Is(err, context.Canceled) && + !errors.Is(err, context.DeadlineExceeded) { + lastErr = err + readiness = resp + } + select { + case <-ctx.Done(): + if errors.Is(lastErr, os.ErrNotExist) || + errors.Is(lastErr, syscall.EINVAL) || + errors.Is(lastErr, os.ErrInvalid) || + errors.As(lastErr, new(net.Error)) { + s.Log.WarnContext(ctx, "Socket appears to be disabled. Proceeding without check.", unitKey, s.ServiceName) + s.Log.DebugContext(ctx, "Found error after timeout polling socket.", unitKey, s.ServiceName, errorKey, lastErr) + return nil + } + if lastErr != nil { + s.Log.WarnContext(ctx, "Unexpected error after timeout polling socket. Proceeding without check.", unitKey, s.ServiceName, errorKey, lastErr) + return nil + } + if readiness.Status != "" { + s.Log.ErrorContext(ctx, "Process not ready by deadline.", unitKey, s.ServiceName, "status", readiness.Status) + } + if !equalOrZero(readiness.PID, pid) { + s.Log.ErrorContext(ctx, "Readiness PID response does not match PID file.", unitKey, s.ServiceName, "file_pid", pid, "ready_pid", readiness.PID) + } else { + s.Log.DebugContext(ctx, "PIDs are not mismatched.", unitKey, s.ServiceName, "file_pid", pid, "ready_pid", readiness.PID) + } + return ctx.Err() + case <-tickC: + } + } +} + +// equalOrZero returns true if a and b are equal, or if either has the zero-value. +func equalOrZero[T comparable](a, b T) bool { + var empty T + if a == empty || b == empty { + return true + } + return a == b +} + +// Sync systemd service configuration by running systemctl daemon-reload. +// See Process interface for more details. +func (s SystemdService) Sync(ctx context.Context) error { + if err := s.checkSystem(ctx); err != nil { + return trace.Wrap(err) + } + code := s.systemctl(ctx, slog.LevelError, "daemon-reload") + if code != 0 { + return trace.Errorf("unable to reload systemd configuration") + } + s.Log.InfoContext(ctx, "Systemd configuration synced.", unitKey, s.ServiceName) + return nil +} + +// Enable the systemd service. +func (s SystemdService) Enable(ctx context.Context, now bool) error { + if err := s.checkSystem(ctx); err != nil { + return trace.Wrap(err) + } + args := []string{"enable", s.ServiceName} + if now { + args = append(args, "--now") + } + code := s.systemctl(ctx, slog.LevelInfo, args...) + if code != 0 { + return trace.Errorf("unable to enable systemd service") + } + s.Log.InfoContext(ctx, "Service enabled.", unitKey, s.ServiceName) + return nil +} + +// Disable the systemd service. +func (s SystemdService) Disable(ctx context.Context, now bool) error { + if err := s.checkSystem(ctx); err != nil { + return trace.Wrap(err) + } + args := []string{"disable", s.ServiceName} + if now { + args = append(args, "--now") + } + code := s.systemctl(ctx, slog.LevelInfo, args...) + if code != 0 { + return trace.Errorf("unable to disable systemd service") + } + s.Log.InfoContext(ctx, "Systemd service disabled.", unitKey, s.ServiceName) + return nil +} + +// IsEnabled returns true if the service is enabled. +func (s SystemdService) IsEnabled(ctx context.Context) (bool, error) { + if err := s.checkSystem(ctx); err != nil { + return false, trace.Wrap(err) + } + code := s.systemctl(ctx, slog.LevelDebug, "is-enabled", "--quiet", s.ServiceName) + switch { + case code < 0: + return false, trace.Errorf("unable to determine if systemd service %s is enabled", s.ServiceName) + case code == 0: + return true, nil + } + return false, nil +} + +// IsActive returns true if the service is active. +func (s SystemdService) IsActive(ctx context.Context) (bool, error) { + if err := s.checkSystem(ctx); err != nil { + return false, trace.Wrap(err) + } + code := s.systemctl(ctx, slog.LevelDebug, "is-active", "--quiet", s.ServiceName) + switch { + case code < 0: + return false, trace.Errorf("unable to determine if systemd service %s is active", s.ServiceName) + case code == 0: + return true, nil + } + return false, nil +} + +// IsPresent returns true if the service exists. +func (s SystemdService) IsPresent(ctx context.Context) (bool, error) { + if err := s.checkSystem(ctx); err != nil { + return false, trace.Wrap(err) + } + code := s.systemctl(ctx, slog.LevelDebug, "list-unit-files", "--quiet", s.ServiceName) + if code < 0 { + return false, trace.Errorf("unable to determine if systemd service %s is present", s.ServiceName) + } + return code == 0, nil +} + +// checkSystem returns an error if the system is not compatible with this process manager. +func (s SystemdService) checkSystem(ctx context.Context) error { + present, err := hasSystemD() + if err != nil { + return trace.Wrap(err) + } + if !present { + return trace.Wrap(ErrNotSupported) + } + return nil +} + +// hasSystemD returns true if the system uses the SystemD process manager. +func hasSystemD() (bool, error) { + _, err := os.Stat("/run/systemd/system") + if errors.Is(err, os.ErrNotExist) { + return false, nil + } + if err != nil { + return false, trace.Wrap(err) + } + return true, nil +} + +// systemctl returns a systemctl subcommand, converting the output to logs. +// Output sent to stdout is logged at debug level. +// Output sent to stderr is logged at the level specified by errLevel. +func (s SystemdService) systemctl(ctx context.Context, errLevel slog.Level, args ...string) int { + cmd := &localExec{ + Log: s.Log, + ErrLevel: errLevel, + OutLevel: slog.LevelDebug, + } + code, err := cmd.Run(ctx, "systemctl", args...) + if err == nil { + return code + } + if code >= 0 { + s.Log.Log(ctx, errLevel, "Non-zero exit code or error running systemctl.", + "args", args, "code", code) + return code + } + s.Log.Log(ctx, errLevel, "Unable to run systemctl.", + "args", args, "code", code, errorKey, err) + return code +} + +// localExec runs a command locally, logging any output. +type localExec struct { + // Log contains a slog logger. + // Defaults to slog.Default() if nil. + Log *slog.Logger + // ErrLevel is the log level for stderr. + ErrLevel slog.Level + // OutLevel is the log level for stdout. + OutLevel slog.Level +} + +// Run the command. Same arguments as exec.CommandContext. +// Outputs the status code, or -1 if out-of-range or unstarted. +func (c *localExec) Run(ctx context.Context, name string, args ...string) (int, error) { + cmd := exec.CommandContext(ctx, name, args...) + stderr := &lineLogger{ctx: ctx, log: c.Log, level: c.ErrLevel, prefix: "[stderr] "} + stdout := &lineLogger{ctx: ctx, log: c.Log, level: c.OutLevel, prefix: "[stdout] "} + cmd.Stderr = stderr + cmd.Stdout = stdout + err := cmd.Run() + stderr.Flush() + stdout.Flush() + code := cmd.ProcessState.ExitCode() + return code, trace.Wrap(err) +} diff --git a/lib/autoupdate/agent/process_test.go b/lib/autoupdate/agent/process_test.go new file mode 100644 index 0000000000000..c6db70944be5d --- /dev/null +++ b/lib/autoupdate/agent/process_test.go @@ -0,0 +1,313 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package agent + +import ( + "context" + "errors" + "fmt" + "log/slog" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestWaitForStablePID(t *testing.T) { + t.Parallel() + + svc := &SystemdService{ + Log: slog.Default(), + } + + for _, tt := range []struct { + name string + ticks []int + baseline int + minStable int + maxCrashes int + findErrs map[int]error + + finalPID int + errored bool + canceled bool + }{ + { + name: "immediate restart", + ticks: []int{2, 2}, + baseline: 1, + minStable: 1, + maxCrashes: 1, + finalPID: 2, + }, + { + name: "zero stable", + }, + { + name: "immediate crash", + ticks: []int{2, 3}, + baseline: 1, + minStable: 1, + maxCrashes: 0, + errored: true, + finalPID: 3, + }, + { + name: "no changes times out", + ticks: []int{1, 1, 1, 1}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + canceled: true, + finalPID: 1, + }, + { + name: "baseline restart", + ticks: []int{2, 2, 2, 2}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + finalPID: 2, + }, + { + name: "one restart then stable", + ticks: []int{1, 2, 2, 2, 2}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + finalPID: 2, + }, + { + name: "two restarts then stable", + ticks: []int{1, 2, 3, 3, 3, 3}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + finalPID: 3, + }, + { + name: "three restarts then stable", + ticks: []int{1, 2, 3, 4, 4, 4, 4}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + finalPID: 4, + }, + { + name: "too many restarts excluding baseline", + ticks: []int{1, 2, 3, 4, 5}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + errored: true, + finalPID: 5, + }, + { + name: "too many restarts including baseline", + ticks: []int{1, 2, 3, 4}, + baseline: 0, + minStable: 3, + maxCrashes: 2, + errored: true, + finalPID: 4, + }, + { + name: "too many restarts slow", + ticks: []int{1, 1, 1, 2, 2, 2, 3, 3, 3, 4}, + baseline: 0, + minStable: 3, + maxCrashes: 2, + errored: true, + finalPID: 4, + }, + { + name: "too many restarts after stable", + ticks: []int{1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4}, + baseline: 0, + minStable: 3, + maxCrashes: 2, + finalPID: 3, + }, + { + name: "stable after too many restarts", + ticks: []int{1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4}, + baseline: 0, + minStable: 3, + maxCrashes: 2, + errored: true, + finalPID: 4, + }, + { + name: "cancel", + ticks: []int{1, 1, 1}, + baseline: 0, + minStable: 3, + maxCrashes: 2, + canceled: true, + finalPID: 1, + }, + { + name: "stale PID crash", + ticks: []int{2, 2, 2, 2, 2}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + findErrs: map[int]error{ + 2: os.ErrProcessDone, + }, + errored: true, + finalPID: 2, + }, + { + name: "stale PID but fixed", + ticks: []int{2, 2, 3, 3, 3, 3}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + findErrs: map[int]error{ + 2: os.ErrProcessDone, + }, + finalPID: 3, + }, + { + name: "error PID", + ticks: []int{2, 2, 3, 3, 3, 3}, + baseline: 1, + minStable: 3, + maxCrashes: 2, + findErrs: map[int]error{ + 2: errors.New("bad"), + }, + errored: true, + finalPID: 2, + }, + } { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + ch := make(chan int) + go func() { + defer cancel() // always quit after last tick + for _, tick := range tt.ticks { + ch <- tick + } + }() + pid, err := svc.waitForStablePID(ctx, tt.minStable, tt.maxCrashes, + tt.baseline, ch, func(pid int) error { + return tt.findErrs[pid] + }) + require.Equal(t, tt.finalPID, pid) + require.Equal(t, tt.canceled, errors.Is(err, context.Canceled)) + if !tt.canceled { + require.Equal(t, tt.errored, err != nil) + } + }) + } +} + +func TestTickFile(t *testing.T) { + t.Parallel() + + for _, tt := range []struct { + name string + ticks []int + errored bool + }{ + { + name: "consistent", + ticks: []int{1, 1, 1}, + errored: false, + }, + { + name: "divergent", + ticks: []int{1, 2, 3}, + errored: false, + }, + { + name: "start error", + ticks: []int{-1, 1, 1}, + errored: false, + }, + { + name: "ephemeral error", + ticks: []int{1, -1, 1}, + errored: false, + }, + { + name: "end error", + ticks: []int{1, 1, -1}, + errored: true, + }, + { + name: "start missing", + ticks: []int{0, 1, 1}, + errored: false, + }, + { + name: "ephemeral missing", + ticks: []int{1, 0, 1}, + errored: false, + }, + { + name: "end missing", + ticks: []int{1, 1, 0}, + errored: false, + }, + { + name: "cancel-only", + errored: false, + }, + } { + t.Run(tt.name, func(t *testing.T) { + filePath := filepath.Join(t.TempDir(), "file") + + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + defer cancel() + tickC := make(chan time.Time) + ch := make(chan int) + + go func() { + defer cancel() // always quit after last tick or fail + for _, tick := range tt.ticks { + _ = os.RemoveAll(filePath) + switch { + case tick > 0: + err := os.WriteFile(filePath, []byte(fmt.Sprintln(tick)), os.ModePerm) + require.NoError(t, err) + case tick < 0: + err := os.Mkdir(filePath, os.ModePerm) + require.NoError(t, err) + } + tickC <- time.Now() + res := <-ch + if tick < 0 { + tick = 0 + } + require.Equal(t, tick, res) + } + }() + err := tickFile(ctx, filePath, ch, tickC) + require.Equal(t, tt.errored, err != nil) + }) + } +} diff --git a/lib/autoupdate/agent/setup.go b/lib/autoupdate/agent/setup.go new file mode 100644 index 0000000000000..8c4806251fcd5 --- /dev/null +++ b/lib/autoupdate/agent/setup.go @@ -0,0 +1,480 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package agent + +import ( + "bytes" + "context" + "errors" + "io/fs" + "log/slog" + "os" + "path/filepath" + "regexp" + "strings" + "text/template" + + "github.com/google/renameio/v2" + "github.com/gravitational/trace" + "gopkg.in/yaml.v3" + + "github.com/gravitational/teleport/lib/defaults" + libdefaults "github.com/gravitational/teleport/lib/defaults" + libutils "github.com/gravitational/teleport/lib/utils" +) + +// Base paths for constructing namespaced directories. +const ( + defaultInstallDir = "/opt/teleport" + defaultPathDir = "/usr/local/bin" + systemdAdminDir = "/etc/systemd/system" + systemdPIDDir = "/run" + needrestartConfDir = "/etc/needrestart/conf.d" + versionsDirName = "versions" + lockFileName = "update.lock" + defaultNamespace = "default" + systemNamespace = "system" +) + +const ( + // deprecatedTimerName is the timer for the deprecated upgrader should be disabled on setup. + deprecatedTimerName = "teleport-upgrade.timer" +) + +const ( + updateServiceTemplate = `# teleport-update +# DO NOT EDIT THIS FILE +[Unit] +Description=Teleport auto-update service + +[Service] +Type=oneshot +ExecStart={{.UpdaterBinary}} --install-suffix={{.InstallSuffix}} --install-dir="{{escape .InstallDir}}" update +` + updateTimerTemplate = `# teleport-update +# DO NOT EDIT THIS FILE +[Unit] +Description=Teleport auto-update timer unit + +[Timer] +OnActiveSec=1m +OnUnitActiveSec=5m +RandomizedDelaySec=1m + +[Install] +WantedBy={{.TeleportService}} +` + teleportDropInTemplate = `# teleport-update +# DO NOT EDIT THIS FILE +[Service] +Environment="TELEPORT_UPDATE_CONFIG_FILE={{escape .UpdaterConfigFile}}" +Environment="TELEPORT_UPDATE_INSTALL_DIR={{escape .InstallDir}}" +` + // This configuration sets the default value for needrestart-trigger automatic restarts for teleport.service to disabled. + // Users may still choose to enable needrestart for teleport.service when installing packaging interactively (or via dpkg config), + // but doing so will result in a hard restart that disconnects the agent whenever any dependent libraries are updated. + // Other network services, like openvpn, follow this pattern. + // It is possible to configure needrestart to trigger a soft restart (via restart.d script), but given that Teleport subprocesses + // can use a wide variety of installed binaries (when executed by the user), this could trigger many unexpected reloads. + needrestartConfTemplate = `$nrconf{override_rc}{qr(^{{replace .TeleportService "." "\\."}})} = 0; +` +) + +type confParams struct { + TeleportService string + UpdaterBinary string + InstallSuffix string + InstallDir string + Path string + UpdaterConfigFile string +} + +// Namespace represents a namespace within various system paths for a isolated installation of Teleport. +type Namespace struct { + log *slog.Logger + // name of namespace + name string + // installDir for Teleport namespaces (/opt/teleport) + installDir string + // defaultPathDir for Teleport binaries (ns: /opt/teleport/myns/bin) + defaultPathDir string + // dataDir parsed from teleport.yaml, if present + dataDir string + // defaultProxyAddr parsed from teleport.yaml, if present + defaultProxyAddr string + // serviceFile for the Teleport systemd service (ns: /etc/systemd/system/teleport_myns.service) + serviceFile string + // configFile for Teleport config (ns: /etc/teleport_myns.yaml) + configFile string + // pidFile for Teleport (ns: /run/teleport_myns.pid) + pidFile string + // updaterServiceFile is the systemd service path for the updater + updaterServiceFile string + // updaterTimerFile is the systemd timer path for the updater + updaterTimerFile string + // dropInFile is the Teleport systemd drop-in path extending Teleport + dropInFile string + // needrestartConfFile is the path to needrestart configuration for Teleport + needrestartConfFile string +} + +var alphanum = regexp.MustCompile("^[a-zA-Z0-9-]*$") + +// NewNamespace validates and returns a Namespace. +// Namespaces must be alphanumeric + `-`. +// defaultPathDir overrides the destination directory for namespace setup (i.e., /usr/local) +func NewNamespace(ctx context.Context, log *slog.Logger, name, installDir string) (ns *Namespace, err error) { + defer ns.overrideFromConfig(ctx) + + if name == defaultNamespace || + name == systemNamespace { + return nil, trace.Errorf("namespace %s is reserved", name) + } + if !alphanum.MatchString(name) { + return nil, trace.Errorf("invalid namespace name %s, must be alphanumeric", name) + } + if installDir == "" { + installDir = defaultInstallDir + } + if name == "" { + linkDir := defaultPathDir + return &Namespace{ + log: log, + name: name, + installDir: installDir, + defaultPathDir: linkDir, + dataDir: defaults.DataDir, + serviceFile: filepath.Join("/", serviceDir, serviceName), + configFile: defaults.ConfigFilePath, + pidFile: filepath.Join(systemdPIDDir, "teleport.pid"), + updaterServiceFile: filepath.Join(systemdAdminDir, BinaryName+".service"), + updaterTimerFile: filepath.Join(systemdAdminDir, BinaryName+".timer"), + dropInFile: filepath.Join(systemdAdminDir, "teleport.service.d", BinaryName+".conf"), + needrestartConfFile: filepath.Join(needrestartConfDir, BinaryName+".conf"), + }, nil + } + + prefix := "teleport_" + name + linkDir := filepath.Join(installDir, name, "bin") + return &Namespace{ + log: log, + name: name, + installDir: installDir, + defaultPathDir: linkDir, + dataDir: filepath.Join(filepath.Dir(defaults.DataDir), prefix), + serviceFile: filepath.Join(systemdAdminDir, prefix+".service"), + configFile: filepath.Join(filepath.Dir(defaults.ConfigFilePath), prefix+".yaml"), + pidFile: filepath.Join(systemdPIDDir, prefix+".pid"), + updaterServiceFile: filepath.Join(systemdAdminDir, BinaryName+"_"+name+".service"), + updaterTimerFile: filepath.Join(systemdAdminDir, BinaryName+"_"+name+".timer"), + dropInFile: filepath.Join(systemdAdminDir, prefix+".service.d", BinaryName+"_"+name+".conf"), + needrestartConfFile: filepath.Join(needrestartConfDir, BinaryName+"_"+name+".conf"), + }, nil +} + +func (ns *Namespace) Dir() string { + name := ns.name + if name == "" { + name = defaultNamespace + } + return filepath.Join(ns.installDir, name) +} + +// Init create the initial directory structure and returns the lockfile for a Namespace. +func (ns *Namespace) Init() (lockFile string, err error) { + if err := os.MkdirAll(filepath.Join(ns.Dir(), versionsDirName), systemDirMode); err != nil { + return "", trace.Wrap(err) + } + return filepath.Join(ns.Dir(), lockFileName), nil +} + +// Setup installs service and timer files for the teleport-update binary. +// Afterwords, Setup reloads systemd and enables the timer with --now. +func (ns *Namespace) Setup(ctx context.Context, path string) error { + if ok, err := hasSystemD(); err == nil && !ok { + ns.log.WarnContext(ctx, "Systemd is not running, skipping updater installation.") + return nil + } + + err := ns.writeConfigFiles(ctx, path) + if err != nil { + return trace.Wrap(err, "failed to write teleport-update systemd config files") + } + timer := &SystemdService{ + ServiceName: filepath.Base(ns.updaterTimerFile), + Log: ns.log, + } + if err := timer.Sync(ctx); err != nil { + return trace.Wrap(err, "failed to sync systemd config") + } + if err := timer.Enable(ctx, true); err != nil { + return trace.Wrap(err, "failed to enable teleport-update systemd timer") + } + if ns.name == "" { + oldTimer := &SystemdService{ + ServiceName: deprecatedTimerName, + Log: ns.log, + } + // If the old teleport-upgrade script is detected, disable it to ensure they do not interfere. + // Note that the schedule is also set to nop by the Teleport agent -- this just prevents restarts. + enabled, err := isActiveOrEnabled(ctx, oldTimer) + if err != nil { + return trace.Wrap(err, "failed to determine if deprecated teleport-upgrade systemd timer is enabled") + } + if enabled { + if err := oldTimer.Disable(ctx, true); err != nil { + ns.log.ErrorContext(ctx, "The deprecated teleport-ent-updater package is installed on this server, and it cannot be disabled due to an error. You must remove the teleport-ent-updater package after verifying that teleport-update is working.", errorKey, err) + } else { + ns.log.WarnContext(ctx, "The deprecated teleport-ent-updater package is installed on this server. This package has been disabled to prevent conflicts. Please remove the teleport-ent-updater package after verifying that teleport-update is working.") + } + } + } + return nil +} + +// Teardown removes all traces of the auto-updater, including its configuration. +func (ns *Namespace) Teardown(ctx context.Context) error { + if ok, err := hasSystemD(); err == nil && !ok { + ns.log.WarnContext(ctx, "Systemd is not running, skipping updater removal.") + if err := os.RemoveAll(ns.Dir()); err != nil { + return trace.Wrap(err, "failed to remove versions directory") + } + return nil + } + + svc := &SystemdService{ + ServiceName: filepath.Base(ns.updaterTimerFile), + Log: ns.log, + } + if err := svc.Disable(ctx, true); err != nil { + ns.log.WarnContext(ctx, "Unable to disable teleport-update systemd timer before removing.", errorKey, err) + } + for _, p := range []string{ + ns.updaterServiceFile, + ns.updaterTimerFile, + ns.dropInFile, + ns.needrestartConfFile, + } { + if err := os.Remove(p); err != nil && !errors.Is(err, fs.ErrNotExist) { + return trace.Wrap(err, "failed to remove %s", filepath.Base(p)) + } + } + if err := svc.Sync(ctx); err != nil { + return trace.Wrap(err, "failed to sync systemd config") + } + if err := os.RemoveAll(ns.Dir()); err != nil { + return trace.Wrap(err, "failed to remove versions directory") + } + if ns.name == "" { + oldTimer := &SystemdService{ + ServiceName: deprecatedTimerName, + Log: ns.log, + } + // If the old upgrader exists, attempt to re-enable it automatically + present, err := oldTimer.IsPresent(ctx) + if err != nil { + return trace.Wrap(err, "failed to determine if deprecated teleport-upgrade systemd timer is present") + } + if present { + if err := oldTimer.Enable(ctx, true); err != nil { + ns.log.ErrorContext(ctx, "The deprecated teleport-ent-updater package is installed on this server, and it cannot be re-enabled due to an error. Please fix the teleport-ent-updater package if you intend to use the deprecated updater.", errorKey, err) + } else { + ns.log.WarnContext(ctx, "The deprecated teleport-ent-updater package is installed on this server. This package has been re-enabled to ensure continued updates. To disable automatic updates entirely, please remove the teleport-ent-updater package.") + } + } + } + return nil +} + +func (ns *Namespace) writeConfigFiles(ctx context.Context, path string) error { + teleportService := filepath.Base(ns.serviceFile) + params := confParams{ + TeleportService: teleportService, + UpdaterBinary: filepath.Join(path, BinaryName), + InstallSuffix: ns.name, + InstallDir: ns.installDir, + Path: path, + UpdaterConfigFile: filepath.Join(ns.Dir(), updateConfigName), + } + err := writeSystemTemplate(ns.updaterServiceFile, updateServiceTemplate, params) + if err != nil { + return trace.Wrap(err) + } + err = writeSystemTemplate(ns.updaterTimerFile, updateTimerTemplate, params) + if err != nil { + return trace.Wrap(err) + } + err = writeSystemTemplate(ns.dropInFile, teleportDropInTemplate, params) + if err != nil { + return trace.Wrap(err) + } + // Needrestart config is non-critical for updater functionality. + _, err = os.Stat(filepath.Dir(ns.needrestartConfFile)) + if os.IsNotExist(err) { + return nil // needrestart is not present + } + if err != nil { + ns.log.ErrorContext(ctx, "Unable to disable needrestart.", errorKey, err) + return nil + } + ns.log.InfoContext(ctx, "Disabling needrestart.", unitKey, teleportService) + err = writeSystemTemplate(ns.needrestartConfFile, needrestartConfTemplate, params) + if err != nil { + ns.log.ErrorContext(ctx, "Unable to disable needrestart.", errorKey, err) + return nil + } + return nil +} + +// writeSystemTemplate atomically writes a template to a system file, creating any needed directories. +// Temporarily files are stored in the target path to ensure the file has needed SELinux contexts. +func writeSystemTemplate(path, t string, values any) error { + dir, file := filepath.Split(path) + if err := os.MkdirAll(dir, systemDirMode); err != nil { + return trace.Wrap(err) + } + opts := []renameio.Option{ + renameio.WithPermissions(configFileMode), + renameio.WithExistingPermissions(), + renameio.WithTempDir(dir), + } + f, err := renameio.NewPendingFile(path, opts...) + if err != nil { + return trace.Wrap(err) + } + defer f.Cleanup() + + tmpl, err := template.New(file).Funcs(template.FuncMap{ + "replace": func(s, old, new string) string { + return strings.ReplaceAll(s, old, new) + }, + // escape is a best-effort function for escaping quotes in systemd service templates. + // Paths that are escaped with this method should not be advertised to the user as + // configurable until a more robust escaping mechanism is shipped. + // See: https://www.freedesktop.org/software/systemd/man/latest/systemd.syntax.html + "escape": func(s string) string { + replacer := strings.NewReplacer( + `"`, `\"`, + `\`, `\\`, + ) + return replacer.Replace(s) + }, + }).Parse(t) + if err != nil { + return trace.Wrap(err) + } + err = tmpl.Execute(f, values) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(f.CloseAtomicallyReplace()) +} + +// replaceTeleportService replaces the default paths in the Teleport service config with namespaced paths. +func (ns *Namespace) replaceTeleportService(cfg []byte) []byte { + for _, rep := range []struct { + old, new string + }{ + { + old: "/usr/local/bin/", + new: ns.defaultPathDir + "/", + }, + { + old: "/etc/teleport.yaml", + new: ns.configFile, + }, + { + old: "/run/teleport.pid", + new: ns.pidFile, + }, + } { + cfg = bytes.ReplaceAll(cfg, []byte(rep.old), []byte(rep.new)) + } + return cfg +} + +func (ns *Namespace) LogWarning(ctx context.Context) { + ns.log.WarnContext(ctx, "Custom install suffix specified. Teleport data_dir must be configured in the config file.", + "data_dir", ns.dataDir, + "path", ns.defaultPathDir, + "config", ns.configFile, + "service", filepath.Base(ns.serviceFile), + "pid", ns.pidFile, + ) +} + +// unversionedConfig is used to read all versions of teleport.yaml, including +// versions that may now be unsupported. +type unversionedConfig struct { + Teleport unversionedTeleport `yaml:"teleport"` +} + +type unversionedTeleport struct { + AuthServers []string `yaml:"auth_servers"` + AuthServer string `yaml:"auth_server"` + ProxyServer string `yaml:"proxy_server"` + DataDir string `yaml:"data_dir"` +} + +// overrideFromConfig loads fields from teleport.yaml into the namespace, overriding any defaults. +func (ns *Namespace) overrideFromConfig(ctx context.Context) { + if ns == nil || ns.configFile == "" { + return + } + path := ns.configFile + f, err := libutils.OpenFileAllowingUnsafeLinks(path) + if err != nil { + ns.log.DebugContext(ctx, "Unable to open Teleport config to read proxy or data dir", "config", path, errorKey, err) + return + } + defer f.Close() + var cfg unversionedConfig + if err := yaml.NewDecoder(f).Decode(&cfg); err != nil { + ns.log.DebugContext(ctx, "Unable to parse Teleport config to read proxy or data dir", "config", path, errorKey, err) + return + } + if cfg.Teleport.DataDir != "" { + ns.dataDir = cfg.Teleport.DataDir + } + + // Any implicitly defaulted port in teleport.yaml is explicitly defaulted (to 3080). + + var addr string + var port int + switch t := cfg.Teleport; { + case t.ProxyServer != "": + addr = t.ProxyServer + port = libdefaults.HTTPListenPort + case t.AuthServer != "": + addr = t.AuthServer + port = libdefaults.AuthListenPort + case len(t.AuthServers) > 0: + addr = t.AuthServers[0] + port = libdefaults.AuthListenPort + default: + ns.log.DebugContext(ctx, "Unable to find proxy in Teleport config", "config", path, errorKey, err) + return + } + netaddr, err := libutils.ParseHostPortAddr(addr, port) + if err != nil { + ns.log.DebugContext(ctx, "Unable to parse proxy in Teleport config", "config", path, "proxy_addr", addr, "proxy_port", port, errorKey, err) + return + } + ns.defaultProxyAddr = netaddr.String() +} diff --git a/lib/autoupdate/agent/setup_test.go b/lib/autoupdate/agent/setup_test.go new file mode 100644 index 0000000000000..bebda789cadb7 --- /dev/null +++ b/lib/autoupdate/agent/setup_test.go @@ -0,0 +1,365 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package agent + +import ( + "bytes" + "context" + "log/slog" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" + + "github.com/gravitational/teleport/lib/config" + "github.com/gravitational/teleport/lib/utils/testutils/golden" +) + +func TestNewNamespace(t *testing.T) { + for _, p := range []struct { + name string + namespace string + installDir string + errMatch string + ns *Namespace + }{ + { + name: "no namespace", + ns: &Namespace{ + dataDir: "/var/lib/teleport", + installDir: "/opt/teleport", + defaultPathDir: "/usr/local/bin", + serviceFile: "/lib/systemd/system/teleport.service", + configFile: "/etc/teleport.yaml", + pidFile: "/run/teleport.pid", + updaterServiceFile: "/etc/systemd/system/teleport-update.service", + updaterTimerFile: "/etc/systemd/system/teleport-update.timer", + dropInFile: "/etc/systemd/system/teleport.service.d/teleport-update.conf", + needrestartConfFile: "/etc/needrestart/conf.d/teleport-update.conf", + }, + }, + { + name: "no namespace with dirs", + installDir: "/install", + ns: &Namespace{ + dataDir: "/var/lib/teleport", + installDir: "/install", + defaultPathDir: "/usr/local/bin", + serviceFile: "/lib/systemd/system/teleport.service", + configFile: "/etc/teleport.yaml", + pidFile: "/run/teleport.pid", + updaterServiceFile: "/etc/systemd/system/teleport-update.service", + updaterTimerFile: "/etc/systemd/system/teleport-update.timer", + dropInFile: "/etc/systemd/system/teleport.service.d/teleport-update.conf", + needrestartConfFile: "/etc/needrestart/conf.d/teleport-update.conf", + }, + }, + { + name: "test namespace", + namespace: "test", + ns: &Namespace{ + name: "test", + dataDir: "/var/lib/teleport_test", + installDir: "/opt/teleport", + defaultPathDir: "/opt/teleport/test/bin", + serviceFile: "/etc/systemd/system/teleport_test.service", + configFile: "/etc/teleport_test.yaml", + pidFile: "/run/teleport_test.pid", + updaterServiceFile: "/etc/systemd/system/teleport-update_test.service", + updaterTimerFile: "/etc/systemd/system/teleport-update_test.timer", + dropInFile: "/etc/systemd/system/teleport_test.service.d/teleport-update_test.conf", + needrestartConfFile: "/etc/needrestart/conf.d/teleport-update_test.conf", + }, + }, + { + name: "test namespace with dirs", + namespace: "test", + installDir: "/install", + ns: &Namespace{ + name: "test", + dataDir: "/var/lib/teleport_test", + installDir: "/install", + defaultPathDir: "/install/test/bin", + configFile: "/etc/teleport_test.yaml", + pidFile: "/run/teleport_test.pid", + serviceFile: "/etc/systemd/system/teleport_test.service", + updaterServiceFile: "/etc/systemd/system/teleport-update_test.service", + updaterTimerFile: "/etc/systemd/system/teleport-update_test.timer", + dropInFile: "/etc/systemd/system/teleport_test.service.d/teleport-update_test.conf", + needrestartConfFile: "/etc/needrestart/conf.d/teleport-update_test.conf", + }, + }, + { + name: "reserved default", + namespace: defaultNamespace, + errMatch: "reserved", + }, + { + name: "reserved system", + namespace: systemNamespace, + errMatch: "reserved", + }, + } { + t.Run(p.name, func(t *testing.T) { + log := slog.Default() + ctx := context.Background() + ns, err := NewNamespace(ctx, log, p.namespace, p.installDir) + if p.errMatch != "" { + require.Error(t, err) + require.Contains(t, err.Error(), p.errMatch) + return + } + require.NoError(t, err) + ns.log = nil + require.Equal(t, p.ns, ns) + }) + } +} + +func TestWriteConfigFiles(t *testing.T) { + for _, p := range []struct { + name string + namespace string + }{ + { + name: "no namespace", + }, + { + name: "test namespace", + namespace: "test", + }, + } { + t.Run(p.name, func(t *testing.T) { + log := slog.Default() + linkDir := t.TempDir() + ctx := context.Background() + ns, err := NewNamespace(ctx, log, p.namespace, "") + require.NoError(t, err) + ns.updaterServiceFile = filepath.Join(linkDir, serviceDir, filepath.Base(ns.updaterServiceFile)) + ns.updaterTimerFile = filepath.Join(linkDir, serviceDir, filepath.Base(ns.updaterTimerFile)) + ns.dropInFile = filepath.Join(linkDir, serviceDir, filepath.Base(filepath.Dir(ns.dropInFile)), filepath.Base(ns.dropInFile)) + ns.needrestartConfFile = filepath.Join(linkDir, filepath.Base(ns.dropInFile)) + err = ns.writeConfigFiles(ctx, linkDir) + require.NoError(t, err) + + for _, tt := range []struct { + name string + path string + }{ + {name: "service", path: ns.updaterServiceFile}, + {name: "timer", path: ns.updaterTimerFile}, + {name: "dropin", path: ns.dropInFile}, + {name: "needrestart", path: ns.needrestartConfFile}, + } { + t.Run(tt.name, func(t *testing.T) { + data, err := os.ReadFile(tt.path) + require.NoError(t, err) + data = replaceValues(data, map[string]string{ + defaultPathDir: linkDir, + }) + if golden.ShouldSet() { + golden.Set(t, data) + } + require.Equal(t, string(golden.Get(t)), string(data)) + }) + } + }) + } +} + +func replaceValues(data []byte, m map[string]string) []byte { + for k, v := range m { + data = bytes.ReplaceAll(data, []byte(v), []byte(k)) + } + return data +} + +func TestNamespace_overrideFromConfig(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *unversionedTeleport + want Namespace + }{ + { + name: "default", + cfg: &unversionedTeleport{ + ProxyServer: "example.com", + DataDir: "/data", + }, + want: Namespace{ + defaultProxyAddr: "example.com:3080", + dataDir: "/data", + }, + }, + { + name: "empty", + cfg: &unversionedTeleport{}, + want: Namespace{ + defaultProxyAddr: "default.example.com", + dataDir: "/var/lib/teleport", + }, + }, + { + name: "full proxy", + cfg: &unversionedTeleport{ + ProxyServer: "https://example.com:8080", + }, + want: Namespace{ + defaultProxyAddr: "example.com:8080", + dataDir: "/var/lib/teleport", + }, + }, + { + name: "protocol and host", + cfg: &unversionedTeleport{ + ProxyServer: "https://example.com", + }, + want: Namespace{ + defaultProxyAddr: "example.com:3080", + dataDir: "/var/lib/teleport", + }, + }, + { + name: "host and port", + cfg: &unversionedTeleport{ + ProxyServer: "example.com:443", + }, + want: Namespace{ + defaultProxyAddr: "example.com:443", + dataDir: "/var/lib/teleport", + }, + }, + { + name: "host", + cfg: &unversionedTeleport{ + ProxyServer: "example.com", + }, + want: Namespace{ + defaultProxyAddr: "example.com:3080", + dataDir: "/var/lib/teleport", + }, + }, + { + name: "auth server (v3)", + cfg: &unversionedTeleport{ + AuthServer: "example.com", + }, + want: Namespace{ + defaultProxyAddr: "example.com:3025", + dataDir: "/var/lib/teleport", + }, + }, + { + name: "auth server (v1/2)", + cfg: &unversionedTeleport{ + AuthServers: []string{ + "one.example.com", + "two.example.com", + }, + }, + want: Namespace{ + defaultProxyAddr: "one.example.com:3025", + dataDir: "/var/lib/teleport", + }, + }, + { + name: "proxy priority", + cfg: &unversionedTeleport{ + ProxyServer: "one.example.com", + AuthServer: "two.example.com", + AuthServers: []string{"three.example.com"}, + }, + want: Namespace{ + defaultProxyAddr: "one.example.com:3080", + dataDir: "/var/lib/teleport", + }, + }, + { + name: "auth priority", + cfg: &unversionedTeleport{ + AuthServer: "two.example.com", + AuthServers: []string{"three.example.com"}, + }, + want: Namespace{ + defaultProxyAddr: "two.example.com:3025", + dataDir: "/var/lib/teleport", + }, + }, + { + name: "missing", + want: Namespace{ + defaultProxyAddr: "default.example.com", + dataDir: "/var/lib/teleport", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ns := &Namespace{ + log: slog.Default(), + configFile: filepath.Join(t.TempDir(), "teleport.yaml"), + defaultProxyAddr: "default.example.com", + dataDir: "/var/lib/teleport", + } + if tt.cfg != nil { + out, err := yaml.Marshal(unversionedConfig{Teleport: *tt.cfg}) + require.NoError(t, err) + err = os.WriteFile(ns.configFile, out, os.ModePerm) + require.NoError(t, err) + } + ctx := context.Background() + ns.overrideFromConfig(ctx) + ns.configFile = "" + ns.log = nil + require.Equal(t, &tt.want, ns) + }) + } +} + +// In the future, the latest version of the updater may need to read a version of teleport.yaml that has +// an unsupported version which is supported by the updater-managed version of Teleport. +// This test will break if Teleport removes a field that the updater reads. +func TestUnversionedTeleportConfig(t *testing.T) { + in := unversionedConfig{ + Teleport: unversionedTeleport{ + ProxyServer: "proxy.example.com", + AuthServer: "auth.example.com", + AuthServers: []string{"auth1.example.com", "auth2.example.com"}, + DataDir: "example_dir", + }, + } + var inB bytes.Buffer + err := yaml.NewEncoder(&inB).Encode(in) + require.NoError(t, err) + fc, err := config.ReadConfig(&inB) + require.NoError(t, err) + + var outB bytes.Buffer + err = yaml.NewEncoder(&outB).Encode(fc) + require.NoError(t, err) + + var out unversionedConfig + err = yaml.NewDecoder(&outB).Decode(&out) + require.NoError(t, err) + require.Equal(t, in, out) +} diff --git a/lib/autoupdate/agent/telemetry.go b/lib/autoupdate/agent/telemetry.go new file mode 100644 index 0000000000000..bd6bb887cad40 --- /dev/null +++ b/lib/autoupdate/agent/telemetry.go @@ -0,0 +1,105 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package agent + +import ( + "os" + "path/filepath" + "strings" + + "github.com/gravitational/trace" +) + +const installDirEnvVar = "TELEPORT_UPDATE_INSTALL_DIR" + +// IsManagedByUpdater returns true if the local Teleport binary is managed by teleport-update. +// Note that true may be returned even if auto-updates is disabled or the version is pinned. +// The binary is considered managed if it lives under /opt/teleport, but not within the package +// path at /opt/teleport/system. +func IsManagedByUpdater() (bool, error) { + systemd, err := hasSystemD() + if err != nil { + return false, trace.Wrap(err) + } + if !systemd { + return false, nil + } + teleportPath, err := os.Readlink("/proc/self/exe") + if err != nil { + return false, trace.Wrap(err, "cannot find Teleport binary") + } + installDir := os.Getenv(installDirEnvVar) + if installDir == "" { + installDir = defaultInstallDir + } + // Check if current binary is under the updater-managed path. + managed, err := hasParentDir(teleportPath, installDir) + if err != nil { + return false, trace.Wrap(err) + } + if !managed { + return false, nil + } + // Return false if the binary is under the updater-managed path, but in the system prefix reserved for the package. + system, err := hasParentDir(teleportPath, packageSystemDir) + return !system, trace.Wrap(err) +} + +// IsManagedAndDefault returns true if the local Teleport binary is both managed by teleport-update +// and the default installation (with teleport.service as the unit file name). +// The binary is considered managed and default if it lives within /opt/teleport/default. +func IsManagedAndDefault() (bool, error) { + systemd, err := hasSystemD() + if err != nil { + return false, trace.Wrap(err) + } + if !systemd { + return false, nil + } + teleportPath, err := os.Readlink("/proc/self/exe") + if err != nil { + return false, trace.Wrap(err, "cannot find Teleport binary") + } + installDir := os.Getenv(installDirEnvVar) + if installDir == "" { + installDir = defaultInstallDir + } + isDefault, err := hasParentDir(teleportPath, filepath.Join(installDir, defaultNamespace)) + return isDefault, trace.Wrap(err) +} + +// hasParentDir returns true if dir is any parent directory of parent. +// hasParentDir does not resolve symlinks, and requires that files be represented the same way in dir and parent. +func hasParentDir(dir, parent string) (bool, error) { + // Note that os.Stat + os.SameFile would be more reliable, + // but does not work well for arbitrarily nested subdirectories. + absDir, err := filepath.Abs(dir) + if err != nil { + return false, trace.Wrap(err, "cannot get absolute path for directory %s", dir) + } + absParent, err := filepath.Abs(parent) + if err != nil { + return false, trace.Wrap(err, "cannot get absolute path for parent directory %s", dir) + } + sep := string(filepath.Separator) + if !strings.HasSuffix(absParent, sep) { + absParent += sep + } + return strings.HasPrefix(absDir, absParent), nil +} diff --git a/lib/autoupdate/agent/telemetry_test.go b/lib/autoupdate/agent/telemetry_test.go new file mode 100644 index 0000000000000..8332657785c04 --- /dev/null +++ b/lib/autoupdate/agent/telemetry_test.go @@ -0,0 +1,109 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package agent + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestHasParentDir(t *testing.T) { + tests := []struct { + name string + path string + parent string + wantResult bool + }{ + { + name: "Has valid parent directory", + path: "/opt/teleport/dir/test", + parent: "/opt/teleport", + wantResult: true, + }, + { + name: "Has valid parent directory with slash", + path: "/opt/teleport/dir/test", + parent: "/opt/teleport/", + wantResult: true, + }, + { + name: "Parent directory is root", + path: "/opt/teleport/dir", + parent: "/", + wantResult: true, + }, + { + name: "Parent is the same as the path", + path: "/opt/teleport/dir", + parent: "/opt/teleport/dir", + wantResult: false, + }, + { + name: "Parent the same as the path but without slash", + path: "/opt/teleport/dir/", + parent: "/opt/teleport/dir", + wantResult: false, + }, + { + name: "Parent the same as the path but with slash", + path: "/opt/teleport/dir", + parent: "/opt/teleport/dir/", + wantResult: false, + }, + { + name: "Parent is substring of the path", + path: "/opt/teleport/dir-place", + parent: "/opt/teleport/dir", + wantResult: false, + }, + { + name: "Parent is in path", + path: "/opt/teleport", + parent: "/opt/teleport/dir", + wantResult: false, + }, + { + name: "Empty parent", + path: "/opt/teleport/dir", + parent: "", + wantResult: false, + }, + { + name: "Empty path", + path: "", + parent: "/opt/teleport", + wantResult: false, + }, + { + name: "Both empty", + path: "", + parent: "", + wantResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := hasParentDir(tt.path, tt.parent) + require.NoError(t, err) + require.Equal(t, tt.wantResult, result) + }) + } +} diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Disable/already_disabled.golden b/lib/autoupdate/agent/testdata/TestUpdater_Disable/already_disabled.golden index e8773a6d88b7f..6e104086250e3 100644 --- a/lib/autoupdate/agent/testdata/TestUpdater_Disable/already_disabled.golden +++ b/lib/autoupdate/agent/testdata/TestUpdater_Disable/already_disabled.golden @@ -2,8 +2,9 @@ version: v1 kind: update_config spec: proxy: "" - group: "" - url_template: "" + path: "" enabled: false + pinned: false status: - active_version: "" + active: + version: "" diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Disable/enabled.golden b/lib/autoupdate/agent/testdata/TestUpdater_Disable/enabled.golden index e8773a6d88b7f..6e104086250e3 100644 --- a/lib/autoupdate/agent/testdata/TestUpdater_Disable/enabled.golden +++ b/lib/autoupdate/agent/testdata/TestUpdater_Disable/enabled.golden @@ -2,8 +2,9 @@ version: v1 kind: update_config spec: proxy: "" - group: "" - url_template: "" + path: "" enabled: false + pinned: false status: - active_version: "" + active: + version: "" diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_does_not_exist.golden b/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_does_not_exist.golden deleted file mode 100644 index e03f369eb1017..0000000000000 --- a/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_does_not_exist.golden +++ /dev/null @@ -1,9 +0,0 @@ -version: v1 -kind: update_config -spec: - proxy: localhost - group: "" - url_template: "" - enabled: true -status: - active_version: 16.3.0 diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_from_file.golden b/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_from_file.golden deleted file mode 100644 index b172d858bc55a..0000000000000 --- a/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_from_file.golden +++ /dev/null @@ -1,9 +0,0 @@ -version: v1 -kind: update_config -spec: - proxy: localhost - group: group - url_template: https://example.com - enabled: true -status: - active_version: 16.3.0 diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_from_user.golden b/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_from_user.golden deleted file mode 100644 index bb9ce8b9d8fa8..0000000000000 --- a/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_from_user.golden +++ /dev/null @@ -1,9 +0,0 @@ -version: v1 -kind: update_config -spec: - proxy: localhost - group: new-group - url_template: https://example.com/new - enabled: true -status: - active_version: new-version diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Enable/version_already_installed.golden b/lib/autoupdate/agent/testdata/TestUpdater_Enable/version_already_installed.golden deleted file mode 100644 index e03f369eb1017..0000000000000 --- a/lib/autoupdate/agent/testdata/TestUpdater_Enable/version_already_installed.golden +++ /dev/null @@ -1,9 +0,0 @@ -version: v1 -kind: update_config -spec: - proxy: localhost - group: "" - url_template: "" - enabled: true -status: - active_version: 16.3.0 diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/FIPS_and_Enterprise_flags.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/FIPS_and_Enterprise_flags.golden new file mode 100644 index 0000000000000..04c32df19ee6f --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/FIPS_and_Enterprise_flags.golden @@ -0,0 +1,11 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + enabled: false + pinned: false +status: + active: + version: 16.3.0 + flags: [Enterprise, FIPS] diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/agpl_requires_base_URL.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/agpl_requires_base_URL.golden new file mode 100644 index 0000000000000..6e104086250e3 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/agpl_requires_base_URL.golden @@ -0,0 +1,10 @@ +version: v1 +kind: update_config +spec: + proxy: "" + path: "" + enabled: false + pinned: false +status: + active: + version: "" diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/backup_version_kept_for_validation.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/backup_version_kept_for_validation.golden new file mode 100644 index 0000000000000..07eedcd8cadbb --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/backup_version_kept_for_validation.golden @@ -0,0 +1,12 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + enabled: false + pinned: false +status: + active: + version: 16.3.0 + backup: + version: backup-version diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/backup_version_removed_on_install.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/backup_version_removed_on_install.golden new file mode 100644 index 0000000000000..fb70091fe8d78 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/backup_version_removed_on_install.golden @@ -0,0 +1,12 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + enabled: false + pinned: false +status: + active: + version: 16.3.0 + backup: + version: old-version diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/config_does_not_exist.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/config_does_not_exist.golden new file mode 100644 index 0000000000000..36c71be81e379 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/config_does_not_exist.golden @@ -0,0 +1,10 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + enabled: false + pinned: false +status: + active: + version: 16.3.0 diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/config_from_file.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/config_from_file.golden new file mode 100644 index 0000000000000..1918a2a3434d1 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/config_from_file.golden @@ -0,0 +1,14 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /path + group: group + base_url: https://example.com + enabled: true + pinned: false +status: + active: + version: 16.3.0 + backup: + version: old-version diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/config_from_user.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/config_from_user.golden new file mode 100644 index 0000000000000..25cd918d5af79 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/config_from_user.golden @@ -0,0 +1,14 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /path + group: new-group + base_url: https://example.com/new + enabled: true + pinned: false +status: + active: + version: new-version + backup: + version: old-version diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/defaults.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/defaults.golden new file mode 100644 index 0000000000000..fb70091fe8d78 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/defaults.golden @@ -0,0 +1,12 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + enabled: false + pinned: false +status: + active: + version: 16.3.0 + backup: + version: old-version diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/insecure_URL.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/insecure_URL.golden new file mode 100644 index 0000000000000..69b08b9c83c9d --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/insecure_URL.golden @@ -0,0 +1,11 @@ +version: v1 +kind: update_config +spec: + proxy: "" + path: "" + base_url: http://example.com + enabled: false + pinned: false +status: + active: + version: "" diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/install_error.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/install_error.golden new file mode 100644 index 0000000000000..6e104086250e3 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/install_error.golden @@ -0,0 +1,10 @@ +version: v1 +kind: update_config +spec: + proxy: "" + path: "" + enabled: false + pinned: false +status: + active: + version: "" diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/invalid_metadata.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/invalid_metadata.golden new file mode 100644 index 0000000000000..0c3dcaac8edbd --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/invalid_metadata.golden @@ -0,0 +1,10 @@ +version: "" +kind: "" +spec: + proxy: "" + path: "" + enabled: false + pinned: false +status: + active: + version: "" diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/no_need_to_reload.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/no_need_to_reload.golden new file mode 100644 index 0000000000000..36c71be81e379 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/no_need_to_reload.golden @@ -0,0 +1,10 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + enabled: false + pinned: false +status: + active: + version: 16.3.0 diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/not_started_or_enabled.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/not_started_or_enabled.golden new file mode 100644 index 0000000000000..36c71be81e379 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/not_started_or_enabled.golden @@ -0,0 +1,10 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + enabled: false + pinned: false +status: + active: + version: 16.3.0 diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/override_skip.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/override_skip.golden new file mode 100644 index 0000000000000..fb70091fe8d78 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/override_skip.golden @@ -0,0 +1,12 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + enabled: false + pinned: false +status: + active: + version: 16.3.0 + backup: + version: old-version diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/setup_fails_already_installed.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/setup_fails_already_installed.golden new file mode 100644 index 0000000000000..067fcf60bd527 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/setup_fails_already_installed.golden @@ -0,0 +1,10 @@ +version: v1 +kind: update_config +spec: + proxy: "" + path: "" + enabled: false + pinned: false +status: + active: + version: 16.3.0 diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/version_already_installed.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/version_already_installed.golden new file mode 100644 index 0000000000000..36c71be81e379 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/version_already_installed.golden @@ -0,0 +1,10 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + enabled: false + pinned: false +status: + active: + version: 16.3.0 diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Unpin/not_pinned.golden b/lib/autoupdate/agent/testdata/TestUpdater_Unpin/not_pinned.golden new file mode 100644 index 0000000000000..6e104086250e3 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Unpin/not_pinned.golden @@ -0,0 +1,10 @@ +version: v1 +kind: update_config +spec: + proxy: "" + path: "" + enabled: false + pinned: false +status: + active: + version: "" diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Unpin/pinned.golden b/lib/autoupdate/agent/testdata/TestUpdater_Unpin/pinned.golden new file mode 100644 index 0000000000000..6e104086250e3 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Unpin/pinned.golden @@ -0,0 +1,10 @@ +version: v1 +kind: update_config +spec: + proxy: "" + path: "" + enabled: false + pinned: false +status: + active: + version: "" diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/FIPS_and_Enterprise_flags.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/FIPS_and_Enterprise_flags.golden new file mode 100644 index 0000000000000..f13285e6d27cb --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/FIPS_and_Enterprise_flags.golden @@ -0,0 +1,15 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + base_url: https://example.com + enabled: true + pinned: false +status: + active: + version: 16.3.0 + flags: [Enterprise, FIPS] + backup: + version: old-version + flags: [Enterprise, FIPS] diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/agpl_requires_base_URL.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/agpl_requires_base_URL.golden new file mode 100644 index 0000000000000..d50417064563b --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/agpl_requires_base_URL.golden @@ -0,0 +1,12 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + enabled: true + pinned: false +status: + active: + version: old-version + backup: + version: backup-version diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/backup_version_is_linked.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/backup_version_is_linked.golden new file mode 100644 index 0000000000000..a153751d1854a --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/backup_version_is_linked.golden @@ -0,0 +1,13 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + base_url: https://example.com + enabled: true + pinned: false +status: + active: + version: 16.3.0 + backup: + version: old-version diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/backup_version_kept_when_no_change.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/backup_version_kept_when_no_change.golden new file mode 100644 index 0000000000000..d257cd6c30282 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/backup_version_kept_when_no_change.golden @@ -0,0 +1,13 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + base_url: https://example.com + enabled: true + pinned: false +status: + active: + version: 16.3.0 + backup: + version: backup-version diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/backup_version_removed_on_install.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/backup_version_removed_on_install.golden new file mode 100644 index 0000000000000..a153751d1854a --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/backup_version_removed_on_install.golden @@ -0,0 +1,13 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + base_url: https://example.com + enabled: true + pinned: false +status: + active: + version: 16.3.0 + backup: + version: old-version diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/insecure_URL.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/insecure_URL.golden new file mode 100644 index 0000000000000..297b00ce4ecf8 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/insecure_URL.golden @@ -0,0 +1,11 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + base_url: http://example.com + enabled: true + pinned: false +status: + active: + version: "" diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Enable/already_enabled.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/install_error.golden similarity index 53% rename from lib/autoupdate/agent/testdata/TestUpdater_Enable/already_enabled.golden rename to lib/autoupdate/agent/testdata/TestUpdater_Update/install_error.golden index e03f369eb1017..c68ebdf570843 100644 --- a/lib/autoupdate/agent/testdata/TestUpdater_Enable/already_enabled.golden +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/install_error.golden @@ -2,8 +2,9 @@ version: v1 kind: update_config spec: proxy: localhost - group: "" - url_template: "" + path: /usr/local/bin enabled: true + pinned: false status: - active_version: 16.3.0 + active: + version: "" diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/invalid_metadata.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/invalid_metadata.golden new file mode 100644 index 0000000000000..a771da1fc9e25 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/invalid_metadata.golden @@ -0,0 +1,10 @@ +version: "" +kind: "" +spec: + proxy: localhost + path: "" + enabled: false + pinned: false +status: + active: + version: "" diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/missing_path_during_window.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/missing_path_during_window.golden new file mode 100644 index 0000000000000..f29735ca0230b --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/missing_path_during_window.golden @@ -0,0 +1,12 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: "" + group: group + base_url: https://example.com + enabled: true + pinned: false +status: + active: + version: old-version diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/pinned_version.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/pinned_version.golden new file mode 100644 index 0000000000000..501506cf96c78 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/pinned_version.golden @@ -0,0 +1,13 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + base_url: https://example.com + enabled: true + pinned: true +status: + active: + version: old-version + backup: + version: backup-version diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/setup_fails.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/setup_fails.golden new file mode 100644 index 0000000000000..10b36430ffed1 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/setup_fails.golden @@ -0,0 +1,15 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + base_url: https://example.com + enabled: true + pinned: false +status: + active: + version: old-version + backup: + version: backup-version + skip: + version: 16.3.0 diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/skip_version.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/skip_version.golden new file mode 100644 index 0000000000000..10b36430ffed1 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/skip_version.golden @@ -0,0 +1,15 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + base_url: https://example.com + enabled: true + pinned: false +status: + active: + version: old-version + backup: + version: backup-version + skip: + version: 16.3.0 diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_disabled_during_window.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_disabled_during_window.golden new file mode 100644 index 0000000000000..b6b43595a5903 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_disabled_during_window.golden @@ -0,0 +1,12 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + group: group + base_url: https://example.com + enabled: false + pinned: false +status: + active: + version: old-version diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_disabled_outside_of_window.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_disabled_outside_of_window.golden new file mode 100644 index 0000000000000..b6b43595a5903 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_disabled_outside_of_window.golden @@ -0,0 +1,12 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + group: group + base_url: https://example.com + enabled: false + pinned: false +status: + active: + version: old-version diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_enabled_during_window.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_enabled_during_window.golden new file mode 100644 index 0000000000000..d7fb9f94d1354 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_enabled_during_window.golden @@ -0,0 +1,14 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + group: group + base_url: https://example.com + enabled: true + pinned: false +status: + active: + version: 16.3.0 + backup: + version: old-version diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_enabled_now,_not_started_or_enabled.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_enabled_now,_not_started_or_enabled.golden new file mode 100644 index 0000000000000..d7fb9f94d1354 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_enabled_now,_not_started_or_enabled.golden @@ -0,0 +1,14 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + group: group + base_url: https://example.com + enabled: true + pinned: false +status: + active: + version: 16.3.0 + backup: + version: old-version diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_enabled_now.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_enabled_now.golden new file mode 100644 index 0000000000000..d7fb9f94d1354 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_enabled_now.golden @@ -0,0 +1,14 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + group: group + base_url: https://example.com + enabled: true + pinned: false +status: + active: + version: 16.3.0 + backup: + version: old-version diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_enabled_outside_of_window.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_enabled_outside_of_window.golden new file mode 100644 index 0000000000000..a4cac37b8733c --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_enabled_outside_of_window.golden @@ -0,0 +1,12 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + group: group + base_url: https://example.com + enabled: true + pinned: false +status: + active: + version: old-version diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/version_already_installed_in_window.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/version_already_installed_in_window.golden new file mode 100644 index 0000000000000..926667a2e7fc0 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/version_already_installed_in_window.golden @@ -0,0 +1,11 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + base_url: https://example.com + enabled: true + pinned: false +status: + active: + version: 16.3.0 diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/version_already_installed_outside_of_window.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/version_already_installed_outside_of_window.golden new file mode 100644 index 0000000000000..926667a2e7fc0 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/version_already_installed_outside_of_window.golden @@ -0,0 +1,11 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + base_url: https://example.com + enabled: true + pinned: false +status: + active: + version: 16.3.0 diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/version_detects_as_linked.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/version_detects_as_linked.golden new file mode 100644 index 0000000000000..a153751d1854a --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/version_detects_as_linked.golden @@ -0,0 +1,13 @@ +version: v1 +kind: update_config +spec: + proxy: localhost + path: /usr/local/bin + base_url: https://example.com + enabled: true + pinned: false +status: + active: + version: 16.3.0 + backup: + version: old-version diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/no_namespace/dropin.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/no_namespace/dropin.golden new file mode 100644 index 0000000000000..cb09143fe9fdf --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/no_namespace/dropin.golden @@ -0,0 +1,5 @@ +# teleport-update +# DO NOT EDIT THIS FILE +[Service] +Environment="TELEPORT_UPDATE_CONFIG_FILE=/opt/teleport/default/update.yaml" +Environment="TELEPORT_UPDATE_INSTALL_DIR=/opt/teleport" diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/no_namespace/needrestart.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/no_namespace/needrestart.golden new file mode 100644 index 0000000000000..b5d6a74435cb2 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/no_namespace/needrestart.golden @@ -0,0 +1 @@ +$nrconf{override_rc}{qr(^teleport\.service)} = 0; diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/no_namespace/service.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/no_namespace/service.golden new file mode 100644 index 0000000000000..b71ce08a9c4bb --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/no_namespace/service.golden @@ -0,0 +1,8 @@ +# teleport-update +# DO NOT EDIT THIS FILE +[Unit] +Description=Teleport auto-update service + +[Service] +Type=oneshot +ExecStart=/usr/local/bin/teleport-update --install-suffix= --install-dir="/opt/teleport" update diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/no_namespace/timer.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/no_namespace/timer.golden new file mode 100644 index 0000000000000..d14a43d679e53 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/no_namespace/timer.golden @@ -0,0 +1,12 @@ +# teleport-update +# DO NOT EDIT THIS FILE +[Unit] +Description=Teleport auto-update timer unit + +[Timer] +OnActiveSec=1m +OnUnitActiveSec=5m +RandomizedDelaySec=1m + +[Install] +WantedBy=teleport.service diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/test_namespace/dropin.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/test_namespace/dropin.golden new file mode 100644 index 0000000000000..dc6445dc6e7f9 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/test_namespace/dropin.golden @@ -0,0 +1,5 @@ +# teleport-update +# DO NOT EDIT THIS FILE +[Service] +Environment="TELEPORT_UPDATE_CONFIG_FILE=/opt/teleport/test/update.yaml" +Environment="TELEPORT_UPDATE_INSTALL_DIR=/opt/teleport" diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/test_namespace/needrestart.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/test_namespace/needrestart.golden new file mode 100644 index 0000000000000..ad6bd606a74cb --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/test_namespace/needrestart.golden @@ -0,0 +1 @@ +$nrconf{override_rc}{qr(^teleport_test\.service)} = 0; diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/test_namespace/service.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/test_namespace/service.golden new file mode 100644 index 0000000000000..832658a8d5a61 --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/test_namespace/service.golden @@ -0,0 +1,8 @@ +# teleport-update +# DO NOT EDIT THIS FILE +[Unit] +Description=Teleport auto-update service + +[Service] +Type=oneshot +ExecStart=/usr/local/bin/teleport-update --install-suffix=test --install-dir="/opt/teleport" update diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/test_namespace/timer.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/test_namespace/timer.golden new file mode 100644 index 0000000000000..f57a3c08055bc --- /dev/null +++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/test_namespace/timer.golden @@ -0,0 +1,12 @@ +# teleport-update +# DO NOT EDIT THIS FILE +[Unit] +Description=Teleport auto-update timer unit + +[Timer] +OnActiveSec=1m +OnUnitActiveSec=5m +RandomizedDelaySec=1m + +[Install] +WantedBy=teleport_test.service diff --git a/lib/autoupdate/agent/updater.go b/lib/autoupdate/agent/updater.go index 59df5f0b3ba85..5114ef7a223c8 100644 --- a/lib/autoupdate/agent/updater.go +++ b/lib/autoupdate/agent/updater.go @@ -23,76 +23,55 @@ import ( "crypto/tls" "crypto/x509" "errors" - "io/fs" "log/slog" "net/http" "os" + "os/exec" "path/filepath" - "strings" + "runtime" + "slices" "time" - "github.com/google/renameio/v2" "github.com/gravitational/trace" - "gopkg.in/yaml.v3" "github.com/gravitational/teleport/api/client/webclient" + "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/lib/autoupdate" + "github.com/gravitational/teleport/lib/client/debug" libdefaults "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/modules" libutils "github.com/gravitational/teleport/lib/utils" ) const ( - // cdnURITemplate is the default template for the Teleport tgz download. - cdnURITemplate = "https://cdn.teleport.dev/teleport{{if .Enterprise}}-ent{{end}}-v{{.Version}}-{{.OS}}-{{.Arch}}{{if .FIPS}}-fips{{end}}-bin.tar.gz" + // BinaryName specifies the name of the updater binary. + BinaryName = "teleport-update" +) + +const ( + // packageSystemDir is the location where packaged Teleport binaries and services are installed. + packageSystemDir = "/opt/teleport/system" // reservedFreeDisk is the minimum required free space left on disk during downloads. // TODO(sclevine): This value is arbitrary and could be replaced by, e.g., min(1%, 200mb) in the future // to account for a range of disk sizes. - reservedFreeDisk = 10_000_000 // 10 MB + reservedFreeDisk = 10_000_000 + // debugSocketFileName is the name of Teleport's debug socket in the data dir. + debugSocketFileName = "debug.sock" // 10 MB ) +// Log keys const ( - // updateConfigName specifies the name of the file inside versionsDirName containing configuration for the teleport update. - updateConfigName = "update.yaml" - - // UpdateConfig metadata - updateConfigVersion = "v1" - updateConfigKind = "update_config" + targetKey = "target_version" + activeKey = "active_version" + backupKey = "backup_version" + errorKey = "error" ) -// UpdateConfig describes the update.yaml file schema. -type UpdateConfig struct { - // Version of the configuration file - Version string `yaml:"version"` - // Kind of configuration file (always "update_config") - Kind string `yaml:"kind"` - // Spec contains user-specified configuration. - Spec UpdateSpec `yaml:"spec"` - // Status contains state configuration. - Status UpdateStatus `yaml:"status"` -} - -// UpdateSpec describes the spec field in update.yaml. -type UpdateSpec struct { - // Proxy address - Proxy string `yaml:"proxy"` - // Group specifies the update group identifier for the agent. - Group string `yaml:"group"` - // URLTemplate for the Teleport tgz download URL. - URLTemplate string `yaml:"url_template"` - // Enabled controls whether auto-updates are enabled. - Enabled bool `yaml:"enabled"` -} - -// UpdateStatus describes the status field in update.yaml. -type UpdateStatus struct { - // ActiveVersion is the currently active Teleport version. - ActiveVersion string `yaml:"active_version"` -} - // NewLocalUpdater returns a new Updater that auto-updates local // installations of the Teleport agent. // The AutoUpdater uses an HTTP client with sane defaults for downloads, and // will not fill disk to within 10 MB of available capacity. -func NewLocalUpdater(cfg LocalUpdaterConfig) (*Updater, error) { +func NewLocalUpdater(cfg LocalUpdaterConfig, ns *Namespace) (*Updater, error) { certPool, err := x509.SystemCertPool() if err != nil { return nil, trace.Wrap(err) @@ -112,19 +91,64 @@ func NewLocalUpdater(cfg LocalUpdaterConfig) (*Updater, error) { if cfg.Log == nil { cfg.Log = slog.Default() } + if cfg.SystemDir == "" { + cfg.SystemDir = packageSystemDir + } + validator := Validator{Log: cfg.Log} + debugClient := debug.NewClient(filepath.Join(ns.dataDir, debugSocketFileName)) return &Updater{ Log: cfg.Log, Pool: certPool, InsecureSkipVerify: cfg.InsecureSkipVerify, - ConfigPath: filepath.Join(cfg.VersionsDir, updateConfigName), + UpdateConfigPath: filepath.Join(ns.Dir(), updateConfigName), + TeleportConfigPath: ns.configFile, + DefaultProxyAddr: ns.defaultProxyAddr, + DefaultPathDir: ns.defaultPathDir, Installer: &LocalInstaller{ - InstallDir: cfg.VersionsDir, - HTTP: client, - Log: cfg.Log, - + InstallDir: filepath.Join(ns.Dir(), versionsDirName), + TargetServiceFile: ns.serviceFile, + SystemBinDir: filepath.Join(cfg.SystemDir, "bin"), + SystemServiceFile: filepath.Join(cfg.SystemDir, serviceDir, serviceName), + HTTP: client, + Log: cfg.Log, ReservedFreeTmpDisk: reservedFreeDisk, ReservedFreeInstallDisk: reservedFreeDisk, + TransformService: ns.replaceTeleportService, + ValidateBinary: validator.IsBinary, + Template: autoupdate.DefaultCDNURITemplate, + }, + Process: &SystemdService{ + ServiceName: filepath.Base(ns.serviceFile), + PIDFile: ns.pidFile, + Ready: debugClient, + Log: cfg.Log, }, + ReexecSetup: func(ctx context.Context, pathDir string, reload bool) error { + name := filepath.Join(pathDir, BinaryName) + if cfg.SelfSetup && runtime.GOOS == constants.LinuxOS { + name = "/proc/self/exe" + } + args := []string{ + "--install-dir", ns.installDir, + "--install-suffix", ns.name, + "--log-format", cfg.LogFormat, + } + if cfg.Debug { + args = append(args, "--debug") + } + args = append(args, "setup", "--path", pathDir) + if reload { + args = append(args, "--reload") + } + cmd := exec.CommandContext(ctx, name, args...) + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + cfg.Log.InfoContext(ctx, "Executing new teleport-update binary to update configuration.") + defer cfg.Log.InfoContext(ctx, "Finished executing new teleport-update binary.") + return trace.Wrap(cmd.Run()) + }, + SetupNamespace: ns.Setup, + TeardownNamespace: ns.Teardown, }, nil } @@ -138,8 +162,14 @@ type LocalUpdaterConfig struct { // DownloadTimeout is a timeout for file download requests. // Defaults to no timeout. DownloadTimeout time.Duration - // VersionsDir for installing Teleport (usually /var/lib/teleport/versions). - VersionsDir string + // SystemDir for package-installed Teleport installations (usually /opt/teleport/system). + SystemDir string + // SelfSetup mode for using the current version of the teleport-update to setup the update service. + SelfSetup bool + // Debug logs enabled. + Debug bool + // LogFormat controls the format of logging. Can be either `json` or `text`. + LogFormat string } // Updater implements the agent-local logic for Teleport agent auto-updates. @@ -150,192 +180,859 @@ type Updater struct { Pool *x509.CertPool // InsecureSkipVerify skips TLS verification. InsecureSkipVerify bool - // ConfigPath contains the path to the agent auto-updates configuration. - ConfigPath string + // UpdateConfigPath contains the path to the agent auto-updates configuration. + UpdateConfigPath string + // TeleportConfig contains the path to Teleport's configuration. + TeleportConfigPath string + // DefaultProxyAddr contains Teleport's proxy address. This may differ from the updater's. + DefaultProxyAddr string + // DefaultPathDir contains the default path that Teleport binaries should be installed into. + DefaultPathDir string // Installer manages installations of the Teleport agent. Installer Installer + // Process manages a running instance of Teleport. + Process Process + // ReexecSetup re-execs teleport-update with the setup command. + // This configures the updater service, verifies the installation, and optionally reloads Teleport. + ReexecSetup func(ctx context.Context, path string, reload bool) error + // SetupNamespace configures the Teleport updater service for the current Namespace. + SetupNamespace func(ctx context.Context, path string) error + // TeardownNamespace removes all traces of the updater service in the current Namespace, including Teleport. + TeardownNamespace func(ctx context.Context) error } // Installer provides an API for installing Teleport agents. type Installer interface { - // Install the Teleport agent at version from the download template. - // This function must be idempotent. - Install(ctx context.Context, version, template string, flags InstallFlags) error - // Remove the Teleport agent at version. - // This function must be idempotent. - Remove(ctx context.Context, version string) error + // Install the Teleport agent at revision from the download Template. + // If force is true, Install will remove broken revisions. + // Install must be idempotent. + Install(ctx context.Context, rev Revision, baseURL string, force bool) error + // Link the Teleport agent at the specified revision of Teleport into path. + // The revert function must restore the previous linking, returning false on any failure. + // If force is true, Link will overwrite non-symlinks. + // Link must be idempotent. Link's revert function must be idempotent. + Link(ctx context.Context, rev Revision, pathDir string, force bool) (revert func(context.Context) bool, err error) + // LinkSystem links the system installation of Teleport into the system linking location. + // The revert function must restore the previous linking, returning false on any failure. + // LinkSystem must be idempotent. LinkSystem's revert function must be idempotent. + LinkSystem(ctx context.Context) (revert func(context.Context) bool, err error) + // TryLink links the specified revision of Teleport into path. + // Unlike Link, TryLink will fail if existing links to other locations are present. + // TryLink must be idempotent. + TryLink(ctx context.Context, rev Revision, pathDir string) error + // TryLinkSystem links the system (package) installation of Teleport into the system linking location. + // Unlike LinkSystem, TryLinkSystem will fail if existing links to other locations are present. + // TryLinkSystem must be idempotent. + TryLinkSystem(ctx context.Context) error + // Unlink unlinks the specified revision of Teleport from path. + // Unlink must be idempotent. + Unlink(ctx context.Context, rev Revision, pathDir string) error + // UnlinkSystem unlinks the system (package) installation of Teleport from the system linking location. + // UnlinkSystem must be idempotent. + UnlinkSystem(ctx context.Context) error + // List the installed revisions of Teleport. + List(ctx context.Context) (revisions []Revision, err error) + // Remove the Teleport agent at revision. + // Remove must be idempotent. + Remove(ctx context.Context, rev Revision) error + // IsLinked returns true if the revision is linked to path. + IsLinked(ctx context.Context, rev Revision, pathDir string) (bool, error) } -// InstallFlags sets flags for the Teleport installation -type InstallFlags int - -const ( - // FlagEnterprise installs enterprise Teleport - FlagEnterprise InstallFlags = 1 << iota - // FlagFIPS installs FIPS Teleport - FlagFIPS +var ( + // ErrLinked is returned when a linked version cannot be operated on. + ErrLinked = errors.New("version is linked") + // ErrNotNeeded is returned when the operation is not needed. + ErrNotNeeded = errors.New("not needed") + // ErrNotSupported is returned when the operation is not supported on the platform. + ErrNotSupported = errors.New("not supported on this platform") + // ErrNoBinaries is returned when no binaries are available to be linked. + ErrNoBinaries = errors.New("no binaries available to link") + // ErrFilePresent is returned when a file is present. + ErrFilePresent = errors.New("file present") ) +// Process provides an API for interacting with a running Teleport process. +type Process interface { + // Reload must reload the Teleport process as gracefully as possible. + // If the process is not healthy after reloading, Reload must return an error. + // If the process did not require reloading, Reload must return ErrNotNeeded. + // E.g., if the process is not enabled, or it was already reloaded after the last Sync. + // If the type implementing Process does not support the system process manager, + // Reload must return ErrNotSupported. + Reload(ctx context.Context) error + // Sync must validate and synchronize process configuration. + // After the linked Teleport installation is changed, failure to call Sync without + // error before Reload may result in undefined behavior. + // If the type implementing Process does not support the system process manager, + // Sync must return ErrNotSupported. + Sync(ctx context.Context) error + // IsEnabled must return true if the Process is configured to run on system boot. + // If the type implementing Process does not support the system process manager, + // IsEnabled must return ErrNotSupported. + IsEnabled(ctx context.Context) (bool, error) + // IsActive must return true if the Process is currently running. + // If the type implementing Process does not support the system process manager, + // IsActive must return ErrNotSupported. + IsActive(ctx context.Context) (bool, error) + // IsPresent must return true if the Process is installed on the system. + // If the type implementing Process does not support the system process manager, + // IsPresent must return ErrNotSupported. + IsPresent(ctx context.Context) (bool, error) +} + // OverrideConfig contains overrides for individual update operations. // If validated, these overrides may be persisted to disk. type OverrideConfig struct { - // Proxy address, scheme and port optional. - // Overrides existing value if specified. - Proxy string - // Group identifier for updates (e.g., staging) - // Overrides existing value if specified. - Group string - // URLTemplate for the Teleport tgz download URL - // Overrides existing value if specified. - URLTemplate string + UpdateSpec + + // The fields below override the behavior of + // Updater.Install for a single run. + // ForceVersion to the specified version. ForceVersion string + // ForceFlags in installed Teleport. + ForceFlags autoupdate.InstallFlags + // AllowOverwrite of installed binaries. + AllowOverwrite bool } -// Enable enables agent updates and attempts an initial update. -// If the initial update succeeds, auto-updates are enabled and the configuration is persisted. -// Otherwise, the auto-updates configuration is not changed. +func deref[T any](ptr *T) T { + if ptr != nil { + return *ptr + } + var t T + return t +} + +func toPtr[T any](t T) *T { + return &t +} + +// Install attempts an initial installation of Teleport. +// If the initial installation succeeds, the override configuration is persisted. +// Otherwise, the configuration is not changed. // This function is idempotent. -func (u *Updater) Enable(ctx context.Context, override OverrideConfig) error { +func (u *Updater) Install(ctx context.Context, override OverrideConfig) error { // Read configuration from update.yaml and override any new values passed as flags. - cfg, err := u.readConfig(u.ConfigPath) + cfg, err := readConfig(u.UpdateConfigPath) if err != nil { - return trace.Errorf("failed to read %s: %w", updateConfigName, err) + return trace.Wrap(err, "failed to read %s", updateConfigName) + } + if err := validateConfigSpec(&cfg.Spec, override); err != nil { + return trace.Wrap(err) + } + + if cfg.Spec.Proxy == "" { + cfg.Spec.Proxy = u.DefaultProxyAddr + } else if u.DefaultProxyAddr != "" && + !sameProxies(cfg.Spec.Proxy, u.DefaultProxyAddr) { + u.Log.WarnContext(ctx, "Proxy specified in update.yaml does not match teleport.yaml. Unexpected updates may occur.", "update_proxy", cfg.Spec.Proxy, "teleport_proxy", u.DefaultProxyAddr) + } + if cfg.Spec.Path == "" { + cfg.Spec.Path = u.DefaultPathDir } - if override.Proxy != "" { - cfg.Spec.Proxy = override.Proxy + + active := cfg.Status.Active + skip := deref(cfg.Status.Skip) + + // Lookup target version from the proxy. + + resp, err := u.find(ctx, cfg) + if err != nil { + return trace.Wrap(err) } - if override.Group != "" { - cfg.Spec.Group = override.Group + targetVersion := resp.Target.Version + targetFlags := resp.Target.Flags + targetFlags |= override.ForceFlags + if override.ForceVersion != "" { + targetVersion = override.ForceVersion } - if override.URLTemplate != "" { - cfg.Spec.URLTemplate = override.URLTemplate + target := NewRevision(targetVersion, targetFlags) + + switch target.Version { + case "": + return trace.Errorf("agent version not available from Teleport cluster") + case skip.Version: + u.Log.WarnContext(ctx, "Target version was previously marked as broken. Retrying update.", targetKey, target, activeKey, active) + default: + u.Log.InfoContext(ctx, "Initiating installation.", targetKey, target, activeKey, active) } - cfg.Spec.Enabled = true - if err := validateUpdatesSpec(&cfg.Spec); err != nil { + + if err := u.update(ctx, cfg, target, override.AllowOverwrite, resp.AGPL); err != nil { + if errors.Is(err, ErrFilePresent) && !override.AllowOverwrite { + u.Log.WarnContext(ctx, "Use --overwrite to force removal of existing binaries installed via script.") + u.Log.WarnContext(ctx, "If a teleport rpm or deb package is installed, upgrade it to the latest version and retry. DO NOT USE --overwrite.") + } return trace.Wrap(err) } + if target.Version == skip.Version { + cfg.Status.Skip = nil + } - // Lookup target version from the proxy. - addr, err := libutils.ParseAddr(cfg.Spec.Proxy) + // Only write the configuration file if the initial update succeeds. + // Note: skip_version is never set on failed enable, only failed update. + + if err := writeConfig(u.UpdateConfigPath, cfg); err != nil { + return trace.Wrap(err, "failed to write %s", updateConfigName) + } + u.Log.InfoContext(ctx, "Configuration updated.") + return trace.Wrap(u.notices(ctx)) +} + +// sameProxies returns true if both proxies addresses are the same. +// Note that the port is defaulted to 443, which is different from teleport.yaml's default. +func sameProxies(a, b string) bool { + const defaultPort = 443 + if a == b { + return true + } + addrA, err := libutils.ParseAddr(a) if err != nil { - return trace.Errorf("failed to parse proxy server address: %w", err) - } - - desiredVersion := override.ForceVersion - if desiredVersion == "" { - resp, err := webclient.Find(&webclient.Config{ - Context: ctx, - ProxyAddr: addr.Addr, - Insecure: u.InsecureSkipVerify, - Timeout: 30 * time.Second, - //Group: cfg.Spec.Group, // TODO(sclevine): add web API for verssion - Pool: u.Pool, - }) - if err != nil { - return trace.Errorf("failed to request version from proxy: %w", err) + return false + } + addrB, err := libutils.ParseAddr(b) + if err != nil { + return false + } + return addrA.Host() == addrB.Host() && + addrA.Port(defaultPort) == addrB.Port(defaultPort) +} + +// Remove removes everything created by the updater for the given namespace. +// Before attempting this, Remove attempts to gracefully recover the system-packaged version of Teleport (if present). +// This function is idempotent. +func (u *Updater) Remove(ctx context.Context, force bool) error { + cfg, err := readConfig(u.UpdateConfigPath) + if err != nil { + return trace.Wrap(err, "failed to read %s", updateConfigName) + } + if err := validateConfigSpec(&cfg.Spec, OverrideConfig{}); err != nil { + return trace.Wrap(err) + } + active := cfg.Status.Active + if active.Version == "" { + u.Log.InfoContext(ctx, "No installation of Teleport managed by the updater. Removing updater configuration.") + if err := u.TeardownNamespace(ctx); err != nil { + return trace.Wrap(err) } - desiredVersion, _ = "16.3.0", resp // TODO(sclevine): add web API for version - //desiredVersion := resp.AutoUpdate.AgentVersion + u.Log.InfoContext(ctx, "Automatic update configuration for Teleport successfully uninstalled.") + return nil } - if desiredVersion == "" { - return trace.Errorf("agent version not available from Teleport cluster") + // Do not link system package installation if the installation we are removing + // is not installed into /usr/local/bin. + if filepath.Clean(cfg.Spec.Path) != filepath.Clean(defaultPathDir) { + return u.removeWithoutSystem(ctx, cfg, force) } - // If the active version and target don't match, kick off upgrade. - template := cfg.Spec.URLTemplate - if template == "" { - template = cdnURITemplate + revert, err := u.Installer.LinkSystem(ctx) + if errors.Is(err, ErrNoBinaries) { + return u.removeWithoutSystem(ctx, cfg, force) } - err = u.Installer.Install(ctx, desiredVersion, template, 0) // TODO(sclevine): add web API for flags if err != nil { + return trace.Wrap(err, "failed to link") + } + + u.Log.InfoContext(ctx, "Updater-managed installation of Teleport detected. Restoring packaged version of Teleport before removing.") + + revertConfig := func(ctx context.Context) bool { + if ok := revert(ctx); !ok { + u.Log.ErrorContext(ctx, "Failed to revert Teleport symlinks. Installation likely broken.") + return false + } + if err := u.Process.Sync(ctx); err != nil { + u.Log.ErrorContext(ctx, "Failed to revert systemd configuration after failed restart.", errorKey, err) + return false + } + return true + } + + // Sync systemd. + + err = u.Process.Sync(ctx) + if errors.Is(err, context.Canceled) { + return trace.Errorf("sync canceled") + } + if errors.Is(err, ErrNotSupported) { + u.Log.WarnContext(ctx, "Not syncing systemd configuration because systemd is not running.") + } else if err != nil { + // If sync fails, we may have left the host in a bad state, so we revert linking and re-Sync. + u.Log.ErrorContext(ctx, "Reverting symlinks due to invalid configuration.") + if ok := revertConfig(ctx); ok { + u.Log.WarnContext(ctx, "Teleport updater encountered a configuration error and successfully reverted the installation.") + } + return trace.Wrap(err, "failed to validate configuration for system package version of Teleport") + } + + // Restart Teleport. + + u.Log.InfoContext(ctx, "Teleport package successfully restored.") + err = u.Process.Reload(ctx) + if errors.Is(err, context.Canceled) { + return trace.Errorf("reload canceled") + } + if err != nil && + !errors.Is(err, ErrNotNeeded) && // no output if restart not needed + !errors.Is(err, ErrNotSupported) { // already logged above for Sync + + // If reloading Teleport at the new version fails, revert and reload. + u.Log.ErrorContext(ctx, "Reverting symlinks due to failed restart.") + if ok := revertConfig(ctx); ok { + if err := u.Process.Reload(ctx); err != nil && !errors.Is(err, ErrNotNeeded) { + u.Log.ErrorContext(ctx, "Failed to reload Teleport after reverting. Installation likely broken.", errorKey, err) + } else { + u.Log.WarnContext(ctx, "Teleport updater detected an error with the new installation and successfully reverted it.") + } + } + return trace.Wrap(err, "failed to start system package version of Teleport") + } + u.Log.InfoContext(ctx, "Auto-updating Teleport removed and replaced by Teleport package.", "version", active) + if err := u.TeardownNamespace(ctx); err != nil { return trace.Wrap(err) } - if cfg.Status.ActiveVersion != desiredVersion { - u.Log.InfoContext(ctx, "Target version successfully installed.", "version", desiredVersion) + u.Log.InfoContext(ctx, "Auto-update configuration for Teleport successfully uninstalled.") + return nil +} + +func (u *Updater) removeWithoutSystem(ctx context.Context, cfg *UpdateConfig, force bool) error { + if !force { + u.Log.ErrorContext(ctx, "No packaged installation of Teleport was found, and --force was not passed. Refusing to remove Teleport from this system.") + return trace.Errorf("unable to remove Teleport completely without --force") } else { - u.Log.InfoContext(ctx, "Target version successfully validated.", "version", desiredVersion) + u.Log.WarnContext(ctx, "No packaged installation of Teleport was found, and --force was passed. Teleport will be removed from this system.") } - cfg.Status.ActiveVersion = desiredVersion - - // Always write the configuration file if enable succeeds. - if err := u.writeConfig(u.ConfigPath, cfg); err != nil { - return trace.Errorf("failed to write %s: %w", updateConfigName, err) + u.Log.InfoContext(ctx, "Updater-managed installation of Teleport detected. Attempting to unlink and remove.") + ok, err := isActiveOrEnabled(ctx, u.Process) + if err != nil && !errors.Is(err, ErrNotSupported) { + return trace.Wrap(err) } - u.Log.InfoContext(ctx, "Configuration updated.") + if ok { + return trace.Errorf("refusing to remove active installation of Teleport, please stop and disable Teleport first") + } + if err := u.Installer.Unlink(ctx, cfg.Status.Active, cfg.Spec.Path); err != nil { + return trace.Wrap(err) + } + u.Log.InfoContext(ctx, "Teleport uninstalled.", "version", cfg.Status.Active) + if err := u.TeardownNamespace(ctx); err != nil { + return trace.Wrap(err) + } + u.Log.InfoContext(ctx, "Automatic update configuration for Teleport successfully uninstalled.") return nil } -func validateUpdatesSpec(spec *UpdateSpec) error { - if spec.URLTemplate != "" && - !strings.HasPrefix(strings.ToLower(spec.URLTemplate), "https://") { - return trace.Errorf("Teleport download URL must use TLS (https://)") +// isActiveOrEnabled returns true if the service is active or enabled. +func isActiveOrEnabled(ctx context.Context, s Process) (bool, error) { + enabled, err := s.IsEnabled(ctx) + if err != nil { + return false, trace.Wrap(err) + } + if enabled { + return true, nil + } + active, err := s.IsActive(ctx) + if err != nil { + return false, trace.Wrap(err) } + if active { + return true, nil + } + return false, nil +} - if spec.Proxy == "" { - return trace.Errorf("Teleport proxy URL must be specified with --proxy or present in %s", updateConfigName) +// Status returns all available local and remote fields related to agent auto-updates. +// Status is safe to run concurrently with other Updater commands. +func (u *Updater) Status(ctx context.Context) (Status, error) { + var out Status + // Read configuration from update.yaml. + cfg, err := readConfig(u.UpdateConfigPath) + if err != nil { + return out, trace.Wrap(err, "failed to read %s", updateConfigName) } - return nil + if err := validateConfigSpec(&cfg.Spec, OverrideConfig{}); err != nil { + return out, trace.Wrap(err) + } + out.UpdateSpec = cfg.Spec + out.UpdateStatus = cfg.Status + + // Lookup target version from the proxy. + resp, err := u.find(ctx, cfg) + if err != nil { + return out, trace.Wrap(err) + } + out.FindResp = resp + return out, nil } // Disable disables agent auto-updates. // This function is idempotent. func (u *Updater) Disable(ctx context.Context) error { - cfg, err := u.readConfig(u.ConfigPath) + cfg, err := readConfig(u.UpdateConfigPath) if err != nil { - return trace.Errorf("failed to read %s: %w", updateConfigName, err) + return trace.Wrap(err, "failed to read %s", updateConfigName) } if !cfg.Spec.Enabled { u.Log.InfoContext(ctx, "Automatic updates already disabled.") return nil } cfg.Spec.Enabled = false - if err := u.writeConfig(u.ConfigPath, cfg); err != nil { - return trace.Errorf("failed to write %s: %w", updateConfigName, err) + if err := writeConfig(u.UpdateConfigPath, cfg); err != nil { + return trace.Wrap(err, "failed to write %s", updateConfigName) } return nil } -// readConfig reads UpdateConfig from a file. -func (*Updater) readConfig(path string) (*UpdateConfig, error) { - f, err := os.Open(path) - if errors.Is(err, fs.ErrNotExist) { - return &UpdateConfig{ - Version: updateConfigVersion, - Kind: updateConfigKind, - }, nil +// Unpin allows the current version to be changed by Update. +// This function is idempotent. +func (u *Updater) Unpin(ctx context.Context) error { + cfg, err := readConfig(u.UpdateConfigPath) + if err != nil { + return trace.Wrap(err, "failed to read %s", updateConfigName) + } + if !cfg.Spec.Pinned { + u.Log.InfoContext(ctx, "Current version not pinned.", activeKey, cfg.Status.Active) + return nil } + cfg.Spec.Pinned = false + if err := writeConfig(u.UpdateConfigPath, cfg); err != nil { + return trace.Wrap(err, "failed to write %s", updateConfigName) + } + return nil +} + +// Update initiates an agent update. +// If the update succeeds, the new installed version is marked as active. +// Otherwise, the auto-updates configuration is not changed. +// Unlike Enable, Update will not validate or repair the current version. +// This function is idempotent. +func (u *Updater) Update(ctx context.Context, now bool) error { + // Read configuration from update.yaml and override any new values passed as flags. + cfg, err := readConfig(u.UpdateConfigPath) if err != nil { - return nil, trace.Errorf("failed to open: %w", err) + return trace.Wrap(err, "failed to read %s", updateConfigName) } - defer f.Close() - var cfg UpdateConfig - if err := yaml.NewDecoder(f).Decode(&cfg); err != nil { - return nil, trace.Errorf("failed to parse: %w", err) + if err := validateConfigSpec(&cfg.Spec, OverrideConfig{}); err != nil { + return trace.Wrap(err) } - if k := cfg.Kind; k != updateConfigKind { - return nil, trace.Errorf("invalid kind %q", k) + if u.DefaultProxyAddr != "" && + !sameProxies(cfg.Spec.Proxy, u.DefaultProxyAddr) { + u.Log.WarnContext(ctx, "Proxy specified in update.yaml does not match teleport.yaml. Unexpected updates may occur.", "update_proxy", cfg.Spec.Proxy, "teleport_proxy", u.DefaultProxyAddr) } - if v := cfg.Version; v != updateConfigVersion { - return nil, trace.Errorf("invalid version %q", v) + + active := cfg.Status.Active + skip := deref(cfg.Status.Skip) + if !cfg.Spec.Enabled { + u.Log.InfoContext(ctx, "Automatic updates disabled.", activeKey, active) + return nil } - return &cfg, nil -} -// writeConfig writes UpdateConfig to a file atomically, ensuring the file cannot be corrupted. -func (*Updater) writeConfig(filename string, cfg *UpdateConfig) error { - opts := []renameio.Option{ - renameio.WithPermissions(0755), - renameio.WithExistingPermissions(), + if cfg.Spec.Path == "" { + return trace.Errorf("failed to read destination path for binary links from %s", updateConfigName) } - t, err := renameio.NewPendingFile(filename, opts...) + + resp, err := u.find(ctx, cfg) if err != nil { return trace.Wrap(err) } - defer t.Cleanup() - err = yaml.NewEncoder(t).Encode(cfg) + target := resp.Target + + if cfg.Spec.Pinned { + switch target { + case active: + u.Log.InfoContext(ctx, "Teleport is up-to-date. Installation is pinned to prevent future updates.", activeKey, active) + default: + u.Log.InfoContext(ctx, "Teleport version is pinned. Skipping update.", targetKey, target, activeKey, active) + } + return nil + } + + // If a version fails and is marked skip, we ignore any edition changes as well. + // If a cluster is broadcasting a version that failed to start, changing ent/fips is unlikely to fix the issue. + + if !resp.InWindow && !now { + switch { + case target.Version == "": + u.Log.WarnContext(ctx, "Cannot determine target agent version. Waiting for both version and update window.") + case target == active: + u.Log.InfoContext(ctx, "Teleport is up-to-date. Update window is not active.", activeKey, active) + case target.Version == skip.Version: + u.Log.InfoContext(ctx, "Update available, but the new version is marked as broken. Update window is not active.", targetKey, target, activeKey, active) + default: + u.Log.InfoContext(ctx, "Update available, but update window is not active.", targetKey, target, activeKey, active) + } + return nil + } + + switch { + case target.Version == "": + if resp.InWindow { + u.Log.ErrorContext(ctx, "Update window is active, but target version is not available.", activeKey, active) + } + return trace.Errorf("target version missing") + case target == active: + if resp.InWindow { + u.Log.InfoContext(ctx, "Teleport is up-to-date. Update window is active, but no action is needed.", activeKey, active) + } else { + u.Log.InfoContext(ctx, "Teleport is up-to-date. No action is needed.", activeKey, active) + } + return nil + case target.Version == skip.Version: + u.Log.InfoContext(ctx, "Update available, but the new version is marked as broken. Skipping update.", targetKey, target, activeKey, active) + return nil + default: + u.Log.InfoContext(ctx, "Update available. Initiating update.", targetKey, target, activeKey, active) + } + if !now { + time.Sleep(resp.Jitter) + } + + updateErr := u.update(ctx, cfg, target, false, resp.AGPL) + writeErr := writeConfig(u.UpdateConfigPath, cfg) + if writeErr != nil { + writeErr = trace.Wrap(writeErr, "failed to write %s", updateConfigName) + } else { + u.Log.InfoContext(ctx, "Configuration updated.") + } + // Show notices last + if updateErr == nil && now { + updateErr = u.notices(ctx) + } + return trace.NewAggregate(updateErr, writeErr) +} + +func (u *Updater) find(ctx context.Context, cfg *UpdateConfig) (FindResp, error) { + if cfg.Spec.Proxy == "" { + return FindResp{}, trace.Errorf("Teleport proxy URL must be specified with --proxy or present in %s", updateConfigName) + } + addr, err := libutils.ParseAddr(cfg.Spec.Proxy) + if err != nil { + return FindResp{}, trace.Wrap(err, "failed to parse proxy server address") + } + resp, err := webclient.Find(&webclient.Config{ + Context: ctx, + ProxyAddr: addr.Addr, + Insecure: u.InsecureSkipVerify, + Timeout: 30 * time.Second, + UpdateGroup: cfg.Spec.Group, + Pool: u.Pool, + }) + if err != nil { + return FindResp{}, trace.Wrap(err, "failed to request version from proxy") + } + var flags autoupdate.InstallFlags + var agpl bool + switch resp.Edition { + case modules.BuildEnterprise: + flags |= autoupdate.FlagEnterprise + case modules.BuildCommunity: + case modules.BuildOSS: + agpl = true + default: + agpl = true + u.Log.WarnContext(ctx, "Unknown edition detected, defaulting to OSS.", "edition", resp.Edition) + } + if resp.FIPS { + flags |= autoupdate.FlagFIPS + } + jitterSec := resp.AutoUpdate.AgentUpdateJitterSeconds + return FindResp{ + Target: NewRevision(resp.AutoUpdate.AgentVersion, flags), + InWindow: resp.AutoUpdate.AgentAutoUpdate, + Jitter: time.Duration(jitterSec) * time.Second, + AGPL: agpl, + }, nil +} + +func (u *Updater) removeRevision(ctx context.Context, cfg *UpdateConfig, rev Revision) error { + linked, err := u.Installer.IsLinked(ctx, rev, cfg.Spec.Path) + if err != nil { + return trace.Wrap(err, "failed to determine if linked") + } + if linked { + return trace.Wrap(ErrLinked, "refusing to remove") + } + return trace.Wrap(u.Installer.Remove(ctx, rev)) +} + +func (u *Updater) update(ctx context.Context, cfg *UpdateConfig, target Revision, force, agpl bool) error { + baseURL := cfg.Spec.BaseURL + if baseURL == "" { + if agpl { + return trace.Errorf("--base-url flag must be specified for AGPL edition of Teleport") + } + baseURL = autoupdate.DefaultBaseURL + } + + active := cfg.Status.Active + backup := deref(cfg.Status.Backup) + switch backup { + case Revision{}, target, active: + default: + if target == active { + // Keep backup version if we are only verifying active version + break + } + err := u.removeRevision(ctx, cfg, backup) + if err != nil { + // this could happen if it was already removed due to a failed installation + u.Log.WarnContext(ctx, "Failed to remove backup version of Teleport before new install.", errorKey, err, backupKey, backup) + } + } + + // Install and link the desired version (or validate existing installation) + + linked, err := u.Installer.IsLinked(ctx, target, cfg.Spec.Path) + if err != nil { + return trace.Wrap(err, "failed to determine if linked") + } + err = u.Installer.Install(ctx, target, baseURL, !linked) + if err != nil { + return trace.Wrap(err, "failed to install") + } + + // If the target version has fewer binaries, this will leave old binaries linked. + // This may prevent the installation from being removed. + // Cleanup logic at the end of this function will ensure that they are removed + // eventually. + + revert, err := u.Installer.Link(ctx, target, cfg.Spec.Path, force) + if err != nil { + return trace.Wrap(err, "failed to link") + } + + // If we fail to revert after this point, the next update/enable will + // fix the link to restore the active version. + + revertConfig := func(ctx context.Context) bool { + if target.Version != "" { + cfg.Status.Skip = toPtr(target) + } + if force { + u.Log.ErrorContext(ctx, "Unable to revert Teleport symlinks in overwrite mode. Installation likely broken.") + return false + } + if ok := revert(ctx); !ok { + u.Log.ErrorContext(ctx, "Failed to revert Teleport symlinks. Installation likely broken.") + return false + } + if err := u.SetupNamespace(ctx, cfg.Spec.Path); err != nil { + u.Log.ErrorContext(ctx, "Failed to revert configuration after failed restart.", errorKey, err) + return false + } + return true + } + + if cfg.Status.Active != target { + err := u.ReexecSetup(ctx, cfg.Spec.Path, true) + if errors.Is(err, context.Canceled) { + return trace.Errorf("check canceled") + } + if err != nil { + // If reloading Teleport at the new version fails, revert and reload. + u.Log.ErrorContext(ctx, "Reverting symlinks due to failed restart.") + if ok := revertConfig(ctx); ok { + if err := u.Process.Reload(ctx); err != nil && !errors.Is(err, ErrNotNeeded) { + u.Log.ErrorContext(ctx, "Failed to reload Teleport after reverting. Installation likely broken.", errorKey, err) + } else { + u.Log.WarnContext(ctx, "Teleport updater detected an error with the new installation and successfully reverted it.") + } + } + return trace.Wrap(err, "failed to start new version %s of Teleport", target) + } + u.Log.InfoContext(ctx, "Target version successfully installed.", targetKey, target) + + if r := cfg.Status.Active; r.Version != "" { + cfg.Status.Backup = toPtr(r) + } + cfg.Status.Active = target + } else { + err := u.ReexecSetup(ctx, cfg.Spec.Path, false) + if errors.Is(err, context.Canceled) { + return trace.Errorf("check canceled") + } + if err != nil { + // If sync fails, we may have left the host in a bad state, so we revert linking and re-Sync. + u.Log.ErrorContext(ctx, "Reverting symlinks due to invalid configuration.") + if ok := revertConfig(ctx); ok { + u.Log.WarnContext(ctx, "Teleport updater encountered a configuration error and successfully reverted the installation.") + } + return trace.Wrap(err, "failed to validate new version %s of Teleport", target) + } + u.Log.InfoContext(ctx, "Target version successfully validated.", targetKey, target) + } + if r := deref(cfg.Status.Backup); r.Version != "" { + u.Log.InfoContext(ctx, "Backup version set.", backupKey, r) + } + + return trace.Wrap(u.cleanup(ctx, cfg, []Revision{ + target, active, backup, + })) +} + +// Setup writes updater configuration and verifies the Teleport installation. +// If restart is true, Setup also restarts Teleport. +// Setup is safe to run concurrently with other Updater commands. +func (u *Updater) Setup(ctx context.Context, path string, restart bool) error { + // Setup teleport-updater configuration and sync systemd. + + err := u.SetupNamespace(ctx, path) + if errors.Is(err, context.Canceled) { + return trace.Errorf("sync canceled") + } + if err != nil { + return trace.Wrap(err, "failed to setup updater") + } + + present, err := u.Process.IsPresent(ctx) + if errors.Is(err, context.Canceled) { + return trace.Errorf("config check canceled") + } + if errors.Is(err, ErrNotSupported) { + u.Log.WarnContext(ctx, "Skipping all systemd setup because systemd is not running.") + return nil + } + if err != nil { + return trace.Wrap(err, "failed to determine if new version of Teleport has an installed systemd service") + } + if !present { + return trace.Errorf("cannot find systemd service for new version of Teleport, check SELinux settings") + } + + // Restart Teleport if necessary. + + if restart { + err = u.Process.Reload(ctx) + if errors.Is(err, context.Canceled) { + return trace.Errorf("reload canceled") + } + if err != nil && + !errors.Is(err, ErrNotNeeded) { // skip if not needed + return trace.Wrap(err, "failed to reload Teleport") + } + } + return nil +} + +// notices displays final notices after install or update. +func (u *Updater) notices(ctx context.Context) error { + enabled, err := u.Process.IsEnabled(ctx) + if errors.Is(err, ErrNotSupported) { + u.Log.WarnContext(ctx, "Teleport is installed, but systemd is not present to start it.") + u.Log.WarnContext(ctx, "After configuring teleport.yaml, your system must also be configured to start Teleport.") + return nil + } + if err != nil { + return trace.Wrap(err, "failed to query Teleport systemd enabled status") + } + active, err := u.Process.IsActive(ctx) + if err != nil { + return trace.Wrap(err, "failed to query Teleport systemd active status") + } + if !enabled && active { + u.Log.WarnContext(ctx, "Teleport is installed and started, but not configured to start on boot.") + u.Log.WarnContext(ctx, "After configuring teleport.yaml, you can enable it with: systemctl enable teleport") + } + if !active && enabled { + u.Log.WarnContext(ctx, "Teleport is installed and enabled at boot, but not running.") + u.Log.WarnContext(ctx, "After configuring teleport.yaml, you can start it with: systemctl start teleport") + } + if !active && !enabled { + u.Log.WarnContext(ctx, "Teleport is installed, but not running or enabled at boot.") + u.Log.WarnContext(ctx, "After configuring teleport.yaml, you can enable and start it with: systemctl enable teleport --now") + } + + return nil +} + +// cleanup orphan installations +func (u *Updater) cleanup(ctx context.Context, cfg *UpdateConfig, keep []Revision) error { + revs, err := u.Installer.List(ctx) + if err != nil { + u.Log.ErrorContext(ctx, "Failed to read installed versions.", errorKey, err) + return nil + } + if len(revs) < 3 { + return nil + } + u.Log.WarnContext(ctx, "More than two versions of Teleport are installed. Removing unused versions.", "count", len(revs)) + for _, v := range revs { + if v.Version == "" || slices.Contains(keep, v) { + continue + } + err := u.removeRevision(ctx, cfg, v) + if errors.Is(err, ErrLinked) { + u.Log.WarnContext(ctx, "Refusing to remove version with orphan links.", "version", v) + continue + } + if err != nil { + u.Log.WarnContext(ctx, "Failed to remove unused version of Teleport.", errorKey, err, "version", v) + continue + } + u.Log.WarnContext(ctx, "Deleted unused version of Teleport.", "version", v) + } + return nil +} + +// LinkPackage creates links from the system (package) installation of Teleport, if they are needed. +// LinkPackage returns nil and warns if an auto-updates version is already linked, but auto-updates is disabled. +// LinkPackage returns an error only if an unknown version of Teleport is present (e.g., manually copied files). +// This function is idempotent. +func (u *Updater) LinkPackage(ctx context.Context) error { + cfg, err := readConfig(u.UpdateConfigPath) if err != nil { + return trace.Wrap(err, "failed to read %s", updateConfigName) + } + if err := validateConfigSpec(&cfg.Spec, OverrideConfig{}); err != nil { return trace.Wrap(err) } - return trace.Wrap(t.CloseAtomicallyReplace()) + active := cfg.Status.Active + if cfg.Spec.Enabled { + u.Log.InfoContext(ctx, "Automatic updates is enabled. Skipping system package link.", activeKey, active) + return nil + } + if cfg.Spec.Pinned { + u.Log.InfoContext(ctx, "Automatic update version is pinned. Skipping system package link.", activeKey, active) + return nil + } + // If an active version is set, but auto-updates is disabled, try to link the system installation in case the config is stale. + // If any links are present, this will return ErrLinked and not create any system links. + // This state is important to log as a warning, + if err := u.Installer.TryLinkSystem(ctx); errors.Is(err, ErrLinked) { + u.Log.WarnContext(ctx, "Automatic updates is disabled, but a non-package version of Teleport is linked.", activeKey, active) + return nil + } else if err != nil { + return trace.Wrap(err, "failed to link system package installation") + } + + // If syncing succeeds, ensure the installed systemd service can be found via systemctl. + // SELinux contexts can interfere with systemctl's ability to read service files. + if err := u.Process.Sync(ctx); errors.Is(err, ErrNotSupported) { + u.Log.WarnContext(ctx, "Systemd is not installed. Skipping sync.") + } else if err != nil { + return trace.Wrap(err, "failed to sync systemd configuration") + } else { + present, err := u.Process.IsPresent(ctx) + if err != nil { + return trace.Wrap(err, "failed to determine if Teleport has an installed systemd service") + } + if !present { + return trace.Errorf("cannot find systemd service for Teleport, check SELinux settings") + } + } + u.Log.InfoContext(ctx, "Successfully linked system package installation.") + return nil +} + +// UnlinkPackage removes links from the system (package) installation of Teleport, if they are present. +// This function is idempotent. +func (u *Updater) UnlinkPackage(ctx context.Context) error { + if err := u.Installer.UnlinkSystem(ctx); err != nil { + return trace.Wrap(err, "failed to unlink system package installation") + } + if err := u.Process.Sync(ctx); errors.Is(err, ErrNotSupported) { + u.Log.WarnContext(ctx, "Systemd is not installed. Skipping sync.") + } else if err != nil { + return trace.Wrap(err, "failed to sync systemd configuration") + } + u.Log.InfoContext(ctx, "Successfully unlinked system package installation.") + return nil } diff --git a/lib/autoupdate/agent/updater_test.go b/lib/autoupdate/agent/updater_test.go index 5ec93b43be9cc..cc789c20bca08 100644 --- a/lib/autoupdate/agent/updater_test.go +++ b/lib/autoupdate/agent/updater_test.go @@ -20,12 +20,14 @@ package agent import ( "context" + "encoding/json" "errors" "net/http" "net/http/httptest" "os" "path/filepath" "regexp" + "slices" "strings" "testing" @@ -33,6 +35,8 @@ import ( "github.com/stretchr/testify/require" "gopkg.in/yaml.v3" + "github.com/gravitational/teleport/api/client/webclient" + "github.com/gravitational/teleport/lib/autoupdate" "github.com/gravitational/teleport/lib/utils/testutils/golden" ) @@ -79,10 +83,1154 @@ func TestUpdater_Disable(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { dir := t.TempDir() - cfgPath := filepath.Join(dir, "update.yaml") + ns := &Namespace{installDir: dir} + _, err := ns.Init() + require.NoError(t, err) + cfgPath := filepath.Join(ns.Dir(), updateConfigName) + updater, err := NewLocalUpdater(LocalUpdaterConfig{ + InsecureSkipVerify: true, + }, ns) + require.NoError(t, err) + + // Create config file only if provided in test case + if tt.cfg != nil { + b, err := yaml.Marshal(tt.cfg) + require.NoError(t, err) + err = os.WriteFile(cfgPath, b, 0600) + require.NoError(t, err) + } + + err = updater.Disable(context.Background()) + if tt.errMatch != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMatch) + return + } + require.NoError(t, err) + + data, err := os.ReadFile(cfgPath) + + // If no config is present, disable should not create it + if tt.cfg == nil { + require.ErrorIs(t, err, os.ErrNotExist) + return + } + require.NoError(t, err) + + if golden.ShouldSet() { + golden.Set(t, data) + } + require.Equal(t, string(golden.Get(t)), string(data)) + }) + } +} + +func TestUpdater_Unpin(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *UpdateConfig // nil -> file not present + errMatch string + }{ + { + name: "pinned", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Pinned: true, + }, + }, + }, + { + name: "not pinned", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Pinned: false, + }, + }, + }, + { + name: "config does not exist", + }, + { + name: "invalid metadata", + cfg: &UpdateConfig{ + Spec: UpdateSpec{ + Enabled: true, + }, + }, + errMatch: "invalid", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + ns := &Namespace{installDir: dir} + _, err := ns.Init() + require.NoError(t, err) + cfgPath := filepath.Join(ns.Dir(), updateConfigName) + + updater, err := NewLocalUpdater(LocalUpdaterConfig{ + InsecureSkipVerify: true, + }, ns) + require.NoError(t, err) + + // Create config file only if provided in test case + if tt.cfg != nil { + b, err := yaml.Marshal(tt.cfg) + require.NoError(t, err) + err = os.WriteFile(cfgPath, b, 0600) + require.NoError(t, err) + } + + err = updater.Unpin(context.Background()) + if tt.errMatch != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMatch) + return + } + require.NoError(t, err) + + data, err := os.ReadFile(cfgPath) + + // If no config is present, disable should not create it + if tt.cfg == nil { + require.ErrorIs(t, err, os.ErrNotExist) + return + } + require.NoError(t, err) + + if golden.ShouldSet() { + golden.Set(t, data) + } + require.Equal(t, string(golden.Get(t)), string(data)) + }) + } +} + +func TestUpdater_Update(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *UpdateConfig // nil -> file not present + flags autoupdate.InstallFlags + inWindow bool + now bool + agpl bool + installErr error + setupErr error + reloadErr error + notActive bool + notEnabled bool + linkedRevisions []Revision + + removedRevisions []Revision + installedRevision Revision + installedBaseURL string + linkedRevision Revision + requestGroup string + reloadCalls int + revertCalls int + setupCalls int + restarted bool + errMatch string + }{ + { + name: "updates enabled during window", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + Group: "group", + BaseURL: "https://example.com", + Enabled: true, + }, + Status: UpdateStatus{ + Active: NewRevision("old-version", 0), + }, + }, + inWindow: true, + + removedRevisions: []Revision{NewRevision("unknown-version", 0)}, + installedRevision: NewRevision("16.3.0", 0), + installedBaseURL: "https://example.com", + linkedRevision: NewRevision("16.3.0", 0), + requestGroup: "group", + setupCalls: 1, + restarted: true, + }, + { + name: "updates enabled now", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + Group: "group", + BaseURL: "https://example.com", + Enabled: true, + }, + Status: UpdateStatus{ + Active: NewRevision("old-version", 0), + }, + }, + now: true, + + removedRevisions: []Revision{NewRevision("unknown-version", 0)}, + installedRevision: NewRevision("16.3.0", 0), + installedBaseURL: "https://example.com", + linkedRevision: NewRevision("16.3.0", 0), + requestGroup: "group", + setupCalls: 1, + restarted: true, + }, + { + name: "updates enabled now, not started or enabled", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + Group: "group", + BaseURL: "https://example.com", + Enabled: true, + }, + Status: UpdateStatus{ + Active: NewRevision("old-version", 0), + }, + }, + now: true, + notEnabled: true, + notActive: true, + + removedRevisions: []Revision{NewRevision("unknown-version", 0)}, + installedRevision: NewRevision("16.3.0", 0), + installedBaseURL: "https://example.com", + linkedRevision: NewRevision("16.3.0", 0), + requestGroup: "group", + setupCalls: 1, + restarted: true, + }, + { + name: "updates disabled during window", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + Group: "group", + BaseURL: "https://example.com", + Enabled: false, + }, + Status: UpdateStatus{ + Active: NewRevision("old-version", 0), + }, + }, + inWindow: true, + }, + { + name: "missing path during window", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Group: "group", + BaseURL: "https://example.com", + Enabled: true, + }, + Status: UpdateStatus{ + Active: NewRevision("old-version", 0), + }, + }, + inWindow: true, + errMatch: "destination path", + }, + { + name: "updates enabled outside of window", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + Group: "group", + BaseURL: "https://example.com", + Enabled: true, + }, + Status: UpdateStatus{ + Active: NewRevision("old-version", 0), + }, + }, + requestGroup: "group", + }, + { + name: "updates disabled outside of window", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + Group: "group", + BaseURL: "https://example.com", + Enabled: false, + }, + Status: UpdateStatus{ + Active: NewRevision("old-version", 0), + }, + }, + }, + { + name: "insecure URL", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + BaseURL: "http://example.com", + Enabled: true, + }, + }, + inWindow: true, + + errMatch: "must use TLS", + }, + { + name: "install error", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + Enabled: true, + }, + }, + inWindow: true, + installErr: errors.New("install error"), + + installedRevision: NewRevision("16.3.0", 0), + installedBaseURL: autoupdate.DefaultBaseURL, + errMatch: "install error", + }, + { + name: "version already installed in window", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + BaseURL: "https://example.com", + Enabled: true, + }, + Status: UpdateStatus{ + Active: NewRevision("16.3.0", 0), + }, + }, + inWindow: true, + }, + { + name: "version already installed outside of window", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + BaseURL: "https://example.com", + Enabled: true, + }, + Status: UpdateStatus{ + Active: NewRevision("16.3.0", 0), + }, + }, + }, + { + name: "version detects as linked", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + BaseURL: "https://example.com", + Enabled: true, + }, + Status: UpdateStatus{ + Active: NewRevision("old-version", 0), + }, + }, + linkedRevisions: []Revision{NewRevision("16.3.0", 0)}, + inWindow: true, + + installedRevision: NewRevision("16.3.0", 0), + installedBaseURL: "https://example.com", + linkedRevision: NewRevision("16.3.0", 0), + removedRevisions: []Revision{ + NewRevision("unknown-version", 0), + }, + setupCalls: 1, + restarted: true, + }, + { + name: "backup version removed on install", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + BaseURL: "https://example.com", + Enabled: true, + }, + Status: UpdateStatus{ + Active: NewRevision("old-version", 0), + Backup: toPtr(NewRevision("backup-version", 0)), + }, + }, + inWindow: true, + + installedRevision: NewRevision("16.3.0", 0), + installedBaseURL: "https://example.com", + linkedRevision: NewRevision("16.3.0", 0), + removedRevisions: []Revision{ + NewRevision("backup-version", 0), + NewRevision("unknown-version", 0), + }, + setupCalls: 1, + restarted: true, + }, + { + name: "backup version is linked", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + BaseURL: "https://example.com", + Enabled: true, + }, + Status: UpdateStatus{ + Active: NewRevision("old-version", 0), + Backup: toPtr(NewRevision("backup-version", 0)), + }, + }, + inWindow: true, + linkedRevisions: []Revision{NewRevision("backup-version", 0)}, + + installedRevision: NewRevision("16.3.0", 0), + installedBaseURL: "https://example.com", + linkedRevision: NewRevision("16.3.0", 0), + removedRevisions: []Revision{ + NewRevision("unknown-version", 0), + }, + setupCalls: 1, + restarted: true, + }, + { + name: "backup version kept when no change", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + BaseURL: "https://example.com", + Enabled: true, + }, + Status: UpdateStatus{ + Active: NewRevision("16.3.0", 0), + Backup: toPtr(NewRevision("backup-version", 0)), + }, + }, + inWindow: true, + }, + { + name: "config does not exist", + }, + { + name: "FIPS and Enterprise flags", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + BaseURL: "https://example.com", + Enabled: true, + }, + Status: UpdateStatus{ + Active: NewRevision("old-version", autoupdate.FlagEnterprise|autoupdate.FlagFIPS), + Backup: toPtr(NewRevision("backup-version", autoupdate.FlagEnterprise|autoupdate.FlagFIPS)), + }, + }, + inWindow: true, + flags: autoupdate.FlagEnterprise | autoupdate.FlagFIPS, + + installedRevision: NewRevision("16.3.0", autoupdate.FlagEnterprise|autoupdate.FlagFIPS), + installedBaseURL: "https://example.com", + linkedRevision: NewRevision("16.3.0", autoupdate.FlagEnterprise|autoupdate.FlagFIPS), + removedRevisions: []Revision{ + NewRevision("backup-version", autoupdate.FlagEnterprise|autoupdate.FlagFIPS), + NewRevision("unknown-version", 0), + }, + setupCalls: 1, + restarted: true, + }, + { + name: "invalid metadata", + cfg: &UpdateConfig{}, + errMatch: "invalid", + }, + { + name: "setup fails", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + BaseURL: "https://example.com", + Enabled: true, + }, + Status: UpdateStatus{ + Active: NewRevision("old-version", 0), + Backup: toPtr(NewRevision("backup-version", 0)), + }, + }, + inWindow: true, + setupErr: errors.New("setup error"), + + installedRevision: NewRevision("16.3.0", 0), + installedBaseURL: "https://example.com", + linkedRevision: NewRevision("16.3.0", 0), + removedRevisions: []Revision{ + NewRevision("backup-version", 0), + }, + reloadCalls: 1, + revertCalls: 1, + setupCalls: 1, + restarted: true, + errMatch: "setup error", + }, + { + name: "agpl requires base URL", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + Enabled: true, + }, + Status: UpdateStatus{ + Active: NewRevision("old-version", 0), + Backup: toPtr(NewRevision("backup-version", 0)), + }, + }, + inWindow: true, + agpl: true, + + reloadCalls: 0, + revertCalls: 0, + setupCalls: 0, + errMatch: "AGPL", + }, + { + name: "skip version", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + BaseURL: "https://example.com", + Enabled: true, + }, + Status: UpdateStatus{ + Active: NewRevision("old-version", 0), + Backup: toPtr(NewRevision("backup-version", 0)), + Skip: toPtr(NewRevision("16.3.0", 0)), + }, + }, + inWindow: true, + }, + { + name: "pinned version", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + BaseURL: "https://example.com", + Enabled: true, + Pinned: true, + }, + Status: UpdateStatus{ + Active: NewRevision("old-version", 0), + Backup: toPtr(NewRevision("backup-version", 0)), + }, + }, + inWindow: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var requestedGroup string + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestedGroup = r.URL.Query().Get("group") + config := webclient.PingResponse{ + AutoUpdate: webclient.AutoUpdateSettings{ + AgentVersion: "16.3.0", + AgentAutoUpdate: tt.inWindow, + }, + } + config.Edition = "community" + if tt.flags&autoupdate.FlagEnterprise != 0 { + config.Edition = "ent" + } + if tt.agpl { + config.Edition = "oss" + } + config.FIPS = tt.flags&autoupdate.FlagFIPS != 0 + err := json.NewEncoder(w).Encode(config) + require.NoError(t, err) + })) + t.Cleanup(server.Close) + + dir := t.TempDir() + ns := &Namespace{ + installDir: dir, + defaultPathDir: "ignored", + } + _, err := ns.Init() + require.NoError(t, err) + cfgPath := filepath.Join(ns.Dir(), updateConfigName) + + updater, err := NewLocalUpdater(LocalUpdaterConfig{ + InsecureSkipVerify: true, + }, ns) + require.NoError(t, err) + + // Create config file only if provided in test case + if tt.cfg != nil { + tt.cfg.Spec.Proxy = strings.TrimPrefix(server.URL, "https://") + b, err := yaml.Marshal(tt.cfg) + require.NoError(t, err) + err = os.WriteFile(cfgPath, b, 0600) + require.NoError(t, err) + } + + var ( + installedRevision Revision + installedBaseURL string + linkedRevision Revision + removedRevisions []Revision + revertFuncCalls int + setupCalls int + revertSetupCalls int + reloadCalls int + ) + updater.Installer = &testInstaller{ + FuncInstall: func(_ context.Context, rev Revision, baseURL string, force bool) error { + for _, r := range tt.linkedRevisions { + if r == rev { + require.False(t, force) + } + } + installedRevision = rev + installedBaseURL = baseURL + return tt.installErr + }, + FuncLink: func(_ context.Context, rev Revision, path string, force bool) (revert func(context.Context) bool, err error) { + linkedRevision = rev + return func(_ context.Context) bool { + revertFuncCalls++ + return true + }, nil + }, + FuncList: func(_ context.Context) (revs []Revision, err error) { + return slices.Compact([]Revision{ + installedRevision, + tt.cfg.Status.Active, + NewRevision("unknown-version", 0), + }), nil + }, + FuncRemove: func(_ context.Context, rev Revision) error { + removedRevisions = append(removedRevisions, rev) + return nil + }, + FuncIsLinked: func(ctx context.Context, rev Revision, path string) (bool, error) { + for _, r := range tt.linkedRevisions { + if r == rev { + return true, nil + } + } + return false, nil + }, + } + updater.Process = &testProcess{ + FuncReload: func(_ context.Context) error { + reloadCalls++ + return tt.reloadErr + }, + FuncIsPresent: func(ctx context.Context) (bool, error) { + return true, nil + }, + FuncIsEnabled: func(ctx context.Context) (bool, error) { + return !tt.notEnabled, nil + }, + FuncIsActive: func(ctx context.Context) (bool, error) { + return !tt.notActive, nil + }, + } + var restarted bool + updater.ReexecSetup = func(_ context.Context, path string, reload bool) error { + restarted = reload + setupCalls++ + return tt.setupErr + } + updater.SetupNamespace = func(_ context.Context, path string) error { + revertSetupCalls++ + return nil + } + + ctx := context.Background() + err = updater.Update(ctx, tt.now) + if tt.errMatch != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMatch) + } else { + require.NoError(t, err) + } + require.Equal(t, tt.installedRevision, installedRevision) + require.Equal(t, tt.installedBaseURL, installedBaseURL) + require.Equal(t, tt.linkedRevision, linkedRevision) + require.Equal(t, tt.removedRevisions, removedRevisions) + require.Equal(t, tt.flags, installedRevision.Flags) + require.Equal(t, tt.requestGroup, requestedGroup) + require.Equal(t, tt.reloadCalls, reloadCalls) + require.Equal(t, tt.revertCalls, revertSetupCalls) + require.Equal(t, tt.revertCalls, revertFuncCalls) + require.Equal(t, tt.setupCalls, setupCalls) + require.Equal(t, tt.restarted, restarted) + + if tt.cfg == nil { + _, err := os.Stat(cfgPath) + require.Error(t, err) + return + } + + data, err := os.ReadFile(cfgPath) + require.NoError(t, err) + data = blankTestAddr(data) + + if golden.ShouldSet() { + golden.Set(t, data) + } + require.Equal(t, string(golden.Get(t)), string(data)) + }) + } +} + +func TestUpdater_LinkPackage(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *UpdateConfig // nil -> file not present + tryLinkSystemErr error + syncErr error + notPresent bool + + syncCalls int + tryLinkSystemCalls int + errMatch string + }{ + { + name: "updates enabled", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Enabled: true, + }, + }, + + tryLinkSystemCalls: 0, + syncCalls: 0, + }, + { + name: "pinned", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Pinned: true, + }, + }, + + tryLinkSystemCalls: 0, + syncCalls: 0, + }, + { + name: "updates disabled", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Enabled: false, + }, + }, + + tryLinkSystemCalls: 1, + syncCalls: 1, + }, + { + name: "already linked", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Enabled: false, + }, + }, + tryLinkSystemErr: ErrLinked, + + tryLinkSystemCalls: 1, + syncCalls: 0, + }, + { + name: "link error", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Enabled: false, + }, + }, + tryLinkSystemErr: errors.New("bad"), + + tryLinkSystemCalls: 1, + syncCalls: 0, + errMatch: "bad", + }, + { + name: "no config", + tryLinkSystemCalls: 1, + syncCalls: 1, + }, + { + name: "systemd is not installed", + tryLinkSystemCalls: 1, + syncCalls: 1, + syncErr: ErrNotSupported, + }, + { + name: "systemd is not installed, already linked", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Enabled: false, + }, + }, + tryLinkSystemCalls: 1, + syncCalls: 1, + syncErr: ErrNotSupported, + }, + { + name: "SELinux blocks service from being read", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Enabled: false, + }, + }, + tryLinkSystemCalls: 1, + syncCalls: 1, + notPresent: true, + errMatch: "cannot find systemd service", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + ns := &Namespace{installDir: dir} + _, err := ns.Init() + require.NoError(t, err) + cfgPath := filepath.Join(ns.Dir(), updateConfigName) + + updater, err := NewLocalUpdater(LocalUpdaterConfig{ + InsecureSkipVerify: true, + }, ns) + require.NoError(t, err) + + // Create config file only if provided in test case + if tt.cfg != nil { + b, err := yaml.Marshal(tt.cfg) + require.NoError(t, err) + err = os.WriteFile(cfgPath, b, 0600) + require.NoError(t, err) + } + + var tryLinkSystemCalls int + updater.Installer = &testInstaller{ + FuncTryLinkSystem: func(_ context.Context) error { + tryLinkSystemCalls++ + return tt.tryLinkSystemErr + }, + } + var syncCalls int + updater.Process = &testProcess{ + FuncSync: func(_ context.Context) error { + syncCalls++ + return tt.syncErr + }, + FuncIsPresent: func(ctx context.Context) (bool, error) { + return !tt.notPresent, nil + }, + } + + ctx := context.Background() + err = updater.LinkPackage(ctx) + if tt.errMatch != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMatch) + } else { + require.NoError(t, err) + } + require.Equal(t, tt.tryLinkSystemCalls, tryLinkSystemCalls) + require.Equal(t, tt.syncCalls, syncCalls) + }) + } +} + +func TestUpdater_Remove(t *testing.T) { + t.Parallel() + + const version = "active-version" + + tests := []struct { + name string + cfg *UpdateConfig // nil -> file not present + linkSystemErr error + isEnabledErr error + syncErr error + reloadErr error + processEnabled bool + force bool + + unlinkedVersion string + teardownCalls int + syncCalls int + revertFuncCalls int + linkSystemCalls int + reloadCalls int + errMatch string + }{ + { + name: "no config", + teardownCalls: 1, + }, + { + name: "no active version", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + }, + teardownCalls: 1, + }, + { + name: "no conflicting system links, process disabled, force", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Status: UpdateStatus{ + Active: NewRevision(version, 0), + }, + }, + unlinkedVersion: version, + teardownCalls: 1, + force: true, + }, + { + name: "no system links, process enabled, force", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + }, + Status: UpdateStatus{ + Active: NewRevision(version, 0), + }, + }, + linkSystemErr: ErrNoBinaries, + linkSystemCalls: 1, + processEnabled: true, + force: true, + errMatch: "refusing to remove", + }, + { + name: "no system links, process disabled, force", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + }, + Status: UpdateStatus{ + Active: NewRevision(version, 0), + }, + }, + linkSystemErr: ErrNoBinaries, + linkSystemCalls: 1, + unlinkedVersion: version, + teardownCalls: 1, + force: true, + }, + { + name: "no system links, process disabled, no force", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + }, + Status: UpdateStatus{ + Active: NewRevision(version, 0), + }, + }, + linkSystemErr: ErrNoBinaries, + linkSystemCalls: 1, + errMatch: "unable to remove", + }, + { + name: "no system links, process disabled, no systemd, force", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + }, + Status: UpdateStatus{ + Active: NewRevision(version, 0), + }, + }, + linkSystemErr: ErrNoBinaries, + linkSystemCalls: 1, + isEnabledErr: ErrNotSupported, + unlinkedVersion: version, + teardownCalls: 1, + force: true, + }, + { + name: "active version", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + }, + Status: UpdateStatus{ + Active: NewRevision(version, 0), + }, + }, + linkSystemCalls: 1, + syncCalls: 1, + reloadCalls: 1, + teardownCalls: 1, + }, + { + name: "active version, no systemd", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + }, + Status: UpdateStatus{ + Active: NewRevision(version, 0), + }, + }, + linkSystemCalls: 1, + syncCalls: 1, + reloadCalls: 1, + teardownCalls: 1, + syncErr: ErrNotSupported, + reloadErr: ErrNotSupported, + }, + { + name: "active version, no reload", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + }, + Status: UpdateStatus{ + Active: NewRevision(version, 0), + }, + }, + linkSystemCalls: 1, + syncCalls: 1, + reloadCalls: 1, + teardownCalls: 1, + reloadErr: ErrNotNeeded, + }, + { + name: "active version, sync error", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + }, + Status: UpdateStatus{ + Active: NewRevision(version, 0), + }, + }, + linkSystemCalls: 1, + syncCalls: 2, + revertFuncCalls: 1, + syncErr: errors.New("sync error"), + errMatch: "configuration", + }, + { + name: "active version, reload error", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Spec: UpdateSpec{ + Path: defaultPathDir, + }, + Status: UpdateStatus{ + Active: NewRevision(version, 0), + }, + }, + linkSystemCalls: 1, + syncCalls: 2, + reloadCalls: 2, + revertFuncCalls: 1, + reloadErr: errors.New("reload error"), + errMatch: "start", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + ns := &Namespace{installDir: dir} + _, err := ns.Init() + require.NoError(t, err) + cfgPath := filepath.Join(ns.Dir(), updateConfigName) + + updater, err := NewLocalUpdater(LocalUpdaterConfig{ + InsecureSkipVerify: true, + }, ns) + require.NoError(t, err) // Create config file only if provided in test case if tt.cfg != nil { @@ -91,47 +1239,92 @@ func TestUpdater_Disable(t *testing.T) { err = os.WriteFile(cfgPath, b, 0600) require.NoError(t, err) } - updater, err := NewLocalUpdater(LocalUpdaterConfig{ - InsecureSkipVerify: true, - VersionsDir: dir, - }) - require.NoError(t, err) - err = updater.Disable(context.Background()) - if tt.errMatch != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tt.errMatch) - return - } - require.NoError(t, err) - - data, err := os.ReadFile(cfgPath) - // If no config is present, disable should not create it - if tt.cfg == nil { - require.ErrorIs(t, err, os.ErrNotExist) - return + var ( + linkSystemCalls int + revertFuncCalls int + syncCalls int + reloadCalls int + teardownCalls int + unlinkedVersion string + ) + updater.Installer = &testInstaller{ + FuncLinkSystem: func(_ context.Context) (revert func(context.Context) bool, err error) { + linkSystemCalls++ + return func(_ context.Context) bool { + revertFuncCalls++ + return true + }, tt.linkSystemErr + }, + FuncUnlink: func(_ context.Context, rev Revision, path string) error { + unlinkedVersion = rev.Version + return nil + }, + } + updater.Process = &testProcess{ + FuncSync: func(_ context.Context) error { + syncCalls++ + return tt.syncErr + }, + FuncReload: func(_ context.Context) error { + reloadCalls++ + return tt.reloadErr + }, + FuncIsEnabled: func(_ context.Context) (bool, error) { + return tt.processEnabled, tt.isEnabledErr + }, + FuncIsActive: func(_ context.Context) (bool, error) { + return false, nil + }, + } + updater.TeardownNamespace = func(_ context.Context) error { + teardownCalls++ + return nil } - require.NoError(t, err) - if golden.ShouldSet() { - golden.Set(t, data) + ctx := context.Background() + err = updater.Remove(ctx, tt.force) + if tt.errMatch != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMatch) + } else { + require.NoError(t, err) } - require.Equal(t, string(golden.Get(t)), string(data)) + require.Equal(t, tt.syncCalls, syncCalls) + require.Equal(t, tt.reloadCalls, reloadCalls) + require.Equal(t, tt.linkSystemCalls, linkSystemCalls) + require.Equal(t, tt.revertFuncCalls, revertFuncCalls) + require.Equal(t, tt.unlinkedVersion, unlinkedVersion) + require.Equal(t, tt.teardownCalls, teardownCalls) }) } } -func TestUpdater_Enable(t *testing.T) { +func TestUpdater_Install(t *testing.T) { t.Parallel() tests := []struct { name string cfg *UpdateConfig // nil -> file not present userCfg OverrideConfig + flags autoupdate.InstallFlags + agpl bool installErr error + setupErr error + reloadErr error + notPresent bool + notEnabled bool + notActive bool - installedVersion string - installedTemplate string + removedRevision Revision + installedRevision Revision + installedBaseURL string + linkedRevision Revision + requestGroup string + reloadCalls int + revertCalls int + setupCalls int + restarted bool errMatch string }{ { @@ -140,15 +1333,22 @@ func TestUpdater_Enable(t *testing.T) { Version: updateConfigVersion, Kind: updateConfigKind, Spec: UpdateSpec{ - Group: "group", - URLTemplate: "https://example.com", + Enabled: true, + Group: "group", + Path: "/path", + BaseURL: "https://example.com", }, Status: UpdateStatus{ - ActiveVersion: "old-version", + Active: NewRevision("old-version", 0), }, }, - installedVersion: "16.3.0", - installedTemplate: "https://example.com", + + installedRevision: NewRevision("16.3.0", 0), + installedBaseURL: "https://example.com", + linkedRevision: NewRevision("16.3.0", 0), + requestGroup: "group", + setupCalls: 1, + restarted: true, }, { name: "config from user", @@ -156,35 +1356,62 @@ func TestUpdater_Enable(t *testing.T) { Version: updateConfigVersion, Kind: updateConfigKind, Spec: UpdateSpec{ - Group: "old-group", - URLTemplate: "https://example.com/old", + Group: "old-group", + BaseURL: "https://example.com/old", }, Status: UpdateStatus{ - ActiveVersion: "old-version", + Active: NewRevision("old-version", 0), }, }, userCfg: OverrideConfig{ - Group: "new-group", - URLTemplate: "https://example.com/new", + UpdateSpec: UpdateSpec{ + Enabled: true, + Path: "/path", + Group: "new-group", + BaseURL: "https://example.com/new", + }, ForceVersion: "new-version", }, - installedVersion: "new-version", - installedTemplate: "https://example.com/new", + + installedRevision: NewRevision("new-version", 0), + installedBaseURL: "https://example.com/new", + linkedRevision: NewRevision("new-version", 0), + requestGroup: "new-group", + setupCalls: 1, + restarted: true, }, { - name: "already enabled", + name: "defaults", cfg: &UpdateConfig{ Version: updateConfigVersion, Kind: updateConfigKind, - Spec: UpdateSpec{ - Enabled: true, + Status: UpdateStatus{ + Active: NewRevision("old-version", 0), }, + }, + + installedRevision: NewRevision("16.3.0", 0), + installedBaseURL: autoupdate.DefaultBaseURL, + linkedRevision: NewRevision("16.3.0", 0), + setupCalls: 1, + restarted: true, + }, + { + name: "override skip", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, Status: UpdateStatus{ - ActiveVersion: "old-version", + Active: NewRevision("old-version", 0), + Skip: toPtr(NewRevision("16.3.0", 0)), }, }, - installedVersion: "16.3.0", - installedTemplate: cdnURITemplate, + + installedRevision: NewRevision("16.3.0", 0), + installedBaseURL: autoupdate.DefaultBaseURL, + linkedRevision: NewRevision("16.3.0", 0), + setupCalls: 1, + restarted: true, }, { name: "insecure URL", @@ -192,22 +1419,32 @@ func TestUpdater_Enable(t *testing.T) { Version: updateConfigVersion, Kind: updateConfigKind, Spec: UpdateSpec{ - URLTemplate: "http://example.com", + BaseURL: "http://example.com", }, }, - errMatch: "URL must use TLS", + + errMatch: "must use TLS", }, { name: "install error", cfg: &UpdateConfig{ Version: updateConfigVersion, Kind: updateConfigKind, - Spec: UpdateSpec{ - URLTemplate: "https://example.com", - }, }, installErr: errors.New("install error"), - errMatch: "install error", + + installedRevision: NewRevision("16.3.0", 0), + installedBaseURL: autoupdate.DefaultBaseURL, + errMatch: "install error", + }, + { + name: "agpl requires base URL", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + }, + agpl: true, + errMatch: "AGPL", }, { name: "version already installed", @@ -215,29 +1452,143 @@ func TestUpdater_Enable(t *testing.T) { Version: updateConfigVersion, Kind: updateConfigKind, Status: UpdateStatus{ - ActiveVersion: "16.3.0", + Active: NewRevision("16.3.0", 0), + }, + }, + + installedRevision: NewRevision("16.3.0", 0), + installedBaseURL: autoupdate.DefaultBaseURL, + linkedRevision: NewRevision("16.3.0", 0), + setupCalls: 1, + restarted: false, + }, + { + name: "backup version removed on install", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Status: UpdateStatus{ + Active: NewRevision("old-version", 0), + Backup: toPtr(NewRevision("backup-version", 0)), }, }, - installedVersion: "16.3.0", - installedTemplate: cdnURITemplate, + + installedRevision: NewRevision("16.3.0", 0), + installedBaseURL: autoupdate.DefaultBaseURL, + linkedRevision: NewRevision("16.3.0", 0), + removedRevision: NewRevision("backup-version", 0), + setupCalls: 1, + restarted: true, + }, + { + name: "backup version kept for validation", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Status: UpdateStatus{ + Active: NewRevision("16.3.0", 0), + Backup: toPtr(NewRevision("backup-version", 0)), + }, + }, + + installedRevision: NewRevision("16.3.0", 0), + installedBaseURL: autoupdate.DefaultBaseURL, + linkedRevision: NewRevision("16.3.0", 0), + setupCalls: 1, + }, + { + name: "config does not exist", + + installedRevision: NewRevision("16.3.0", 0), + installedBaseURL: autoupdate.DefaultBaseURL, + linkedRevision: NewRevision("16.3.0", 0), + setupCalls: 1, + restarted: true, }, { - name: "config does not exist", - installedVersion: "16.3.0", - installedTemplate: cdnURITemplate, + name: "FIPS and Enterprise flags", + flags: autoupdate.FlagEnterprise | autoupdate.FlagFIPS, + installedRevision: NewRevision("16.3.0", autoupdate.FlagEnterprise|autoupdate.FlagFIPS), + installedBaseURL: autoupdate.DefaultBaseURL, + linkedRevision: NewRevision("16.3.0", autoupdate.FlagEnterprise|autoupdate.FlagFIPS), + setupCalls: 1, + restarted: true, }, { name: "invalid metadata", cfg: &UpdateConfig{}, errMatch: "invalid", }, + { + name: "setup fails", + setupErr: errors.New("setup error"), + + installedRevision: NewRevision("16.3.0", 0), + installedBaseURL: autoupdate.DefaultBaseURL, + linkedRevision: NewRevision("16.3.0", 0), + revertCalls: 1, + setupCalls: 1, + reloadCalls: 1, + restarted: true, + errMatch: "setup error", + }, + { + name: "setup fails already installed", + cfg: &UpdateConfig{ + Version: updateConfigVersion, + Kind: updateConfigKind, + Status: UpdateStatus{ + Active: NewRevision("16.3.0", 0), + }, + }, + setupErr: errors.New("setup error"), + + installedRevision: NewRevision("16.3.0", 0), + installedBaseURL: autoupdate.DefaultBaseURL, + linkedRevision: NewRevision("16.3.0", 0), + revertCalls: 1, + setupCalls: 1, + errMatch: "setup error", + }, + { + name: "no need to reload", + reloadErr: ErrNotNeeded, + + installedRevision: NewRevision("16.3.0", 0), + installedBaseURL: autoupdate.DefaultBaseURL, + linkedRevision: NewRevision("16.3.0", 0), + setupCalls: 1, + restarted: true, + }, + { + name: "not started or enabled", + notEnabled: true, + notActive: true, + + installedRevision: NewRevision("16.3.0", 0), + installedBaseURL: autoupdate.DefaultBaseURL, + linkedRevision: NewRevision("16.3.0", 0), + setupCalls: 1, + restarted: true, + }, } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { dir := t.TempDir() - cfgPath := filepath.Join(dir, "update.yaml") + ns := &Namespace{ + installDir: dir, + defaultPathDir: defaultPathDir, + defaultProxyAddr: "default-proxy", + } + _, err := ns.Init() + require.NoError(t, err) + cfgPath := filepath.Join(ns.Dir(), updateConfigName) + + updater, err := NewLocalUpdater(LocalUpdaterConfig{ + InsecureSkipVerify: true, + }, ns) + require.NoError(t, err) // Create config file only if provided in test case if tt.cfg != nil { @@ -247,9 +1598,24 @@ func TestUpdater_Enable(t *testing.T) { require.NoError(t, err) } + var requestedGroup string server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // TODO(sclevine): add web API test including group verification - w.Write([]byte(`{}`)) + requestedGroup = r.URL.Query().Get("group") + config := webclient.PingResponse{ + AutoUpdate: webclient.AutoUpdateSettings{ + AgentVersion: "16.3.0", + }, + } + config.Edition = "community" + if tt.flags&autoupdate.FlagEnterprise != 0 { + config.Edition = "ent" + } + if tt.agpl { + config.Edition = "oss" + } + config.FIPS = tt.flags&autoupdate.FlagFIPS != 0 + err := json.NewEncoder(w).Encode(config) + require.NoError(t, err) })) t.Cleanup(server.Close) @@ -257,31 +1623,91 @@ func TestUpdater_Enable(t *testing.T) { tt.userCfg.Proxy = strings.TrimPrefix(server.URL, "https://") } - updater, err := NewLocalUpdater(LocalUpdaterConfig{ - InsecureSkipVerify: true, - VersionsDir: dir, - }) - require.NoError(t, err) - - var installedVersion, installedTemplate string + var ( + installedRevision Revision + installedBaseURL string + linkedRevision Revision + removedRevision Revision + revertFuncCalls int + reloadCalls int + setupCalls int + revertSetupCalls int + ) updater.Installer = &testInstaller{ - FuncInstall: func(_ context.Context, version, template string, _ InstallFlags) error { - installedVersion = version - installedTemplate = template + FuncInstall: func(_ context.Context, rev Revision, baseURL string, force bool) error { + installedRevision = rev + installedBaseURL = baseURL return tt.installErr }, + FuncLink: func(_ context.Context, rev Revision, path string, force bool) (revert func(context.Context) bool, err error) { + linkedRevision = rev + return func(_ context.Context) bool { + revertFuncCalls++ + return true + }, nil + }, + FuncList: func(_ context.Context) (revs []Revision, err error) { + return []Revision{}, nil + }, + FuncRemove: func(_ context.Context, rev Revision) error { + removedRevision = rev + return nil + }, + FuncIsLinked: func(ctx context.Context, rev Revision, path string) (bool, error) { + return false, nil + }, + } + updater.Process = &testProcess{ + FuncReload: func(_ context.Context) error { + reloadCalls++ + return tt.reloadErr + }, + FuncIsPresent: func(ctx context.Context) (bool, error) { + return !tt.notPresent, nil + }, + FuncIsEnabled: func(ctx context.Context) (bool, error) { + return !tt.notEnabled, nil + }, + FuncIsActive: func(ctx context.Context) (bool, error) { + return !tt.notActive, nil + }, + } + var restarted bool + updater.ReexecSetup = func(_ context.Context, path string, reload bool) error { + setupCalls++ + restarted = reload + return tt.setupErr + } + updater.SetupNamespace = func(_ context.Context, path string) error { + revertSetupCalls++ + return nil } ctx := context.Background() - err = updater.Enable(ctx, tt.userCfg) + err = updater.Install(ctx, tt.userCfg) if tt.errMatch != "" { require.Error(t, err) assert.Contains(t, err.Error(), tt.errMatch) + } else { + require.NoError(t, err) + } + require.Equal(t, tt.installedRevision, installedRevision) + require.Equal(t, tt.installedBaseURL, installedBaseURL) + require.Equal(t, tt.linkedRevision, linkedRevision) + require.Equal(t, tt.removedRevision, removedRevision) + require.Equal(t, tt.flags, installedRevision.Flags) + require.Equal(t, tt.requestGroup, requestedGroup) + require.Equal(t, tt.reloadCalls, reloadCalls) + require.Equal(t, tt.revertCalls, revertSetupCalls) + require.Equal(t, tt.revertCalls, revertFuncCalls) + require.Equal(t, tt.setupCalls, setupCalls) + require.Equal(t, tt.restarted, restarted) + + if tt.cfg == nil && err != nil { + _, err := os.Stat(cfgPath) + require.Error(t, err) return } - require.NoError(t, err) - require.Equal(t, tt.installedVersion, installedVersion) - require.Equal(t, tt.installedTemplate, installedTemplate) data, err := os.ReadFile(cfgPath) require.NoError(t, err) @@ -295,6 +1721,170 @@ func TestUpdater_Enable(t *testing.T) { } } +func TestSameProxies(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + a, b string + match bool + }{ + { + name: "protocol missing with port", + a: "https://example.com:8080", + b: "example.com:8080", + match: true, + }, + { + name: "protocol missing without port", + a: "https://example.com", + b: "example.com", + match: true, + }, + { + name: "same with port", + a: "example.com:443", + b: "example.com:443", + match: true, + }, + { + name: "does not set default teleport port", + a: "example.com", + b: "example.com:3080", + match: false, + }, + { + name: "does set default standard port", + a: "example.com", + b: "example.com:443", + match: true, + }, + { + name: "other formats if equal", + a: "@123", + b: "@123", + match: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := sameProxies(tt.a, tt.b) + require.Equal(t, tt.match, s) + }) + } +} + +func TestUpdater_Setup(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + restart bool + present bool + setupErr error + presentErr error + reloadErr error + + errMatch string + }{ + { + name: "no restart", + restart: false, + present: true, + }, + { + name: "restart", + restart: true, + present: true, + }, + { + name: "reload not needed", + restart: true, + present: true, + reloadErr: ErrNotNeeded, + }, + { + name: "not present", + restart: true, + present: false, + errMatch: "cannot find systemd", + }, + { + name: "setup error", + restart: false, + setupErr: errors.New("some error"), + errMatch: "some error", + }, + { + name: "setup error canceled", + restart: false, + setupErr: context.Canceled, + errMatch: "canceled", + }, + { + name: "present error", + restart: false, + presentErr: errors.New("some error"), + errMatch: "some error", + }, + { + name: "present error canceled", + restart: false, + presentErr: context.Canceled, + errMatch: "canceled", + }, + { + name: "preset error not supported", + restart: false, + presentErr: ErrNotSupported, + }, + { + name: "reload error canceled", + restart: true, + present: true, + reloadErr: context.Canceled, + errMatch: "canceled", + }, + { + name: "reload error", + restart: true, + present: true, + reloadErr: errors.New("some error"), + errMatch: "some error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ns := &Namespace{} + updater, err := NewLocalUpdater(LocalUpdaterConfig{}, ns) + require.NoError(t, err) + + updater.Process = &testProcess{ + FuncReload: func(_ context.Context) error { + return tt.reloadErr + }, + FuncIsPresent: func(_ context.Context) (bool, error) { + return tt.present, tt.presentErr + }, + } + updater.SetupNamespace = func(_ context.Context, path string) error { + require.Equal(t, "test", path) + return tt.setupErr + } + + ctx := context.Background() + err = updater.Setup(ctx, "test", tt.restart) + if tt.errMatch != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMatch) + } else { + require.NoError(t, err) + } + }) + } +} + var serverRegexp = regexp.MustCompile("127.0.0.1:[0-9]+") func blankTestAddr(s []byte) []byte { @@ -302,14 +1892,82 @@ func blankTestAddr(s []byte) []byte { } type testInstaller struct { - FuncInstall func(ctx context.Context, version, template string, flags InstallFlags) error - FuncRemove func(ctx context.Context, version string) error + FuncInstall func(ctx context.Context, rev Revision, baseURL string, force bool) error + FuncRemove func(ctx context.Context, rev Revision) error + FuncLink func(ctx context.Context, rev Revision, path string, force bool) (revert func(context.Context) bool, err error) + FuncLinkSystem func(ctx context.Context) (revert func(context.Context) bool, err error) + FuncTryLink func(ctx context.Context, rev Revision, path string) error + FuncTryLinkSystem func(ctx context.Context) error + FuncUnlink func(ctx context.Context, rev Revision, path string) error + FuncUnlinkSystem func(ctx context.Context) error + FuncList func(ctx context.Context) (revs []Revision, err error) + FuncIsLinked func(ctx context.Context, rev Revision, path string) (bool, error) +} + +func (ti *testInstaller) Install(ctx context.Context, rev Revision, baseURL string, force bool) error { + return ti.FuncInstall(ctx, rev, baseURL, force) +} + +func (ti *testInstaller) Remove(ctx context.Context, rev Revision) error { + return ti.FuncRemove(ctx, rev) +} + +func (ti *testInstaller) Link(ctx context.Context, rev Revision, path string, force bool) (revert func(context.Context) bool, err error) { + return ti.FuncLink(ctx, rev, path, force) +} + +func (ti *testInstaller) LinkSystem(ctx context.Context) (revert func(context.Context) bool, err error) { + return ti.FuncLinkSystem(ctx) +} + +func (ti *testInstaller) TryLink(ctx context.Context, rev Revision, path string) error { + return ti.FuncTryLink(ctx, rev, path) +} + +func (ti *testInstaller) TryLinkSystem(ctx context.Context) error { + return ti.FuncTryLinkSystem(ctx) +} + +func (ti *testInstaller) Unlink(ctx context.Context, rev Revision, path string) error { + return ti.FuncUnlink(ctx, rev, path) +} + +func (ti *testInstaller) UnlinkSystem(ctx context.Context) error { + return ti.FuncUnlinkSystem(ctx) +} + +func (ti *testInstaller) List(ctx context.Context) (revs []Revision, err error) { + return ti.FuncList(ctx) +} + +func (ti *testInstaller) IsLinked(ctx context.Context, rev Revision, path string) (bool, error) { + return ti.FuncIsLinked(ctx, rev, path) +} + +type testProcess struct { + FuncReload func(ctx context.Context) error + FuncSync func(ctx context.Context) error + FuncIsEnabled func(ctx context.Context) (bool, error) + FuncIsActive func(ctx context.Context) (bool, error) + FuncIsPresent func(ctx context.Context) (bool, error) +} + +func (tp *testProcess) Reload(ctx context.Context) error { + return tp.FuncReload(ctx) +} + +func (tp *testProcess) Sync(ctx context.Context) error { + return tp.FuncSync(ctx) +} + +func (tp *testProcess) IsEnabled(ctx context.Context) (bool, error) { + return tp.FuncIsEnabled(ctx) } -func (ti *testInstaller) Install(ctx context.Context, version, template string, flags InstallFlags) error { - return ti.FuncInstall(ctx, version, template, flags) +func (tp *testProcess) IsActive(ctx context.Context) (bool, error) { + return tp.FuncIsActive(ctx) } -func (ti *testInstaller) Remove(ctx context.Context, version string) error { - return ti.FuncRemove(ctx, version) +func (tp *testProcess) IsPresent(ctx context.Context) (bool, error) { + return tp.FuncIsPresent(ctx) } diff --git a/lib/autoupdate/agent/validate.go b/lib/autoupdate/agent/validate.go new file mode 100644 index 0000000000000..8ce1732d3d5d1 --- /dev/null +++ b/lib/autoupdate/agent/validate.go @@ -0,0 +1,130 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package agent + +import ( + "bytes" + "context" + "io" + "log/slog" + "os" + "path/filepath" + "unicode" + + "github.com/gravitational/trace" +) + +const ( + // fileHeaderSniffBytes is the max size to read to determine a file's MIME type + fileHeaderSniffBytes = 512 // MIME standard size + // execModeMask is the minimum required set of bits to consider a file executable. + execModeMask = 0111 +) + +// Validator validates filesystem paths. +type Validator struct { + Log *slog.Logger +} + +// IsBinary returns true for working binaries that are executable by all users. +// If the file is irregular, non-executable, or a shell script, IsBinary returns false and logs a warning. +// IsBinary errors if lstat fails, a regular file is unreadable, or an executable file does not execute. +func (v *Validator) IsBinary(ctx context.Context, path string) (bool, error) { + // The behavior of this method is intended to protect against executable files + // being adding to the Teleport tgz that should not be made available on PATH, + // and additionally, should not cause installation to fail. + // While known copies of these files (e.g., "install") are excluded during extraction, + // it is safer to assume others could be present in past or future tgzs. + + if exec, err := v.IsExecutable(ctx, path); err != nil || !exec { + return exec, trace.Wrap(err) + } + name := filepath.Base(path) + d, err := readFileLimit(path, fileHeaderSniffBytes) + if err != nil { + return false, trace.Wrap(err) + } + // Refuse to test or link shell scripts + if isTextScript(d) { + v.Log.WarnContext(ctx, "Found unexpected shell script", "name", name) + return false, nil + } + v.Log.InfoContext(ctx, "Validating binary", "name", name) + r := localExec{ + Log: v.Log, + ErrLevel: slog.LevelDebug, + OutLevel: slog.LevelInfo, // always show version + } + code, err := r.Run(ctx, path, "version") + if code < 0 { + return false, trace.Wrap(err, "error validating binary %s", name) + } + if code > 0 { + v.Log.InfoContext(ctx, "Binary does not support version command", "name", name) + } + return true, nil +} + +// IsExecutable returns true for regular, executable files. +func (v *Validator) IsExecutable(ctx context.Context, path string) (bool, error) { + name := filepath.Base(path) + fi, err := os.Lstat(path) + if err != nil { + return false, trace.Wrap(err) + } + if !fi.Mode().IsRegular() { + v.Log.WarnContext(ctx, "Found unexpected irregular file", "name", name) + return false, nil + } + if fi.Mode()&execModeMask != execModeMask { + v.Log.WarnContext(ctx, "Found unexpected non-executable file", "name", name) + return false, nil + } + return true, nil +} + +func isTextScript(data []byte) bool { + data = bytes.TrimLeftFunc(data, unicode.IsSpace) + if !bytes.HasPrefix(data, []byte("#!")) { + return false + } + // Assume presence of MIME binary data bytes means binary: + // https://mimesniff.spec.whatwg.org/#terminology + for _, b := range data { + switch { + case b <= 0x08, b == 0x0B, + 0x0E <= b && b <= 0x1A, + 0x1C <= b && b <= 0x1F: + return false + } + } + return true +} + +// readFileLimit the first n bytes of a file. +func readFileLimit(name string, n int64) ([]byte, error) { + f, err := os.Open(name) + if err != nil { + return nil, err + } + defer f.Close() + var buf bytes.Buffer + _, err = io.Copy(&buf, io.LimitReader(f, n)) + return buf.Bytes(), trace.Wrap(err) +} diff --git a/lib/autoupdate/agent/validate_test.go b/lib/autoupdate/agent/validate_test.go new file mode 100644 index 0000000000000..71b979016f8fc --- /dev/null +++ b/lib/autoupdate/agent/validate_test.go @@ -0,0 +1,102 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package agent + +import ( + "bytes" + "context" + "log/slog" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestValidator_IsBinary(t *testing.T) { + for _, tt := range []struct { + name string + mode os.FileMode + contents string + + valid bool + errMatch string + logMatch string + }{ + { + name: "missing", + errMatch: "no such", + }, + { + name: "non-executable", + contents: "test", + mode: 0666, + logMatch: "non-executable", + }, + { + name: "shell script", + contents: " #!bash ", + mode: 0777, + logMatch: "unexpected shell", + }, + { + name: "unqualified shell script", + contents: " #!bash" + string([]byte{0x0B}), + mode: 0777, + errMatch: "validating binary", + }, + { + name: "exit 0", + contents: "#!/bin/sh\nexit 0\n" + string([]byte{0x0B}), + mode: 0777, + valid: true, + }, + { + name: "exit 1", + contents: "#!/bin/sh\nexit 1\n" + string([]byte{0x0B}), + mode: 0777, + valid: true, + logMatch: "version command", + }, + } { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + opts := &slog.HandlerOptions{AddSource: true} + log := slog.New(slog.NewTextHandler(&buf, opts)) + v := Validator{Log: log} + ctx := context.Background() + path := filepath.Join(t.TempDir(), "file") + if tt.contents != "" { + os.WriteFile(path, []byte(tt.contents), tt.mode) + } + val, err := v.IsBinary(ctx, path) + if tt.logMatch != "" { + require.Contains(t, buf.String(), tt.logMatch) + } + if tt.errMatch != "" { + require.Error(t, err) + require.False(t, val) + require.Contains(t, err.Error(), tt.errMatch) + return + } + require.Equal(t, tt.valid, val) + require.NoError(t, err) + }) + } +} diff --git a/lib/autoupdate/package_url.go b/lib/autoupdate/package_url.go index 9b283c3da59c2..b00eb59fea5c9 100644 --- a/lib/autoupdate/package_url.go +++ b/lib/autoupdate/package_url.go @@ -20,10 +20,12 @@ package autoupdate import ( "bytes" + "encoding/json" "runtime" "text/template" "github.com/gravitational/trace" + "gopkg.in/yaml.v3" ) // InstallFlags sets flags for the Teleport installation. @@ -54,6 +56,82 @@ const ( BaseURLEnvVar = "TELEPORT_CDN_BASE_URL" ) +// NewInstallFlagsFromStrings returns InstallFlags given a slice of human-readable strings. +func NewInstallFlagsFromStrings(s []string) InstallFlags { + var out InstallFlags + for _, f := range s { + for _, flag := range []InstallFlags{ + FlagEnterprise, + FlagFIPS, + } { + if f == flag.String() { + out |= flag + } + } + } + return out +} + +// Strings converts InstallFlags to a slice of human-readable strings. +func (i InstallFlags) Strings() []string { + var out []string + for _, flag := range []InstallFlags{ + FlagEnterprise, + FlagFIPS, + } { + if i&flag != 0 { + out = append(out, flag.String()) + } + } + return out +} + +// String returns the string representation of a single InstallFlag flag, or "Unknown". +func (i InstallFlags) String() string { + switch i { + case 0: + return "" + case FlagEnterprise: + return "Enterprise" + case FlagFIPS: + return "FIPS" + } + return "Unknown" +} + +// DirFlag returns the directory path representation of a single InstallFlag flag, or "unknown". +func (i InstallFlags) DirFlag() string { + switch i { + case 0: + return "" + case FlagEnterprise: + return "ent" + case FlagFIPS: + return "fips" + } + return "unknown" +} + +func (i InstallFlags) MarshalYAML() (any, error) { + return i.Strings(), nil +} + +func (i InstallFlags) MarshalJSON() ([]byte, error) { + return json.Marshal(i.Strings()) +} + +func (i *InstallFlags) UnmarshalYAML(n *yaml.Node) error { + var s []string + if err := n.Decode(&s); err != nil { + return trace.Wrap(err) + } + if i == nil { + return trace.BadParameter("nil install flags while parsing YAML") + } + *i = NewInstallFlagsFromStrings(s) + return nil +} + // MakeURL constructs the package download URL from template, base URL and revision. func MakeURL(uriTmpl string, baseURL string, pkg string, version string, flags InstallFlags) (string, error) { tmpl, err := template.New("uri").Parse(uriTmpl) diff --git a/lib/autoupdate/package_url_test.go b/lib/autoupdate/package_url_test.go new file mode 100644 index 0000000000000..b3eca4be38d7e --- /dev/null +++ b/lib/autoupdate/package_url_test.go @@ -0,0 +1,95 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package autoupdate + +import ( + "testing" + + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +func TestInstallFlagsYAML(t *testing.T) { + t.Parallel() + + for _, tt := range []struct { + name string + yaml string + flags InstallFlags + skipYAML bool + }{ + { + name: "both", + yaml: `["Enterprise", "FIPS"]`, + flags: FlagEnterprise | FlagFIPS, + }, + { + name: "order", + yaml: `["FIPS", "Enterprise"]`, + flags: FlagEnterprise | FlagFIPS, + skipYAML: true, + }, + { + name: "extra", + yaml: `["FIPS", "Enterprise", "bad"]`, + flags: FlagEnterprise | FlagFIPS, + skipYAML: true, + }, + { + name: "enterprise", + yaml: `["Enterprise"]`, + flags: FlagEnterprise, + }, + { + name: "fips", + yaml: `["FIPS"]`, + flags: FlagFIPS, + }, + { + name: "empty", + yaml: `[]`, + }, + { + name: "nil", + skipYAML: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + var flags InstallFlags + err := yaml.Unmarshal([]byte(tt.yaml), &flags) + require.NoError(t, err) + require.Equal(t, tt.flags, flags) + + // verify test YAML + var v any + err = yaml.Unmarshal([]byte(tt.yaml), &v) + require.NoError(t, err) + res, err := yaml.Marshal(v) + require.NoError(t, err) + + // compare verified YAML to flag output + out, err := yaml.Marshal(flags) + require.NoError(t, err) + + if !tt.skipYAML { + require.Equal(t, string(res), string(out)) + } + }) + } +} diff --git a/lib/client/debug/debug.go b/lib/client/debug/debug.go index 5553af7b8187c..368a4eadc3ab1 100644 --- a/lib/client/debug/debug.go +++ b/lib/client/debug/debug.go @@ -19,6 +19,7 @@ package debug import ( "bytes" "context" + "encoding/json" "io" "net" "net/http" @@ -55,9 +56,14 @@ func NewClient(socketPath string) *Client { clt: &http.Client{ Timeout: apidefaults.DefaultIOTimeout, Transport: &http.Transport{ - DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { - return net.Dial("unix", socketPath) + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "unix", socketPath) }, + DisableKeepAlives: true, + }, + CheckRedirect: func(_ *http.Request, _ []*http.Request) error { + return trace.Errorf("redirect via socket not allowed") }, }, } @@ -137,6 +143,36 @@ func (c *Client) CollectProfile(ctx context.Context, profileName string, seconds return result, nil } +// Readiness describes the readiness of the Teleport instance. +type Readiness struct { + // Ready is true if the instance is ready. + // This field is only set by clients, based on status. + Ready bool `json:"-"` + // Status provides more detail about the readiness status. + Status string `json:"status"` + // PID is the process PID + PID int `json:"pid"` +} + +// GetReadiness returns true if the Teleport service is ready. +func (c *Client) GetReadiness(ctx context.Context) (Readiness, error) { + var ready Readiness + resp, err := c.do(ctx, http.MethodGet, url.URL{Path: "/readyz"}, nil) + if err != nil { + return ready, trace.Wrap(err) + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusNotFound { + return ready, trace.NotFound("readiness endpoint not found") + } + ready.Ready = resp.StatusCode == http.StatusOK + err = json.NewDecoder(resp.Body).Decode(&ready) + if err != nil { + return ready, trace.Wrap(err) + } + return ready, nil +} + func (c *Client) do(ctx context.Context, method string, u url.URL, body []byte) (*http.Response, error) { u.Scheme = "http" u.Host = "debug" diff --git a/lib/client/debug/debug_test.go b/lib/client/debug/debug_test.go index 0c9872baa6724..fd1dbe663b31a 100644 --- a/lib/client/debug/debug_test.go +++ b/lib/client/debug/debug_test.go @@ -56,6 +56,69 @@ func TestSetLogLevel(t *testing.T) { }) } +func TestGetReadiness(t *testing.T) { + ctx := context.Background() + + t.Run("Success", func(t *testing.T) { + socketPath, _ := newSocketMockService(t, http.StatusOK, []byte(`{"status": "OK", "pid": 1234}`)) + clt := NewClient(socketPath) + + out, err := clt.GetReadiness(ctx) + require.NoError(t, err) + require.Equal(t, "OK", out.Status) + require.True(t, out.Ready) + require.Equal(t, 1234, out.PID) + }) + + t.Run("Failure", func(t *testing.T) { + socketPath, _ := newSocketMockService(t, http.StatusBadRequest, []byte(`{"status": "BAD", "pid": 1234}`)) + clt := NewClient(socketPath) + + out, err := clt.GetReadiness(ctx) + require.NoError(t, err) + require.Equal(t, "BAD", out.Status) + require.False(t, out.Ready) + require.Equal(t, 1234, out.PID) + }) + + t.Run("Not found", func(t *testing.T) { + socketPath, _ := newSocketMockService(t, http.StatusNotFound, []byte(`404`)) + clt := NewClient(socketPath) + + out, err := clt.GetReadiness(ctx) + require.True(t, trace.IsNotFound(err)) + require.Equal(t, "", out.Status) + require.False(t, out.Ready) + require.Equal(t, 0, out.PID) + }) + + t.Run("Closed", func(t *testing.T) { + socketPath, closeFn := newSocketMockService(t, http.StatusOK, []byte(`{"status": "OK", "pid": 1234}`)) + closeFn() + clt := NewClient(socketPath) + + out, err := clt.GetReadiness(ctx) + var netError net.Error + require.ErrorAs(t, err, &netError) + require.Equal(t, "", out.Status) + require.False(t, out.Ready) + require.Equal(t, 0, out.PID) + }) + + t.Run("Missing", func(t *testing.T) { + socketPath, _ := newSocketMockService(t, http.StatusOK, []byte(`{"status": "OK", "pid": 1234}`)) + err := os.RemoveAll(socketPath) + require.NoError(t, err) + clt := NewClient(socketPath) + + out, err := clt.GetReadiness(ctx) + require.ErrorIs(t, err, os.ErrNotExist) + require.Equal(t, "", out.Status) + require.False(t, out.Ready) + require.Equal(t, 0, out.PID) + }) +} + func TestCollectProfile(t *testing.T) { ctx := context.Background() @@ -144,6 +207,7 @@ func newSocketMockService(t *testing.T, status int, contents []byte) (string, fu t.Cleanup(func() { srv.Shutdown(context.Background()) }) return socketPath, func() []string { + srv.Shutdown(context.Background()) return requests } } diff --git a/lib/service/service.go b/lib/service/service.go index a7b76ae337c90..a1de3ec15b0dd 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -96,6 +96,7 @@ import ( "github.com/gravitational/teleport/lib/auth/storage" "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/automaticupgrades" + autoupdate "github.com/gravitational/teleport/lib/autoupdate/agent" "github.com/gravitational/teleport/lib/autoupdate/rollout" "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/backend/dynamo" @@ -1246,18 +1247,7 @@ func NewTeleport(cfg *servicecfg.Config) (*TeleportProcess, error) { return nil, trace.Wrap(err) } - upgraderKind := os.Getenv(automaticupgrades.EnvUpgrader) - upgraderVersion := automaticupgrades.GetUpgraderVersion(process.GracefulExitContext()) - if upgraderVersion == "" { - upgraderKind = "" - } - - // Instances deployed using the AWS OIDC integration are automatically updated - // by the proxy. The instance heartbeat should properly reflect that. - externalUpgrader := upgraderKind - if externalUpgrader == "" && os.Getenv(types.InstallMethodAWSOIDCDeployServiceEnvVar) == "true" { - externalUpgrader = types.OriginIntegrationAWSOIDC - } + upgraderKind, externalUpgrader, upgraderVersion := process.detectUpgrader() // note: we must create the inventory handle *after* registerExpectedServices because that function determines // the list of services (instance roles) to be included in the heartbeat. @@ -1286,7 +1276,26 @@ func NewTeleport(cfg *servicecfg.Config) (*TeleportProcess, error) { process.logger.WarnContext(process.ExitContext(), "Use of external upgraders on control-plane instances is not recommended.") } - if upgraderKind == "unit" { + switch upgraderKind { + case types.UpgraderKindTeleportUpdate: + isDefault, err := autoupdate.IsManagedAndDefault() + if err != nil { + return nil, trace.Wrap(err) + } + if !isDefault { + // Only write the nop schedule for the default updater. + // Suffixed installations of Teleport can coexist with the old upgrader system. + break + } + driver, err := uw.NewSystemdUnitDriver(uw.SystemdUnitDriverConfig{}) + if err != nil { + return nil, trace.Wrap(err) + } + if err := driver.ForceNop(process.ExitContext()); err != nil { + process.logger.WarnContext(process.ExitContext(), "Unable to disable the teleport-upgrade command provided by the deprecated teleport-ent-updater package.", "error", err) + process.logger.WarnContext(process.ExitContext(), "If the deprecated teleport-ent-updater package is installed, please ensure /etc/teleport-upgrade.d/schedule contains 'nop'.") + } + case types.UpgraderKindSystemdUnit: process.RegisterFunc("autoupdates.endpoint.export", func() error { conn, err := waitForInstanceConnector(process, process.logger) if err != nil { @@ -1314,28 +1323,14 @@ func NewTeleport(cfg *servicecfg.Config) (*TeleportProcess, error) { process.logger.InfoContext(process.ExitContext(), "Exported autoupdates endpoint.", "addr", resolverAddr.String()) return nil }) + if err := process.configureUpgraderExporter(upgraderKind); err != nil { + return nil, trace.Wrap(err) + } + default: + if err := process.configureUpgraderExporter(upgraderKind); err != nil { + return nil, trace.Wrap(err) + } } - - driver, err := uw.NewDriver(upgraderKind) - if err != nil { - return nil, trace.Wrap(err) - } - - exporter, err := uw.NewExporter(uw.ExporterConfig[inventory.DownstreamSender]{ - Driver: driver, - ExportFunc: process.exportUpgradeWindows, - AuthConnectivitySentinel: process.inventoryHandle.Sender(), - }) - if err != nil { - return nil, trace.Wrap(err) - } - - process.RegisterCriticalFunc("upgradeewindow.export", exporter.Run) - process.OnExit("upgradewindow.export.stop", func(_ interface{}) { - exporter.Close() - }) - - process.logger.InfoContext(process.ExitContext(), "Configured upgrade window exporter for external upgrader.", "kind", upgraderKind) } serviceStarted := false @@ -1549,6 +1544,63 @@ func NewTeleport(cfg *servicecfg.Config) (*TeleportProcess, error) { return process, nil } +// detectUpgrader returns metadata about auto-upgraders that may be active. +// Note that kind and externalName are usually the same. +// However, some unregistered upgraders like the AWS ODIC upgrader are not valid kinds. +// For these upgraders, kind is empty and externalName is set to a non-kind value. +func (process *TeleportProcess) detectUpgrader() (kind, externalName, version string) { + // Check if the deprecated teleport-upgrader script is being used. + kind = os.Getenv(automaticupgrades.EnvUpgrader) + version = automaticupgrades.GetUpgraderVersion(process.GracefulExitContext()) + if version == "" { + kind = "" + } + + // If the installation is managed by teleport-update, it supersedes the teleport-upgrader script. + ok, err := autoupdate.IsManagedByUpdater() + if err != nil { + process.logger.WarnContext(process.ExitContext(), "Failed to determine if auto-updates are enabled.", "error", err) + } else if ok { + // If this is a teleport-update managed installation, the version + // managed by the timer will always match the installed version of teleport. + kind = types.UpgraderKindTeleportUpdate + version = "v" + teleport.Version + } + + // Instances deployed using the AWS OIDC integration are automatically updated + // by the proxy. The instance heartbeat should properly reflect that. + externalName = kind + if externalName == "" && os.Getenv(types.InstallMethodAWSOIDCDeployServiceEnvVar) == "true" { + externalName = types.OriginIntegrationAWSOIDC + } + return kind, externalName, version +} + +// configureUpgraderExporter configures the window exporter for upgraders that export windows. +func (process *TeleportProcess) configureUpgraderExporter(kind string) error { + driver, err := uw.NewDriver(kind) + if err != nil { + return trace.Wrap(err) + } + + exporter, err := uw.NewExporter(uw.ExporterConfig[inventory.DownstreamSender]{ + Driver: driver, + ExportFunc: process.exportUpgradeWindows, + AuthConnectivitySentinel: process.inventoryHandle.Sender(), + }) + if err != nil { + return trace.Wrap(err) + } + + process.RegisterCriticalFunc("upgradeewindow.export", exporter.Run) + process.OnExit("upgradewindow.export.stop", func(_ interface{}) { + exporter.Close() + }) + + process.logger.InfoContext(process.ExitContext(), "Configured upgrade window exporter for external upgrader.", "kind", kind) + return nil +} + // enterpriseServicesEnabled will return true if any enterprise services are enabled. func (process *TeleportProcess) enterpriseServicesEnabled() bool { return modules.GetModules().BuildType() == modules.BuildEnterprise && diff --git a/lib/service/state.go b/lib/service/state.go index 3c057148683c6..fab3b29a23bc4 100644 --- a/lib/service/state.go +++ b/lib/service/state.go @@ -21,6 +21,7 @@ package service import ( "fmt" "net/http" + "os" "sync" "time" @@ -29,6 +30,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/lib/client/debug" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/observability/metrics" ) @@ -179,22 +181,26 @@ func (f *processState) readinessHandler() http.HandlerFunc { switch f.getState() { // 503 case stateDegraded: - roundtrip.ReplyJSON(w, http.StatusServiceUnavailable, map[string]any{ - "status": "teleport is in a degraded state, check logs for details", + roundtrip.ReplyJSON(w, http.StatusServiceUnavailable, debug.Readiness{ + Status: "teleport is in a degraded state, check logs for details", + PID: os.Getpid(), }) // 400 case stateRecovering: - roundtrip.ReplyJSON(w, http.StatusBadRequest, map[string]any{ - "status": "teleport is recovering from a degraded state, check logs for details", + roundtrip.ReplyJSON(w, http.StatusBadRequest, debug.Readiness{ + Status: "teleport is recovering from a degraded state, check logs for details", + PID: os.Getpid(), }) case stateStarting: - roundtrip.ReplyJSON(w, http.StatusBadRequest, map[string]any{ - "status": "teleport is starting and hasn't joined the cluster yet", + roundtrip.ReplyJSON(w, http.StatusBadRequest, debug.Readiness{ + Status: "teleport is starting and hasn't joined the cluster yet", + PID: os.Getpid(), }) // 200 case stateOK: - roundtrip.ReplyJSON(w, http.StatusOK, map[string]any{ - "status": "ok", + roundtrip.ReplyJSON(w, http.StatusOK, debug.Readiness{ + Status: "ok", + PID: os.Getpid(), }) } } diff --git a/lib/utils/unpack.go b/lib/utils/unpack.go index a32cff6ef76b9..14b213f08a173 100644 --- a/lib/utils/unpack.go +++ b/lib/utils/unpack.go @@ -23,6 +23,7 @@ import ( "errors" "io" "os" + "path" "path/filepath" "strings" @@ -36,7 +37,10 @@ import ( // resulting files and directories are created using the current user context. // Extract will only unarchive files into dir, and will fail if the tarball // tries to write files outside of dir. -func Extract(r io.Reader, dir string) error { +// +// If any paths are specified, only the specified paths are extracted. +// The destination specified in the first matching path is selected. +func Extract(r io.Reader, dir string, paths ...ExtractPath) error { tarball := tar.NewReader(r) for { @@ -46,32 +50,95 @@ func Extract(r io.Reader, dir string) error { } else if err != nil { return trace.Wrap(err) } - + dirMode, ok := filterHeader(header, paths) + if !ok { + continue + } err = sanitizeTarPath(header, dir) if err != nil { return trace.Wrap(err) } - if err := extractFile(tarball, header, dir); err != nil { + if err := extractFile(tarball, header, dir, dirMode); err != nil { return trace.Wrap(err) } } return nil } +// ExtractPath specifies a path to be extracted. +type ExtractPath struct { + // Src path and Dst path within the archive to extract files to. + // Directories in the Src path are not included in the extraction dir. + // For example, given foo/bar/file.txt with Src=foo/bar Dst=baz, baz/file.txt results. + // Trailing slashes are always ignored. + Src, Dst string + // Skip extracting the Src path and ignore Dst. + Skip bool + // DirMode is the file mode for implicit parent directories in Dst. + DirMode os.FileMode +} + +// filterHeader modifies the tar header by filtering it through the ExtractPaths. +// filterHeader returns false if the tar header should be skipped. +// If no paths are provided, filterHeader assumes the header should be included, and sets +// the mode for implicit parent directories to teleport.DirMaskSharedGroup. +func filterHeader(hdr *tar.Header, paths []ExtractPath) (dirMode os.FileMode, include bool) { + name := path.Clean(hdr.Name) + for _, p := range paths { + src := path.Clean(p.Src) + switch hdr.Typeflag { + case tar.TypeDir: + // If name is a directory, then + // assume src is a directory prefix, or the directory itself, + // and replace that prefix with dst. + if src != "/" { + src += "/" // ensure HasPrefix does not match partial names + } + if !strings.HasPrefix(name, src) { + continue + } + dst := path.Join(p.Dst, strings.TrimPrefix(name, src)) + if dst != "/" { + dst += "/" // tar directory headers end in / + } + hdr.Name = dst + return p.DirMode, !p.Skip + default: + // If name is a file, then + // if src is an exact match to the file name, assume src is a file and write directly to dst, + // otherwise, assume src is a directory prefix, and replace that prefix with dst. + if src == name { + hdr.Name = path.Clean(p.Dst) + return p.DirMode, !p.Skip + } + if src != "/" { + src += "/" // ensure HasPrefix does not match partial names + } + if !strings.HasPrefix(name, src) { + continue + } + hdr.Name = path.Join(p.Dst, strings.TrimPrefix(name, src)) + return p.DirMode, !p.Skip + + } + } + return teleport.DirMaskSharedGroup, len(paths) == 0 +} + // extractFile extracts a single file or directory from tarball into dir. // Uses header to determine the type of item to create // Based on https://github.com/mholt/archiver -func extractFile(tarball *tar.Reader, header *tar.Header, dir string) error { +func extractFile(tarball *tar.Reader, header *tar.Header, dir string, dirMode os.FileMode) error { switch header.Typeflag { case tar.TypeDir: - return withDir(filepath.Join(dir, header.Name), nil) + return withDir(filepath.Join(dir, header.Name), dirMode, nil) case tar.TypeBlock, tar.TypeChar, tar.TypeReg, tar.TypeFifo: - return writeFile(filepath.Join(dir, header.Name), tarball, header.FileInfo().Mode()) + return writeFile(filepath.Join(dir, header.Name), tarball, header.FileInfo().Mode(), dirMode) case tar.TypeLink: - return writeHardLink(filepath.Join(dir, header.Name), filepath.Join(dir, header.Linkname)) + return writeHardLink(filepath.Join(dir, header.Name), filepath.Join(dir, header.Linkname), dirMode) case tar.TypeSymlink: - return writeSymbolicLink(filepath.Join(dir, header.Name), header.Linkname) + return writeSymbolicLink(filepath.Join(dir, header.Name), header.Linkname, dirMode) default: log.Warnf("Unsupported type flag %v for %v.", header.Typeflag, header.Name) } @@ -106,8 +173,8 @@ func sanitizeTarPath(header *tar.Header, dir string) error { return nil } -func writeFile(path string, r io.Reader, mode os.FileMode) error { - err := withDir(path, func() error { +func writeFile(path string, r io.Reader, mode, dirMode os.FileMode) error { + err := withDir(path, dirMode, func() error { // Create file only if it does not exist to prevent overwriting existing // files (like session recordings). out, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, mode) @@ -120,24 +187,24 @@ func writeFile(path string, r io.Reader, mode os.FileMode) error { return trace.Wrap(err) } -func writeSymbolicLink(path string, target string) error { - err := withDir(path, func() error { +func writeSymbolicLink(path, target string, dirMode os.FileMode) error { + err := withDir(path, dirMode, func() error { err := os.Symlink(target, path) return trace.ConvertSystemError(err) }) return trace.Wrap(err) } -func writeHardLink(path string, target string) error { - err := withDir(path, func() error { +func writeHardLink(path, target string, dirMode os.FileMode) error { + err := withDir(path, dirMode, func() error { err := os.Link(target, path) return trace.ConvertSystemError(err) }) return trace.Wrap(err) } -func withDir(path string, fn func() error) error { - err := os.MkdirAll(filepath.Dir(path), teleport.DirMaskSharedGroup) +func withDir(path string, mode os.FileMode, fn func() error) error { + err := os.MkdirAll(filepath.Dir(path), mode) if err != nil { return trace.ConvertSystemError(err) } diff --git a/lib/versioncontrol/upgradewindow/upgradewindow.go b/lib/versioncontrol/upgradewindow/upgradewindow.go index 7f90a9d70d41c..2a82b5f14fc14 100644 --- a/lib/versioncontrol/upgradewindow/upgradewindow.go +++ b/lib/versioncontrol/upgradewindow/upgradewindow.go @@ -44,6 +44,9 @@ const ( // unitScheduleFile is the name of the file to which the unit schedule is exported. unitScheduleFile = "schedule" + + // scheduleNop is the name of the no-op schedule. + scheduleNop = "nop" ) // ExportFunc represents the ExportUpgradeWindows rpc exposed by auth servers. @@ -313,6 +316,12 @@ type Driver interface { // called if teleport experiences prolonged loss of auth connectivity, which may be an indicator // that the control plane has been upgraded s.t. this agent is no longer compatible. Reset(ctx context.Context) error + + // ForceNop sets the NOP schedule, ensuring that updates do not happen. + // This schedule was originally only used for testing, but now it is also used by the + // teleport-update binary to protect against package updates that could interfere with + // the new update system. + ForceNop(ctx context.Context) error } // NewDriver sets up a new export driver corresponding to the specified upgrader kind. @@ -361,7 +370,15 @@ func (e *kubeDriver) Kind() string { } func (e *kubeDriver) Sync(ctx context.Context, rsp proto.ExportUpgradeWindowsResponse) error { - if rsp.KubeControllerSchedule == "" { + return trace.Wrap(e.setSchedule(ctx, rsp.KubeControllerSchedule)) +} + +func (e *kubeDriver) ForceNop(ctx context.Context) error { + return trace.Wrap(e.setSchedule(ctx, scheduleNop)) +} + +func (e *kubeDriver) setSchedule(ctx context.Context, schedule string) error { + if schedule == "" { return e.Reset(ctx) } @@ -369,7 +386,7 @@ func (e *kubeDriver) Sync(ctx context.Context, rsp proto.ExportUpgradeWindowsRes // backend.KeyFromString is intentionally used here instead of backend.NewKey // because existing backend items were persisted without the leading /. Key: backend.KeyFromString(kubeSchedKey), - Value: []byte(rsp.KubeControllerSchedule), + Value: []byte(schedule), }) return trace.Wrap(err) @@ -411,7 +428,15 @@ func (e *systemdDriver) Kind() string { } func (e *systemdDriver) Sync(ctx context.Context, rsp proto.ExportUpgradeWindowsResponse) error { - if len(rsp.SystemdUnitSchedule) == 0 { + return trace.Wrap(e.setSchedule(ctx, rsp.SystemdUnitSchedule)) +} + +func (e *systemdDriver) ForceNop(ctx context.Context) error { + return trace.Wrap(e.setSchedule(ctx, scheduleNop)) +} + +func (e *systemdDriver) setSchedule(ctx context.Context, schedule string) error { + if len(schedule) == 0 { // treat an empty schedule value as equivalent to a reset return e.Reset(ctx) } @@ -423,7 +448,7 @@ func (e *systemdDriver) Sync(ctx context.Context, rsp proto.ExportUpgradeWindows } // export schedule file. if created it is set to 644, which is reasonable for a sensitive but non-secret config value. - if err := os.WriteFile(e.scheduleFile(), []byte(rsp.SystemdUnitSchedule), defaults.FilePermissions); err != nil { + if err := os.WriteFile(e.scheduleFile(), []byte(schedule), defaults.FilePermissions); err != nil { return trace.Errorf("failed to write schedule file: %v", err) } diff --git a/lib/versioncontrol/upgradewindow/upgradewindow_test.go b/lib/versioncontrol/upgradewindow/upgradewindow_test.go index 7b724708652f4..c5c2236673ea7 100644 --- a/lib/versioncontrol/upgradewindow/upgradewindow_test.go +++ b/lib/versioncontrol/upgradewindow/upgradewindow_test.go @@ -27,6 +27,7 @@ import ( "testing" "time" + "github.com/gravitational/trace" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/client/proto" @@ -182,6 +183,37 @@ func TestSystemdUnitDriver(t *testing.T) { require.Equal(t, "", string(sb)) } +// TestSystemdUnitDriverNop verifies the nop schedule behavior of the systemd unit export driver. +func TestSystemdUnitDriverNop(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // use a sub-directory of a temp dir in order to verify that + // driver creates dir when needed. + dir := filepath.Join(t.TempDir(), "config") + + driver, err := NewSystemdUnitDriver(SystemdUnitDriverConfig{ + ConfigDir: dir, + }) + require.NoError(t, err) + + err = driver.Sync(ctx, proto.ExportUpgradeWindowsResponse{ + SystemdUnitSchedule: "fake-schedule", + }) + require.NoError(t, err) + + err = driver.ForceNop(ctx) + require.NoError(t, err) + + schedPath := filepath.Join(dir, "schedule") + sb, err := os.ReadFile(schedPath) + require.NoError(t, err) + + require.Equal(t, scheduleNop, string(sb)) +} + // fakeDriver is used to inject custom behavior into a dummy Driver instance. type fakeDriver struct { mu sync.Mutex @@ -209,6 +241,10 @@ func (d *fakeDriver) Sync(ctx context.Context, rsp proto.ExportUpgradeWindowsRes return nil } +func (d *fakeDriver) ForceNop(ctx context.Context) error { + return trace.NotImplemented("force-nop not used by exporter") +} + func (d *fakeDriver) Reset(ctx context.Context) error { d.mu.Lock() defer d.mu.Unlock() diff --git a/tool/teleport-update/main.go b/tool/teleport-update/main.go index 11aee2aae3906..801b62a3d8171 100644 --- a/tool/teleport-update/main.go +++ b/tool/teleport-update/main.go @@ -20,17 +20,19 @@ package main import ( "context" + "errors" + "fmt" "log/slog" "os" "os/signal" - "path/filepath" "syscall" "github.com/gravitational/trace" + "gopkg.in/yaml.v3" "github.com/gravitational/teleport" + common "github.com/gravitational/teleport/lib/autoupdate" autoupdate "github.com/gravitational/teleport/lib/autoupdate/agent" - libdefaults "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/modules" libutils "github.com/gravitational/teleport/lib/utils" logutils "github.com/gravitational/teleport/lib/utils/log" @@ -38,16 +40,13 @@ import ( const appHelp = `Teleport Updater -The Teleport Updater updates the version a Teleport agent on a Linux server -that is being used as agent to provide connectivity to Teleport resources. +The Teleport Updater applies Managed Updates to a Teleport agent installation. -The Teleport Updater supports upgrade schedules and automated rollbacks. +The Teleport Updater supports update scheduling and automated rollbacks. -Find out more at https://goteleport.com/docs/updater` +Find out more at https://goteleport.com/docs/upgrading/agent-managed-updates` const ( - // templateEnvVar allows the template for the Teleport tgz to be specified via env var. - templateEnvVar = "TELEPORT_URL_TEMPLATE" // proxyServerEnvVar allows the proxy server address to be specified via env var. proxyServerEnvVar = "TELEPORT_PROXY" // updateGroupEnvVar allows the update group to be specified via env var. @@ -56,107 +55,166 @@ const ( updateVersionEnvVar = "TELEPORT_UPDATE_VERSION" ) -const ( - // versionsDirName specifies the name of the subdirectory inside of the Teleport data dir for storing Teleport versions. - versionsDirName = "versions" - // lockFileName specifies the name of the file inside versionsDirName containing the flock lock preventing concurrent updater execution. - lockFileName = ".lock" -) - var plog = logutils.NewPackageLogger(teleport.ComponentKey, teleport.ComponentUpdater) func main() { - if err := Run(os.Args[1:]); err != nil { - libutils.FatalError(err) + if code := Run(os.Args[1:]); code != 0 { + os.Exit(code) } } type cliConfig struct { autoupdate.OverrideConfig - // Debug logs enabled Debug bool // LogFormat controls the format of logging. Can be either `json` or `text`. // By default, this is `text`. LogFormat string - // DataDir for Teleport (usually /var/lib/teleport) - DataDir string + // InstallDir for Teleport (usually /opt/teleport) + InstallDir string + // InstallSuffix is the isolated suffix for the installation. + InstallSuffix string + // SelfSetup mode for using the current version of the teleport-update to setup the update service. + SelfSetup bool + // UpdateNow forces an immediate update. + UpdateNow bool + // Reload reloads Teleport. + Reload bool + // ForceUninstall allows Teleport to be completely removed. + ForceUninstall bool + // Insecure skips TLS certificate verification. + Insecure bool } -func (c *cliConfig) CheckAndSetDefaults() error { - if c.DataDir == "" { - c.DataDir = libdefaults.DataDir - } - if c.LogFormat == "" { - c.LogFormat = libutils.LogFormatText - } - return nil -} - -func Run(args []string) error { +func Run(args []string) int { var ccfg cliConfig + ctx := context.Background() ctx, _ = signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM) - app := libutils.InitCLIParser("teleport-updater", appHelp).Interspersed(false) + app := libutils.InitCLIParser(autoupdate.BinaryName, appHelp).Interspersed(false) app.Flag("debug", "Verbose logging to stdout."). Short('d').BoolVar(&ccfg.Debug) - app.Flag("data-dir", "Teleport data directory. Access to this directory should be limited."). - Default(libdefaults.DataDir).StringVar(&ccfg.DataDir) app.Flag("log-format", "Controls the format of output logs. Can be `json` or `text`. Defaults to `text`."). Default(libutils.LogFormatText).EnumVar(&ccfg.LogFormat, libutils.LogFormatJSON, libutils.LogFormatText) + app.Flag("install-suffix", "Suffix for creating an agent installation outside of the default $PATH. Note: this changes the default data directory."). + Short('i').StringVar(&ccfg.InstallSuffix) + app.Flag("install-dir", "Directory containing Teleport installations."). + Hidden().StringVar(&ccfg.InstallDir) + app.Flag("insecure", "Insecure mode disables certificate verification. Do not use in production."). + BoolVar(&ccfg.Insecure) app.HelpFlag.Short('h') - versionCmd := app.Command("version", "Print the version of your teleport-updater binary.") + versionCmd := app.Command("version", fmt.Sprintf("Print the version of your %s binary.", autoupdate.BinaryName)) - enableCmd := app.Command("enable", "Enable agent auto-updates and perform initial update.") + enableCmd := app.Command("enable", "Enable agent managed updates and perform initial installation or update. This creates a systemd timer that periodically runs the update subcommand.") enableCmd.Flag("proxy", "Address of the Teleport Proxy."). Short('p').Envar(proxyServerEnvVar).StringVar(&ccfg.Proxy) enableCmd.Flag("group", "Update group for this agent installation."). Short('g').Envar(updateGroupEnvVar).StringVar(&ccfg.Group) - enableCmd.Flag("template", "Go template used to override Teleport download URL."). - Short('t').Envar(templateEnvVar).StringVar(&ccfg.URLTemplate) - enableCmd.Flag("force-version", "Force the provided version instead of querying it from the Teleport cluster."). - Short('f').Envar(updateVersionEnvVar).Hidden().StringVar(&ccfg.ForceVersion) - - disableCmd := app.Command("disable", "Disable agent auto-updates.") - - updateCmd := app.Command("update", "Update agent to the latest version, if a new version is available.") - updateCmd.Flag("force-version", "Use the provided version instead of querying it from the Teleport cluster."). - Short('f').Envar(updateVersionEnvVar).Hidden().StringVar(&ccfg.ForceVersion) + enableCmd.Flag("base-url", "Base URL used to override the Teleport download URL."). + Short('b').Envar(common.BaseURLEnvVar).StringVar(&ccfg.BaseURL) + enableCmd.Flag("overwrite", "Allow existing installed Teleport binaries to be overwritten."). + Short('o').BoolVar(&ccfg.AllowOverwrite) + enableCmd.Flag("force-version", "Force the provided version instead of using the version provided by the Teleport cluster."). + Hidden().Short('f').Envar(updateVersionEnvVar).StringVar(&ccfg.ForceVersion) + enableCmd.Flag("self-setup", "Use the current teleport-update binary to create systemd service config for managed updates."). + Hidden().BoolVar(&ccfg.SelfSetup) + enableCmd.Flag("path", "Directory to link the active Teleport installation's binaries into."). + Hidden().StringVar(&ccfg.Path) + // TODO(sclevine): add force-fips and force-enterprise as hidden flags + + disableCmd := app.Command("disable", "Disable agent managed updates. Does not affect the active installation of Teleport.") + + pinCmd := app.Command("pin", "Install Teleport and lock the updater to the installed version.") + pinCmd.Flag("proxy", "Address of the Teleport Proxy."). + Short('p').Envar(proxyServerEnvVar).StringVar(&ccfg.Proxy) + pinCmd.Flag("group", "Update group for this agent installation."). + Short('g').Envar(updateGroupEnvVar).StringVar(&ccfg.Group) + pinCmd.Flag("base-url", "Base URL used to override the Teleport download URL."). + Short('b').Envar(common.BaseURLEnvVar).StringVar(&ccfg.BaseURL) + pinCmd.Flag("overwrite", "Allow existing installed Teleport binaries to be overwritten."). + Short('o').BoolVar(&ccfg.AllowOverwrite) + pinCmd.Flag("force-version", "Force the provided version instead of using the version provided by the Teleport cluster."). + Short('f').Envar(updateVersionEnvVar).StringVar(&ccfg.ForceVersion) + pinCmd.Flag("self-setup", "Use the current teleport-update binary to create systemd service config for managed updates."). + Hidden().BoolVar(&ccfg.SelfSetup) + pinCmd.Flag("path", "Directory to link the active Teleport installation's binaries into."). + Hidden().StringVar(&ccfg.Path) + + unpinCmd := app.Command("unpin", "Unpin the current version, allowing it to be updated.") + + updateCmd := app.Command("update", "Update the agent to the latest version, if a new version is available.") + updateCmd.Flag("now", "Force immediate update even if update window is not active."). + Short('n').BoolVar(&ccfg.UpdateNow) + updateCmd.Flag("self-setup", "Use the current teleport-update binary to create systemd service config for managed updates and verify the Teleport installation."). + Hidden().BoolVar(&ccfg.SelfSetup) + + linkCmd := app.Command("link-package", "Link the system installation of Teleport from the Teleport package, if managed updates is disabled.") + unlinkCmd := app.Command("unlink-package", "Unlink the system installation of Teleport from the Teleport package.") + + setupCmd := app.Command("setup", "Write configuration files that run the update subcommand on a timer and verify the Teleport installation."). + Hidden() + setupCmd.Flag("reload", "Reload the Teleport agent. If not set, Teleport is not reloaded or restarted."). + BoolVar(&ccfg.Reload) + setupCmd.Flag("path", "Directory that the active Teleport installation's binaries are linked into."). + Required().StringVar(&ccfg.Path) + + statusCmd := app.Command("status", "Show Teleport agent auto-update status.") + + uninstallCmd := app.Command("uninstall", "Uninstall the updater-managed installation of Teleport. If the Teleport package is installed, it is restored as the primary installation.") + uninstallCmd.Flag("force", "Force complete uninstallation of Teleport, even if there is no packaged version of Teleport to revert to."). + Short('f').BoolVar(&ccfg.ForceUninstall) libutils.UpdateAppUsageTemplate(app, args) command, err := app.Parse(args) if err != nil { app.Usage(args) - return trace.Wrap(err) + libutils.FatalError(err) } + // Logging must be configured as early as possible to ensure all log // message are formatted correctly. if err := setupLogger(ccfg.Debug, ccfg.LogFormat); err != nil { - return trace.Errorf("failed to set up logger") - } - - if err := ccfg.CheckAndSetDefaults(); err != nil { - return trace.Wrap(err) + plog.ErrorContext(ctx, "Failed to set up logger.", "error", err) + return 1 } switch command { case enableCmd.FullCommand(): - err = cmdEnable(ctx, &ccfg) + ccfg.Enabled = true + err = cmdInstall(ctx, &ccfg) + case pinCmd.FullCommand(): + ccfg.Pinned = true + err = cmdInstall(ctx, &ccfg) case disableCmd.FullCommand(): err = cmdDisable(ctx, &ccfg) + case unpinCmd.FullCommand(): + err = cmdUnpin(ctx, &ccfg) case updateCmd.FullCommand(): err = cmdUpdate(ctx, &ccfg) + case linkCmd.FullCommand(): + err = cmdLinkPackage(ctx, &ccfg) + case unlinkCmd.FullCommand(): + err = cmdUnlinkPackage(ctx, &ccfg) + case setupCmd.FullCommand(): + err = cmdSetup(ctx, &ccfg) + case statusCmd.FullCommand(): + err = cmdStatus(ctx, &ccfg) + case uninstallCmd.FullCommand(): + err = cmdUninstall(ctx, &ccfg) case versionCmd.FullCommand(): modules.GetModules().PrintVersion() default: // This should only happen when there's a missing switch case above. - err = trace.Errorf("command %q not configured", command) + err = trace.Errorf("command %s not configured", command) } - - return err + if err != nil { + plog.ErrorContext(ctx, "Command failed.", "error", err) + return 1 + } + return 0 } func setupLogger(debug bool, format string) error { @@ -176,46 +234,178 @@ func setupLogger(debug bool, format string) error { return nil } +func initConfig(ctx context.Context, ccfg *cliConfig) (updater *autoupdate.Updater, lockFile string, err error) { + ns, err := autoupdate.NewNamespace(ctx, plog, ccfg.InstallSuffix, ccfg.InstallDir) + if err != nil { + return nil, "", trace.Wrap(err) + } + lockFile, err = ns.Init() + if err != nil { + return nil, "", trace.Wrap(err) + } + updater, err = autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ + SelfSetup: ccfg.SelfSetup, + Log: plog, + LogFormat: ccfg.LogFormat, + Debug: ccfg.Debug, + InsecureSkipVerify: ccfg.Insecure, + }, ns) + return updater, lockFile, trace.Wrap(err) +} + +func statusConfig(ctx context.Context, ccfg *cliConfig) (*autoupdate.Updater, error) { + ns, err := autoupdate.NewNamespace(ctx, plog, ccfg.InstallSuffix, ccfg.InstallDir) + if err != nil { + return nil, trace.Wrap(err) + } + updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ + SelfSetup: ccfg.SelfSetup, + Log: plog, + LogFormat: ccfg.LogFormat, + Debug: ccfg.Debug, + InsecureSkipVerify: ccfg.Insecure, + }, ns) + return updater, trace.Wrap(err) +} + // cmdDisable disables updates. func cmdDisable(ctx context.Context, ccfg *cliConfig) error { - versionsDir := filepath.Join(ccfg.DataDir, versionsDirName) - if err := os.MkdirAll(versionsDir, 0755); err != nil { - return trace.Errorf("failed to create versions directory: %w", err) + updater, lockFile, err := initConfig(ctx, ccfg) + if err != nil { + return trace.Wrap(err, "failed to initialize updater") } - - unlock, err := libutils.FSWriteLock(filepath.Join(versionsDir, lockFileName)) + unlock, err := libutils.FSWriteLock(lockFile) if err != nil { - return trace.Errorf("failed to grab concurrent execution lock: %w", err) + return trace.Wrap(err, "failed to grab concurrent execution lock %s", lockFile) } defer func() { if err := unlock(); err != nil { plog.DebugContext(ctx, "Failed to close lock file", "error", err) } }() - updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ - VersionsDir: versionsDir, - Log: plog, - }) + if err := updater.Disable(ctx); err != nil { + return trace.Wrap(err) + } + return nil +} + +// cmdUnpin unpins the current version. +func cmdUnpin(ctx context.Context, ccfg *cliConfig) error { + updater, lockFile, err := initConfig(ctx, ccfg) if err != nil { - return trace.Errorf("failed to setup updater: %w", err) + return trace.Wrap(err, "failed to setup updater") } - if err := updater.Disable(ctx); err != nil { + unlock, err := libutils.FSWriteLock(lockFile) + if err != nil { + return trace.Wrap(err, "failed to grab concurrent execution lock %n", lockFile) + } + defer func() { + if err := unlock(); err != nil { + plog.DebugContext(ctx, "Failed to close lock file", "error", err) + } + }() + if err := updater.Unpin(ctx); err != nil { return trace.Wrap(err) } return nil } -// cmdEnable enables updates and triggers an initial update. -func cmdEnable(ctx context.Context, ccfg *cliConfig) error { - versionsDir := filepath.Join(ccfg.DataDir, versionsDirName) - if err := os.MkdirAll(versionsDir, 0755); err != nil { - return trace.Errorf("failed to create versions directory: %w", err) +// cmdInstall installs Teleport and sets configuration. +func cmdInstall(ctx context.Context, ccfg *cliConfig) error { + if ccfg.InstallSuffix != "" { + ns, err := autoupdate.NewNamespace(ctx, plog, ccfg.InstallSuffix, ccfg.InstallDir) + if err != nil { + return trace.Wrap(err) + } + ns.LogWarning(ctx) + } + updater, lockFile, err := initConfig(ctx, ccfg) + if err != nil { + return trace.Wrap(err, "failed to initialize updater") } // Ensure enable can't run concurrently. - unlock, err := libutils.FSWriteLock(filepath.Join(versionsDir, lockFileName)) + unlock, err := libutils.FSWriteLock(lockFile) + if err != nil { + return trace.Wrap(err, "failed to grab concurrent execution lock %s", lockFile) + } + defer func() { + if err := unlock(); err != nil { + plog.DebugContext(ctx, "Failed to close lock file", "error", err) + } + }() + if err := updater.Install(ctx, ccfg.OverrideConfig); err != nil { + return trace.Wrap(err) + } + return nil +} + +// cmdUpdate updates Teleport to the version specified by cluster reachable at the proxy address. +func cmdUpdate(ctx context.Context, ccfg *cliConfig) error { + updater, lockFile, err := initConfig(ctx, ccfg) + if err != nil { + return trace.Wrap(err, "failed to initialize updater") + } + // Ensure update can't run concurrently. + unlock, err := libutils.FSWriteLock(lockFile) + if err != nil { + return trace.Wrap(err, "failed to grab concurrent execution lock %s", lockFile) + } + defer func() { + if err := unlock(); err != nil { + plog.DebugContext(ctx, "Failed to close lock file", "error", err) + } + }() + + if err := updater.Update(ctx, ccfg.UpdateNow); err != nil { + return trace.Wrap(err) + } + return nil +} + +// cmdLinkPackage creates system package links if no version is linked and managed updates is disabled. +func cmdLinkPackage(ctx context.Context, ccfg *cliConfig) error { + updater, lockFile, err := initConfig(ctx, ccfg) + if err != nil { + return trace.Wrap(err, "failed to initialize updater") + } + + // Skip operation and warn if the updater is currently running. + unlock, err := libutils.FSTryReadLock(lockFile) + if errors.Is(err, libutils.ErrUnsuccessfulLockTry) { + plog.WarnContext(ctx, "Updater is currently running. Skipping package linking.") + return nil + } + if err != nil { + return trace.Wrap(err, "failed to grab concurrent execution lock %q", lockFile) + } + defer func() { + if err := unlock(); err != nil { + plog.DebugContext(ctx, "Failed to close lock file", "error", err) + } + }() + + if err := updater.LinkPackage(ctx); err != nil { + return trace.Wrap(err) + } + return nil +} + +// cmdUnlinkPackage remove system package links. +func cmdUnlinkPackage(ctx context.Context, ccfg *cliConfig) error { + updater, lockFile, err := initConfig(ctx, ccfg) + if err != nil { + return trace.Wrap(err, "failed to setup updater") + } + + // Error if the updater is running. We could remove its links by accident. + unlock, err := libutils.FSTryWriteLock(lockFile) + if errors.Is(err, libutils.ErrUnsuccessfulLockTry) { + plog.WarnContext(ctx, "Updater is currently running. Skipping package unlinking.") + return nil + } if err != nil { - return trace.Errorf("failed to grab concurrent execution lock: %w", err) + return trace.Wrap(err, "failed to grab concurrent execution lock %q", lockFile) } defer func() { if err := unlock(); err != nil { @@ -223,20 +413,68 @@ func cmdEnable(ctx context.Context, ccfg *cliConfig) error { } }() + if err := updater.UnlinkPackage(ctx); err != nil { + return trace.Wrap(err) + } + return nil +} + +// cmdSetup writes configuration files that are needed to run teleport-update update. +func cmdSetup(ctx context.Context, ccfg *cliConfig) error { + ns, err := autoupdate.NewNamespace(ctx, plog, ccfg.InstallSuffix, ccfg.InstallDir) + if err != nil { + return trace.Wrap(err) + } updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{ - VersionsDir: versionsDir, - Log: plog, - }) + SelfSetup: ccfg.SelfSetup, + Log: plog, + LogFormat: ccfg.LogFormat, + Debug: ccfg.Debug, + InsecureSkipVerify: ccfg.Insecure, + }, ns) if err != nil { - return trace.Errorf("failed to setup updater: %w", err) + return trace.Wrap(err) } - if err := updater.Enable(ctx, ccfg.OverrideConfig); err != nil { + err = updater.Setup(ctx, ccfg.Path, ccfg.Reload) + if err != nil { return trace.Wrap(err) } return nil } -// cmdUpdate updates Teleport to the version specified by cluster reachable at the proxy address. -func cmdUpdate(ctx context.Context, ccfg *cliConfig) error { - return trace.NotImplemented("TODO") +// cmdStatus displays auto-update status. +func cmdStatus(ctx context.Context, ccfg *cliConfig) error { + updater, err := statusConfig(ctx, ccfg) + if err != nil { + return trace.Wrap(err, "failed to initialize updater") + } + status, err := updater.Status(ctx) + if err != nil { + return trace.Wrap(err, "failed to get status") + } + enc := yaml.NewEncoder(os.Stdout) + return trace.Wrap(enc.Encode(status)) +} + +// cmdUninstall removes the updater-managed install of Teleport and gracefully reverts back to the Teleport package. +func cmdUninstall(ctx context.Context, ccfg *cliConfig) error { + updater, lockFile, err := initConfig(ctx, ccfg) + if err != nil { + return trace.Wrap(err, "failed to initialize updater") + } + // Ensure update can't run concurrently. + unlock, err := libutils.FSWriteLock(lockFile) + if err != nil { + return trace.Wrap(err, "failed to grab concurrent execution lock %s", lockFile) + } + defer func() { + if err := unlock(); err != nil { + plog.DebugContext(ctx, "Failed to close lock file", "error", err) + } + }() + + if err := updater.Remove(ctx, ccfg.ForceUninstall); err != nil { + return trace.Wrap(err) + } + return nil }