diff --git a/api/types/maintenance.go b/api/types/maintenance.go
index 31a48472e6aa8..5d98a3ec311cb 100644
--- a/api/types/maintenance.go
+++ b/api/types/maintenance.go
@@ -26,13 +26,17 @@ import (
)
const (
- // UpgraderKindKuberController is a short name used to identify the kube-controller-based
+ // UpgraderKindKubeController is a short name used to identify the kube-controller-based
// external upgrader variant.
UpgraderKindKubeController = "kube"
// UpgraderKindSystemdUnit is a short name used to identify the systemd-unit-based
// external upgrader variant.
UpgraderKindSystemdUnit = "unit"
+
+ // UpgraderKindTeleportUpdate is a short name used to identify the teleport-update
+ // external upgrader variant.
+ UpgraderKindTeleportUpdate = "binary"
)
var validWeekdays = [7]time.Weekday{
diff --git a/constants.go b/constants.go
index 30adf42be55ca..42955b6f5f138 100644
--- a/constants.go
+++ b/constants.go
@@ -286,7 +286,7 @@ const (
// ComponentProxySecureGRPC represents a secure gRPC server running on Proxy (used for Kube).
ComponentProxySecureGRPC = "proxy:secure-grpc"
- // ComponentUpdater represents the agent updater.
+ // ComponentUpdater represents the teleport-update binary.
ComponentUpdater = "updater"
// ComponentGit represents git proxy related services.
diff --git a/lib/auth/auth.go b/lib/auth/auth.go
index 0f6c85be8a5e7..96552aeae7778 100644
--- a/lib/auth/auth.go
+++ b/lib/auth/auth.go
@@ -6858,7 +6858,7 @@ func (a *Server) ExportUpgradeWindows(ctx context.Context, req proto.ExportUpgra
}
switch req.UpgraderKind {
- case "":
+ case "", types.UpgraderKindTeleportUpdate:
rsp.CanonicalSchedule = cached.CanonicalSchedule.Clone()
case types.UpgraderKindKubeController:
rsp.KubeControllerSchedule = cached.KubeControllerSchedule
diff --git a/lib/autoupdate/agent/config.go b/lib/autoupdate/agent/config.go
new file mode 100644
index 0000000000000..fe10d84808d11
--- /dev/null
+++ b/lib/autoupdate/agent/config.go
@@ -0,0 +1,251 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package agent
+
+import (
+ "errors"
+ "fmt"
+ "io/fs"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+
+ "github.com/google/renameio/v2"
+ "github.com/gravitational/trace"
+ "gopkg.in/yaml.v3"
+
+ "github.com/gravitational/teleport/lib/autoupdate"
+)
+
+const (
+ // updateConfigName specifies the name of the file inside versionsDirName containing configuration for the teleport update.
+ updateConfigName = "update.yaml"
+
+ // UpdateConfig metadata
+ updateConfigVersion = "v1"
+ updateConfigKind = "update_config"
+)
+
+// UpdateConfig describes the update.yaml file schema.
+type UpdateConfig struct {
+ // Version of the configuration file
+ Version string `yaml:"version"`
+ // Kind of configuration file (always "update_config")
+ Kind string `yaml:"kind"`
+ // Spec contains user-specified configuration.
+ Spec UpdateSpec `yaml:"spec"`
+ // Status contains state configuration.
+ Status UpdateStatus `yaml:"status"`
+}
+
+// UpdateSpec describes the spec field in update.yaml.
+type UpdateSpec struct {
+ // Proxy address
+ Proxy string `yaml:"proxy"`
+ // Path is the location the Teleport binaries are linked into.
+ Path string `yaml:"path"`
+ // Group specifies the update group identifier for the agent.
+ Group string `yaml:"group,omitempty"`
+ // BaseURL is CDN base URL used for the Teleport tgz download URL.
+ BaseURL string `yaml:"base_url,omitempty"`
+ // Enabled controls whether auto-updates are enabled.
+ Enabled bool `yaml:"enabled"`
+ // Pinned controls whether the active_version is pinned.
+ Pinned bool `yaml:"pinned"`
+}
+
+// UpdateStatus describes the status field in update.yaml.
+type UpdateStatus struct {
+ // Active is the currently active revision of Teleport.
+ Active Revision `yaml:"active"`
+ // Backup is the last working revision of Teleport.
+ Backup *Revision `yaml:"backup,omitempty"`
+ // Skip is the skipped revision of Teleport.
+ // Skipped revisions are not applied because they
+ // are known to crash.
+ Skip *Revision `yaml:"skip,omitempty"`
+}
+
+// Revision is a version and edition of Teleport.
+type Revision struct {
+ // Version is the version of Teleport.
+ Version string `yaml:"version" json:"version"`
+ // Flags describe the edition of Teleport.
+ Flags autoupdate.InstallFlags `yaml:"flags,flow,omitempty" json:"flags,omitempty"`
+}
+
+// NewRevision create a Revision.
+// If version is not set, no flags are returned.
+// This ensures that all Revisions without versions are zero-valued.
+func NewRevision(version string, flags autoupdate.InstallFlags) Revision {
+ if version != "" {
+ return Revision{
+ Version: version,
+ Flags: flags,
+ }
+ }
+ return Revision{}
+}
+
+// NewRevisionFromDir translates a directory path containing Teleport into a Revision.
+func NewRevisionFromDir(dir string) (Revision, error) {
+ parts := strings.Split(dir, "_")
+ var out Revision
+ if len(parts) == 0 {
+ return out, trace.Errorf("dir name empty")
+ }
+ out.Version = parts[0]
+ if out.Version == "" {
+ return out, trace.Errorf("version missing in dir %s", dir)
+ }
+ switch flags := parts[1:]; len(flags) {
+ case 2:
+ if flags[1] != autoupdate.FlagFIPS.DirFlag() {
+ break
+ }
+ out.Flags |= autoupdate.FlagFIPS
+ fallthrough
+ case 1:
+ if flags[0] != autoupdate.FlagEnterprise.DirFlag() {
+ break
+ }
+ out.Flags |= autoupdate.FlagEnterprise
+ fallthrough
+ case 0:
+ return out, nil
+ }
+ return out, trace.Errorf("invalid flag in %s", dir)
+}
+
+// Dir returns the directory path name of a Revision.
+func (r Revision) Dir() string {
+ // Do not change the order of these statements.
+ // Otherwise, installed versions will no longer match update.yaml.
+ var suffix string
+ if r.Flags&(autoupdate.FlagEnterprise|autoupdate.FlagFIPS) != 0 {
+ suffix += "_" + autoupdate.FlagEnterprise.DirFlag()
+ }
+ if r.Flags&autoupdate.FlagFIPS != 0 {
+ suffix += "_" + autoupdate.FlagFIPS.DirFlag()
+ }
+ return r.Version + suffix
+}
+
+// String returns a human-readable description of a Teleport revision.
+func (r Revision) String() string {
+ if flags := r.Flags.Strings(); len(flags) > 0 {
+ return fmt.Sprintf("%s+%s", r.Version, strings.Join(flags, "+"))
+ }
+ return r.Version
+}
+
+// readConfig reads UpdateConfig from a file.
+func readConfig(path string) (*UpdateConfig, error) {
+ f, err := os.Open(path)
+ if errors.Is(err, fs.ErrNotExist) {
+ return &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ }, nil
+ }
+ if err != nil {
+ return nil, trace.Wrap(err, "failed to open")
+ }
+ defer f.Close()
+ var cfg UpdateConfig
+ if err := yaml.NewDecoder(f).Decode(&cfg); err != nil {
+ return nil, trace.Wrap(err, "failed to parse")
+ }
+ if k := cfg.Kind; k != updateConfigKind {
+ return nil, trace.Errorf("invalid kind %s", k)
+ }
+ if v := cfg.Version; v != updateConfigVersion {
+ return nil, trace.Errorf("invalid version %s", v)
+ }
+ return &cfg, nil
+}
+
+// writeConfig writes UpdateConfig to a file atomically, ensuring the file cannot be corrupted.
+func writeConfig(filename string, cfg *UpdateConfig) error {
+ opts := []renameio.Option{
+ renameio.WithPermissions(configFileMode),
+ renameio.WithExistingPermissions(),
+ renameio.WithTempDir(filepath.Dir(filename)),
+ }
+ t, err := renameio.NewPendingFile(filename, opts...)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ defer t.Cleanup()
+ err = yaml.NewEncoder(t).Encode(cfg)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ return trace.Wrap(t.CloseAtomicallyReplace())
+}
+
+func validateConfigSpec(spec *UpdateSpec, override OverrideConfig) error {
+ if override.Proxy != "" {
+ spec.Proxy = override.Proxy
+ }
+ if override.Path != "" {
+ spec.Path = override.Path
+ }
+ if override.Group != "" {
+ spec.Group = override.Group
+ }
+ switch override.BaseURL {
+ case "":
+ case "default":
+ spec.BaseURL = ""
+ default:
+ spec.BaseURL = override.BaseURL
+ }
+ if spec.BaseURL != "" &&
+ !strings.HasPrefix(strings.ToLower(spec.BaseURL), "https://") {
+ return trace.Errorf("Teleport download base URL %s must use TLS (https://)", spec.BaseURL)
+ }
+ if override.Enabled {
+ spec.Enabled = true
+ }
+ if override.Pinned {
+ spec.Pinned = true
+ }
+ return nil
+}
+
+// Status of the agent auto-updates system.
+type Status struct {
+ UpdateSpec `yaml:",inline"`
+ UpdateStatus `yaml:",inline"`
+ FindResp `yaml:",inline"`
+}
+
+// FindResp summarizes the auto-update status response from cluster.
+type FindResp struct {
+ // Target revision of Teleport to install
+ Target Revision `yaml:"target"`
+ // InWindow is true when the install should happen now.
+ InWindow bool `yaml:"in_window"`
+ // Jitter duration before an automated install
+ Jitter time.Duration `yaml:"jitter"`
+ // AGPL installations cannot use the official CDN.
+ AGPL bool `yaml:"agpl,omitempty"`
+}
diff --git a/lib/autoupdate/agent/config_test.go b/lib/autoupdate/agent/config_test.go
new file mode 100644
index 0000000000000..39d318cd6ee4c
--- /dev/null
+++ b/lib/autoupdate/agent/config_test.go
@@ -0,0 +1,127 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package agent
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "github.com/gravitational/teleport/lib/autoupdate"
+)
+
+func TestNewRevisionFromDir(t *testing.T) {
+ t.Parallel()
+
+ for _, tt := range []struct {
+ name string
+ dir string
+ rev Revision
+ errMatch string
+ }{
+ {
+ name: "version",
+ dir: "1.2.3",
+ rev: Revision{
+ Version: "1.2.3",
+ },
+ },
+ {
+ name: "full",
+ dir: "1.2.3_ent_fips",
+ rev: Revision{
+ Version: "1.2.3",
+ Flags: autoupdate.FlagEnterprise | autoupdate.FlagFIPS,
+ },
+ },
+ {
+ name: "ent",
+ dir: "1.2.3_ent",
+ rev: Revision{
+ Version: "1.2.3",
+ Flags: autoupdate.FlagEnterprise,
+ },
+ },
+ {
+ name: "empty",
+ errMatch: "missing",
+ },
+ {
+ name: "trailing",
+ dir: "1.2.3_",
+ errMatch: "invalid",
+ },
+ {
+ name: "more trailing",
+ dir: "1.2.3___",
+ errMatch: "invalid",
+ },
+ {
+ name: "no version",
+ dir: "_fips",
+ errMatch: "missing",
+ },
+ {
+ name: "fips no ent",
+ dir: "1.2.3_fips",
+ errMatch: "invalid",
+ },
+ {
+ name: "unknown start fips",
+ dir: "1.2.3_test_fips",
+ errMatch: "invalid",
+ },
+ {
+ name: "unknown start ent",
+ dir: "1.2.3_test_ent",
+ errMatch: "invalid",
+ },
+ {
+ name: "unknown end fips",
+ dir: "1.2.3_fips_test",
+ errMatch: "invalid",
+ },
+ {
+ name: "unknown end ent",
+ dir: "1.2.3_ent_test",
+ errMatch: "invalid",
+ },
+ {
+ name: "bad order",
+ dir: "1.2.3_fips_ent",
+ errMatch: "invalid",
+ },
+ {
+ name: "underscore",
+ dir: "_",
+ errMatch: "missing",
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ rev, err := NewRevisionFromDir(tt.dir)
+ if tt.errMatch != "" {
+ require.ErrorContains(t, err, tt.errMatch)
+ return
+ }
+ require.NoError(t, err)
+ require.Equal(t, tt.rev, rev)
+ require.Equal(t, tt.dir, rev.Dir())
+ })
+ }
+}
diff --git a/lib/autoupdate/agent/installer.go b/lib/autoupdate/agent/installer.go
index e31813866eacf..29afb6d9da7aa 100644
--- a/lib/autoupdate/agent/installer.go
+++ b/lib/autoupdate/agent/installer.go
@@ -29,19 +29,37 @@ import (
"log/slog"
"net/http"
"os"
+ "path"
"path/filepath"
- "runtime"
- "text/template"
+ "syscall"
"time"
+ "github.com/google/renameio/v2"
"github.com/gravitational/trace"
+ "github.com/gravitational/teleport"
+ "github.com/gravitational/teleport/lib/autoupdate"
"github.com/gravitational/teleport/lib/utils"
)
const (
- checksumType = "sha256"
+ // checksumType for Teleport tgzs
+ checksumType = "sha256"
+ // checksumHexLen is the length of the Teleport checksum.
checksumHexLen = sha256.Size * 2 // bytes to hex
+ // maxServiceFileSize is the maximum size allowed for a systemd service file.
+ maxServiceFileSize = 1_000_000 // 1 MB
+ // configFileMode is the mode used for new configuration files.
+ configFileMode = 0644
+ // systemDirMode is the mode used for new directories.
+ systemDirMode = 0755
+)
+
+const (
+ // serviceDir contains the relative path to the Teleport SystemD service dir.
+ serviceDir = "lib/systemd/system"
+ // serviceName contains the upstream name of the Teleport SystemD service file.
+ serviceName = "teleport.service"
)
// LocalInstaller manages the creation and removal of installations
@@ -49,6 +67,12 @@ const (
type LocalInstaller struct {
// InstallDir contains each installation, named by version.
InstallDir string
+ // TargetServiceFile contains a copy of the linked installation's systemd service.
+ TargetServiceFile string
+ // SystemBinDir contains binaries for the system (packaged) install of Teleport.
+ SystemBinDir string
+ // SystemServiceFile contains the systemd service file for the system (packaged) install of Teleport.
+ SystemServiceFile string
// HTTP is an HTTP client for downloading Teleport.
HTTP *http.Client
// Log contains a logger.
@@ -57,17 +81,31 @@ type LocalInstaller struct {
ReservedFreeTmpDisk uint64
// ReservedFreeInstallDisk is the amount of disk that must remain free in the install directory.
ReservedFreeInstallDisk uint64
+ // TransformService transforms the systemd service during copying.
+ TransformService func([]byte) []byte
+ // ValidateBinary returns true if a file is a linkable binary, or
+ // false if a file should not be linked.
+ ValidateBinary func(ctx context.Context, path string) (bool, error)
+ // Template is download URI Template of Teleport packages.
+ Template string
}
// Remove a Teleport version directory from InstallDir.
// This function is idempotent.
-func (li *LocalInstaller) Remove(ctx context.Context, version string) error {
- versionDir := filepath.Join(li.InstallDir, version)
- sumPath := filepath.Join(versionDir, checksumType)
+// See Installer interface for additional specs.
+func (li *LocalInstaller) Remove(ctx context.Context, rev Revision) error {
+ // os.RemoveAll is dangerous because it can remove an entire directory tree.
+ // We must validate the version to ensure that we remove only a single path
+ // element under the InstallDir, and not InstallDir or its parents.
+ // revisionDir performs these validations.
+ versionDir, err := li.revisionDir(rev)
+ if err != nil {
+ return trace.Wrap(err)
+ }
// invalidate checksum first, to protect against partially-removed
// directory with valid checksum.
- err := os.Remove(sumPath)
+ err = os.Remove(filepath.Join(versionDir, checksumType))
if err != nil && !errors.Is(err, os.ErrNotExist) {
return trace.Wrap(err)
}
@@ -79,12 +117,16 @@ func (li *LocalInstaller) Remove(ctx context.Context, version string) error {
// Install a Teleport version directory in InstallDir.
// This function is idempotent.
-func (li *LocalInstaller) Install(ctx context.Context, version, template string, flags InstallFlags) error {
- versionDir := filepath.Join(li.InstallDir, version)
+// See Installer interface for additional specs.
+func (li *LocalInstaller) Install(ctx context.Context, rev Revision, baseURL string, force bool) (err error) {
+ versionDir, err := li.revisionDir(rev)
+ if err != nil {
+ return trace.Wrap(err)
+ }
sumPath := filepath.Join(versionDir, checksumType)
- // generate download URI from template
- uri, err := makeURL(template, version, flags)
+ // generate download URI from Template
+ uri, err := autoupdate.MakeURL(li.Template, baseURL, autoupdate.DefaultPackage, rev.Version, rev.Flags)
if err != nil {
return trace.Wrap(err)
}
@@ -94,33 +136,37 @@ func (li *LocalInstaller) Install(ctx context.Context, version, template string,
checksumURI := uri + "." + checksumType
newSum, err := li.getChecksum(ctx, checksumURI)
if err != nil {
- return trace.Errorf("failed to download checksum from %s: %w", checksumURI, err)
+ return trace.Wrap(err, "failed to download checksum from %s", checksumURI)
}
oldSum, err := readChecksum(sumPath)
- if err == nil {
- if bytes.Equal(oldSum, newSum) {
- li.Log.InfoContext(ctx, "Version already present.", "version", version)
- return nil
- }
- li.Log.WarnContext(ctx, "Removing version that does not match checksum.", "version", version)
- if err := li.Remove(ctx, version); err != nil {
- return trace.Wrap(err)
- }
+ versionPresent := err == nil
+ if versionPresent && bytes.Equal(oldSum, newSum) {
+ li.Log.InfoContext(ctx, "Version already present.", "version", rev)
+ return nil
+ }
+ if versionPresent {
+ li.Log.WarnContext(ctx, "Removing version that does not match checksum.", "version", rev)
} else if !errors.Is(err, os.ErrNotExist) {
- li.Log.WarnContext(ctx, "Removing version with unreadable checksum.", "version", version, "error", err)
- if err := li.Remove(ctx, version); err != nil {
- return trace.Wrap(err)
+ li.Log.WarnContext(ctx, "Removing version with unreadable checksum.", "version", rev, "error", err)
+ }
+ if versionPresent || !errors.Is(err, os.ErrNotExist) {
+ if force {
+ if err := li.Remove(ctx, rev); err != nil {
+ return trace.Wrap(err)
+ }
+ } else {
+ return trace.Errorf("refusing to remove linked installation of Teleport")
}
}
// Verify that we have enough free temp space, then download tgz
freeTmp, err := utils.FreeDiskWithReserve(os.TempDir(), li.ReservedFreeTmpDisk)
if err != nil {
- return trace.Errorf("failed to calculate free disk: %w", err)
+ return trace.Wrap(err, "failed to calculate free disk")
}
f, err := os.CreateTemp("", "teleport-update-")
if err != nil {
- return trace.Errorf("failed to create temporary file: %w", err)
+ return trace.Wrap(err, "failed to create temporary file")
}
defer func() {
_ = f.Close() // data never read after close
@@ -130,13 +176,19 @@ func (li *LocalInstaller) Install(ctx context.Context, version, template string,
}()
pathSum, err := li.download(ctx, f, int64(freeTmp), uri)
if err != nil {
- return trace.Errorf("failed to download teleport: %w", err)
+ return trace.Wrap(err, "failed to download teleport")
}
-
// Seek to the start of the tgz file after writing
if _, err := f.Seek(0, io.SeekStart); err != nil {
- return trace.Errorf("failed seek to start of download: %w", err)
+ return trace.Wrap(err, "failed seek to start of download")
}
+
+ // If interrupted, close the file immediately to stop extracting.
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ context.AfterFunc(ctx, func() {
+ _ = f.Close() // safe to close file multiple times
+ })
// Check integrity before decompression
if !bytes.Equal(newSum, pathSum) {
return trace.Errorf("mismatched checksum, download possibly corrupt")
@@ -144,47 +196,33 @@ func (li *LocalInstaller) Install(ctx context.Context, version, template string,
// Get uncompressed size of the tgz
n, err := uncompressedSize(f)
if err != nil {
- return trace.Errorf("failed to determine uncompressed size: %w", err)
+ return trace.Wrap(err, "failed to determine uncompressed size")
}
// Seek to start of tgz after reading size
if _, err := f.Seek(0, io.SeekStart); err != nil {
- return trace.Errorf("failed seek to start: %w", err)
+ return trace.Wrap(err, "failed seek to start")
}
- if err := li.extract(ctx, versionDir, f, n); err != nil {
- return trace.Errorf("failed to extract teleport: %w", err)
+
+ // If there's an error after we start extracting, delete the version dir.
+ defer func() {
+ if err != nil {
+ if err := os.RemoveAll(versionDir); err != nil {
+ li.Log.WarnContext(ctx, "Failed to cleanup broken version extraction.", "error", err, "dir", versionDir)
+ }
+ }
+ }()
+
+ // Extract tgz into version directory.
+ if err := li.extract(ctx, versionDir, f, n, rev.Flags); err != nil {
+ return trace.Wrap(err, "failed to extract teleport")
}
// Write the checksum last. This marks the version directory as valid.
- err = os.WriteFile(sumPath, []byte(hex.EncodeToString(newSum)), 0755)
- if err != nil {
- return trace.Errorf("failed to write checksum: %w", err)
+ if err := os.WriteFile(sumPath, []byte(hex.EncodeToString(newSum)), configFileMode); err != nil {
+ return trace.Wrap(err, "failed to write checksum")
}
return nil
}
-// makeURL to download the Teleport tgz.
-func makeURL(uriTmpl, version string, flags InstallFlags) (string, error) {
- tmpl, err := template.New("uri").Parse(uriTmpl)
- if err != nil {
- return "", trace.Wrap(err)
- }
- var uriBuf bytes.Buffer
- params := struct {
- OS, Version, Arch string
- FIPS, Enterprise bool
- }{
- OS: runtime.GOOS,
- Version: version,
- Arch: runtime.GOARCH,
- FIPS: flags&FlagFIPS != 0,
- Enterprise: flags&(FlagEnterprise|FlagFIPS) != 0,
- }
- err = tmpl.Execute(&uriBuf, params)
- if err != nil {
- return "", trace.Wrap(err)
- }
- return uriBuf.String(), nil
-}
-
// readChecksum from the version directory.
func readChecksum(path string) ([]byte, error) {
f, err := os.Open(path)
@@ -208,6 +246,7 @@ func readChecksum(path string) ([]byte, error) {
func (li *LocalInstaller) getChecksum(ctx context.Context, url string) ([]byte, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, trace.Wrap(err)
@@ -241,6 +280,7 @@ func (li *LocalInstaller) download(ctx context.Context, w io.Writer, max int64,
if err != nil {
return nil, trace.Wrap(err)
}
+ startTime := time.Now()
resp, err := li.HTTP.Do(req)
if err != nil {
return nil, trace.Wrap(err)
@@ -264,42 +304,71 @@ func (li *LocalInstaller) download(ctx context.Context, w io.Writer, max int64,
}
// Calculate checksum concurrently with download.
shaReader := sha256.New()
- n, err := io.CopyN(w, io.TeeReader(resp.Body, shaReader), size)
+ tee := io.TeeReader(resp.Body, shaReader)
+ tee = io.TeeReader(tee, &progressLogger{
+ ctx: ctx,
+ log: li.Log,
+ level: slog.LevelInfo,
+ name: path.Base(resp.Request.URL.Path),
+ max: int(resp.ContentLength),
+ lines: 5,
+ })
+ n, err := io.CopyN(w, tee, size)
if err != nil {
return nil, trace.Wrap(err)
}
if resp.ContentLength >= 0 && n != resp.ContentLength {
return nil, trace.Errorf("mismatch in Teleport download size")
}
+ li.Log.InfoContext(ctx, "Download complete.", "duration", time.Since(startTime), "size", n)
return shaReader.Sum(nil), nil
}
-func (li *LocalInstaller) extract(ctx context.Context, dstDir string, src io.Reader, max int64) error {
- if err := os.MkdirAll(dstDir, 0755); err != nil {
+func (li *LocalInstaller) extract(ctx context.Context, dstDir string, src io.Reader, max int64, flags autoupdate.InstallFlags) error {
+ if err := os.MkdirAll(dstDir, systemDirMode); err != nil {
return trace.Wrap(err)
}
free, err := utils.FreeDiskWithReserve(dstDir, li.ReservedFreeInstallDisk)
if err != nil {
- return trace.Errorf("failed to calculate free disk in %q: %w", dstDir, err)
+ return trace.Wrap(err, "failed to calculate free disk in %s", dstDir)
}
// Bail if there's not enough free disk space at the target
if d := int64(free) - max; d < 0 {
- return trace.Errorf("%q needs %d additional bytes of disk space for decompression", dstDir, -d)
+ return trace.Errorf("%s needs %d additional bytes of disk space for decompression", dstDir, -d)
}
zr, err := gzip.NewReader(src)
if err != nil {
- return trace.Errorf("requires gzip-compressed body: %v", err)
+ return trace.Wrap(err, "requires gzip-compressed body")
}
li.Log.InfoContext(ctx, "Extracting Teleport tarball.", "path", dstDir, "size", max)
- // TODO(sclevine): add variadic arg to Extract to extract teleport/ subdir into bin/.
- err = utils.Extract(zr, dstDir)
+ err = utils.Extract(zr, dstDir, tgzExtractPaths(flags&(autoupdate.FlagEnterprise|autoupdate.FlagFIPS) != 0)...)
if err != nil {
return trace.Wrap(err)
}
return nil
}
+// tgzExtractPaths describes how to extract the Teleport tgz.
+// See utils.Extract for more details on how this list is parsed.
+// Paths must use tarball-style / separators (not filepath).
+func tgzExtractPaths(ent bool) []utils.ExtractPath {
+ prefix := "teleport"
+ if ent {
+ prefix += "-ent"
+ }
+ return []utils.ExtractPath{
+ {Src: path.Join(prefix, "examples/systemd/teleport.service"), Dst: filepath.Join(serviceDir, serviceName), DirMode: systemDirMode},
+ {Src: path.Join(prefix, "examples"), Skip: true, DirMode: systemDirMode},
+ {Src: path.Join(prefix, "install"), Skip: true, DirMode: systemDirMode},
+ {Src: path.Join(prefix, "README.md"), Dst: "share/README.md", DirMode: systemDirMode},
+ {Src: path.Join(prefix, "CHANGELOG.md"), Dst: "share/CHANGELOG.md", DirMode: systemDirMode},
+ {Src: path.Join(prefix, "VERSION"), Dst: "share/VERSION", DirMode: systemDirMode},
+ {Src: path.Join(prefix, "LICENSE-community"), Dst: "share/LICENSE-community", DirMode: systemDirMode},
+ {Src: prefix, Dst: "bin", DirMode: systemDirMode},
+ }
+}
+
func uncompressedSize(f io.Reader) (int64, error) {
// NOTE: The gzip length trailer is very unreliable,
// but we could optimize this in the future if
@@ -315,3 +384,500 @@ func uncompressedSize(f io.Reader) (int64, error) {
}
return n, nil
}
+
+// List installed versions of Teleport.
+func (li *LocalInstaller) List(ctx context.Context) (revs []Revision, err error) {
+ entries, err := os.ReadDir(li.InstallDir)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ for _, entry := range entries {
+ if !entry.IsDir() {
+ continue
+ }
+ rev, err := NewRevisionFromDir(entry.Name())
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ revs = append(revs, rev)
+ }
+ return revs, nil
+}
+
+// Link the specified version into pathDir and TargetServiceFile.
+// The revert function restores the previous linking.
+// If force is true, Link will overwrite files that are not symlinks.
+// See Installer interface for additional specs.
+func (li *LocalInstaller) Link(ctx context.Context, rev Revision, pathDir string, force bool) (revert func(context.Context) bool, err error) {
+ revert = func(context.Context) bool { return true }
+ versionDir, err := li.revisionDir(rev)
+ if err != nil {
+ return revert, trace.Wrap(err)
+ }
+ revert, err = li.forceLinks(ctx,
+ filepath.Join(versionDir, "bin"),
+ filepath.Join(versionDir, serviceDir, serviceName),
+ pathDir,
+ force,
+ )
+ if err != nil {
+ return revert, trace.Wrap(err)
+ }
+ return revert, nil
+}
+
+// LinkSystem links the system (package) version into defaultPathDir and TargetServiceFile.
+// This prevents namespaced installations in /opt/teleport from linking to the system package.
+// The revert function restores the previous linking.
+// See Installer interface for additional specs.
+func (li *LocalInstaller) LinkSystem(ctx context.Context) (revert func(context.Context) bool, err error) {
+ revert, err = li.forceLinks(ctx, li.SystemBinDir, li.SystemServiceFile, defaultPathDir, false)
+ return revert, trace.Wrap(err)
+}
+
+// TryLink links the specified version into pathDir, but only in the case that
+// no installation of Teleport is already linked or partially linked.
+// See Installer interface for additional specs.
+func (li *LocalInstaller) TryLink(ctx context.Context, revision Revision, pathDir string) error {
+ versionDir, err := li.revisionDir(revision)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ return trace.Wrap(li.tryLinks(ctx,
+ filepath.Join(versionDir, "bin"),
+ filepath.Join(versionDir, serviceDir, serviceName),
+ pathDir,
+ ))
+}
+
+// TryLinkSystem links the system installation to defaultPathDir, but only in the case that
+// no installation of Teleport is already linked or partially linked.
+// See Installer interface for additional specs.
+func (li *LocalInstaller) TryLinkSystem(ctx context.Context) error {
+ return trace.Wrap(li.tryLinks(ctx, li.SystemBinDir, li.SystemServiceFile, defaultPathDir))
+}
+
+// Unlink unlinks a version from pathDir and TargetServiceFile.
+// See Installer interface for additional specs.
+func (li *LocalInstaller) Unlink(ctx context.Context, rev Revision, pathDir string) error {
+ versionDir, err := li.revisionDir(rev)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ return trace.Wrap(li.removeLinks(ctx,
+ filepath.Join(versionDir, "bin"),
+ filepath.Join(versionDir, serviceDir, serviceName),
+ pathDir,
+ ))
+}
+
+// UnlinkSystem unlinks the system (package) version from defaultPathDir and TargetServiceFile.
+// See Installer interface for additional specs.
+func (li *LocalInstaller) UnlinkSystem(ctx context.Context) error {
+ return trace.Wrap(li.removeLinks(ctx, li.SystemBinDir, li.SystemServiceFile, defaultPathDir))
+}
+
+// symlink from oldname to newname
+type symlink struct {
+ oldname, newname string
+}
+
+// smallFile is a file small enough to be stored in memory.
+type smallFile struct {
+ name string
+ data []byte
+ mode os.FileMode
+}
+
+// forceLinks replaces binary links and service files using files in binDir and svcDir.
+// Existing links and files are replaced, but mismatched links and files will result in error.
+// forceLinks will revert any overridden links or files if it hits an error.
+// If successful, forceLinks may also be reverted after it returns by calling revert.
+// The revert function returns true if reverting succeeds.
+// If force is true, non-link files will be overwritten.
+func (li *LocalInstaller) forceLinks(ctx context.Context, srcBinDir, srcSvcFile, dstBinDir string, force bool) (revert func(context.Context) bool, err error) {
+ // setup revert function
+ var (
+ revertLinks []symlink
+ revertFiles []smallFile
+ )
+ revert = func(ctx context.Context) bool {
+ // This function is safe to call repeatedly.
+ // Returns true only when all changes are successfully reverted.
+ var (
+ keepLinks []symlink
+ keepFiles []smallFile
+ )
+ for _, l := range revertLinks {
+ err := renameio.Symlink(l.oldname, l.newname)
+ if err != nil {
+ keepLinks = append(keepLinks, l)
+ li.Log.ErrorContext(ctx, "Failed to revert symlink", "oldname", l.oldname, "newname", l.newname, errorKey, err)
+ }
+ }
+ for _, f := range revertFiles {
+ err := writeFileAtomicWithinDir(f.name, f.data, f.mode)
+ if err != nil {
+ keepFiles = append(keepFiles, f)
+ li.Log.ErrorContext(ctx, "Failed to revert files", "name", f.name, errorKey, err)
+ }
+ }
+ revertLinks = keepLinks
+ revertFiles = keepFiles
+ return len(revertLinks) == 0 && len(revertFiles) == 0
+ }
+ // revert immediately on error, so caller can ignore revert arg
+ defer func() {
+ if err != nil {
+ revert(ctx)
+ }
+ }()
+
+ // ensure source directory exists
+ entries, err := os.ReadDir(srcBinDir)
+ if errors.Is(err, os.ErrNotExist) {
+ return revert, trace.Wrap(ErrNoBinaries)
+ }
+ if err != nil {
+ return revert, trace.Wrap(err, "failed to read Teleport binary directory")
+ }
+
+ // ensure target directories exist before trying to create links
+ err = os.MkdirAll(dstBinDir, systemDirMode)
+ if err != nil {
+ return revert, trace.Wrap(err)
+ }
+ err = os.MkdirAll(filepath.Dir(li.TargetServiceFile), systemDirMode)
+ if err != nil {
+ return revert, trace.Wrap(err)
+ }
+
+ // create binary links
+ var linked int
+ for _, entry := range entries {
+ if entry.IsDir() {
+ continue
+ }
+ oldname := filepath.Join(srcBinDir, entry.Name())
+ newname := filepath.Join(dstBinDir, entry.Name())
+ exec, err := li.ValidateBinary(ctx, oldname)
+ if err != nil {
+ return revert, trace.Wrap(err)
+ }
+ if !exec {
+ continue
+ }
+ orig, err := forceLink(oldname, newname, force)
+ if err != nil && !errors.Is(err, os.ErrExist) {
+ return revert, trace.Wrap(err, "failed to create symlink for %s", entry.Name())
+ }
+ if orig != "" {
+ revertLinks = append(revertLinks, symlink{
+ oldname: orig,
+ newname: newname,
+ })
+ }
+ linked++
+ }
+ if linked == 0 {
+ return revert, trace.Wrap(ErrNoBinaries)
+ }
+
+ // create systemd service file
+
+ orig, err := li.forceCopyService(li.TargetServiceFile, srcSvcFile, maxServiceFileSize)
+ if err != nil && !errors.Is(err, os.ErrExist) {
+ return revert, trace.Wrap(err, "failed to copy service")
+ }
+ if orig != nil {
+ revertFiles = append(revertFiles, *orig)
+ }
+ return revert, nil
+}
+
+// forceCopyService uses forceCopy to copy a systemd service file from src to dst.
+// The contents of both src and dst must be smaller than n.
+// See forceCopy for more details.
+func (li *LocalInstaller) forceCopyService(dst, src string, n int64) (orig *smallFile, err error) {
+ srcData, err := readFileAtMost(src, n)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return forceCopy(dst, li.TransformService(srcData), n)
+}
+
+// forceLink attempts to create a symlink, atomically replacing an existing link if already present.
+// If a non-symlink file or directory exists in newname already and force is false, forceLink errors with ErrFilePresent.
+// If the link is already present with the desired oldname, forceLink returns os.ErrExist.
+func forceLink(oldname, newname string, force bool) (orig string, err error) {
+ orig, err = os.Readlink(newname)
+ if errors.Is(err, os.ErrInvalid) ||
+ errors.Is(err, syscall.EINVAL) { // workaround missing ErrInvalid wrapper
+ if force {
+ return "", trace.Wrap(renameio.Symlink(oldname, newname))
+ }
+ // important: do not attempt to replace a non-linked install of Teleport without force
+ return "", trace.Wrap(ErrFilePresent, "refusing to replace file at %s", newname)
+ } else if err != nil && !errors.Is(err, os.ErrNotExist) {
+ return "", trace.Wrap(err)
+ }
+ if orig == oldname {
+ return "", trace.Wrap(os.ErrExist)
+ }
+ err = renameio.Symlink(oldname, newname)
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+ return orig, nil
+}
+
+// forceCopy atomically copies a file from srcData to dst, replacing an existing file at dst if needed.
+// The contents of dst must be smaller than n.
+// forceCopy returns the original file path, mode, and contents as orig.
+// If an irregular file, too large file, or directory exists in dst already, forceCopy errors.
+// If the file is already present with the desired contents, forceCopy returns os.ErrExist.
+func forceCopy(dst string, srcData []byte, n int64) (orig *smallFile, err error) {
+ fi, err := os.Lstat(dst)
+ if err != nil && !errors.Is(err, os.ErrNotExist) {
+ return nil, trace.Wrap(err)
+ }
+ if err == nil {
+ orig = &smallFile{
+ name: dst,
+ mode: fi.Mode(),
+ }
+ if !orig.mode.IsRegular() {
+ return nil, trace.Errorf("refusing to replace irregular file at %s", dst)
+ }
+ orig.data, err = readFileAtMost(dst, n)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ if bytes.Equal(srcData, orig.data) {
+ return nil, trace.Wrap(os.ErrExist)
+ }
+ }
+ err = writeFileAtomicWithinDir(dst, srcData, configFileMode)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return orig, nil
+}
+
+// writeFileAtomicWithinDir atomically creates a new file with renameio, while ensuring that temporary
+// files use the same directory as the target file (with format: .[base][randints]).
+// This ensures that SELinux contexts for important files are set correctly.
+func writeFileAtomicWithinDir(filename string, data []byte, perm os.FileMode) error {
+ dir := filepath.Dir(filename)
+ err := renameio.WriteFile(filename, data, perm, renameio.WithTempDir(dir))
+ return trace.Wrap(err)
+}
+
+// readFileAtMost reads a file up to n, or errors if it is too large.
+func readFileAtMost(name string, n int64) ([]byte, error) {
+ f, err := os.Open(name)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+ data, err := utils.ReadAtMost(f, n)
+ return data, trace.Wrap(err)
+}
+
+func (li *LocalInstaller) removeLinks(ctx context.Context, srcBinDir, srcSvcFile, dstBinDir string) error {
+ removeService := false
+ entries, err := os.ReadDir(srcBinDir)
+ if err != nil {
+ return trace.Wrap(err, "failed to find Teleport binary directory")
+ }
+ for _, entry := range entries {
+ if entry.IsDir() {
+ continue
+ }
+ oldname := filepath.Join(srcBinDir, entry.Name())
+ newname := filepath.Join(dstBinDir, entry.Name())
+ v, err := os.Readlink(newname)
+ if errors.Is(err, os.ErrNotExist) ||
+ errors.Is(err, os.ErrInvalid) ||
+ errors.Is(err, syscall.EINVAL) {
+ li.Log.DebugContext(ctx, "Link not present.", "oldname", oldname, "newname", newname)
+ continue
+ }
+ if err != nil {
+ return trace.Wrap(err, "error reading link for %s", filepath.Base(newname))
+ }
+ if v != oldname {
+ li.Log.DebugContext(ctx, "Skipping link to different binary.", "oldname", oldname, "newname", newname)
+ continue
+ }
+ if err := os.Remove(newname); err != nil {
+ li.Log.ErrorContext(ctx, "Unable to remove link.", "oldname", oldname, "newname", newname, errorKey, err)
+ continue
+ }
+ if filepath.Base(newname) == teleport.ComponentTeleport {
+ removeService = true
+ }
+ }
+ // only remove service if teleport was removed
+ if !removeService {
+ li.Log.DebugContext(ctx, "Teleport binary not unlinked. Skipping removal of teleport.service.")
+ return nil
+ }
+ srcBytes, err := readFileAtMost(srcSvcFile, maxServiceFileSize)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ dstBytes, err := readFileAtMost(li.TargetServiceFile, maxServiceFileSize)
+ if errors.Is(err, os.ErrNotExist) {
+ li.Log.DebugContext(ctx, "Service not present.", "path", li.TargetServiceFile)
+ return nil
+ }
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ if !bytes.Equal(li.TransformService(srcBytes), dstBytes) {
+ li.Log.WarnContext(ctx, "Removed teleport binary link, but skipping removal of custom teleport.service: the service file does not match the reference file for this version. The file might have been manually edited.")
+ return nil
+ }
+ if err := os.Remove(li.TargetServiceFile); err != nil {
+ return trace.Wrap(err, "error removing copy of %s", filepath.Base(li.TargetServiceFile))
+ }
+ return nil
+}
+
+// tryLinks create binary and service links for files in binDir and svcDir if links are not already present.
+// Existing links that point to files outside binDir or svcDir, as well as existing non-link files, will error.
+// tryLinks will not attempt to create any links if linking could result in an error.
+// However, concurrent changes to links may result in an error with partially-complete linking.
+func (li *LocalInstaller) tryLinks(ctx context.Context, srcBinDir, srcSvcFile, dstBinDir string) error {
+ // ensure source directory exists
+ entries, err := os.ReadDir(srcBinDir)
+ if errors.Is(err, os.ErrNotExist) {
+ return trace.Wrap(ErrNoBinaries)
+ }
+ if err != nil {
+ return trace.Wrap(err, "failed to read Teleport binary directory")
+ }
+
+ // ensure target directories exist before trying to create links
+ err = os.MkdirAll(dstBinDir, systemDirMode)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ err = os.MkdirAll(filepath.Dir(li.TargetServiceFile), systemDirMode)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ // validate that we can link all system binaries before attempting linking
+ var links []symlink
+ var linked int
+ for _, entry := range entries {
+ if entry.IsDir() {
+ continue
+ }
+ oldname := filepath.Join(srcBinDir, entry.Name())
+ newname := filepath.Join(dstBinDir, entry.Name())
+ exec, err := li.ValidateBinary(ctx, oldname)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ if !exec {
+ continue
+ }
+ ok, err := needsLink(oldname, newname)
+ if err != nil {
+ return trace.Wrap(err, "error evaluating link for %s", filepath.Base(oldname))
+ }
+ if ok {
+ links = append(links, symlink{oldname, newname})
+ }
+ linked++
+ }
+ // bail if no binaries can be linked
+ if linked == 0 {
+ return trace.Wrap(ErrNoBinaries)
+ }
+
+ // link binaries that are missing links
+ for _, link := range links {
+ if err := os.Symlink(link.oldname, link.newname); err != nil {
+ return trace.Wrap(err, "failed to create symlink for %s", filepath.Base(link.oldname))
+ }
+ }
+
+ // if any binaries are linked from srcBinDir, always link the service from svcDir
+ _, err = li.forceCopyService(li.TargetServiceFile, srcSvcFile, maxServiceFileSize)
+ if err != nil && !errors.Is(err, os.ErrExist) {
+ return trace.Wrap(err, "failed to copy service")
+ }
+
+ return nil
+}
+
+// needsLink returns true when a symlink from oldname to newname needs to be created, or false if it exists.
+// If a non-symlink file or directory exists at newname, needsLink errors with ErrFilePresent.
+// If a symlink to a different location exists, needsLink errors with ErrLinked.
+func needsLink(oldname, newname string) (ok bool, err error) {
+ orig, err := os.Readlink(newname)
+ if errors.Is(err, os.ErrInvalid) ||
+ errors.Is(err, syscall.EINVAL) { // workaround missing ErrInvalid wrapper
+ // important: do not attempt to replace a non-linked install of Teleport
+ return false, trace.Wrap(ErrFilePresent, "refusing to replace file at %s", newname)
+ }
+ if errors.Is(err, os.ErrNotExist) {
+ return true, nil
+ }
+ if err != nil {
+ return false, trace.Wrap(err)
+ }
+ if orig != oldname {
+ return false, trace.Wrap(ErrLinked, "refusing to replace link at %s", newname)
+ }
+ return false, nil
+}
+
+// revisionDir returns the storage directory for a Teleport revision.
+// revisionDir will fail if the revision cannot be used to construct the directory name.
+// For example, it ensures that ".." cannot be provided to return a system directory.
+func (li *LocalInstaller) revisionDir(rev Revision) (string, error) {
+ installDir, err := filepath.Abs(li.InstallDir)
+ if err != nil {
+ return "", trace.Wrap(err)
+ }
+ versionDir := filepath.Join(installDir, rev.Dir())
+ if filepath.Dir(versionDir) != filepath.Clean(installDir) {
+ return "", trace.Errorf("refusing to link directory outside of version directory")
+ }
+ return versionDir, nil
+}
+
+// IsLinked returns true if any binaries for Revision rev are linked to pathDir.
+// Returns os.ErrNotExist error if the revision does not exist.
+func (li *LocalInstaller) IsLinked(ctx context.Context, rev Revision, pathDir string) (bool, error) {
+ versionDir, err := li.revisionDir(rev)
+ if err != nil {
+ return false, trace.Wrap(err)
+ }
+ binDir := filepath.Join(versionDir, "bin")
+ entries, err := os.ReadDir(binDir)
+ if errors.Is(err, os.ErrNotExist) {
+ return false, nil
+ }
+ if err != nil {
+ return false, trace.Wrap(err)
+ }
+ for _, entry := range entries {
+ if entry.IsDir() {
+ continue
+ }
+ v, err := os.Readlink(filepath.Join(pathDir, entry.Name()))
+ if err != nil {
+ continue
+ }
+ if filepath.Clean(v) == filepath.Join(binDir, entry.Name()) {
+ return true, nil
+ }
+ }
+ return false, nil
+}
diff --git a/lib/autoupdate/agent/installer_test.go b/lib/autoupdate/agent/installer_test.go
index be778f7bcf16a..7fbee63c8aa73 100644
--- a/lib/autoupdate/agent/installer_test.go
+++ b/lib/autoupdate/agent/installer_test.go
@@ -32,15 +32,18 @@ import (
"os"
"path/filepath"
"runtime"
+ "slices"
"strconv"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+
+ "github.com/gravitational/teleport/lib/autoupdate"
)
-func TestTeleportInstaller_Install(t *testing.T) {
+func TestLocalInstaller_Install(t *testing.T) {
t.Parallel()
const version = "new-version"
@@ -51,7 +54,8 @@ func TestTeleportInstaller_Install(t *testing.T) {
reservedTmp uint64
reservedInstall uint64
existingSum string
- flags InstallFlags
+ flags autoupdate.InstallFlags
+ force bool
errMatch string
}{
@@ -65,10 +69,18 @@ func TestTeleportInstaller_Install(t *testing.T) {
{
name: "mismatched checksum",
existingSum: hex.EncodeToString(sha256.New().Sum(nil)),
+ force: true,
},
{
name: "unreadable checksum",
existingSum: "bad",
+ force: true,
+ },
+ {
+ name: "unreadable checksum, force false",
+ existingSum: "bad",
+ force: false,
+ errMatch: "refusing",
},
{
name: "out of space in /tmp",
@@ -84,12 +96,13 @@ func TestTeleportInstaller_Install(t *testing.T) {
}
for _, tt := range tests {
- tt := tt
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
+ err := os.MkdirAll(filepath.Join(dir, version), os.ModePerm)
+ require.NoError(t, err)
if tt.existingSum != "" {
- err := os.WriteFile(filepath.Join(dir, checksumType), []byte(tt.existingSum), os.ModePerm)
+ err := os.WriteFile(filepath.Join(dir, version, checksumType), []byte(tt.existingSum), os.ModePerm)
require.NoError(t, err)
}
@@ -122,9 +135,10 @@ func TestTeleportInstaller_Install(t *testing.T) {
Log: slog.Default(),
ReservedFreeTmpDisk: tt.reservedTmp,
ReservedFreeInstallDisk: tt.reservedInstall,
+ Template: "{{.BaseURL}}/{{.Package}}-{{.OS}}/{{.Arch}}/{{.Version}}",
}
ctx := context.Background()
- err := installer.Install(ctx, version, server.URL+"/{{.OS}}/{{.Arch}}/{{.Version}}", tt.flags)
+ err = installer.Install(ctx, NewRevision(version, tt.flags), server.URL, tt.force)
if tt.errMatch != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errMatch)
@@ -132,17 +146,24 @@ func TestTeleportInstaller_Install(t *testing.T) {
}
require.NoError(t, err)
- const expectedPath = "/" + runtime.GOOS + "/" + runtime.GOARCH + "/" + version
- require.Equal(t, expectedPath, dlPath)
+ const expectedPath = "/teleport-" + runtime.GOOS + "/" + runtime.GOARCH + "/" + version
require.Equal(t, expectedPath+"."+checksumType, shaPath)
- teleportVersion, err := os.ReadFile(filepath.Join(dir, version, "teleport"))
- require.NoError(t, err)
- require.Equal(t, version, string(teleportVersion))
+ if tt.existingSum == testSum {
+ return
+ }
- tshVersion, err := os.ReadFile(filepath.Join(dir, version, "tsh"))
- require.NoError(t, err)
- require.Equal(t, version, string(tshVersion))
+ require.Equal(t, expectedPath, dlPath)
+
+ for _, p := range []string{
+ filepath.Join(dir, version, "lib", "systemd", "system", "teleport.service"),
+ filepath.Join(dir, version, "bin", "teleport"),
+ filepath.Join(dir, version, "bin", "tsh"),
+ } {
+ v, err := os.ReadFile(p)
+ require.NoError(t, err)
+ require.Equal(t, version, string(v))
+ }
sum, err := os.ReadFile(filepath.Join(dir, version, checksumType))
require.NoError(t, err)
@@ -163,8 +184,9 @@ func testTGZ(t *testing.T, version string) (tgz *bytes.Buffer, shasum string) {
var files = []struct {
Name, Body string
}{
- {"teleport", version},
- {"tsh", version},
+ {"teleport/examples/systemd/teleport.service", version},
+ {"teleport/teleport", version},
+ {"teleport/tsh", version},
}
for _, file := range files {
hdr := &tar.Header{
@@ -187,3 +209,955 @@ func testTGZ(t *testing.T, version string) (tgz *bytes.Buffer, shasum string) {
}
return &buf, hex.EncodeToString(sha.Sum(nil))
}
+
+func TestLocalInstaller_Link(t *testing.T) {
+ t.Parallel()
+ const version = "new-version"
+ servicePath := filepath.Join(serviceDir, serviceName)
+
+ tests := []struct {
+ name string
+ installDirs []string
+ installFiles []string
+ installFileMode os.FileMode
+ existingLinks []string
+ existingFiles []string
+ force bool
+
+ resultLinks []string
+ resultServices []string
+ errMatch string
+ }{
+ {
+ name: "present with new links",
+ installDirs: []string{
+ "bin",
+ "bin/somedir",
+ "lib",
+ "lib/systemd",
+ "lib/systemd/system",
+ "somedir",
+ },
+ installFiles: []string{
+ "bin/teleport",
+ "bin/tsh",
+ "bin/tbot",
+ servicePath,
+ "README",
+ },
+ installFileMode: os.ModePerm,
+
+ resultLinks: []string{
+ "bin/teleport",
+ "bin/tsh",
+ "bin/tbot",
+ },
+ resultServices: []string{
+ "lib/systemd/system/teleport.service",
+ },
+ },
+ {
+ name: "present with non-executable files",
+ installDirs: []string{
+ "bin",
+ "bin/somedir",
+ "lib",
+ "lib/systemd",
+ "lib/systemd/system",
+ "somedir",
+ },
+ installFiles: []string{
+ "bin/teleport",
+ "bin/tsh",
+ "bin/tbot",
+ servicePath,
+ "README",
+ },
+ installFileMode: 0644,
+
+ errMatch: ErrNoBinaries.Error(),
+ },
+ {
+ name: "present with existing links",
+ installDirs: []string{
+ "bin",
+ "bin/somedir",
+ "lib",
+ "lib/systemd",
+ "lib/systemd/system",
+ "somedir",
+ },
+ installFiles: []string{
+ "bin/teleport",
+ "bin/tsh",
+ "bin/tbot",
+ servicePath,
+ "README",
+ },
+ installFileMode: os.ModePerm,
+ existingLinks: []string{
+ "bin/teleport",
+ "bin/tsh",
+ "bin/tbot",
+ },
+ existingFiles: []string{
+ "lib/systemd/system/teleport.service",
+ },
+
+ resultLinks: []string{
+ "bin/teleport",
+ "bin/tsh",
+ "bin/tbot",
+ },
+ resultServices: []string{
+ "lib/systemd/system/teleport.service",
+ },
+ },
+ {
+ name: "conflicting systemd files",
+ installDirs: []string{
+ "bin",
+ "bin/somedir",
+ "lib",
+ "lib/systemd",
+ "lib/systemd/system",
+ "somedir",
+ },
+ installFiles: []string{
+ "bin/teleport",
+ "bin/tsh",
+ "bin/tbot",
+ servicePath,
+ "README",
+ },
+ installFileMode: os.ModePerm,
+ existingLinks: []string{
+ "bin/teleport",
+ "bin/tsh",
+ "bin/tbot",
+ "lib/systemd/system/teleport.service",
+ },
+
+ errMatch: "refusing",
+ },
+ {
+ name: "conflicting bin files",
+ installDirs: []string{
+ "bin",
+ "bin/somedir",
+ "lib",
+ "lib/systemd",
+ "lib/systemd/system",
+ "somedir",
+ },
+ installFiles: []string{
+ "bin/teleport",
+ "bin/tsh",
+ "bin/tbot",
+ servicePath,
+ "README",
+ },
+ installFileMode: os.ModePerm,
+ existingLinks: []string{
+ "bin/teleport",
+ "bin/tbot",
+ },
+ existingFiles: []string{
+ "lib/systemd/system/teleport.service",
+ "bin/tsh",
+ },
+
+ errMatch: ErrFilePresent.Error(),
+ },
+ {
+ name: "overwriting bin files",
+ installDirs: []string{
+ "bin",
+ "bin/somedir",
+ "lib",
+ "lib/systemd",
+ "lib/systemd/system",
+ "somedir",
+ },
+ installFiles: []string{
+ "bin/teleport",
+ "bin/tsh",
+ "bin/tbot",
+ servicePath,
+ "README",
+ },
+ installFileMode: os.ModePerm,
+ existingLinks: []string{
+ "bin/teleport",
+ "bin/tbot",
+ },
+ existingFiles: []string{
+ "lib/systemd/system/teleport.service",
+ "bin/tsh",
+ },
+ force: true,
+
+ resultLinks: []string{
+ "bin/teleport",
+ "bin/tsh",
+ "bin/tbot",
+ },
+ resultServices: []string{
+ "lib/systemd/system/teleport.service",
+ },
+ },
+ {
+ name: "no links",
+ installFiles: []string{"README"},
+ installDirs: []string{"bin"},
+
+ errMatch: ErrNoBinaries.Error(),
+ },
+ {
+ name: "no bin directory",
+ installFiles: []string{"README"},
+
+ errMatch: ErrNoBinaries.Error(),
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ versionsDir := t.TempDir()
+ versionDir := filepath.Join(versionsDir, version)
+ err := os.MkdirAll(versionDir, 0o755)
+ require.NoError(t, err)
+
+ // setup files in version directory
+ for _, d := range tt.installDirs {
+ err := os.Mkdir(filepath.Join(versionDir, d), os.ModePerm)
+ require.NoError(t, err)
+ }
+ for _, n := range tt.installFiles {
+ err := os.WriteFile(filepath.Join(versionDir, n), []byte(filepath.Base(n)), tt.installFileMode)
+ require.NoError(t, err)
+ }
+
+ // setup files in system links directory
+ linkDir := t.TempDir()
+ for _, n := range tt.existingLinks {
+ err := os.MkdirAll(filepath.Dir(filepath.Join(linkDir, n)), os.ModePerm)
+ require.NoError(t, err)
+ err = os.Symlink(filepath.Base(n)+".old", filepath.Join(linkDir, n))
+ require.NoError(t, err)
+ }
+ for _, n := range tt.existingFiles {
+ err := os.MkdirAll(filepath.Dir(filepath.Join(linkDir, n)), os.ModePerm)
+ require.NoError(t, err)
+ err = os.WriteFile(filepath.Join(linkDir, n), []byte(filepath.Base(n)), os.ModePerm)
+ require.NoError(t, err)
+ }
+
+ validator := Validator{Log: slog.Default()}
+ installer := &LocalInstaller{
+ InstallDir: versionsDir,
+ TargetServiceFile: filepath.Join(linkDir, serviceDir, serviceName),
+ Log: slog.Default(),
+ TransformService: func(b []byte) []byte {
+ return []byte("[transform]" + string(b))
+ },
+ ValidateBinary: validator.IsExecutable,
+ Template: autoupdate.DefaultCDNURITemplate,
+ }
+ ctx := context.Background()
+ revert, err := installer.Link(ctx, NewRevision(version, 0), filepath.Join(linkDir, "bin"), tt.force)
+ if tt.errMatch != "" {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errMatch)
+
+ // verify automatic revert
+ for _, link := range tt.existingLinks {
+ v, err := os.Readlink(filepath.Join(linkDir, link))
+ require.NoError(t, err)
+ require.Equal(t, filepath.Base(link)+".old", v)
+ }
+ for _, n := range tt.existingFiles {
+ v, err := os.ReadFile(filepath.Join(linkDir, n))
+ require.NoError(t, err)
+ require.Equal(t, filepath.Base(n), string(v))
+ }
+
+ // ensure revert still succeeds
+ ok := revert(ctx)
+ require.True(t, ok)
+ return
+ }
+ require.NoError(t, err)
+
+ // verify links
+ for _, link := range tt.resultLinks {
+ v, err := os.ReadFile(filepath.Join(linkDir, link))
+ require.NoError(t, err)
+ require.Equal(t, filepath.Base(link), string(v))
+ }
+ for _, svc := range tt.resultServices {
+ v, err := os.ReadFile(filepath.Join(linkDir, svc))
+ require.NoError(t, err)
+ require.Equal(t, "[transform]"+filepath.Base(svc), string(v))
+ }
+
+ // verify manual revert
+ ok := revert(ctx)
+ require.True(t, ok)
+ for _, link := range tt.existingLinks {
+ v, err := os.Readlink(filepath.Join(linkDir, link))
+ require.NoError(t, err)
+ require.Equal(t, filepath.Base(link)+".old", v)
+ }
+ for _, n := range tt.existingFiles {
+ v, err := os.ReadFile(filepath.Join(linkDir, n))
+ require.NoError(t, err)
+ require.Equal(t, filepath.Base(n), string(v))
+ }
+ })
+ }
+}
+
+func TestLocalInstaller_TryLink(t *testing.T) {
+ t.Parallel()
+ const version = "new-version"
+ servicePath := filepath.Join(serviceDir, serviceName)
+
+ tests := []struct {
+ name string
+ installDirs []string
+ installFiles []string
+ installFileMode os.FileMode
+ existingLinks []string
+ existingFiles []string
+
+ resultLinks []string
+ resultServices []string
+ errMatch string
+ }{
+ {
+ name: "present with new links",
+ installDirs: []string{
+ "bin",
+ "bin/somedir",
+ "lib",
+ "lib/systemd",
+ "lib/systemd/system",
+ "somedir",
+ },
+ installFiles: []string{
+ "bin/teleport",
+ "bin/tsh",
+ "bin/tbot",
+ servicePath,
+ "README",
+ },
+ installFileMode: os.ModePerm,
+
+ resultLinks: []string{
+ "bin/teleport",
+ "bin/tsh",
+ "bin/tbot",
+ },
+ resultServices: []string{
+ "lib/systemd/system/teleport.service",
+ },
+ },
+ {
+ name: "present with non-executable files",
+ installDirs: []string{
+ "bin",
+ "bin/somedir",
+ "lib",
+ "lib/systemd",
+ "lib/systemd/system",
+ "somedir",
+ },
+ installFiles: []string{
+ "bin/teleport",
+ "bin/tsh",
+ "bin/tbot",
+ servicePath,
+ "README",
+ },
+ installFileMode: 0644,
+
+ errMatch: ErrNoBinaries.Error(),
+ },
+ {
+ name: "present with existing links",
+ installDirs: []string{
+ "bin",
+ "bin/somedir",
+ "lib",
+ "lib/systemd",
+ "lib/systemd/system",
+ "somedir",
+ },
+ installFiles: []string{
+ "bin/teleport",
+ "bin/tsh",
+ "bin/tbot",
+ servicePath,
+ "README",
+ },
+ installFileMode: os.ModePerm,
+ existingLinks: []string{
+ "bin/teleport",
+ "bin/tsh",
+ "bin/tbot",
+ },
+ existingFiles: []string{
+ "lib/systemd/system/teleport.service",
+ },
+
+ errMatch: "refusing",
+ },
+ {
+ name: "conflicting systemd files",
+ installDirs: []string{
+ "bin",
+ "bin/somedir",
+ "lib",
+ "lib/systemd",
+ "lib/systemd/system",
+ "somedir",
+ },
+ installFiles: []string{
+ "bin/teleport",
+ "bin/tsh",
+ "bin/tbot",
+ servicePath,
+ "README",
+ },
+ installFileMode: os.ModePerm,
+ existingLinks: []string{
+ "lib/systemd/system/teleport.service",
+ },
+
+ errMatch: "replace irregular file",
+ },
+ {
+ name: "conflicting bin files",
+ installDirs: []string{
+ "bin",
+ "bin/somedir",
+ "lib",
+ "lib/systemd",
+ "lib/systemd/system",
+ "somedir",
+ },
+ installFiles: []string{
+ "bin/teleport",
+ "bin/tsh",
+ "bin/tbot",
+ servicePath,
+ "README",
+ },
+ installFileMode: os.ModePerm,
+ existingFiles: []string{
+ "bin/tsh",
+ },
+
+ errMatch: ErrFilePresent.Error(),
+ },
+ {
+ name: "no links",
+ installFiles: []string{"README"},
+ installDirs: []string{"bin"},
+
+ errMatch: ErrNoBinaries.Error(),
+ },
+ {
+ name: "no bin directory",
+ installFiles: []string{"README"},
+
+ errMatch: ErrNoBinaries.Error(),
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ versionsDir := t.TempDir()
+ versionDir := filepath.Join(versionsDir, version)
+ err := os.MkdirAll(versionDir, 0o755)
+ require.NoError(t, err)
+
+ // setup files in version directory
+ for _, d := range tt.installDirs {
+ err := os.Mkdir(filepath.Join(versionDir, d), os.ModePerm)
+ require.NoError(t, err)
+ }
+ for _, n := range tt.installFiles {
+ err := os.WriteFile(filepath.Join(versionDir, n), []byte(filepath.Base(n)), tt.installFileMode)
+ require.NoError(t, err)
+ }
+
+ // setup files in system links directory
+ linkDir := t.TempDir()
+ for _, n := range tt.existingLinks {
+ err := os.MkdirAll(filepath.Dir(filepath.Join(linkDir, n)), os.ModePerm)
+ require.NoError(t, err)
+ err = os.Symlink(filepath.Base(n)+".old", filepath.Join(linkDir, n))
+ require.NoError(t, err)
+ }
+ for _, n := range tt.existingFiles {
+ err := os.MkdirAll(filepath.Dir(filepath.Join(linkDir, n)), os.ModePerm)
+ require.NoError(t, err)
+ err = os.WriteFile(filepath.Join(linkDir, n), []byte(filepath.Base(n)), os.ModePerm)
+ require.NoError(t, err)
+ }
+
+ validator := Validator{Log: slog.Default()}
+ installer := &LocalInstaller{
+ InstallDir: versionsDir,
+ TargetServiceFile: filepath.Join(linkDir, serviceDir, serviceName),
+ Log: slog.Default(),
+ TransformService: func(b []byte) []byte {
+ return []byte("[transform]" + string(b))
+ },
+ ValidateBinary: validator.IsExecutable,
+ }
+ ctx := context.Background()
+ err = installer.TryLink(ctx, NewRevision(version, 0), filepath.Join(linkDir, "bin"))
+ if tt.errMatch != "" {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errMatch)
+
+ // verify no changes
+ for _, link := range tt.existingLinks {
+ v, err := os.Readlink(filepath.Join(linkDir, link))
+ require.NoError(t, err)
+ require.Equal(t, filepath.Base(link)+".old", v)
+ }
+ for _, n := range tt.existingFiles {
+ v, err := os.ReadFile(filepath.Join(linkDir, n))
+ require.NoError(t, err)
+ require.Equal(t, filepath.Base(n), string(v))
+ }
+ return
+ }
+ require.NoError(t, err)
+
+ // verify links
+ for _, link := range tt.resultLinks {
+ v, err := os.ReadFile(filepath.Join(linkDir, link))
+ require.NoError(t, err)
+ require.Equal(t, filepath.Base(link), string(v))
+ }
+ for _, svc := range tt.resultServices {
+ v, err := os.ReadFile(filepath.Join(linkDir, svc))
+ require.NoError(t, err)
+ require.Equal(t, "[transform]"+filepath.Base(svc), string(v))
+ }
+
+ })
+ }
+}
+
+func TestLocalInstaller_Remove(t *testing.T) {
+ t.Parallel()
+ const version = "existing-version"
+
+ tests := []struct {
+ name string
+ dirs []string
+ files []string
+ createVersion string
+ removeVersion string
+
+ errMatch string
+ }{
+ {
+ name: "present",
+ dirs: []string{"bin/somedir", "somedir"},
+ files: []string{checksumType, "bin/teleport", "bin/tsh", "bin/tbot", "README"},
+ createVersion: version,
+ removeVersion: version,
+ },
+ {
+ name: "present missing checksum",
+ dirs: []string{"bin/somedir", "somedir"},
+ files: []string{"bin/teleport", "bin/tsh", "bin/tbot", "README"},
+ createVersion: version,
+ removeVersion: version,
+ },
+ {
+ name: "not present",
+ dirs: []string{"bin/somedir", "somedir"},
+ files: []string{checksumType, "bin/teleport", "bin/tsh", "bin/tbot", "README"},
+ createVersion: version,
+ removeVersion: "missing-version",
+ },
+ {
+ name: "version empty",
+ dirs: []string{"bin/somedir", "somedir"},
+ files: []string{checksumType, "bin/teleport", "bin/tsh", "bin/tbot", "README"},
+ createVersion: version,
+ removeVersion: "",
+
+ errMatch: "outside",
+ },
+ {
+ name: "version has path",
+ dirs: []string{"bin/somedir", "somedir"},
+ files: []string{checksumType, "bin/teleport", "bin/tsh", "bin/tbot", "README"},
+ createVersion: version,
+ removeVersion: "one/two",
+
+ errMatch: "outside",
+ },
+ {
+ name: "version is ..",
+ dirs: []string{"bin/somedir", "somedir"},
+ files: []string{checksumType, "bin/teleport", "bin/tsh", "bin/tbot", "README"},
+ createVersion: version,
+ removeVersion: "..",
+
+ errMatch: "outside",
+ },
+ {
+ name: "version is .",
+ dirs: []string{"bin/somedir", "somedir"},
+ files: []string{checksumType, "bin/teleport", "bin/tsh", "bin/tbot", "README"},
+ createVersion: version,
+ removeVersion: ".",
+
+ errMatch: "outside",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ versionsDir := t.TempDir()
+ versionDir := filepath.Join(versionsDir, tt.createVersion)
+ err := os.MkdirAll(versionDir, 0o755)
+ require.NoError(t, err)
+
+ for _, d := range tt.dirs {
+ err := os.MkdirAll(filepath.Join(versionDir, d), os.ModePerm)
+ require.NoError(t, err)
+ }
+ for _, n := range tt.files {
+ err := os.WriteFile(filepath.Join(versionDir, n), []byte(filepath.Base(n)), os.ModePerm)
+ require.NoError(t, err)
+ }
+
+ linkDir := t.TempDir()
+
+ validator := Validator{Log: slog.Default()}
+ installer := &LocalInstaller{
+ InstallDir: versionsDir,
+ TargetServiceFile: filepath.Join(linkDir, serviceDir, serviceName),
+ Log: slog.Default(),
+ TransformService: func(b []byte) []byte {
+ return []byte("[transform]" + string(b))
+ },
+ ValidateBinary: validator.IsExecutable,
+ }
+ ctx := context.Background()
+ err = installer.Remove(ctx, NewRevision(tt.removeVersion, 0))
+ if tt.errMatch != "" {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errMatch)
+ return
+ }
+ require.NoError(t, err)
+ _, err = os.Stat(filepath.Join(versionDir, "bin", tt.removeVersion))
+ require.ErrorIs(t, err, os.ErrNotExist)
+ })
+ }
+}
+
+func TestLocalInstaller_IsLinked(t *testing.T) {
+ t.Parallel()
+ const version = "existing-version"
+ servicePath := filepath.Join(serviceDir, serviceName)
+
+ tests := []struct {
+ name string
+ dirs []string
+ files []string
+ createVersion string
+ linkVersion string
+ checkVersion string
+
+ result bool
+ errMatch string
+ }{
+ {
+ name: "linked",
+ dirs: []string{"bin/somedir", "somedir", serviceDir},
+ files: []string{checksumType, "bin/teleport", "bin/tsh", "bin/tbot", "README", servicePath},
+ createVersion: version,
+ linkVersion: version,
+ checkVersion: version,
+ result: true,
+ },
+ {
+ name: "other linked",
+ dirs: []string{"bin/somedir", "somedir", serviceDir},
+ files: []string{checksumType, "bin/teleport", "bin/tsh", "bin/tbot", "README", servicePath},
+ createVersion: version,
+ linkVersion: version,
+ checkVersion: "other",
+ result: false,
+ },
+ {
+ name: "not linked",
+ checkVersion: version,
+ result: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ versionsDir := t.TempDir()
+ linkDir := t.TempDir()
+
+ validator := Validator{Log: slog.Default()}
+ installer := &LocalInstaller{
+ InstallDir: versionsDir,
+ TargetServiceFile: filepath.Join(linkDir, serviceDir, serviceName),
+ Log: slog.Default(),
+ TransformService: func(b []byte) []byte {
+ return []byte("[transform]" + string(b))
+ },
+ ValidateBinary: validator.IsExecutable,
+ }
+ ctx := context.Background()
+ if tt.createVersion != "" {
+ versionDir := filepath.Join(versionsDir, tt.createVersion)
+ err := os.MkdirAll(versionDir, 0o755)
+ require.NoError(t, err)
+
+ for _, d := range tt.dirs {
+ err := os.MkdirAll(filepath.Join(versionDir, d), os.ModePerm)
+ require.NoError(t, err)
+ }
+ for _, n := range tt.files {
+ err := os.WriteFile(filepath.Join(versionDir, n), []byte(filepath.Base(n)), os.ModePerm)
+ require.NoError(t, err)
+ }
+ }
+ if tt.linkVersion != "" {
+ _, err := installer.Link(ctx, NewRevision(tt.linkVersion, 0), linkDir, false)
+ require.NoError(t, err)
+ }
+ result, err := installer.IsLinked(ctx, NewRevision(tt.checkVersion, 0), linkDir)
+ if tt.errMatch != "" {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errMatch)
+ return
+ }
+ require.NoError(t, err)
+ require.Equal(t, tt.result, result)
+ })
+ }
+}
+
+func TestLocalInstaller_Unlink(t *testing.T) {
+ t.Parallel()
+ const version = "existing-version"
+ servicePath := filepath.Join(serviceDir, serviceName)
+
+ tests := []struct {
+ name string
+ bins []string
+ svcOrig []byte
+
+ links []symlink
+ svcCopy []byte
+
+ remaining []string
+ errMatch string
+ }{
+ {
+ name: "normal",
+ bins: []string{"teleport", "tsh"},
+ svcOrig: []byte("orig"),
+ links: []symlink{
+ {oldname: "bin/teleport", newname: "bin/teleport"},
+ {oldname: "bin/tsh", newname: "bin/tsh"},
+ },
+ svcCopy: []byte("[transform]orig"),
+ },
+ {
+ name: "different services",
+ bins: []string{"teleport", "tsh"},
+ svcOrig: []byte("orig"),
+ links: []symlink{
+ {oldname: "bin/teleport", newname: "bin/teleport"},
+ {oldname: "bin/tsh", newname: "bin/tsh"},
+ },
+ svcCopy: []byte("custom"),
+ remaining: []string{servicePath},
+ },
+ {
+ name: "missing target service",
+ bins: []string{"teleport", "tsh"},
+ svcOrig: []byte("orig"),
+ links: []symlink{
+ {oldname: "bin/teleport", newname: "bin/teleport"},
+ {oldname: "bin/tsh", newname: "bin/tsh"},
+ },
+ },
+ {
+ name: "missing source service",
+ bins: []string{"teleport", "tsh"},
+ links: []symlink{
+ {oldname: "bin/teleport", newname: "bin/teleport"},
+ {oldname: "bin/tsh", newname: "bin/tsh"},
+ },
+ svcCopy: []byte("custom"),
+ remaining: []string{servicePath},
+ errMatch: "no such",
+ },
+ {
+ name: "missing teleport link",
+ bins: []string{"teleport", "tsh"},
+ svcOrig: []byte("orig"),
+ links: []symlink{
+ {oldname: "bin/tsh", newname: "bin/tsh"},
+ },
+ svcCopy: []byte("[transform]orig"),
+ remaining: []string{servicePath},
+ },
+ {
+ name: "missing other link",
+ bins: []string{"teleport", "tsh"},
+ svcOrig: []byte("orig"),
+ links: []symlink{
+ {oldname: "bin/teleport", newname: "bin/teleport"},
+ },
+ svcCopy: []byte("[transform]orig"),
+ },
+ {
+ name: "wrong teleport link",
+ bins: []string{"teleport", "tsh"},
+ svcOrig: []byte("orig"),
+ links: []symlink{
+ {oldname: "other", newname: "bin/teleport"},
+ {oldname: "bin/tsh", newname: "bin/tsh"},
+ },
+ svcCopy: []byte("[transform]orig"),
+ remaining: []string{servicePath, "bin/teleport"},
+ },
+ {
+ name: "wrong other link",
+ bins: []string{"teleport", "tsh"},
+ svcOrig: []byte("orig"),
+ links: []symlink{
+ {oldname: "bin/teleport", newname: "bin/teleport"},
+ {oldname: "wrong", newname: "bin/tsh"},
+ },
+ svcCopy: []byte("[transform]orig"),
+ remaining: []string{"bin/tsh"},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ versionsDir := t.TempDir()
+ versionDir := filepath.Join(versionsDir, version)
+ err := os.MkdirAll(versionDir, 0o755)
+ require.NoError(t, err)
+ linkDir := t.TempDir()
+
+ var files []smallFile
+ for _, n := range tt.bins {
+ files = append(files, smallFile{
+ name: filepath.Join(versionDir, "bin", n),
+ data: []byte("binary"),
+ mode: os.ModePerm,
+ })
+ }
+ if tt.svcOrig != nil {
+ files = append(files, smallFile{
+ name: filepath.Join(versionDir, servicePath),
+ data: tt.svcOrig,
+ mode: os.ModePerm,
+ })
+ }
+ if tt.svcCopy != nil {
+ files = append(files, smallFile{
+ name: filepath.Join(linkDir, servicePath),
+ data: tt.svcCopy,
+ mode: os.ModePerm,
+ })
+ }
+
+ for _, n := range files {
+ err = os.MkdirAll(filepath.Dir(n.name), os.ModePerm)
+ require.NoError(t, err)
+ err = os.WriteFile(n.name, n.data, n.mode)
+ require.NoError(t, err)
+ }
+ for _, n := range tt.links {
+ newname := filepath.Join(linkDir, n.newname)
+ oldname := filepath.Join(versionDir, n.oldname)
+ err = os.MkdirAll(filepath.Dir(newname), os.ModePerm)
+ require.NoError(t, err)
+ err = os.Symlink(oldname, newname)
+ require.NoError(t, err)
+ }
+
+ installer := &LocalInstaller{
+ InstallDir: versionsDir,
+ TargetServiceFile: filepath.Join(linkDir, serviceDir, serviceName),
+ Log: slog.Default(),
+ TransformService: func(b []byte) []byte {
+ return []byte("[transform]" + string(b))
+ },
+ }
+ ctx := context.Background()
+ err = installer.Unlink(ctx, NewRevision(version, 0), filepath.Join(linkDir, "bin"))
+ if tt.errMatch != "" {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errMatch)
+ } else {
+ require.NoError(t, err)
+ }
+ for _, n := range tt.remaining {
+ _, err = os.Lstat(filepath.Join(linkDir, n))
+ require.NoError(t, err)
+ }
+ for _, n := range tt.links {
+ if slices.Contains(tt.remaining, n.newname) {
+ continue
+ }
+ _, err = os.Lstat(filepath.Join(linkDir, n.newname))
+ require.ErrorIs(t, err, os.ErrNotExist)
+ }
+ if !slices.Contains(tt.remaining, servicePath) {
+ _, err = os.Lstat(filepath.Join(linkDir, servicePath))
+ require.ErrorIs(t, err, os.ErrNotExist)
+ }
+ })
+ }
+}
+
+func TestLocalInstaller_List(t *testing.T) {
+ installDir := t.TempDir()
+ versions := []string{"v1", "v2"}
+
+ for _, d := range versions {
+ err := os.Mkdir(filepath.Join(installDir, d), os.ModePerm)
+ require.NoError(t, err)
+ }
+ for _, n := range []string{"file1", "file2"} {
+ err := os.WriteFile(filepath.Join(installDir, n), []byte(filepath.Base(n)), os.ModePerm)
+ require.NoError(t, err)
+ }
+ installer := &LocalInstaller{
+ InstallDir: installDir,
+ Log: slog.Default(),
+ }
+ ctx := context.Background()
+ revisions, err := installer.List(ctx)
+ require.NoError(t, err)
+ require.Equal(t, []Revision{
+ NewRevision("v1", 0),
+ NewRevision("v2", 0),
+ }, revisions)
+}
diff --git a/lib/autoupdate/agent/logger.go b/lib/autoupdate/agent/logger.go
new file mode 100644
index 0000000000000..bd50ee50859a4
--- /dev/null
+++ b/lib/autoupdate/agent/logger.go
@@ -0,0 +1,108 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package agent
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "log/slog"
+
+ "github.com/gravitational/trace"
+)
+
+// progressLogger logs progress of any data written as it approaches max.
+// max(lines, call_count(Write)) lines are written for each multiple of max.
+// progressLogger uses the variability of chunk size as a proxy for speed, and avoids
+// logging extraneous lines that do not improve UX for waiting humans.
+type progressLogger struct {
+ ctx context.Context
+ log *slog.Logger
+ level slog.Level
+ name string
+ max int
+ lines int
+
+ l int
+ n int
+}
+
+func (w *progressLogger) Write(p []byte) (n int, err error) {
+ w.n += len(p)
+ if w.n >= w.max*(w.l+1)/w.lines {
+ w.log.Log(w.ctx, w.level, "Downloading",
+ "file", w.name,
+ "progress", fmt.Sprintf("%d%%", w.n*100/w.max),
+ )
+ w.l++
+ }
+ return len(p), nil
+}
+
+// lineLogger logs each line written to it.
+type lineLogger struct {
+ ctx context.Context
+ log *slog.Logger
+ level slog.Level
+ prefix string
+
+ last bytes.Buffer
+}
+
+func (w *lineLogger) out(s string) {
+ w.log.Log(w.ctx, w.level, w.prefix+s) //nolint:sloglint // msg cannot be constant
+}
+
+func (w *lineLogger) Write(p []byte) (n int, err error) {
+ lines := bytes.Split(p, []byte("\n"))
+ // Finish writing line
+ if len(lines) > 0 {
+ n, err = w.last.Write(lines[0])
+ lines = lines[1:]
+ }
+ // Quit if no newline
+ if len(lines) == 0 || err != nil {
+ return n, trace.Wrap(err)
+ }
+
+ // Newline found, log line
+ w.out(w.last.String())
+ n += 1
+ w.last.Reset()
+
+ // Log lines that are already newline-terminated
+ for _, line := range lines[:len(lines)-1] {
+ w.out(string(line))
+ n += len(line) + 1
+ }
+
+ // Store remaining line non-newline-terminated line.
+ n2, err := w.last.Write(lines[len(lines)-1])
+ n += n2
+ return n, trace.Wrap(err)
+}
+
+// Flush logs any trailing bytes that were never terminated with a newline.
+func (w *lineLogger) Flush() {
+ if w.last.Len() == 0 {
+ return
+ }
+ w.out(w.last.String())
+ w.last.Reset()
+}
diff --git a/lib/autoupdate/agent/logger_test.go b/lib/autoupdate/agent/logger_test.go
new file mode 100644
index 0000000000000..2a8430ef8cf44
--- /dev/null
+++ b/lib/autoupdate/agent/logger_test.go
@@ -0,0 +1,186 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package agent
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "log/slog"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestLineLogger(t *testing.T) {
+ t.Parallel()
+
+ out := &bytes.Buffer{}
+ ll := lineLogger{
+ ctx: context.Background(),
+ log: slog.New(slog.NewTextHandler(out,
+ &slog.HandlerOptions{ReplaceAttr: msgOnly},
+ )),
+ }
+
+ for _, e := range []struct {
+ v string
+ n int
+ }{
+ {v: "", n: 0},
+ {v: "a", n: 1},
+ {v: "b\n", n: 2},
+ {v: "c\nd", n: 3},
+ {v: "e\nf\ng", n: 5},
+ {v: "h", n: 1},
+ {v: "", n: 0},
+ {v: "\n", n: 1},
+ {v: "i\n", n: 2},
+ {v: "j", n: 1},
+ } {
+ n, err := ll.Write([]byte(e.v))
+ require.NoError(t, err)
+ require.Equal(t, e.n, n)
+ }
+ require.Equal(t, "msg=ab\nmsg=c\nmsg=de\nmsg=f\nmsg=gh\nmsg=i\n", out.String())
+ ll.Flush()
+ require.Equal(t, "msg=ab\nmsg=c\nmsg=de\nmsg=f\nmsg=gh\nmsg=i\nmsg=j\n", out.String())
+}
+
+func msgOnly(_ []string, a slog.Attr) slog.Attr {
+ switch a.Key {
+ case "time", "level":
+ return slog.Attr{}
+ }
+ return slog.Attr{Key: a.Key, Value: a.Value}
+}
+
+func TestProgressLogger(t *testing.T) {
+ t.Parallel()
+
+ type write struct {
+ n int
+ out string
+ }
+ for _, tt := range []struct {
+ name string
+ max, lines int
+ writes []write
+ }{
+ {
+ name: "even",
+ max: 100,
+ lines: 5,
+ writes: []write{
+ {n: 10},
+ {n: 10, out: "20%"},
+ {n: 10},
+ {n: 10, out: "40%"},
+ {n: 10},
+ {n: 10, out: "60%"},
+ {n: 10},
+ {n: 10, out: "80%"},
+ {n: 10},
+ {n: 10, out: "100%"},
+ {n: 10},
+ {n: 10, out: "120%"},
+ },
+ },
+ {
+ name: "fast",
+ max: 100,
+ lines: 5,
+ writes: []write{
+ {n: 100, out: "100%"},
+ {n: 100, out: "200%"},
+ },
+ },
+ {
+ name: "over fast",
+ max: 100,
+ lines: 5,
+ writes: []write{
+ {n: 200, out: "200%"},
+ },
+ },
+ {
+ name: "slow down when uneven",
+ max: 100,
+ lines: 5,
+ writes: []write{
+ {n: 50, out: "50%"},
+ {n: 10, out: "60%"},
+ {n: 10, out: "70%"},
+ {n: 10, out: "80%"},
+ {n: 10},
+ {n: 10, out: "100%"},
+ {n: 10},
+ {n: 10, out: "120%"},
+ },
+ },
+ {
+ name: "slow down when very uneven",
+ max: 100,
+ lines: 5,
+ writes: []write{
+ {n: 50, out: "50%"},
+ {n: 1, out: "51%"},
+ {n: 1},
+ {n: 20, out: "72%"},
+ {n: 10, out: "82%"},
+ {n: 10},
+ {n: 10, out: "102%"},
+ },
+ },
+ {
+ name: "close",
+ max: 1000,
+ lines: 5,
+ writes: []write{
+ {n: 999, out: "99%"},
+ {n: 1, out: "100%"},
+ },
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ out := &bytes.Buffer{}
+ ll := progressLogger{
+ ctx: context.Background(),
+ log: slog.New(slog.NewTextHandler(out,
+ &slog.HandlerOptions{ReplaceAttr: msgOnly},
+ )),
+ name: "test",
+ max: tt.max,
+ lines: tt.lines,
+ }
+ for _, e := range tt.writes {
+ n, err := ll.Write(make([]byte, e.n))
+ require.NoError(t, err)
+ require.Equal(t, e.n, n)
+ v, err := io.ReadAll(out)
+ require.NoError(t, err)
+ if len(v) > 0 {
+ e.out = fmt.Sprintf(`msg=Downloading file=test progress=%s`+"\n", e.out)
+ }
+ require.Equal(t, e.out, string(v))
+ }
+ })
+ }
+}
diff --git a/lib/autoupdate/agent/process.go b/lib/autoupdate/agent/process.go
new file mode 100644
index 0000000000000..a0b937c6b0b6e
--- /dev/null
+++ b/lib/autoupdate/agent/process.go
@@ -0,0 +1,516 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package agent
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "log/slog"
+ "net"
+ "os"
+ "os/exec"
+ "strconv"
+ "syscall"
+ "time"
+
+ "github.com/gravitational/trace"
+ "golang.org/x/sync/errgroup"
+
+ "github.com/gravitational/teleport/lib/client/debug"
+)
+
+// process monitoring consts
+const (
+ // monitorTimeout is the timeout for determining whether the process has started.
+ monitorTimeout = 1 * time.Minute
+ // monitorInterval is the polling interval for determining whether the process has started.
+ monitorInterval = 2 * time.Second
+ // minRunningIntervalsBeforeStable is the number of consecutive intervals with the same running PID detected
+ // before the service is determined stable.
+ minRunningIntervalsBeforeStable = 6
+ // maxCrashesBeforeFailure is the number of total crashes detected before the service is marked as crash-looping.
+ maxCrashesBeforeFailure = 2
+)
+
+// log keys
+const (
+ unitKey = "unit"
+)
+
+// SystemdService manages a systemd service (e.g., teleport or teleport-update).
+type SystemdService struct {
+ // ServiceName specifies the systemd service name.
+ ServiceName string
+ // PIDFile is a path to a file containing the service's PID.
+ PIDFile string
+ // Ready is a readiness checker.
+ Ready ReadyChecker
+ // Log contains a logger.
+ Log *slog.Logger
+}
+
+// ReadyChecker returns the systemd service readiness status.
+type ReadyChecker interface {
+ GetReadiness(ctx context.Context) (debug.Readiness, error)
+}
+
+// Reload the systemd service.
+// Attempts a graceful reload before a hard restart.
+// See Process interface for more details.
+func (s SystemdService) Reload(ctx context.Context) error {
+ // TODO(sclevine): allow server to force restart instead of reload
+
+ if err := s.checkSystem(ctx); err != nil {
+ return trace.Wrap(err)
+ }
+
+ // Command error codes < 0 indicate that we are unable to run the command.
+ // Errors from s.systemctl are logged along with stderr and stdout (debug only).
+
+ // If the service is not running, return ErrNotNeeded.
+ // Note systemctl reload returns an error if the unit is not active, and
+ // try-reload-or-restart is too recent of an addition for centos7.
+ code := s.systemctl(ctx, slog.LevelDebug, "is-active", "--quiet", s.ServiceName)
+ switch {
+ case code < 0:
+ return trace.Errorf("unable to determine if systemd service is active")
+ case code > 0:
+ s.Log.WarnContext(ctx, "Systemd service not running.", unitKey, s.ServiceName)
+ return trace.Wrap(ErrNotNeeded)
+ }
+
+ // Get initial PID for crash monitoring.
+
+ initPID, err := readInt(s.PIDFile)
+ if errors.Is(err, os.ErrNotExist) {
+ s.Log.InfoContext(ctx, "No existing process detected. Skipping crash monitoring.", unitKey, s.ServiceName)
+ } else if err != nil {
+ s.Log.ErrorContext(ctx, "Error reading initial PID value. Skipping crash monitoring.", unitKey, s.ServiceName, errorKey, err)
+ }
+
+ // Attempt graceful reload of running service.
+ code = s.systemctl(ctx, slog.LevelError, "reload", s.ServiceName)
+ switch {
+ case code < 0:
+ return trace.Errorf("unable to reload systemd service")
+ case code > 0:
+ // Graceful reload fails, try hard restart.
+ code = s.systemctl(ctx, slog.LevelError, "try-restart", s.ServiceName)
+ if code != 0 {
+ return trace.Errorf("hard restart of systemd service failed")
+ }
+ s.Log.WarnContext(ctx, "Service ungracefully restarted. Connections potentially dropped.", unitKey, s.ServiceName)
+ default:
+ s.Log.InfoContext(ctx, "Gracefully reloaded.", unitKey, s.ServiceName)
+ }
+ // monitor logs all relevant errors, so we filter for a few outcomes
+ err = s.monitor(ctx, initPID)
+ if errors.Is(err, context.DeadlineExceeded) ||
+ errors.Is(err, context.Canceled) {
+ return trace.Wrap(err)
+ }
+ if err != nil {
+ return trace.Errorf("failed to monitor process")
+ }
+ return nil
+}
+
+// monitor for a started, healthy process.
+// monitor logs all errors that should be displayed to the user.
+func (s SystemdService) monitor(ctx context.Context, initPID int) error {
+ ctx, cancel := context.WithTimeout(ctx, monitorTimeout)
+ defer cancel()
+ ticker := time.NewTicker(monitorInterval)
+ defer ticker.Stop()
+
+ newPID := 0
+ if initPID != 0 {
+ s.Log.InfoContext(ctx, "Monitoring PID file to detect crashes.", unitKey, s.ServiceName)
+ var err error
+ newPID, err = s.monitorPID(ctx, initPID, ticker.C)
+ if errors.Is(err, context.DeadlineExceeded) {
+ s.Log.ErrorContext(ctx, "Timed out monitoring for crashing PID.", unitKey, s.ServiceName)
+ return trace.Wrap(err)
+ }
+ if err != nil {
+ s.Log.ErrorContext(ctx, "Error monitoring for crashing PID.", errorKey, err, unitKey, s.ServiceName)
+ return trace.Wrap(err)
+ }
+ }
+
+ s.Log.InfoContext(ctx, "Monitoring diagnostic socket to detect readiness.", unitKey, s.ServiceName)
+ ticker = time.NewTicker(monitorInterval)
+ defer ticker.Stop()
+ err := s.waitForReady(ctx, newPID, ticker.C)
+ if errors.Is(err, context.DeadlineExceeded) {
+ s.Log.ErrorContext(ctx, "Timed out monitoring for process readiness.", unitKey, s.ServiceName)
+ return trace.Wrap(err)
+ }
+ if err != nil {
+ s.Log.ErrorContext(ctx, "Error monitoring for process readiness.", errorKey, err, unitKey, s.ServiceName)
+ return trace.Wrap(err)
+ }
+ return nil
+}
+
+// monitorPID for the started process to ensure it's running by polling PIDFile.
+// This function detects several types of crashes while minimizing its own runtime during updates.
+// For example, the process may crash by failing to fork (non-running PID), or looping (repeatedly changing PID),
+// or getting stuck on quit (no change in PID).
+// initPID is the PID before the restart operation has been issued.
+// The final PID is returned.
+func (s SystemdService) monitorPID(ctx context.Context, initPID int, tickC <-chan time.Time) (int, error) {
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ pidC := make(chan int)
+ var g errgroup.Group
+ g.Go(func() error {
+ return tickFile(ctx, s.PIDFile, pidC, tickC)
+ })
+ stablePID, err := s.waitForStablePID(ctx, minRunningIntervalsBeforeStable, maxCrashesBeforeFailure,
+ initPID, pidC, func(pid int) error {
+ p, err := os.FindProcess(pid)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ return trace.Wrap(p.Signal(syscall.Signal(0)))
+ })
+ cancel()
+ if err := g.Wait(); err != nil {
+ s.Log.ErrorContext(ctx, "Error monitoring for crashing process.", errorKey, err, unitKey, s.ServiceName)
+ }
+ return stablePID, trace.Wrap(err)
+}
+
+// waitForStablePID monitors a service's PID via pidC and determines whether the service is crashing.
+// verifyPID must be passed so that waitForStablePID can determine whether the process is running.
+// verifyPID must return os.ErrProcessDone in the case that the PID cannot be found, or nil otherwise.
+// baselinePID is the initial PID before any operation that might cause the process to start crashing.
+// minStable is the number of times pidC must return the same running PID before waitForStablePID returns nil.
+// minCrashes is the number of times pidC conveys a process crash or bad state before waitForStablePID returns an error.
+// The last reported PID is returned.
+func (s SystemdService) waitForStablePID(ctx context.Context, minStable, maxCrashes, baselinePID int, pidC <-chan int, verifyPID func(pid int) error) (int, error) {
+ pid := baselinePID
+ var last, stale int
+ var crashes int
+ for stable := 0; stable < minStable; stable++ {
+ select {
+ case <-ctx.Done():
+ return pid, ctx.Err()
+ case p := <-pidC:
+ last = pid
+ pid = p
+ }
+ // A "crash" is defined as a transition away from a new (non-baseline) PID, or
+ // an interval where the current PID remains non-running (stale) since the last check.
+ if (last != 0 && pid != last && last != baselinePID) ||
+ (stale != 0 && pid == stale && last == stale) {
+ crashes++
+ }
+ if crashes > maxCrashes {
+ return pid, trace.Errorf("detected crashing process")
+ }
+
+ // PID can only be stable if it is a real PID that is not new,
+ // has changed at least once, and hasn't been observed as missing.
+ if pid == 0 ||
+ pid == baselinePID ||
+ pid == stale ||
+ pid != last {
+ stable = -1
+ continue
+ }
+ err := verifyPID(pid)
+ // A stale PID most likely indicates that the process forked and crashed without systemd noticing.
+ // There is a small chance that we read the PID file before systemd removed it.
+ // Note: we only perform this check on PIDs that survive one iteration.
+ if errors.Is(err, os.ErrProcessDone) ||
+ errors.Is(err, syscall.ESRCH) {
+ if pid != stale &&
+ pid != baselinePID {
+ stale = pid
+ s.Log.WarnContext(ctx, "Detected stale PID.", unitKey, s.ServiceName, "pid", stale)
+ }
+ stable = -1
+ continue
+ }
+ if err != nil {
+ return pid, trace.Wrap(err)
+ }
+ }
+ return pid, nil
+}
+
+// readInt reads an integer from a file.
+func readInt(path string) (int, error) {
+ p, err := readFileAtMost(path, 32)
+ if err != nil {
+ return 0, trace.Wrap(err)
+ }
+ i, err := strconv.ParseInt(string(bytes.TrimSpace(p)), 10, 64)
+ if err != nil {
+ return 0, trace.Wrap(err)
+ }
+ return int(i), nil
+}
+
+// tickFile reads the current time on tickC, and outputs the last read int from path on ch for each received tick.
+// If the path cannot be read, tickFile sends 0 on ch.
+// Any error from the last attempt to read path is returned when ctx is canceled, unless the error is os.ErrNotExist.
+func tickFile(ctx context.Context, path string, ch chan<- int, tickC <-chan time.Time) error {
+ var err error
+ for {
+ // two select statements -> never skip reads
+ select {
+ case <-tickC:
+ case <-ctx.Done():
+ return err
+ }
+ var t int
+ t, err = readInt(path)
+ if errors.Is(err, os.ErrNotExist) {
+ err = nil
+ }
+ select {
+ case ch <- t:
+ case <-ctx.Done():
+ return err
+ }
+ }
+}
+
+// waitForReady polls the SocketPath unix domain socket with HTTP requests.
+// If one request returns 200 before the timeout, the service is considered ready.
+func (s SystemdService) waitForReady(ctx context.Context, pid int, tickC <-chan time.Time) error {
+ var lastErr error
+ var readiness debug.Readiness
+ for {
+ resp, err := s.Ready.GetReadiness(ctx)
+ if err == nil &&
+ resp.Ready &&
+ equalOrZero(resp.PID, pid) {
+ return nil
+ }
+ // If the Readiness check fails to due to intervention, we must not interpret
+ // the error as a disabled socket, which results in a passing check.
+ if !errors.Is(err, context.Canceled) &&
+ !errors.Is(err, context.DeadlineExceeded) {
+ lastErr = err
+ readiness = resp
+ }
+ select {
+ case <-ctx.Done():
+ if errors.Is(lastErr, os.ErrNotExist) ||
+ errors.Is(lastErr, syscall.EINVAL) ||
+ errors.Is(lastErr, os.ErrInvalid) ||
+ errors.As(lastErr, new(net.Error)) {
+ s.Log.WarnContext(ctx, "Socket appears to be disabled. Proceeding without check.", unitKey, s.ServiceName)
+ s.Log.DebugContext(ctx, "Found error after timeout polling socket.", unitKey, s.ServiceName, errorKey, lastErr)
+ return nil
+ }
+ if lastErr != nil {
+ s.Log.WarnContext(ctx, "Unexpected error after timeout polling socket. Proceeding without check.", unitKey, s.ServiceName, errorKey, lastErr)
+ return nil
+ }
+ if readiness.Status != "" {
+ s.Log.ErrorContext(ctx, "Process not ready by deadline.", unitKey, s.ServiceName, "status", readiness.Status)
+ }
+ if !equalOrZero(readiness.PID, pid) {
+ s.Log.ErrorContext(ctx, "Readiness PID response does not match PID file.", unitKey, s.ServiceName, "file_pid", pid, "ready_pid", readiness.PID)
+ } else {
+ s.Log.DebugContext(ctx, "PIDs are not mismatched.", unitKey, s.ServiceName, "file_pid", pid, "ready_pid", readiness.PID)
+ }
+ return ctx.Err()
+ case <-tickC:
+ }
+ }
+}
+
+// equalOrZero returns true if a and b are equal, or if either has the zero-value.
+func equalOrZero[T comparable](a, b T) bool {
+ var empty T
+ if a == empty || b == empty {
+ return true
+ }
+ return a == b
+}
+
+// Sync systemd service configuration by running systemctl daemon-reload.
+// See Process interface for more details.
+func (s SystemdService) Sync(ctx context.Context) error {
+ if err := s.checkSystem(ctx); err != nil {
+ return trace.Wrap(err)
+ }
+ code := s.systemctl(ctx, slog.LevelError, "daemon-reload")
+ if code != 0 {
+ return trace.Errorf("unable to reload systemd configuration")
+ }
+ s.Log.InfoContext(ctx, "Systemd configuration synced.", unitKey, s.ServiceName)
+ return nil
+}
+
+// Enable the systemd service.
+func (s SystemdService) Enable(ctx context.Context, now bool) error {
+ if err := s.checkSystem(ctx); err != nil {
+ return trace.Wrap(err)
+ }
+ args := []string{"enable", s.ServiceName}
+ if now {
+ args = append(args, "--now")
+ }
+ code := s.systemctl(ctx, slog.LevelInfo, args...)
+ if code != 0 {
+ return trace.Errorf("unable to enable systemd service")
+ }
+ s.Log.InfoContext(ctx, "Service enabled.", unitKey, s.ServiceName)
+ return nil
+}
+
+// Disable the systemd service.
+func (s SystemdService) Disable(ctx context.Context, now bool) error {
+ if err := s.checkSystem(ctx); err != nil {
+ return trace.Wrap(err)
+ }
+ args := []string{"disable", s.ServiceName}
+ if now {
+ args = append(args, "--now")
+ }
+ code := s.systemctl(ctx, slog.LevelInfo, args...)
+ if code != 0 {
+ return trace.Errorf("unable to disable systemd service")
+ }
+ s.Log.InfoContext(ctx, "Systemd service disabled.", unitKey, s.ServiceName)
+ return nil
+}
+
+// IsEnabled returns true if the service is enabled.
+func (s SystemdService) IsEnabled(ctx context.Context) (bool, error) {
+ if err := s.checkSystem(ctx); err != nil {
+ return false, trace.Wrap(err)
+ }
+ code := s.systemctl(ctx, slog.LevelDebug, "is-enabled", "--quiet", s.ServiceName)
+ switch {
+ case code < 0:
+ return false, trace.Errorf("unable to determine if systemd service %s is enabled", s.ServiceName)
+ case code == 0:
+ return true, nil
+ }
+ return false, nil
+}
+
+// IsActive returns true if the service is active.
+func (s SystemdService) IsActive(ctx context.Context) (bool, error) {
+ if err := s.checkSystem(ctx); err != nil {
+ return false, trace.Wrap(err)
+ }
+ code := s.systemctl(ctx, slog.LevelDebug, "is-active", "--quiet", s.ServiceName)
+ switch {
+ case code < 0:
+ return false, trace.Errorf("unable to determine if systemd service %s is active", s.ServiceName)
+ case code == 0:
+ return true, nil
+ }
+ return false, nil
+}
+
+// IsPresent returns true if the service exists.
+func (s SystemdService) IsPresent(ctx context.Context) (bool, error) {
+ if err := s.checkSystem(ctx); err != nil {
+ return false, trace.Wrap(err)
+ }
+ code := s.systemctl(ctx, slog.LevelDebug, "list-unit-files", "--quiet", s.ServiceName)
+ if code < 0 {
+ return false, trace.Errorf("unable to determine if systemd service %s is present", s.ServiceName)
+ }
+ return code == 0, nil
+}
+
+// checkSystem returns an error if the system is not compatible with this process manager.
+func (s SystemdService) checkSystem(ctx context.Context) error {
+ present, err := hasSystemD()
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ if !present {
+ return trace.Wrap(ErrNotSupported)
+ }
+ return nil
+}
+
+// hasSystemD returns true if the system uses the SystemD process manager.
+func hasSystemD() (bool, error) {
+ _, err := os.Stat("/run/systemd/system")
+ if errors.Is(err, os.ErrNotExist) {
+ return false, nil
+ }
+ if err != nil {
+ return false, trace.Wrap(err)
+ }
+ return true, nil
+}
+
+// systemctl returns a systemctl subcommand, converting the output to logs.
+// Output sent to stdout is logged at debug level.
+// Output sent to stderr is logged at the level specified by errLevel.
+func (s SystemdService) systemctl(ctx context.Context, errLevel slog.Level, args ...string) int {
+ cmd := &localExec{
+ Log: s.Log,
+ ErrLevel: errLevel,
+ OutLevel: slog.LevelDebug,
+ }
+ code, err := cmd.Run(ctx, "systemctl", args...)
+ if err == nil {
+ return code
+ }
+ if code >= 0 {
+ s.Log.Log(ctx, errLevel, "Non-zero exit code or error running systemctl.",
+ "args", args, "code", code)
+ return code
+ }
+ s.Log.Log(ctx, errLevel, "Unable to run systemctl.",
+ "args", args, "code", code, errorKey, err)
+ return code
+}
+
+// localExec runs a command locally, logging any output.
+type localExec struct {
+ // Log contains a slog logger.
+ // Defaults to slog.Default() if nil.
+ Log *slog.Logger
+ // ErrLevel is the log level for stderr.
+ ErrLevel slog.Level
+ // OutLevel is the log level for stdout.
+ OutLevel slog.Level
+}
+
+// Run the command. Same arguments as exec.CommandContext.
+// Outputs the status code, or -1 if out-of-range or unstarted.
+func (c *localExec) Run(ctx context.Context, name string, args ...string) (int, error) {
+ cmd := exec.CommandContext(ctx, name, args...)
+ stderr := &lineLogger{ctx: ctx, log: c.Log, level: c.ErrLevel, prefix: "[stderr] "}
+ stdout := &lineLogger{ctx: ctx, log: c.Log, level: c.OutLevel, prefix: "[stdout] "}
+ cmd.Stderr = stderr
+ cmd.Stdout = stdout
+ err := cmd.Run()
+ stderr.Flush()
+ stdout.Flush()
+ code := cmd.ProcessState.ExitCode()
+ return code, trace.Wrap(err)
+}
diff --git a/lib/autoupdate/agent/process_test.go b/lib/autoupdate/agent/process_test.go
new file mode 100644
index 0000000000000..c6db70944be5d
--- /dev/null
+++ b/lib/autoupdate/agent/process_test.go
@@ -0,0 +1,313 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package agent
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "log/slog"
+ "os"
+ "path/filepath"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestWaitForStablePID(t *testing.T) {
+ t.Parallel()
+
+ svc := &SystemdService{
+ Log: slog.Default(),
+ }
+
+ for _, tt := range []struct {
+ name string
+ ticks []int
+ baseline int
+ minStable int
+ maxCrashes int
+ findErrs map[int]error
+
+ finalPID int
+ errored bool
+ canceled bool
+ }{
+ {
+ name: "immediate restart",
+ ticks: []int{2, 2},
+ baseline: 1,
+ minStable: 1,
+ maxCrashes: 1,
+ finalPID: 2,
+ },
+ {
+ name: "zero stable",
+ },
+ {
+ name: "immediate crash",
+ ticks: []int{2, 3},
+ baseline: 1,
+ minStable: 1,
+ maxCrashes: 0,
+ errored: true,
+ finalPID: 3,
+ },
+ {
+ name: "no changes times out",
+ ticks: []int{1, 1, 1, 1},
+ baseline: 1,
+ minStable: 3,
+ maxCrashes: 2,
+ canceled: true,
+ finalPID: 1,
+ },
+ {
+ name: "baseline restart",
+ ticks: []int{2, 2, 2, 2},
+ baseline: 1,
+ minStable: 3,
+ maxCrashes: 2,
+ finalPID: 2,
+ },
+ {
+ name: "one restart then stable",
+ ticks: []int{1, 2, 2, 2, 2},
+ baseline: 1,
+ minStable: 3,
+ maxCrashes: 2,
+ finalPID: 2,
+ },
+ {
+ name: "two restarts then stable",
+ ticks: []int{1, 2, 3, 3, 3, 3},
+ baseline: 1,
+ minStable: 3,
+ maxCrashes: 2,
+ finalPID: 3,
+ },
+ {
+ name: "three restarts then stable",
+ ticks: []int{1, 2, 3, 4, 4, 4, 4},
+ baseline: 1,
+ minStable: 3,
+ maxCrashes: 2,
+ finalPID: 4,
+ },
+ {
+ name: "too many restarts excluding baseline",
+ ticks: []int{1, 2, 3, 4, 5},
+ baseline: 1,
+ minStable: 3,
+ maxCrashes: 2,
+ errored: true,
+ finalPID: 5,
+ },
+ {
+ name: "too many restarts including baseline",
+ ticks: []int{1, 2, 3, 4},
+ baseline: 0,
+ minStable: 3,
+ maxCrashes: 2,
+ errored: true,
+ finalPID: 4,
+ },
+ {
+ name: "too many restarts slow",
+ ticks: []int{1, 1, 1, 2, 2, 2, 3, 3, 3, 4},
+ baseline: 0,
+ minStable: 3,
+ maxCrashes: 2,
+ errored: true,
+ finalPID: 4,
+ },
+ {
+ name: "too many restarts after stable",
+ ticks: []int{1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4},
+ baseline: 0,
+ minStable: 3,
+ maxCrashes: 2,
+ finalPID: 3,
+ },
+ {
+ name: "stable after too many restarts",
+ ticks: []int{1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4},
+ baseline: 0,
+ minStable: 3,
+ maxCrashes: 2,
+ errored: true,
+ finalPID: 4,
+ },
+ {
+ name: "cancel",
+ ticks: []int{1, 1, 1},
+ baseline: 0,
+ minStable: 3,
+ maxCrashes: 2,
+ canceled: true,
+ finalPID: 1,
+ },
+ {
+ name: "stale PID crash",
+ ticks: []int{2, 2, 2, 2, 2},
+ baseline: 1,
+ minStable: 3,
+ maxCrashes: 2,
+ findErrs: map[int]error{
+ 2: os.ErrProcessDone,
+ },
+ errored: true,
+ finalPID: 2,
+ },
+ {
+ name: "stale PID but fixed",
+ ticks: []int{2, 2, 3, 3, 3, 3},
+ baseline: 1,
+ minStable: 3,
+ maxCrashes: 2,
+ findErrs: map[int]error{
+ 2: os.ErrProcessDone,
+ },
+ finalPID: 3,
+ },
+ {
+ name: "error PID",
+ ticks: []int{2, 2, 3, 3, 3, 3},
+ baseline: 1,
+ minStable: 3,
+ maxCrashes: 2,
+ findErrs: map[int]error{
+ 2: errors.New("bad"),
+ },
+ errored: true,
+ finalPID: 2,
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ ctx := context.Background()
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ ch := make(chan int)
+ go func() {
+ defer cancel() // always quit after last tick
+ for _, tick := range tt.ticks {
+ ch <- tick
+ }
+ }()
+ pid, err := svc.waitForStablePID(ctx, tt.minStable, tt.maxCrashes,
+ tt.baseline, ch, func(pid int) error {
+ return tt.findErrs[pid]
+ })
+ require.Equal(t, tt.finalPID, pid)
+ require.Equal(t, tt.canceled, errors.Is(err, context.Canceled))
+ if !tt.canceled {
+ require.Equal(t, tt.errored, err != nil)
+ }
+ })
+ }
+}
+
+func TestTickFile(t *testing.T) {
+ t.Parallel()
+
+ for _, tt := range []struct {
+ name string
+ ticks []int
+ errored bool
+ }{
+ {
+ name: "consistent",
+ ticks: []int{1, 1, 1},
+ errored: false,
+ },
+ {
+ name: "divergent",
+ ticks: []int{1, 2, 3},
+ errored: false,
+ },
+ {
+ name: "start error",
+ ticks: []int{-1, 1, 1},
+ errored: false,
+ },
+ {
+ name: "ephemeral error",
+ ticks: []int{1, -1, 1},
+ errored: false,
+ },
+ {
+ name: "end error",
+ ticks: []int{1, 1, -1},
+ errored: true,
+ },
+ {
+ name: "start missing",
+ ticks: []int{0, 1, 1},
+ errored: false,
+ },
+ {
+ name: "ephemeral missing",
+ ticks: []int{1, 0, 1},
+ errored: false,
+ },
+ {
+ name: "end missing",
+ ticks: []int{1, 1, 0},
+ errored: false,
+ },
+ {
+ name: "cancel-only",
+ errored: false,
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ filePath := filepath.Join(t.TempDir(), "file")
+
+ ctx := context.Background()
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ tickC := make(chan time.Time)
+ ch := make(chan int)
+
+ go func() {
+ defer cancel() // always quit after last tick or fail
+ for _, tick := range tt.ticks {
+ _ = os.RemoveAll(filePath)
+ switch {
+ case tick > 0:
+ err := os.WriteFile(filePath, []byte(fmt.Sprintln(tick)), os.ModePerm)
+ require.NoError(t, err)
+ case tick < 0:
+ err := os.Mkdir(filePath, os.ModePerm)
+ require.NoError(t, err)
+ }
+ tickC <- time.Now()
+ res := <-ch
+ if tick < 0 {
+ tick = 0
+ }
+ require.Equal(t, tick, res)
+ }
+ }()
+ err := tickFile(ctx, filePath, ch, tickC)
+ require.Equal(t, tt.errored, err != nil)
+ })
+ }
+}
diff --git a/lib/autoupdate/agent/setup.go b/lib/autoupdate/agent/setup.go
new file mode 100644
index 0000000000000..8c4806251fcd5
--- /dev/null
+++ b/lib/autoupdate/agent/setup.go
@@ -0,0 +1,480 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package agent
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "io/fs"
+ "log/slog"
+ "os"
+ "path/filepath"
+ "regexp"
+ "strings"
+ "text/template"
+
+ "github.com/google/renameio/v2"
+ "github.com/gravitational/trace"
+ "gopkg.in/yaml.v3"
+
+ "github.com/gravitational/teleport/lib/defaults"
+ libdefaults "github.com/gravitational/teleport/lib/defaults"
+ libutils "github.com/gravitational/teleport/lib/utils"
+)
+
+// Base paths for constructing namespaced directories.
+const (
+ defaultInstallDir = "/opt/teleport"
+ defaultPathDir = "/usr/local/bin"
+ systemdAdminDir = "/etc/systemd/system"
+ systemdPIDDir = "/run"
+ needrestartConfDir = "/etc/needrestart/conf.d"
+ versionsDirName = "versions"
+ lockFileName = "update.lock"
+ defaultNamespace = "default"
+ systemNamespace = "system"
+)
+
+const (
+ // deprecatedTimerName is the timer for the deprecated upgrader should be disabled on setup.
+ deprecatedTimerName = "teleport-upgrade.timer"
+)
+
+const (
+ updateServiceTemplate = `# teleport-update
+# DO NOT EDIT THIS FILE
+[Unit]
+Description=Teleport auto-update service
+
+[Service]
+Type=oneshot
+ExecStart={{.UpdaterBinary}} --install-suffix={{.InstallSuffix}} --install-dir="{{escape .InstallDir}}" update
+`
+ updateTimerTemplate = `# teleport-update
+# DO NOT EDIT THIS FILE
+[Unit]
+Description=Teleport auto-update timer unit
+
+[Timer]
+OnActiveSec=1m
+OnUnitActiveSec=5m
+RandomizedDelaySec=1m
+
+[Install]
+WantedBy={{.TeleportService}}
+`
+ teleportDropInTemplate = `# teleport-update
+# DO NOT EDIT THIS FILE
+[Service]
+Environment="TELEPORT_UPDATE_CONFIG_FILE={{escape .UpdaterConfigFile}}"
+Environment="TELEPORT_UPDATE_INSTALL_DIR={{escape .InstallDir}}"
+`
+ // This configuration sets the default value for needrestart-trigger automatic restarts for teleport.service to disabled.
+ // Users may still choose to enable needrestart for teleport.service when installing packaging interactively (or via dpkg config),
+ // but doing so will result in a hard restart that disconnects the agent whenever any dependent libraries are updated.
+ // Other network services, like openvpn, follow this pattern.
+ // It is possible to configure needrestart to trigger a soft restart (via restart.d script), but given that Teleport subprocesses
+ // can use a wide variety of installed binaries (when executed by the user), this could trigger many unexpected reloads.
+ needrestartConfTemplate = `$nrconf{override_rc}{qr(^{{replace .TeleportService "." "\\."}})} = 0;
+`
+)
+
+type confParams struct {
+ TeleportService string
+ UpdaterBinary string
+ InstallSuffix string
+ InstallDir string
+ Path string
+ UpdaterConfigFile string
+}
+
+// Namespace represents a namespace within various system paths for a isolated installation of Teleport.
+type Namespace struct {
+ log *slog.Logger
+ // name of namespace
+ name string
+ // installDir for Teleport namespaces (/opt/teleport)
+ installDir string
+ // defaultPathDir for Teleport binaries (ns: /opt/teleport/myns/bin)
+ defaultPathDir string
+ // dataDir parsed from teleport.yaml, if present
+ dataDir string
+ // defaultProxyAddr parsed from teleport.yaml, if present
+ defaultProxyAddr string
+ // serviceFile for the Teleport systemd service (ns: /etc/systemd/system/teleport_myns.service)
+ serviceFile string
+ // configFile for Teleport config (ns: /etc/teleport_myns.yaml)
+ configFile string
+ // pidFile for Teleport (ns: /run/teleport_myns.pid)
+ pidFile string
+ // updaterServiceFile is the systemd service path for the updater
+ updaterServiceFile string
+ // updaterTimerFile is the systemd timer path for the updater
+ updaterTimerFile string
+ // dropInFile is the Teleport systemd drop-in path extending Teleport
+ dropInFile string
+ // needrestartConfFile is the path to needrestart configuration for Teleport
+ needrestartConfFile string
+}
+
+var alphanum = regexp.MustCompile("^[a-zA-Z0-9-]*$")
+
+// NewNamespace validates and returns a Namespace.
+// Namespaces must be alphanumeric + `-`.
+// defaultPathDir overrides the destination directory for namespace setup (i.e., /usr/local)
+func NewNamespace(ctx context.Context, log *slog.Logger, name, installDir string) (ns *Namespace, err error) {
+ defer ns.overrideFromConfig(ctx)
+
+ if name == defaultNamespace ||
+ name == systemNamespace {
+ return nil, trace.Errorf("namespace %s is reserved", name)
+ }
+ if !alphanum.MatchString(name) {
+ return nil, trace.Errorf("invalid namespace name %s, must be alphanumeric", name)
+ }
+ if installDir == "" {
+ installDir = defaultInstallDir
+ }
+ if name == "" {
+ linkDir := defaultPathDir
+ return &Namespace{
+ log: log,
+ name: name,
+ installDir: installDir,
+ defaultPathDir: linkDir,
+ dataDir: defaults.DataDir,
+ serviceFile: filepath.Join("/", serviceDir, serviceName),
+ configFile: defaults.ConfigFilePath,
+ pidFile: filepath.Join(systemdPIDDir, "teleport.pid"),
+ updaterServiceFile: filepath.Join(systemdAdminDir, BinaryName+".service"),
+ updaterTimerFile: filepath.Join(systemdAdminDir, BinaryName+".timer"),
+ dropInFile: filepath.Join(systemdAdminDir, "teleport.service.d", BinaryName+".conf"),
+ needrestartConfFile: filepath.Join(needrestartConfDir, BinaryName+".conf"),
+ }, nil
+ }
+
+ prefix := "teleport_" + name
+ linkDir := filepath.Join(installDir, name, "bin")
+ return &Namespace{
+ log: log,
+ name: name,
+ installDir: installDir,
+ defaultPathDir: linkDir,
+ dataDir: filepath.Join(filepath.Dir(defaults.DataDir), prefix),
+ serviceFile: filepath.Join(systemdAdminDir, prefix+".service"),
+ configFile: filepath.Join(filepath.Dir(defaults.ConfigFilePath), prefix+".yaml"),
+ pidFile: filepath.Join(systemdPIDDir, prefix+".pid"),
+ updaterServiceFile: filepath.Join(systemdAdminDir, BinaryName+"_"+name+".service"),
+ updaterTimerFile: filepath.Join(systemdAdminDir, BinaryName+"_"+name+".timer"),
+ dropInFile: filepath.Join(systemdAdminDir, prefix+".service.d", BinaryName+"_"+name+".conf"),
+ needrestartConfFile: filepath.Join(needrestartConfDir, BinaryName+"_"+name+".conf"),
+ }, nil
+}
+
+func (ns *Namespace) Dir() string {
+ name := ns.name
+ if name == "" {
+ name = defaultNamespace
+ }
+ return filepath.Join(ns.installDir, name)
+}
+
+// Init create the initial directory structure and returns the lockfile for a Namespace.
+func (ns *Namespace) Init() (lockFile string, err error) {
+ if err := os.MkdirAll(filepath.Join(ns.Dir(), versionsDirName), systemDirMode); err != nil {
+ return "", trace.Wrap(err)
+ }
+ return filepath.Join(ns.Dir(), lockFileName), nil
+}
+
+// Setup installs service and timer files for the teleport-update binary.
+// Afterwords, Setup reloads systemd and enables the timer with --now.
+func (ns *Namespace) Setup(ctx context.Context, path string) error {
+ if ok, err := hasSystemD(); err == nil && !ok {
+ ns.log.WarnContext(ctx, "Systemd is not running, skipping updater installation.")
+ return nil
+ }
+
+ err := ns.writeConfigFiles(ctx, path)
+ if err != nil {
+ return trace.Wrap(err, "failed to write teleport-update systemd config files")
+ }
+ timer := &SystemdService{
+ ServiceName: filepath.Base(ns.updaterTimerFile),
+ Log: ns.log,
+ }
+ if err := timer.Sync(ctx); err != nil {
+ return trace.Wrap(err, "failed to sync systemd config")
+ }
+ if err := timer.Enable(ctx, true); err != nil {
+ return trace.Wrap(err, "failed to enable teleport-update systemd timer")
+ }
+ if ns.name == "" {
+ oldTimer := &SystemdService{
+ ServiceName: deprecatedTimerName,
+ Log: ns.log,
+ }
+ // If the old teleport-upgrade script is detected, disable it to ensure they do not interfere.
+ // Note that the schedule is also set to nop by the Teleport agent -- this just prevents restarts.
+ enabled, err := isActiveOrEnabled(ctx, oldTimer)
+ if err != nil {
+ return trace.Wrap(err, "failed to determine if deprecated teleport-upgrade systemd timer is enabled")
+ }
+ if enabled {
+ if err := oldTimer.Disable(ctx, true); err != nil {
+ ns.log.ErrorContext(ctx, "The deprecated teleport-ent-updater package is installed on this server, and it cannot be disabled due to an error. You must remove the teleport-ent-updater package after verifying that teleport-update is working.", errorKey, err)
+ } else {
+ ns.log.WarnContext(ctx, "The deprecated teleport-ent-updater package is installed on this server. This package has been disabled to prevent conflicts. Please remove the teleport-ent-updater package after verifying that teleport-update is working.")
+ }
+ }
+ }
+ return nil
+}
+
+// Teardown removes all traces of the auto-updater, including its configuration.
+func (ns *Namespace) Teardown(ctx context.Context) error {
+ if ok, err := hasSystemD(); err == nil && !ok {
+ ns.log.WarnContext(ctx, "Systemd is not running, skipping updater removal.")
+ if err := os.RemoveAll(ns.Dir()); err != nil {
+ return trace.Wrap(err, "failed to remove versions directory")
+ }
+ return nil
+ }
+
+ svc := &SystemdService{
+ ServiceName: filepath.Base(ns.updaterTimerFile),
+ Log: ns.log,
+ }
+ if err := svc.Disable(ctx, true); err != nil {
+ ns.log.WarnContext(ctx, "Unable to disable teleport-update systemd timer before removing.", errorKey, err)
+ }
+ for _, p := range []string{
+ ns.updaterServiceFile,
+ ns.updaterTimerFile,
+ ns.dropInFile,
+ ns.needrestartConfFile,
+ } {
+ if err := os.Remove(p); err != nil && !errors.Is(err, fs.ErrNotExist) {
+ return trace.Wrap(err, "failed to remove %s", filepath.Base(p))
+ }
+ }
+ if err := svc.Sync(ctx); err != nil {
+ return trace.Wrap(err, "failed to sync systemd config")
+ }
+ if err := os.RemoveAll(ns.Dir()); err != nil {
+ return trace.Wrap(err, "failed to remove versions directory")
+ }
+ if ns.name == "" {
+ oldTimer := &SystemdService{
+ ServiceName: deprecatedTimerName,
+ Log: ns.log,
+ }
+ // If the old upgrader exists, attempt to re-enable it automatically
+ present, err := oldTimer.IsPresent(ctx)
+ if err != nil {
+ return trace.Wrap(err, "failed to determine if deprecated teleport-upgrade systemd timer is present")
+ }
+ if present {
+ if err := oldTimer.Enable(ctx, true); err != nil {
+ ns.log.ErrorContext(ctx, "The deprecated teleport-ent-updater package is installed on this server, and it cannot be re-enabled due to an error. Please fix the teleport-ent-updater package if you intend to use the deprecated updater.", errorKey, err)
+ } else {
+ ns.log.WarnContext(ctx, "The deprecated teleport-ent-updater package is installed on this server. This package has been re-enabled to ensure continued updates. To disable automatic updates entirely, please remove the teleport-ent-updater package.")
+ }
+ }
+ }
+ return nil
+}
+
+func (ns *Namespace) writeConfigFiles(ctx context.Context, path string) error {
+ teleportService := filepath.Base(ns.serviceFile)
+ params := confParams{
+ TeleportService: teleportService,
+ UpdaterBinary: filepath.Join(path, BinaryName),
+ InstallSuffix: ns.name,
+ InstallDir: ns.installDir,
+ Path: path,
+ UpdaterConfigFile: filepath.Join(ns.Dir(), updateConfigName),
+ }
+ err := writeSystemTemplate(ns.updaterServiceFile, updateServiceTemplate, params)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ err = writeSystemTemplate(ns.updaterTimerFile, updateTimerTemplate, params)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ err = writeSystemTemplate(ns.dropInFile, teleportDropInTemplate, params)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ // Needrestart config is non-critical for updater functionality.
+ _, err = os.Stat(filepath.Dir(ns.needrestartConfFile))
+ if os.IsNotExist(err) {
+ return nil // needrestart is not present
+ }
+ if err != nil {
+ ns.log.ErrorContext(ctx, "Unable to disable needrestart.", errorKey, err)
+ return nil
+ }
+ ns.log.InfoContext(ctx, "Disabling needrestart.", unitKey, teleportService)
+ err = writeSystemTemplate(ns.needrestartConfFile, needrestartConfTemplate, params)
+ if err != nil {
+ ns.log.ErrorContext(ctx, "Unable to disable needrestart.", errorKey, err)
+ return nil
+ }
+ return nil
+}
+
+// writeSystemTemplate atomically writes a template to a system file, creating any needed directories.
+// Temporarily files are stored in the target path to ensure the file has needed SELinux contexts.
+func writeSystemTemplate(path, t string, values any) error {
+ dir, file := filepath.Split(path)
+ if err := os.MkdirAll(dir, systemDirMode); err != nil {
+ return trace.Wrap(err)
+ }
+ opts := []renameio.Option{
+ renameio.WithPermissions(configFileMode),
+ renameio.WithExistingPermissions(),
+ renameio.WithTempDir(dir),
+ }
+ f, err := renameio.NewPendingFile(path, opts...)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ defer f.Cleanup()
+
+ tmpl, err := template.New(file).Funcs(template.FuncMap{
+ "replace": func(s, old, new string) string {
+ return strings.ReplaceAll(s, old, new)
+ },
+ // escape is a best-effort function for escaping quotes in systemd service templates.
+ // Paths that are escaped with this method should not be advertised to the user as
+ // configurable until a more robust escaping mechanism is shipped.
+ // See: https://www.freedesktop.org/software/systemd/man/latest/systemd.syntax.html
+ "escape": func(s string) string {
+ replacer := strings.NewReplacer(
+ `"`, `\"`,
+ `\`, `\\`,
+ )
+ return replacer.Replace(s)
+ },
+ }).Parse(t)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ err = tmpl.Execute(f, values)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ return trace.Wrap(f.CloseAtomicallyReplace())
+}
+
+// replaceTeleportService replaces the default paths in the Teleport service config with namespaced paths.
+func (ns *Namespace) replaceTeleportService(cfg []byte) []byte {
+ for _, rep := range []struct {
+ old, new string
+ }{
+ {
+ old: "/usr/local/bin/",
+ new: ns.defaultPathDir + "/",
+ },
+ {
+ old: "/etc/teleport.yaml",
+ new: ns.configFile,
+ },
+ {
+ old: "/run/teleport.pid",
+ new: ns.pidFile,
+ },
+ } {
+ cfg = bytes.ReplaceAll(cfg, []byte(rep.old), []byte(rep.new))
+ }
+ return cfg
+}
+
+func (ns *Namespace) LogWarning(ctx context.Context) {
+ ns.log.WarnContext(ctx, "Custom install suffix specified. Teleport data_dir must be configured in the config file.",
+ "data_dir", ns.dataDir,
+ "path", ns.defaultPathDir,
+ "config", ns.configFile,
+ "service", filepath.Base(ns.serviceFile),
+ "pid", ns.pidFile,
+ )
+}
+
+// unversionedConfig is used to read all versions of teleport.yaml, including
+// versions that may now be unsupported.
+type unversionedConfig struct {
+ Teleport unversionedTeleport `yaml:"teleport"`
+}
+
+type unversionedTeleport struct {
+ AuthServers []string `yaml:"auth_servers"`
+ AuthServer string `yaml:"auth_server"`
+ ProxyServer string `yaml:"proxy_server"`
+ DataDir string `yaml:"data_dir"`
+}
+
+// overrideFromConfig loads fields from teleport.yaml into the namespace, overriding any defaults.
+func (ns *Namespace) overrideFromConfig(ctx context.Context) {
+ if ns == nil || ns.configFile == "" {
+ return
+ }
+ path := ns.configFile
+ f, err := libutils.OpenFileAllowingUnsafeLinks(path)
+ if err != nil {
+ ns.log.DebugContext(ctx, "Unable to open Teleport config to read proxy or data dir", "config", path, errorKey, err)
+ return
+ }
+ defer f.Close()
+ var cfg unversionedConfig
+ if err := yaml.NewDecoder(f).Decode(&cfg); err != nil {
+ ns.log.DebugContext(ctx, "Unable to parse Teleport config to read proxy or data dir", "config", path, errorKey, err)
+ return
+ }
+ if cfg.Teleport.DataDir != "" {
+ ns.dataDir = cfg.Teleport.DataDir
+ }
+
+ // Any implicitly defaulted port in teleport.yaml is explicitly defaulted (to 3080).
+
+ var addr string
+ var port int
+ switch t := cfg.Teleport; {
+ case t.ProxyServer != "":
+ addr = t.ProxyServer
+ port = libdefaults.HTTPListenPort
+ case t.AuthServer != "":
+ addr = t.AuthServer
+ port = libdefaults.AuthListenPort
+ case len(t.AuthServers) > 0:
+ addr = t.AuthServers[0]
+ port = libdefaults.AuthListenPort
+ default:
+ ns.log.DebugContext(ctx, "Unable to find proxy in Teleport config", "config", path, errorKey, err)
+ return
+ }
+ netaddr, err := libutils.ParseHostPortAddr(addr, port)
+ if err != nil {
+ ns.log.DebugContext(ctx, "Unable to parse proxy in Teleport config", "config", path, "proxy_addr", addr, "proxy_port", port, errorKey, err)
+ return
+ }
+ ns.defaultProxyAddr = netaddr.String()
+}
diff --git a/lib/autoupdate/agent/setup_test.go b/lib/autoupdate/agent/setup_test.go
new file mode 100644
index 0000000000000..bebda789cadb7
--- /dev/null
+++ b/lib/autoupdate/agent/setup_test.go
@@ -0,0 +1,365 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package agent
+
+import (
+ "bytes"
+ "context"
+ "log/slog"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "gopkg.in/yaml.v3"
+
+ "github.com/gravitational/teleport/lib/config"
+ "github.com/gravitational/teleport/lib/utils/testutils/golden"
+)
+
+func TestNewNamespace(t *testing.T) {
+ for _, p := range []struct {
+ name string
+ namespace string
+ installDir string
+ errMatch string
+ ns *Namespace
+ }{
+ {
+ name: "no namespace",
+ ns: &Namespace{
+ dataDir: "/var/lib/teleport",
+ installDir: "/opt/teleport",
+ defaultPathDir: "/usr/local/bin",
+ serviceFile: "/lib/systemd/system/teleport.service",
+ configFile: "/etc/teleport.yaml",
+ pidFile: "/run/teleport.pid",
+ updaterServiceFile: "/etc/systemd/system/teleport-update.service",
+ updaterTimerFile: "/etc/systemd/system/teleport-update.timer",
+ dropInFile: "/etc/systemd/system/teleport.service.d/teleport-update.conf",
+ needrestartConfFile: "/etc/needrestart/conf.d/teleport-update.conf",
+ },
+ },
+ {
+ name: "no namespace with dirs",
+ installDir: "/install",
+ ns: &Namespace{
+ dataDir: "/var/lib/teleport",
+ installDir: "/install",
+ defaultPathDir: "/usr/local/bin",
+ serviceFile: "/lib/systemd/system/teleport.service",
+ configFile: "/etc/teleport.yaml",
+ pidFile: "/run/teleport.pid",
+ updaterServiceFile: "/etc/systemd/system/teleport-update.service",
+ updaterTimerFile: "/etc/systemd/system/teleport-update.timer",
+ dropInFile: "/etc/systemd/system/teleport.service.d/teleport-update.conf",
+ needrestartConfFile: "/etc/needrestart/conf.d/teleport-update.conf",
+ },
+ },
+ {
+ name: "test namespace",
+ namespace: "test",
+ ns: &Namespace{
+ name: "test",
+ dataDir: "/var/lib/teleport_test",
+ installDir: "/opt/teleport",
+ defaultPathDir: "/opt/teleport/test/bin",
+ serviceFile: "/etc/systemd/system/teleport_test.service",
+ configFile: "/etc/teleport_test.yaml",
+ pidFile: "/run/teleport_test.pid",
+ updaterServiceFile: "/etc/systemd/system/teleport-update_test.service",
+ updaterTimerFile: "/etc/systemd/system/teleport-update_test.timer",
+ dropInFile: "/etc/systemd/system/teleport_test.service.d/teleport-update_test.conf",
+ needrestartConfFile: "/etc/needrestart/conf.d/teleport-update_test.conf",
+ },
+ },
+ {
+ name: "test namespace with dirs",
+ namespace: "test",
+ installDir: "/install",
+ ns: &Namespace{
+ name: "test",
+ dataDir: "/var/lib/teleport_test",
+ installDir: "/install",
+ defaultPathDir: "/install/test/bin",
+ configFile: "/etc/teleport_test.yaml",
+ pidFile: "/run/teleport_test.pid",
+ serviceFile: "/etc/systemd/system/teleport_test.service",
+ updaterServiceFile: "/etc/systemd/system/teleport-update_test.service",
+ updaterTimerFile: "/etc/systemd/system/teleport-update_test.timer",
+ dropInFile: "/etc/systemd/system/teleport_test.service.d/teleport-update_test.conf",
+ needrestartConfFile: "/etc/needrestart/conf.d/teleport-update_test.conf",
+ },
+ },
+ {
+ name: "reserved default",
+ namespace: defaultNamespace,
+ errMatch: "reserved",
+ },
+ {
+ name: "reserved system",
+ namespace: systemNamespace,
+ errMatch: "reserved",
+ },
+ } {
+ t.Run(p.name, func(t *testing.T) {
+ log := slog.Default()
+ ctx := context.Background()
+ ns, err := NewNamespace(ctx, log, p.namespace, p.installDir)
+ if p.errMatch != "" {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), p.errMatch)
+ return
+ }
+ require.NoError(t, err)
+ ns.log = nil
+ require.Equal(t, p.ns, ns)
+ })
+ }
+}
+
+func TestWriteConfigFiles(t *testing.T) {
+ for _, p := range []struct {
+ name string
+ namespace string
+ }{
+ {
+ name: "no namespace",
+ },
+ {
+ name: "test namespace",
+ namespace: "test",
+ },
+ } {
+ t.Run(p.name, func(t *testing.T) {
+ log := slog.Default()
+ linkDir := t.TempDir()
+ ctx := context.Background()
+ ns, err := NewNamespace(ctx, log, p.namespace, "")
+ require.NoError(t, err)
+ ns.updaterServiceFile = filepath.Join(linkDir, serviceDir, filepath.Base(ns.updaterServiceFile))
+ ns.updaterTimerFile = filepath.Join(linkDir, serviceDir, filepath.Base(ns.updaterTimerFile))
+ ns.dropInFile = filepath.Join(linkDir, serviceDir, filepath.Base(filepath.Dir(ns.dropInFile)), filepath.Base(ns.dropInFile))
+ ns.needrestartConfFile = filepath.Join(linkDir, filepath.Base(ns.dropInFile))
+ err = ns.writeConfigFiles(ctx, linkDir)
+ require.NoError(t, err)
+
+ for _, tt := range []struct {
+ name string
+ path string
+ }{
+ {name: "service", path: ns.updaterServiceFile},
+ {name: "timer", path: ns.updaterTimerFile},
+ {name: "dropin", path: ns.dropInFile},
+ {name: "needrestart", path: ns.needrestartConfFile},
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ data, err := os.ReadFile(tt.path)
+ require.NoError(t, err)
+ data = replaceValues(data, map[string]string{
+ defaultPathDir: linkDir,
+ })
+ if golden.ShouldSet() {
+ golden.Set(t, data)
+ }
+ require.Equal(t, string(golden.Get(t)), string(data))
+ })
+ }
+ })
+ }
+}
+
+func replaceValues(data []byte, m map[string]string) []byte {
+ for k, v := range m {
+ data = bytes.ReplaceAll(data, []byte(v), []byte(k))
+ }
+ return data
+}
+
+func TestNamespace_overrideFromConfig(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ cfg *unversionedTeleport
+ want Namespace
+ }{
+ {
+ name: "default",
+ cfg: &unversionedTeleport{
+ ProxyServer: "example.com",
+ DataDir: "/data",
+ },
+ want: Namespace{
+ defaultProxyAddr: "example.com:3080",
+ dataDir: "/data",
+ },
+ },
+ {
+ name: "empty",
+ cfg: &unversionedTeleport{},
+ want: Namespace{
+ defaultProxyAddr: "default.example.com",
+ dataDir: "/var/lib/teleport",
+ },
+ },
+ {
+ name: "full proxy",
+ cfg: &unversionedTeleport{
+ ProxyServer: "https://example.com:8080",
+ },
+ want: Namespace{
+ defaultProxyAddr: "example.com:8080",
+ dataDir: "/var/lib/teleport",
+ },
+ },
+ {
+ name: "protocol and host",
+ cfg: &unversionedTeleport{
+ ProxyServer: "https://example.com",
+ },
+ want: Namespace{
+ defaultProxyAddr: "example.com:3080",
+ dataDir: "/var/lib/teleport",
+ },
+ },
+ {
+ name: "host and port",
+ cfg: &unversionedTeleport{
+ ProxyServer: "example.com:443",
+ },
+ want: Namespace{
+ defaultProxyAddr: "example.com:443",
+ dataDir: "/var/lib/teleport",
+ },
+ },
+ {
+ name: "host",
+ cfg: &unversionedTeleport{
+ ProxyServer: "example.com",
+ },
+ want: Namespace{
+ defaultProxyAddr: "example.com:3080",
+ dataDir: "/var/lib/teleport",
+ },
+ },
+ {
+ name: "auth server (v3)",
+ cfg: &unversionedTeleport{
+ AuthServer: "example.com",
+ },
+ want: Namespace{
+ defaultProxyAddr: "example.com:3025",
+ dataDir: "/var/lib/teleport",
+ },
+ },
+ {
+ name: "auth server (v1/2)",
+ cfg: &unversionedTeleport{
+ AuthServers: []string{
+ "one.example.com",
+ "two.example.com",
+ },
+ },
+ want: Namespace{
+ defaultProxyAddr: "one.example.com:3025",
+ dataDir: "/var/lib/teleport",
+ },
+ },
+ {
+ name: "proxy priority",
+ cfg: &unversionedTeleport{
+ ProxyServer: "one.example.com",
+ AuthServer: "two.example.com",
+ AuthServers: []string{"three.example.com"},
+ },
+ want: Namespace{
+ defaultProxyAddr: "one.example.com:3080",
+ dataDir: "/var/lib/teleport",
+ },
+ },
+ {
+ name: "auth priority",
+ cfg: &unversionedTeleport{
+ AuthServer: "two.example.com",
+ AuthServers: []string{"three.example.com"},
+ },
+ want: Namespace{
+ defaultProxyAddr: "two.example.com:3025",
+ dataDir: "/var/lib/teleport",
+ },
+ },
+ {
+ name: "missing",
+ want: Namespace{
+ defaultProxyAddr: "default.example.com",
+ dataDir: "/var/lib/teleport",
+ },
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ns := &Namespace{
+ log: slog.Default(),
+ configFile: filepath.Join(t.TempDir(), "teleport.yaml"),
+ defaultProxyAddr: "default.example.com",
+ dataDir: "/var/lib/teleport",
+ }
+ if tt.cfg != nil {
+ out, err := yaml.Marshal(unversionedConfig{Teleport: *tt.cfg})
+ require.NoError(t, err)
+ err = os.WriteFile(ns.configFile, out, os.ModePerm)
+ require.NoError(t, err)
+ }
+ ctx := context.Background()
+ ns.overrideFromConfig(ctx)
+ ns.configFile = ""
+ ns.log = nil
+ require.Equal(t, &tt.want, ns)
+ })
+ }
+}
+
+// In the future, the latest version of the updater may need to read a version of teleport.yaml that has
+// an unsupported version which is supported by the updater-managed version of Teleport.
+// This test will break if Teleport removes a field that the updater reads.
+func TestUnversionedTeleportConfig(t *testing.T) {
+ in := unversionedConfig{
+ Teleport: unversionedTeleport{
+ ProxyServer: "proxy.example.com",
+ AuthServer: "auth.example.com",
+ AuthServers: []string{"auth1.example.com", "auth2.example.com"},
+ DataDir: "example_dir",
+ },
+ }
+ var inB bytes.Buffer
+ err := yaml.NewEncoder(&inB).Encode(in)
+ require.NoError(t, err)
+ fc, err := config.ReadConfig(&inB)
+ require.NoError(t, err)
+
+ var outB bytes.Buffer
+ err = yaml.NewEncoder(&outB).Encode(fc)
+ require.NoError(t, err)
+
+ var out unversionedConfig
+ err = yaml.NewDecoder(&outB).Decode(&out)
+ require.NoError(t, err)
+ require.Equal(t, in, out)
+}
diff --git a/lib/autoupdate/agent/telemetry.go b/lib/autoupdate/agent/telemetry.go
new file mode 100644
index 0000000000000..bd6bb887cad40
--- /dev/null
+++ b/lib/autoupdate/agent/telemetry.go
@@ -0,0 +1,105 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package agent
+
+import (
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/gravitational/trace"
+)
+
+const installDirEnvVar = "TELEPORT_UPDATE_INSTALL_DIR"
+
+// IsManagedByUpdater returns true if the local Teleport binary is managed by teleport-update.
+// Note that true may be returned even if auto-updates is disabled or the version is pinned.
+// The binary is considered managed if it lives under /opt/teleport, but not within the package
+// path at /opt/teleport/system.
+func IsManagedByUpdater() (bool, error) {
+ systemd, err := hasSystemD()
+ if err != nil {
+ return false, trace.Wrap(err)
+ }
+ if !systemd {
+ return false, nil
+ }
+ teleportPath, err := os.Readlink("/proc/self/exe")
+ if err != nil {
+ return false, trace.Wrap(err, "cannot find Teleport binary")
+ }
+ installDir := os.Getenv(installDirEnvVar)
+ if installDir == "" {
+ installDir = defaultInstallDir
+ }
+ // Check if current binary is under the updater-managed path.
+ managed, err := hasParentDir(teleportPath, installDir)
+ if err != nil {
+ return false, trace.Wrap(err)
+ }
+ if !managed {
+ return false, nil
+ }
+ // Return false if the binary is under the updater-managed path, but in the system prefix reserved for the package.
+ system, err := hasParentDir(teleportPath, packageSystemDir)
+ return !system, trace.Wrap(err)
+}
+
+// IsManagedAndDefault returns true if the local Teleport binary is both managed by teleport-update
+// and the default installation (with teleport.service as the unit file name).
+// The binary is considered managed and default if it lives within /opt/teleport/default.
+func IsManagedAndDefault() (bool, error) {
+ systemd, err := hasSystemD()
+ if err != nil {
+ return false, trace.Wrap(err)
+ }
+ if !systemd {
+ return false, nil
+ }
+ teleportPath, err := os.Readlink("/proc/self/exe")
+ if err != nil {
+ return false, trace.Wrap(err, "cannot find Teleport binary")
+ }
+ installDir := os.Getenv(installDirEnvVar)
+ if installDir == "" {
+ installDir = defaultInstallDir
+ }
+ isDefault, err := hasParentDir(teleportPath, filepath.Join(installDir, defaultNamespace))
+ return isDefault, trace.Wrap(err)
+}
+
+// hasParentDir returns true if dir is any parent directory of parent.
+// hasParentDir does not resolve symlinks, and requires that files be represented the same way in dir and parent.
+func hasParentDir(dir, parent string) (bool, error) {
+ // Note that os.Stat + os.SameFile would be more reliable,
+ // but does not work well for arbitrarily nested subdirectories.
+ absDir, err := filepath.Abs(dir)
+ if err != nil {
+ return false, trace.Wrap(err, "cannot get absolute path for directory %s", dir)
+ }
+ absParent, err := filepath.Abs(parent)
+ if err != nil {
+ return false, trace.Wrap(err, "cannot get absolute path for parent directory %s", dir)
+ }
+ sep := string(filepath.Separator)
+ if !strings.HasSuffix(absParent, sep) {
+ absParent += sep
+ }
+ return strings.HasPrefix(absDir, absParent), nil
+}
diff --git a/lib/autoupdate/agent/telemetry_test.go b/lib/autoupdate/agent/telemetry_test.go
new file mode 100644
index 0000000000000..8332657785c04
--- /dev/null
+++ b/lib/autoupdate/agent/telemetry_test.go
@@ -0,0 +1,109 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package agent
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestHasParentDir(t *testing.T) {
+ tests := []struct {
+ name string
+ path string
+ parent string
+ wantResult bool
+ }{
+ {
+ name: "Has valid parent directory",
+ path: "/opt/teleport/dir/test",
+ parent: "/opt/teleport",
+ wantResult: true,
+ },
+ {
+ name: "Has valid parent directory with slash",
+ path: "/opt/teleport/dir/test",
+ parent: "/opt/teleport/",
+ wantResult: true,
+ },
+ {
+ name: "Parent directory is root",
+ path: "/opt/teleport/dir",
+ parent: "/",
+ wantResult: true,
+ },
+ {
+ name: "Parent is the same as the path",
+ path: "/opt/teleport/dir",
+ parent: "/opt/teleport/dir",
+ wantResult: false,
+ },
+ {
+ name: "Parent the same as the path but without slash",
+ path: "/opt/teleport/dir/",
+ parent: "/opt/teleport/dir",
+ wantResult: false,
+ },
+ {
+ name: "Parent the same as the path but with slash",
+ path: "/opt/teleport/dir",
+ parent: "/opt/teleport/dir/",
+ wantResult: false,
+ },
+ {
+ name: "Parent is substring of the path",
+ path: "/opt/teleport/dir-place",
+ parent: "/opt/teleport/dir",
+ wantResult: false,
+ },
+ {
+ name: "Parent is in path",
+ path: "/opt/teleport",
+ parent: "/opt/teleport/dir",
+ wantResult: false,
+ },
+ {
+ name: "Empty parent",
+ path: "/opt/teleport/dir",
+ parent: "",
+ wantResult: false,
+ },
+ {
+ name: "Empty path",
+ path: "",
+ parent: "/opt/teleport",
+ wantResult: false,
+ },
+ {
+ name: "Both empty",
+ path: "",
+ parent: "",
+ wantResult: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result, err := hasParentDir(tt.path, tt.parent)
+ require.NoError(t, err)
+ require.Equal(t, tt.wantResult, result)
+ })
+ }
+}
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Disable/already_disabled.golden b/lib/autoupdate/agent/testdata/TestUpdater_Disable/already_disabled.golden
index e8773a6d88b7f..6e104086250e3 100644
--- a/lib/autoupdate/agent/testdata/TestUpdater_Disable/already_disabled.golden
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Disable/already_disabled.golden
@@ -2,8 +2,9 @@ version: v1
kind: update_config
spec:
proxy: ""
- group: ""
- url_template: ""
+ path: ""
enabled: false
+ pinned: false
status:
- active_version: ""
+ active:
+ version: ""
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Disable/enabled.golden b/lib/autoupdate/agent/testdata/TestUpdater_Disable/enabled.golden
index e8773a6d88b7f..6e104086250e3 100644
--- a/lib/autoupdate/agent/testdata/TestUpdater_Disable/enabled.golden
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Disable/enabled.golden
@@ -2,8 +2,9 @@ version: v1
kind: update_config
spec:
proxy: ""
- group: ""
- url_template: ""
+ path: ""
enabled: false
+ pinned: false
status:
- active_version: ""
+ active:
+ version: ""
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_does_not_exist.golden b/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_does_not_exist.golden
deleted file mode 100644
index e03f369eb1017..0000000000000
--- a/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_does_not_exist.golden
+++ /dev/null
@@ -1,9 +0,0 @@
-version: v1
-kind: update_config
-spec:
- proxy: localhost
- group: ""
- url_template: ""
- enabled: true
-status:
- active_version: 16.3.0
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_from_file.golden b/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_from_file.golden
deleted file mode 100644
index b172d858bc55a..0000000000000
--- a/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_from_file.golden
+++ /dev/null
@@ -1,9 +0,0 @@
-version: v1
-kind: update_config
-spec:
- proxy: localhost
- group: group
- url_template: https://example.com
- enabled: true
-status:
- active_version: 16.3.0
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_from_user.golden b/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_from_user.golden
deleted file mode 100644
index bb9ce8b9d8fa8..0000000000000
--- a/lib/autoupdate/agent/testdata/TestUpdater_Enable/config_from_user.golden
+++ /dev/null
@@ -1,9 +0,0 @@
-version: v1
-kind: update_config
-spec:
- proxy: localhost
- group: new-group
- url_template: https://example.com/new
- enabled: true
-status:
- active_version: new-version
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Enable/version_already_installed.golden b/lib/autoupdate/agent/testdata/TestUpdater_Enable/version_already_installed.golden
deleted file mode 100644
index e03f369eb1017..0000000000000
--- a/lib/autoupdate/agent/testdata/TestUpdater_Enable/version_already_installed.golden
+++ /dev/null
@@ -1,9 +0,0 @@
-version: v1
-kind: update_config
-spec:
- proxy: localhost
- group: ""
- url_template: ""
- enabled: true
-status:
- active_version: 16.3.0
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/FIPS_and_Enterprise_flags.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/FIPS_and_Enterprise_flags.golden
new file mode 100644
index 0000000000000..04c32df19ee6f
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/FIPS_and_Enterprise_flags.golden
@@ -0,0 +1,11 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ enabled: false
+ pinned: false
+status:
+ active:
+ version: 16.3.0
+ flags: [Enterprise, FIPS]
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/agpl_requires_base_URL.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/agpl_requires_base_URL.golden
new file mode 100644
index 0000000000000..6e104086250e3
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/agpl_requires_base_URL.golden
@@ -0,0 +1,10 @@
+version: v1
+kind: update_config
+spec:
+ proxy: ""
+ path: ""
+ enabled: false
+ pinned: false
+status:
+ active:
+ version: ""
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/backup_version_kept_for_validation.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/backup_version_kept_for_validation.golden
new file mode 100644
index 0000000000000..07eedcd8cadbb
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/backup_version_kept_for_validation.golden
@@ -0,0 +1,12 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ enabled: false
+ pinned: false
+status:
+ active:
+ version: 16.3.0
+ backup:
+ version: backup-version
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/backup_version_removed_on_install.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/backup_version_removed_on_install.golden
new file mode 100644
index 0000000000000..fb70091fe8d78
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/backup_version_removed_on_install.golden
@@ -0,0 +1,12 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ enabled: false
+ pinned: false
+status:
+ active:
+ version: 16.3.0
+ backup:
+ version: old-version
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/config_does_not_exist.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/config_does_not_exist.golden
new file mode 100644
index 0000000000000..36c71be81e379
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/config_does_not_exist.golden
@@ -0,0 +1,10 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ enabled: false
+ pinned: false
+status:
+ active:
+ version: 16.3.0
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/config_from_file.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/config_from_file.golden
new file mode 100644
index 0000000000000..1918a2a3434d1
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/config_from_file.golden
@@ -0,0 +1,14 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /path
+ group: group
+ base_url: https://example.com
+ enabled: true
+ pinned: false
+status:
+ active:
+ version: 16.3.0
+ backup:
+ version: old-version
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/config_from_user.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/config_from_user.golden
new file mode 100644
index 0000000000000..25cd918d5af79
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/config_from_user.golden
@@ -0,0 +1,14 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /path
+ group: new-group
+ base_url: https://example.com/new
+ enabled: true
+ pinned: false
+status:
+ active:
+ version: new-version
+ backup:
+ version: old-version
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/defaults.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/defaults.golden
new file mode 100644
index 0000000000000..fb70091fe8d78
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/defaults.golden
@@ -0,0 +1,12 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ enabled: false
+ pinned: false
+status:
+ active:
+ version: 16.3.0
+ backup:
+ version: old-version
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/insecure_URL.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/insecure_URL.golden
new file mode 100644
index 0000000000000..69b08b9c83c9d
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/insecure_URL.golden
@@ -0,0 +1,11 @@
+version: v1
+kind: update_config
+spec:
+ proxy: ""
+ path: ""
+ base_url: http://example.com
+ enabled: false
+ pinned: false
+status:
+ active:
+ version: ""
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/install_error.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/install_error.golden
new file mode 100644
index 0000000000000..6e104086250e3
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/install_error.golden
@@ -0,0 +1,10 @@
+version: v1
+kind: update_config
+spec:
+ proxy: ""
+ path: ""
+ enabled: false
+ pinned: false
+status:
+ active:
+ version: ""
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/invalid_metadata.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/invalid_metadata.golden
new file mode 100644
index 0000000000000..0c3dcaac8edbd
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/invalid_metadata.golden
@@ -0,0 +1,10 @@
+version: ""
+kind: ""
+spec:
+ proxy: ""
+ path: ""
+ enabled: false
+ pinned: false
+status:
+ active:
+ version: ""
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/no_need_to_reload.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/no_need_to_reload.golden
new file mode 100644
index 0000000000000..36c71be81e379
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/no_need_to_reload.golden
@@ -0,0 +1,10 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ enabled: false
+ pinned: false
+status:
+ active:
+ version: 16.3.0
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/not_started_or_enabled.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/not_started_or_enabled.golden
new file mode 100644
index 0000000000000..36c71be81e379
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/not_started_or_enabled.golden
@@ -0,0 +1,10 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ enabled: false
+ pinned: false
+status:
+ active:
+ version: 16.3.0
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/override_skip.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/override_skip.golden
new file mode 100644
index 0000000000000..fb70091fe8d78
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/override_skip.golden
@@ -0,0 +1,12 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ enabled: false
+ pinned: false
+status:
+ active:
+ version: 16.3.0
+ backup:
+ version: old-version
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/setup_fails_already_installed.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/setup_fails_already_installed.golden
new file mode 100644
index 0000000000000..067fcf60bd527
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/setup_fails_already_installed.golden
@@ -0,0 +1,10 @@
+version: v1
+kind: update_config
+spec:
+ proxy: ""
+ path: ""
+ enabled: false
+ pinned: false
+status:
+ active:
+ version: 16.3.0
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Install/version_already_installed.golden b/lib/autoupdate/agent/testdata/TestUpdater_Install/version_already_installed.golden
new file mode 100644
index 0000000000000..36c71be81e379
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Install/version_already_installed.golden
@@ -0,0 +1,10 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ enabled: false
+ pinned: false
+status:
+ active:
+ version: 16.3.0
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Unpin/not_pinned.golden b/lib/autoupdate/agent/testdata/TestUpdater_Unpin/not_pinned.golden
new file mode 100644
index 0000000000000..6e104086250e3
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Unpin/not_pinned.golden
@@ -0,0 +1,10 @@
+version: v1
+kind: update_config
+spec:
+ proxy: ""
+ path: ""
+ enabled: false
+ pinned: false
+status:
+ active:
+ version: ""
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Unpin/pinned.golden b/lib/autoupdate/agent/testdata/TestUpdater_Unpin/pinned.golden
new file mode 100644
index 0000000000000..6e104086250e3
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Unpin/pinned.golden
@@ -0,0 +1,10 @@
+version: v1
+kind: update_config
+spec:
+ proxy: ""
+ path: ""
+ enabled: false
+ pinned: false
+status:
+ active:
+ version: ""
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/FIPS_and_Enterprise_flags.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/FIPS_and_Enterprise_flags.golden
new file mode 100644
index 0000000000000..f13285e6d27cb
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/FIPS_and_Enterprise_flags.golden
@@ -0,0 +1,15 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ base_url: https://example.com
+ enabled: true
+ pinned: false
+status:
+ active:
+ version: 16.3.0
+ flags: [Enterprise, FIPS]
+ backup:
+ version: old-version
+ flags: [Enterprise, FIPS]
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/agpl_requires_base_URL.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/agpl_requires_base_URL.golden
new file mode 100644
index 0000000000000..d50417064563b
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/agpl_requires_base_URL.golden
@@ -0,0 +1,12 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ enabled: true
+ pinned: false
+status:
+ active:
+ version: old-version
+ backup:
+ version: backup-version
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/backup_version_is_linked.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/backup_version_is_linked.golden
new file mode 100644
index 0000000000000..a153751d1854a
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/backup_version_is_linked.golden
@@ -0,0 +1,13 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ base_url: https://example.com
+ enabled: true
+ pinned: false
+status:
+ active:
+ version: 16.3.0
+ backup:
+ version: old-version
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/backup_version_kept_when_no_change.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/backup_version_kept_when_no_change.golden
new file mode 100644
index 0000000000000..d257cd6c30282
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/backup_version_kept_when_no_change.golden
@@ -0,0 +1,13 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ base_url: https://example.com
+ enabled: true
+ pinned: false
+status:
+ active:
+ version: 16.3.0
+ backup:
+ version: backup-version
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/backup_version_removed_on_install.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/backup_version_removed_on_install.golden
new file mode 100644
index 0000000000000..a153751d1854a
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/backup_version_removed_on_install.golden
@@ -0,0 +1,13 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ base_url: https://example.com
+ enabled: true
+ pinned: false
+status:
+ active:
+ version: 16.3.0
+ backup:
+ version: old-version
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/insecure_URL.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/insecure_URL.golden
new file mode 100644
index 0000000000000..297b00ce4ecf8
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/insecure_URL.golden
@@ -0,0 +1,11 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ base_url: http://example.com
+ enabled: true
+ pinned: false
+status:
+ active:
+ version: ""
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Enable/already_enabled.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/install_error.golden
similarity index 53%
rename from lib/autoupdate/agent/testdata/TestUpdater_Enable/already_enabled.golden
rename to lib/autoupdate/agent/testdata/TestUpdater_Update/install_error.golden
index e03f369eb1017..c68ebdf570843 100644
--- a/lib/autoupdate/agent/testdata/TestUpdater_Enable/already_enabled.golden
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/install_error.golden
@@ -2,8 +2,9 @@ version: v1
kind: update_config
spec:
proxy: localhost
- group: ""
- url_template: ""
+ path: /usr/local/bin
enabled: true
+ pinned: false
status:
- active_version: 16.3.0
+ active:
+ version: ""
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/invalid_metadata.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/invalid_metadata.golden
new file mode 100644
index 0000000000000..a771da1fc9e25
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/invalid_metadata.golden
@@ -0,0 +1,10 @@
+version: ""
+kind: ""
+spec:
+ proxy: localhost
+ path: ""
+ enabled: false
+ pinned: false
+status:
+ active:
+ version: ""
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/missing_path_during_window.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/missing_path_during_window.golden
new file mode 100644
index 0000000000000..f29735ca0230b
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/missing_path_during_window.golden
@@ -0,0 +1,12 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: ""
+ group: group
+ base_url: https://example.com
+ enabled: true
+ pinned: false
+status:
+ active:
+ version: old-version
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/pinned_version.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/pinned_version.golden
new file mode 100644
index 0000000000000..501506cf96c78
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/pinned_version.golden
@@ -0,0 +1,13 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ base_url: https://example.com
+ enabled: true
+ pinned: true
+status:
+ active:
+ version: old-version
+ backup:
+ version: backup-version
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/setup_fails.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/setup_fails.golden
new file mode 100644
index 0000000000000..10b36430ffed1
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/setup_fails.golden
@@ -0,0 +1,15 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ base_url: https://example.com
+ enabled: true
+ pinned: false
+status:
+ active:
+ version: old-version
+ backup:
+ version: backup-version
+ skip:
+ version: 16.3.0
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/skip_version.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/skip_version.golden
new file mode 100644
index 0000000000000..10b36430ffed1
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/skip_version.golden
@@ -0,0 +1,15 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ base_url: https://example.com
+ enabled: true
+ pinned: false
+status:
+ active:
+ version: old-version
+ backup:
+ version: backup-version
+ skip:
+ version: 16.3.0
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_disabled_during_window.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_disabled_during_window.golden
new file mode 100644
index 0000000000000..b6b43595a5903
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_disabled_during_window.golden
@@ -0,0 +1,12 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ group: group
+ base_url: https://example.com
+ enabled: false
+ pinned: false
+status:
+ active:
+ version: old-version
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_disabled_outside_of_window.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_disabled_outside_of_window.golden
new file mode 100644
index 0000000000000..b6b43595a5903
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_disabled_outside_of_window.golden
@@ -0,0 +1,12 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ group: group
+ base_url: https://example.com
+ enabled: false
+ pinned: false
+status:
+ active:
+ version: old-version
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_enabled_during_window.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_enabled_during_window.golden
new file mode 100644
index 0000000000000..d7fb9f94d1354
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_enabled_during_window.golden
@@ -0,0 +1,14 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ group: group
+ base_url: https://example.com
+ enabled: true
+ pinned: false
+status:
+ active:
+ version: 16.3.0
+ backup:
+ version: old-version
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_enabled_now,_not_started_or_enabled.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_enabled_now,_not_started_or_enabled.golden
new file mode 100644
index 0000000000000..d7fb9f94d1354
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_enabled_now,_not_started_or_enabled.golden
@@ -0,0 +1,14 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ group: group
+ base_url: https://example.com
+ enabled: true
+ pinned: false
+status:
+ active:
+ version: 16.3.0
+ backup:
+ version: old-version
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_enabled_now.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_enabled_now.golden
new file mode 100644
index 0000000000000..d7fb9f94d1354
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_enabled_now.golden
@@ -0,0 +1,14 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ group: group
+ base_url: https://example.com
+ enabled: true
+ pinned: false
+status:
+ active:
+ version: 16.3.0
+ backup:
+ version: old-version
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_enabled_outside_of_window.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_enabled_outside_of_window.golden
new file mode 100644
index 0000000000000..a4cac37b8733c
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/updates_enabled_outside_of_window.golden
@@ -0,0 +1,12 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ group: group
+ base_url: https://example.com
+ enabled: true
+ pinned: false
+status:
+ active:
+ version: old-version
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/version_already_installed_in_window.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/version_already_installed_in_window.golden
new file mode 100644
index 0000000000000..926667a2e7fc0
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/version_already_installed_in_window.golden
@@ -0,0 +1,11 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ base_url: https://example.com
+ enabled: true
+ pinned: false
+status:
+ active:
+ version: 16.3.0
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/version_already_installed_outside_of_window.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/version_already_installed_outside_of_window.golden
new file mode 100644
index 0000000000000..926667a2e7fc0
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/version_already_installed_outside_of_window.golden
@@ -0,0 +1,11 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ base_url: https://example.com
+ enabled: true
+ pinned: false
+status:
+ active:
+ version: 16.3.0
diff --git a/lib/autoupdate/agent/testdata/TestUpdater_Update/version_detects_as_linked.golden b/lib/autoupdate/agent/testdata/TestUpdater_Update/version_detects_as_linked.golden
new file mode 100644
index 0000000000000..a153751d1854a
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestUpdater_Update/version_detects_as_linked.golden
@@ -0,0 +1,13 @@
+version: v1
+kind: update_config
+spec:
+ proxy: localhost
+ path: /usr/local/bin
+ base_url: https://example.com
+ enabled: true
+ pinned: false
+status:
+ active:
+ version: 16.3.0
+ backup:
+ version: old-version
diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/no_namespace/dropin.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/no_namespace/dropin.golden
new file mode 100644
index 0000000000000..cb09143fe9fdf
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/no_namespace/dropin.golden
@@ -0,0 +1,5 @@
+# teleport-update
+# DO NOT EDIT THIS FILE
+[Service]
+Environment="TELEPORT_UPDATE_CONFIG_FILE=/opt/teleport/default/update.yaml"
+Environment="TELEPORT_UPDATE_INSTALL_DIR=/opt/teleport"
diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/no_namespace/needrestart.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/no_namespace/needrestart.golden
new file mode 100644
index 0000000000000..b5d6a74435cb2
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/no_namespace/needrestart.golden
@@ -0,0 +1 @@
+$nrconf{override_rc}{qr(^teleport\.service)} = 0;
diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/no_namespace/service.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/no_namespace/service.golden
new file mode 100644
index 0000000000000..b71ce08a9c4bb
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/no_namespace/service.golden
@@ -0,0 +1,8 @@
+# teleport-update
+# DO NOT EDIT THIS FILE
+[Unit]
+Description=Teleport auto-update service
+
+[Service]
+Type=oneshot
+ExecStart=/usr/local/bin/teleport-update --install-suffix= --install-dir="/opt/teleport" update
diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/no_namespace/timer.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/no_namespace/timer.golden
new file mode 100644
index 0000000000000..d14a43d679e53
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/no_namespace/timer.golden
@@ -0,0 +1,12 @@
+# teleport-update
+# DO NOT EDIT THIS FILE
+[Unit]
+Description=Teleport auto-update timer unit
+
+[Timer]
+OnActiveSec=1m
+OnUnitActiveSec=5m
+RandomizedDelaySec=1m
+
+[Install]
+WantedBy=teleport.service
diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/test_namespace/dropin.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/test_namespace/dropin.golden
new file mode 100644
index 0000000000000..dc6445dc6e7f9
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/test_namespace/dropin.golden
@@ -0,0 +1,5 @@
+# teleport-update
+# DO NOT EDIT THIS FILE
+[Service]
+Environment="TELEPORT_UPDATE_CONFIG_FILE=/opt/teleport/test/update.yaml"
+Environment="TELEPORT_UPDATE_INSTALL_DIR=/opt/teleport"
diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/test_namespace/needrestart.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/test_namespace/needrestart.golden
new file mode 100644
index 0000000000000..ad6bd606a74cb
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/test_namespace/needrestart.golden
@@ -0,0 +1 @@
+$nrconf{override_rc}{qr(^teleport_test\.service)} = 0;
diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/test_namespace/service.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/test_namespace/service.golden
new file mode 100644
index 0000000000000..832658a8d5a61
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/test_namespace/service.golden
@@ -0,0 +1,8 @@
+# teleport-update
+# DO NOT EDIT THIS FILE
+[Unit]
+Description=Teleport auto-update service
+
+[Service]
+Type=oneshot
+ExecStart=/usr/local/bin/teleport-update --install-suffix=test --install-dir="/opt/teleport" update
diff --git a/lib/autoupdate/agent/testdata/TestWriteConfigFiles/test_namespace/timer.golden b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/test_namespace/timer.golden
new file mode 100644
index 0000000000000..f57a3c08055bc
--- /dev/null
+++ b/lib/autoupdate/agent/testdata/TestWriteConfigFiles/test_namespace/timer.golden
@@ -0,0 +1,12 @@
+# teleport-update
+# DO NOT EDIT THIS FILE
+[Unit]
+Description=Teleport auto-update timer unit
+
+[Timer]
+OnActiveSec=1m
+OnUnitActiveSec=5m
+RandomizedDelaySec=1m
+
+[Install]
+WantedBy=teleport_test.service
diff --git a/lib/autoupdate/agent/updater.go b/lib/autoupdate/agent/updater.go
index 59df5f0b3ba85..5114ef7a223c8 100644
--- a/lib/autoupdate/agent/updater.go
+++ b/lib/autoupdate/agent/updater.go
@@ -23,76 +23,55 @@ import (
"crypto/tls"
"crypto/x509"
"errors"
- "io/fs"
"log/slog"
"net/http"
"os"
+ "os/exec"
"path/filepath"
- "strings"
+ "runtime"
+ "slices"
"time"
- "github.com/google/renameio/v2"
"github.com/gravitational/trace"
- "gopkg.in/yaml.v3"
"github.com/gravitational/teleport/api/client/webclient"
+ "github.com/gravitational/teleport/api/constants"
+ "github.com/gravitational/teleport/lib/autoupdate"
+ "github.com/gravitational/teleport/lib/client/debug"
libdefaults "github.com/gravitational/teleport/lib/defaults"
+ "github.com/gravitational/teleport/lib/modules"
libutils "github.com/gravitational/teleport/lib/utils"
)
const (
- // cdnURITemplate is the default template for the Teleport tgz download.
- cdnURITemplate = "https://cdn.teleport.dev/teleport{{if .Enterprise}}-ent{{end}}-v{{.Version}}-{{.OS}}-{{.Arch}}{{if .FIPS}}-fips{{end}}-bin.tar.gz"
+ // BinaryName specifies the name of the updater binary.
+ BinaryName = "teleport-update"
+)
+
+const (
+ // packageSystemDir is the location where packaged Teleport binaries and services are installed.
+ packageSystemDir = "/opt/teleport/system"
// reservedFreeDisk is the minimum required free space left on disk during downloads.
// TODO(sclevine): This value is arbitrary and could be replaced by, e.g., min(1%, 200mb) in the future
// to account for a range of disk sizes.
- reservedFreeDisk = 10_000_000 // 10 MB
+ reservedFreeDisk = 10_000_000
+ // debugSocketFileName is the name of Teleport's debug socket in the data dir.
+ debugSocketFileName = "debug.sock" // 10 MB
)
+// Log keys
const (
- // updateConfigName specifies the name of the file inside versionsDirName containing configuration for the teleport update.
- updateConfigName = "update.yaml"
-
- // UpdateConfig metadata
- updateConfigVersion = "v1"
- updateConfigKind = "update_config"
+ targetKey = "target_version"
+ activeKey = "active_version"
+ backupKey = "backup_version"
+ errorKey = "error"
)
-// UpdateConfig describes the update.yaml file schema.
-type UpdateConfig struct {
- // Version of the configuration file
- Version string `yaml:"version"`
- // Kind of configuration file (always "update_config")
- Kind string `yaml:"kind"`
- // Spec contains user-specified configuration.
- Spec UpdateSpec `yaml:"spec"`
- // Status contains state configuration.
- Status UpdateStatus `yaml:"status"`
-}
-
-// UpdateSpec describes the spec field in update.yaml.
-type UpdateSpec struct {
- // Proxy address
- Proxy string `yaml:"proxy"`
- // Group specifies the update group identifier for the agent.
- Group string `yaml:"group"`
- // URLTemplate for the Teleport tgz download URL.
- URLTemplate string `yaml:"url_template"`
- // Enabled controls whether auto-updates are enabled.
- Enabled bool `yaml:"enabled"`
-}
-
-// UpdateStatus describes the status field in update.yaml.
-type UpdateStatus struct {
- // ActiveVersion is the currently active Teleport version.
- ActiveVersion string `yaml:"active_version"`
-}
-
// NewLocalUpdater returns a new Updater that auto-updates local
// installations of the Teleport agent.
// The AutoUpdater uses an HTTP client with sane defaults for downloads, and
// will not fill disk to within 10 MB of available capacity.
-func NewLocalUpdater(cfg LocalUpdaterConfig) (*Updater, error) {
+func NewLocalUpdater(cfg LocalUpdaterConfig, ns *Namespace) (*Updater, error) {
certPool, err := x509.SystemCertPool()
if err != nil {
return nil, trace.Wrap(err)
@@ -112,19 +91,64 @@ func NewLocalUpdater(cfg LocalUpdaterConfig) (*Updater, error) {
if cfg.Log == nil {
cfg.Log = slog.Default()
}
+ if cfg.SystemDir == "" {
+ cfg.SystemDir = packageSystemDir
+ }
+ validator := Validator{Log: cfg.Log}
+ debugClient := debug.NewClient(filepath.Join(ns.dataDir, debugSocketFileName))
return &Updater{
Log: cfg.Log,
Pool: certPool,
InsecureSkipVerify: cfg.InsecureSkipVerify,
- ConfigPath: filepath.Join(cfg.VersionsDir, updateConfigName),
+ UpdateConfigPath: filepath.Join(ns.Dir(), updateConfigName),
+ TeleportConfigPath: ns.configFile,
+ DefaultProxyAddr: ns.defaultProxyAddr,
+ DefaultPathDir: ns.defaultPathDir,
Installer: &LocalInstaller{
- InstallDir: cfg.VersionsDir,
- HTTP: client,
- Log: cfg.Log,
-
+ InstallDir: filepath.Join(ns.Dir(), versionsDirName),
+ TargetServiceFile: ns.serviceFile,
+ SystemBinDir: filepath.Join(cfg.SystemDir, "bin"),
+ SystemServiceFile: filepath.Join(cfg.SystemDir, serviceDir, serviceName),
+ HTTP: client,
+ Log: cfg.Log,
ReservedFreeTmpDisk: reservedFreeDisk,
ReservedFreeInstallDisk: reservedFreeDisk,
+ TransformService: ns.replaceTeleportService,
+ ValidateBinary: validator.IsBinary,
+ Template: autoupdate.DefaultCDNURITemplate,
+ },
+ Process: &SystemdService{
+ ServiceName: filepath.Base(ns.serviceFile),
+ PIDFile: ns.pidFile,
+ Ready: debugClient,
+ Log: cfg.Log,
},
+ ReexecSetup: func(ctx context.Context, pathDir string, reload bool) error {
+ name := filepath.Join(pathDir, BinaryName)
+ if cfg.SelfSetup && runtime.GOOS == constants.LinuxOS {
+ name = "/proc/self/exe"
+ }
+ args := []string{
+ "--install-dir", ns.installDir,
+ "--install-suffix", ns.name,
+ "--log-format", cfg.LogFormat,
+ }
+ if cfg.Debug {
+ args = append(args, "--debug")
+ }
+ args = append(args, "setup", "--path", pathDir)
+ if reload {
+ args = append(args, "--reload")
+ }
+ cmd := exec.CommandContext(ctx, name, args...)
+ cmd.Stderr = os.Stderr
+ cmd.Stdout = os.Stdout
+ cfg.Log.InfoContext(ctx, "Executing new teleport-update binary to update configuration.")
+ defer cfg.Log.InfoContext(ctx, "Finished executing new teleport-update binary.")
+ return trace.Wrap(cmd.Run())
+ },
+ SetupNamespace: ns.Setup,
+ TeardownNamespace: ns.Teardown,
}, nil
}
@@ -138,8 +162,14 @@ type LocalUpdaterConfig struct {
// DownloadTimeout is a timeout for file download requests.
// Defaults to no timeout.
DownloadTimeout time.Duration
- // VersionsDir for installing Teleport (usually /var/lib/teleport/versions).
- VersionsDir string
+ // SystemDir for package-installed Teleport installations (usually /opt/teleport/system).
+ SystemDir string
+ // SelfSetup mode for using the current version of the teleport-update to setup the update service.
+ SelfSetup bool
+ // Debug logs enabled.
+ Debug bool
+ // LogFormat controls the format of logging. Can be either `json` or `text`.
+ LogFormat string
}
// Updater implements the agent-local logic for Teleport agent auto-updates.
@@ -150,192 +180,859 @@ type Updater struct {
Pool *x509.CertPool
// InsecureSkipVerify skips TLS verification.
InsecureSkipVerify bool
- // ConfigPath contains the path to the agent auto-updates configuration.
- ConfigPath string
+ // UpdateConfigPath contains the path to the agent auto-updates configuration.
+ UpdateConfigPath string
+ // TeleportConfig contains the path to Teleport's configuration.
+ TeleportConfigPath string
+ // DefaultProxyAddr contains Teleport's proxy address. This may differ from the updater's.
+ DefaultProxyAddr string
+ // DefaultPathDir contains the default path that Teleport binaries should be installed into.
+ DefaultPathDir string
// Installer manages installations of the Teleport agent.
Installer Installer
+ // Process manages a running instance of Teleport.
+ Process Process
+ // ReexecSetup re-execs teleport-update with the setup command.
+ // This configures the updater service, verifies the installation, and optionally reloads Teleport.
+ ReexecSetup func(ctx context.Context, path string, reload bool) error
+ // SetupNamespace configures the Teleport updater service for the current Namespace.
+ SetupNamespace func(ctx context.Context, path string) error
+ // TeardownNamespace removes all traces of the updater service in the current Namespace, including Teleport.
+ TeardownNamespace func(ctx context.Context) error
}
// Installer provides an API for installing Teleport agents.
type Installer interface {
- // Install the Teleport agent at version from the download template.
- // This function must be idempotent.
- Install(ctx context.Context, version, template string, flags InstallFlags) error
- // Remove the Teleport agent at version.
- // This function must be idempotent.
- Remove(ctx context.Context, version string) error
+ // Install the Teleport agent at revision from the download Template.
+ // If force is true, Install will remove broken revisions.
+ // Install must be idempotent.
+ Install(ctx context.Context, rev Revision, baseURL string, force bool) error
+ // Link the Teleport agent at the specified revision of Teleport into path.
+ // The revert function must restore the previous linking, returning false on any failure.
+ // If force is true, Link will overwrite non-symlinks.
+ // Link must be idempotent. Link's revert function must be idempotent.
+ Link(ctx context.Context, rev Revision, pathDir string, force bool) (revert func(context.Context) bool, err error)
+ // LinkSystem links the system installation of Teleport into the system linking location.
+ // The revert function must restore the previous linking, returning false on any failure.
+ // LinkSystem must be idempotent. LinkSystem's revert function must be idempotent.
+ LinkSystem(ctx context.Context) (revert func(context.Context) bool, err error)
+ // TryLink links the specified revision of Teleport into path.
+ // Unlike Link, TryLink will fail if existing links to other locations are present.
+ // TryLink must be idempotent.
+ TryLink(ctx context.Context, rev Revision, pathDir string) error
+ // TryLinkSystem links the system (package) installation of Teleport into the system linking location.
+ // Unlike LinkSystem, TryLinkSystem will fail if existing links to other locations are present.
+ // TryLinkSystem must be idempotent.
+ TryLinkSystem(ctx context.Context) error
+ // Unlink unlinks the specified revision of Teleport from path.
+ // Unlink must be idempotent.
+ Unlink(ctx context.Context, rev Revision, pathDir string) error
+ // UnlinkSystem unlinks the system (package) installation of Teleport from the system linking location.
+ // UnlinkSystem must be idempotent.
+ UnlinkSystem(ctx context.Context) error
+ // List the installed revisions of Teleport.
+ List(ctx context.Context) (revisions []Revision, err error)
+ // Remove the Teleport agent at revision.
+ // Remove must be idempotent.
+ Remove(ctx context.Context, rev Revision) error
+ // IsLinked returns true if the revision is linked to path.
+ IsLinked(ctx context.Context, rev Revision, pathDir string) (bool, error)
}
-// InstallFlags sets flags for the Teleport installation
-type InstallFlags int
-
-const (
- // FlagEnterprise installs enterprise Teleport
- FlagEnterprise InstallFlags = 1 << iota
- // FlagFIPS installs FIPS Teleport
- FlagFIPS
+var (
+ // ErrLinked is returned when a linked version cannot be operated on.
+ ErrLinked = errors.New("version is linked")
+ // ErrNotNeeded is returned when the operation is not needed.
+ ErrNotNeeded = errors.New("not needed")
+ // ErrNotSupported is returned when the operation is not supported on the platform.
+ ErrNotSupported = errors.New("not supported on this platform")
+ // ErrNoBinaries is returned when no binaries are available to be linked.
+ ErrNoBinaries = errors.New("no binaries available to link")
+ // ErrFilePresent is returned when a file is present.
+ ErrFilePresent = errors.New("file present")
)
+// Process provides an API for interacting with a running Teleport process.
+type Process interface {
+ // Reload must reload the Teleport process as gracefully as possible.
+ // If the process is not healthy after reloading, Reload must return an error.
+ // If the process did not require reloading, Reload must return ErrNotNeeded.
+ // E.g., if the process is not enabled, or it was already reloaded after the last Sync.
+ // If the type implementing Process does not support the system process manager,
+ // Reload must return ErrNotSupported.
+ Reload(ctx context.Context) error
+ // Sync must validate and synchronize process configuration.
+ // After the linked Teleport installation is changed, failure to call Sync without
+ // error before Reload may result in undefined behavior.
+ // If the type implementing Process does not support the system process manager,
+ // Sync must return ErrNotSupported.
+ Sync(ctx context.Context) error
+ // IsEnabled must return true if the Process is configured to run on system boot.
+ // If the type implementing Process does not support the system process manager,
+ // IsEnabled must return ErrNotSupported.
+ IsEnabled(ctx context.Context) (bool, error)
+ // IsActive must return true if the Process is currently running.
+ // If the type implementing Process does not support the system process manager,
+ // IsActive must return ErrNotSupported.
+ IsActive(ctx context.Context) (bool, error)
+ // IsPresent must return true if the Process is installed on the system.
+ // If the type implementing Process does not support the system process manager,
+ // IsPresent must return ErrNotSupported.
+ IsPresent(ctx context.Context) (bool, error)
+}
+
// OverrideConfig contains overrides for individual update operations.
// If validated, these overrides may be persisted to disk.
type OverrideConfig struct {
- // Proxy address, scheme and port optional.
- // Overrides existing value if specified.
- Proxy string
- // Group identifier for updates (e.g., staging)
- // Overrides existing value if specified.
- Group string
- // URLTemplate for the Teleport tgz download URL
- // Overrides existing value if specified.
- URLTemplate string
+ UpdateSpec
+
+ // The fields below override the behavior of
+ // Updater.Install for a single run.
+
// ForceVersion to the specified version.
ForceVersion string
+ // ForceFlags in installed Teleport.
+ ForceFlags autoupdate.InstallFlags
+ // AllowOverwrite of installed binaries.
+ AllowOverwrite bool
}
-// Enable enables agent updates and attempts an initial update.
-// If the initial update succeeds, auto-updates are enabled and the configuration is persisted.
-// Otherwise, the auto-updates configuration is not changed.
+func deref[T any](ptr *T) T {
+ if ptr != nil {
+ return *ptr
+ }
+ var t T
+ return t
+}
+
+func toPtr[T any](t T) *T {
+ return &t
+}
+
+// Install attempts an initial installation of Teleport.
+// If the initial installation succeeds, the override configuration is persisted.
+// Otherwise, the configuration is not changed.
// This function is idempotent.
-func (u *Updater) Enable(ctx context.Context, override OverrideConfig) error {
+func (u *Updater) Install(ctx context.Context, override OverrideConfig) error {
// Read configuration from update.yaml and override any new values passed as flags.
- cfg, err := u.readConfig(u.ConfigPath)
+ cfg, err := readConfig(u.UpdateConfigPath)
if err != nil {
- return trace.Errorf("failed to read %s: %w", updateConfigName, err)
+ return trace.Wrap(err, "failed to read %s", updateConfigName)
+ }
+ if err := validateConfigSpec(&cfg.Spec, override); err != nil {
+ return trace.Wrap(err)
+ }
+
+ if cfg.Spec.Proxy == "" {
+ cfg.Spec.Proxy = u.DefaultProxyAddr
+ } else if u.DefaultProxyAddr != "" &&
+ !sameProxies(cfg.Spec.Proxy, u.DefaultProxyAddr) {
+ u.Log.WarnContext(ctx, "Proxy specified in update.yaml does not match teleport.yaml. Unexpected updates may occur.", "update_proxy", cfg.Spec.Proxy, "teleport_proxy", u.DefaultProxyAddr)
+ }
+ if cfg.Spec.Path == "" {
+ cfg.Spec.Path = u.DefaultPathDir
}
- if override.Proxy != "" {
- cfg.Spec.Proxy = override.Proxy
+
+ active := cfg.Status.Active
+ skip := deref(cfg.Status.Skip)
+
+ // Lookup target version from the proxy.
+
+ resp, err := u.find(ctx, cfg)
+ if err != nil {
+ return trace.Wrap(err)
}
- if override.Group != "" {
- cfg.Spec.Group = override.Group
+ targetVersion := resp.Target.Version
+ targetFlags := resp.Target.Flags
+ targetFlags |= override.ForceFlags
+ if override.ForceVersion != "" {
+ targetVersion = override.ForceVersion
}
- if override.URLTemplate != "" {
- cfg.Spec.URLTemplate = override.URLTemplate
+ target := NewRevision(targetVersion, targetFlags)
+
+ switch target.Version {
+ case "":
+ return trace.Errorf("agent version not available from Teleport cluster")
+ case skip.Version:
+ u.Log.WarnContext(ctx, "Target version was previously marked as broken. Retrying update.", targetKey, target, activeKey, active)
+ default:
+ u.Log.InfoContext(ctx, "Initiating installation.", targetKey, target, activeKey, active)
}
- cfg.Spec.Enabled = true
- if err := validateUpdatesSpec(&cfg.Spec); err != nil {
+
+ if err := u.update(ctx, cfg, target, override.AllowOverwrite, resp.AGPL); err != nil {
+ if errors.Is(err, ErrFilePresent) && !override.AllowOverwrite {
+ u.Log.WarnContext(ctx, "Use --overwrite to force removal of existing binaries installed via script.")
+ u.Log.WarnContext(ctx, "If a teleport rpm or deb package is installed, upgrade it to the latest version and retry. DO NOT USE --overwrite.")
+ }
return trace.Wrap(err)
}
+ if target.Version == skip.Version {
+ cfg.Status.Skip = nil
+ }
- // Lookup target version from the proxy.
- addr, err := libutils.ParseAddr(cfg.Spec.Proxy)
+ // Only write the configuration file if the initial update succeeds.
+ // Note: skip_version is never set on failed enable, only failed update.
+
+ if err := writeConfig(u.UpdateConfigPath, cfg); err != nil {
+ return trace.Wrap(err, "failed to write %s", updateConfigName)
+ }
+ u.Log.InfoContext(ctx, "Configuration updated.")
+ return trace.Wrap(u.notices(ctx))
+}
+
+// sameProxies returns true if both proxies addresses are the same.
+// Note that the port is defaulted to 443, which is different from teleport.yaml's default.
+func sameProxies(a, b string) bool {
+ const defaultPort = 443
+ if a == b {
+ return true
+ }
+ addrA, err := libutils.ParseAddr(a)
if err != nil {
- return trace.Errorf("failed to parse proxy server address: %w", err)
- }
-
- desiredVersion := override.ForceVersion
- if desiredVersion == "" {
- resp, err := webclient.Find(&webclient.Config{
- Context: ctx,
- ProxyAddr: addr.Addr,
- Insecure: u.InsecureSkipVerify,
- Timeout: 30 * time.Second,
- //Group: cfg.Spec.Group, // TODO(sclevine): add web API for verssion
- Pool: u.Pool,
- })
- if err != nil {
- return trace.Errorf("failed to request version from proxy: %w", err)
+ return false
+ }
+ addrB, err := libutils.ParseAddr(b)
+ if err != nil {
+ return false
+ }
+ return addrA.Host() == addrB.Host() &&
+ addrA.Port(defaultPort) == addrB.Port(defaultPort)
+}
+
+// Remove removes everything created by the updater for the given namespace.
+// Before attempting this, Remove attempts to gracefully recover the system-packaged version of Teleport (if present).
+// This function is idempotent.
+func (u *Updater) Remove(ctx context.Context, force bool) error {
+ cfg, err := readConfig(u.UpdateConfigPath)
+ if err != nil {
+ return trace.Wrap(err, "failed to read %s", updateConfigName)
+ }
+ if err := validateConfigSpec(&cfg.Spec, OverrideConfig{}); err != nil {
+ return trace.Wrap(err)
+ }
+ active := cfg.Status.Active
+ if active.Version == "" {
+ u.Log.InfoContext(ctx, "No installation of Teleport managed by the updater. Removing updater configuration.")
+ if err := u.TeardownNamespace(ctx); err != nil {
+ return trace.Wrap(err)
}
- desiredVersion, _ = "16.3.0", resp // TODO(sclevine): add web API for version
- //desiredVersion := resp.AutoUpdate.AgentVersion
+ u.Log.InfoContext(ctx, "Automatic update configuration for Teleport successfully uninstalled.")
+ return nil
}
- if desiredVersion == "" {
- return trace.Errorf("agent version not available from Teleport cluster")
+ // Do not link system package installation if the installation we are removing
+ // is not installed into /usr/local/bin.
+ if filepath.Clean(cfg.Spec.Path) != filepath.Clean(defaultPathDir) {
+ return u.removeWithoutSystem(ctx, cfg, force)
}
- // If the active version and target don't match, kick off upgrade.
- template := cfg.Spec.URLTemplate
- if template == "" {
- template = cdnURITemplate
+ revert, err := u.Installer.LinkSystem(ctx)
+ if errors.Is(err, ErrNoBinaries) {
+ return u.removeWithoutSystem(ctx, cfg, force)
}
- err = u.Installer.Install(ctx, desiredVersion, template, 0) // TODO(sclevine): add web API for flags
if err != nil {
+ return trace.Wrap(err, "failed to link")
+ }
+
+ u.Log.InfoContext(ctx, "Updater-managed installation of Teleport detected. Restoring packaged version of Teleport before removing.")
+
+ revertConfig := func(ctx context.Context) bool {
+ if ok := revert(ctx); !ok {
+ u.Log.ErrorContext(ctx, "Failed to revert Teleport symlinks. Installation likely broken.")
+ return false
+ }
+ if err := u.Process.Sync(ctx); err != nil {
+ u.Log.ErrorContext(ctx, "Failed to revert systemd configuration after failed restart.", errorKey, err)
+ return false
+ }
+ return true
+ }
+
+ // Sync systemd.
+
+ err = u.Process.Sync(ctx)
+ if errors.Is(err, context.Canceled) {
+ return trace.Errorf("sync canceled")
+ }
+ if errors.Is(err, ErrNotSupported) {
+ u.Log.WarnContext(ctx, "Not syncing systemd configuration because systemd is not running.")
+ } else if err != nil {
+ // If sync fails, we may have left the host in a bad state, so we revert linking and re-Sync.
+ u.Log.ErrorContext(ctx, "Reverting symlinks due to invalid configuration.")
+ if ok := revertConfig(ctx); ok {
+ u.Log.WarnContext(ctx, "Teleport updater encountered a configuration error and successfully reverted the installation.")
+ }
+ return trace.Wrap(err, "failed to validate configuration for system package version of Teleport")
+ }
+
+ // Restart Teleport.
+
+ u.Log.InfoContext(ctx, "Teleport package successfully restored.")
+ err = u.Process.Reload(ctx)
+ if errors.Is(err, context.Canceled) {
+ return trace.Errorf("reload canceled")
+ }
+ if err != nil &&
+ !errors.Is(err, ErrNotNeeded) && // no output if restart not needed
+ !errors.Is(err, ErrNotSupported) { // already logged above for Sync
+
+ // If reloading Teleport at the new version fails, revert and reload.
+ u.Log.ErrorContext(ctx, "Reverting symlinks due to failed restart.")
+ if ok := revertConfig(ctx); ok {
+ if err := u.Process.Reload(ctx); err != nil && !errors.Is(err, ErrNotNeeded) {
+ u.Log.ErrorContext(ctx, "Failed to reload Teleport after reverting. Installation likely broken.", errorKey, err)
+ } else {
+ u.Log.WarnContext(ctx, "Teleport updater detected an error with the new installation and successfully reverted it.")
+ }
+ }
+ return trace.Wrap(err, "failed to start system package version of Teleport")
+ }
+ u.Log.InfoContext(ctx, "Auto-updating Teleport removed and replaced by Teleport package.", "version", active)
+ if err := u.TeardownNamespace(ctx); err != nil {
return trace.Wrap(err)
}
- if cfg.Status.ActiveVersion != desiredVersion {
- u.Log.InfoContext(ctx, "Target version successfully installed.", "version", desiredVersion)
+ u.Log.InfoContext(ctx, "Auto-update configuration for Teleport successfully uninstalled.")
+ return nil
+}
+
+func (u *Updater) removeWithoutSystem(ctx context.Context, cfg *UpdateConfig, force bool) error {
+ if !force {
+ u.Log.ErrorContext(ctx, "No packaged installation of Teleport was found, and --force was not passed. Refusing to remove Teleport from this system.")
+ return trace.Errorf("unable to remove Teleport completely without --force")
} else {
- u.Log.InfoContext(ctx, "Target version successfully validated.", "version", desiredVersion)
+ u.Log.WarnContext(ctx, "No packaged installation of Teleport was found, and --force was passed. Teleport will be removed from this system.")
}
- cfg.Status.ActiveVersion = desiredVersion
-
- // Always write the configuration file if enable succeeds.
- if err := u.writeConfig(u.ConfigPath, cfg); err != nil {
- return trace.Errorf("failed to write %s: %w", updateConfigName, err)
+ u.Log.InfoContext(ctx, "Updater-managed installation of Teleport detected. Attempting to unlink and remove.")
+ ok, err := isActiveOrEnabled(ctx, u.Process)
+ if err != nil && !errors.Is(err, ErrNotSupported) {
+ return trace.Wrap(err)
}
- u.Log.InfoContext(ctx, "Configuration updated.")
+ if ok {
+ return trace.Errorf("refusing to remove active installation of Teleport, please stop and disable Teleport first")
+ }
+ if err := u.Installer.Unlink(ctx, cfg.Status.Active, cfg.Spec.Path); err != nil {
+ return trace.Wrap(err)
+ }
+ u.Log.InfoContext(ctx, "Teleport uninstalled.", "version", cfg.Status.Active)
+ if err := u.TeardownNamespace(ctx); err != nil {
+ return trace.Wrap(err)
+ }
+ u.Log.InfoContext(ctx, "Automatic update configuration for Teleport successfully uninstalled.")
return nil
}
-func validateUpdatesSpec(spec *UpdateSpec) error {
- if spec.URLTemplate != "" &&
- !strings.HasPrefix(strings.ToLower(spec.URLTemplate), "https://") {
- return trace.Errorf("Teleport download URL must use TLS (https://)")
+// isActiveOrEnabled returns true if the service is active or enabled.
+func isActiveOrEnabled(ctx context.Context, s Process) (bool, error) {
+ enabled, err := s.IsEnabled(ctx)
+ if err != nil {
+ return false, trace.Wrap(err)
+ }
+ if enabled {
+ return true, nil
+ }
+ active, err := s.IsActive(ctx)
+ if err != nil {
+ return false, trace.Wrap(err)
}
+ if active {
+ return true, nil
+ }
+ return false, nil
+}
- if spec.Proxy == "" {
- return trace.Errorf("Teleport proxy URL must be specified with --proxy or present in %s", updateConfigName)
+// Status returns all available local and remote fields related to agent auto-updates.
+// Status is safe to run concurrently with other Updater commands.
+func (u *Updater) Status(ctx context.Context) (Status, error) {
+ var out Status
+ // Read configuration from update.yaml.
+ cfg, err := readConfig(u.UpdateConfigPath)
+ if err != nil {
+ return out, trace.Wrap(err, "failed to read %s", updateConfigName)
}
- return nil
+ if err := validateConfigSpec(&cfg.Spec, OverrideConfig{}); err != nil {
+ return out, trace.Wrap(err)
+ }
+ out.UpdateSpec = cfg.Spec
+ out.UpdateStatus = cfg.Status
+
+ // Lookup target version from the proxy.
+ resp, err := u.find(ctx, cfg)
+ if err != nil {
+ return out, trace.Wrap(err)
+ }
+ out.FindResp = resp
+ return out, nil
}
// Disable disables agent auto-updates.
// This function is idempotent.
func (u *Updater) Disable(ctx context.Context) error {
- cfg, err := u.readConfig(u.ConfigPath)
+ cfg, err := readConfig(u.UpdateConfigPath)
if err != nil {
- return trace.Errorf("failed to read %s: %w", updateConfigName, err)
+ return trace.Wrap(err, "failed to read %s", updateConfigName)
}
if !cfg.Spec.Enabled {
u.Log.InfoContext(ctx, "Automatic updates already disabled.")
return nil
}
cfg.Spec.Enabled = false
- if err := u.writeConfig(u.ConfigPath, cfg); err != nil {
- return trace.Errorf("failed to write %s: %w", updateConfigName, err)
+ if err := writeConfig(u.UpdateConfigPath, cfg); err != nil {
+ return trace.Wrap(err, "failed to write %s", updateConfigName)
}
return nil
}
-// readConfig reads UpdateConfig from a file.
-func (*Updater) readConfig(path string) (*UpdateConfig, error) {
- f, err := os.Open(path)
- if errors.Is(err, fs.ErrNotExist) {
- return &UpdateConfig{
- Version: updateConfigVersion,
- Kind: updateConfigKind,
- }, nil
+// Unpin allows the current version to be changed by Update.
+// This function is idempotent.
+func (u *Updater) Unpin(ctx context.Context) error {
+ cfg, err := readConfig(u.UpdateConfigPath)
+ if err != nil {
+ return trace.Wrap(err, "failed to read %s", updateConfigName)
+ }
+ if !cfg.Spec.Pinned {
+ u.Log.InfoContext(ctx, "Current version not pinned.", activeKey, cfg.Status.Active)
+ return nil
}
+ cfg.Spec.Pinned = false
+ if err := writeConfig(u.UpdateConfigPath, cfg); err != nil {
+ return trace.Wrap(err, "failed to write %s", updateConfigName)
+ }
+ return nil
+}
+
+// Update initiates an agent update.
+// If the update succeeds, the new installed version is marked as active.
+// Otherwise, the auto-updates configuration is not changed.
+// Unlike Enable, Update will not validate or repair the current version.
+// This function is idempotent.
+func (u *Updater) Update(ctx context.Context, now bool) error {
+ // Read configuration from update.yaml and override any new values passed as flags.
+ cfg, err := readConfig(u.UpdateConfigPath)
if err != nil {
- return nil, trace.Errorf("failed to open: %w", err)
+ return trace.Wrap(err, "failed to read %s", updateConfigName)
}
- defer f.Close()
- var cfg UpdateConfig
- if err := yaml.NewDecoder(f).Decode(&cfg); err != nil {
- return nil, trace.Errorf("failed to parse: %w", err)
+ if err := validateConfigSpec(&cfg.Spec, OverrideConfig{}); err != nil {
+ return trace.Wrap(err)
}
- if k := cfg.Kind; k != updateConfigKind {
- return nil, trace.Errorf("invalid kind %q", k)
+ if u.DefaultProxyAddr != "" &&
+ !sameProxies(cfg.Spec.Proxy, u.DefaultProxyAddr) {
+ u.Log.WarnContext(ctx, "Proxy specified in update.yaml does not match teleport.yaml. Unexpected updates may occur.", "update_proxy", cfg.Spec.Proxy, "teleport_proxy", u.DefaultProxyAddr)
}
- if v := cfg.Version; v != updateConfigVersion {
- return nil, trace.Errorf("invalid version %q", v)
+
+ active := cfg.Status.Active
+ skip := deref(cfg.Status.Skip)
+ if !cfg.Spec.Enabled {
+ u.Log.InfoContext(ctx, "Automatic updates disabled.", activeKey, active)
+ return nil
}
- return &cfg, nil
-}
-// writeConfig writes UpdateConfig to a file atomically, ensuring the file cannot be corrupted.
-func (*Updater) writeConfig(filename string, cfg *UpdateConfig) error {
- opts := []renameio.Option{
- renameio.WithPermissions(0755),
- renameio.WithExistingPermissions(),
+ if cfg.Spec.Path == "" {
+ return trace.Errorf("failed to read destination path for binary links from %s", updateConfigName)
}
- t, err := renameio.NewPendingFile(filename, opts...)
+
+ resp, err := u.find(ctx, cfg)
if err != nil {
return trace.Wrap(err)
}
- defer t.Cleanup()
- err = yaml.NewEncoder(t).Encode(cfg)
+ target := resp.Target
+
+ if cfg.Spec.Pinned {
+ switch target {
+ case active:
+ u.Log.InfoContext(ctx, "Teleport is up-to-date. Installation is pinned to prevent future updates.", activeKey, active)
+ default:
+ u.Log.InfoContext(ctx, "Teleport version is pinned. Skipping update.", targetKey, target, activeKey, active)
+ }
+ return nil
+ }
+
+ // If a version fails and is marked skip, we ignore any edition changes as well.
+ // If a cluster is broadcasting a version that failed to start, changing ent/fips is unlikely to fix the issue.
+
+ if !resp.InWindow && !now {
+ switch {
+ case target.Version == "":
+ u.Log.WarnContext(ctx, "Cannot determine target agent version. Waiting for both version and update window.")
+ case target == active:
+ u.Log.InfoContext(ctx, "Teleport is up-to-date. Update window is not active.", activeKey, active)
+ case target.Version == skip.Version:
+ u.Log.InfoContext(ctx, "Update available, but the new version is marked as broken. Update window is not active.", targetKey, target, activeKey, active)
+ default:
+ u.Log.InfoContext(ctx, "Update available, but update window is not active.", targetKey, target, activeKey, active)
+ }
+ return nil
+ }
+
+ switch {
+ case target.Version == "":
+ if resp.InWindow {
+ u.Log.ErrorContext(ctx, "Update window is active, but target version is not available.", activeKey, active)
+ }
+ return trace.Errorf("target version missing")
+ case target == active:
+ if resp.InWindow {
+ u.Log.InfoContext(ctx, "Teleport is up-to-date. Update window is active, but no action is needed.", activeKey, active)
+ } else {
+ u.Log.InfoContext(ctx, "Teleport is up-to-date. No action is needed.", activeKey, active)
+ }
+ return nil
+ case target.Version == skip.Version:
+ u.Log.InfoContext(ctx, "Update available, but the new version is marked as broken. Skipping update.", targetKey, target, activeKey, active)
+ return nil
+ default:
+ u.Log.InfoContext(ctx, "Update available. Initiating update.", targetKey, target, activeKey, active)
+ }
+ if !now {
+ time.Sleep(resp.Jitter)
+ }
+
+ updateErr := u.update(ctx, cfg, target, false, resp.AGPL)
+ writeErr := writeConfig(u.UpdateConfigPath, cfg)
+ if writeErr != nil {
+ writeErr = trace.Wrap(writeErr, "failed to write %s", updateConfigName)
+ } else {
+ u.Log.InfoContext(ctx, "Configuration updated.")
+ }
+ // Show notices last
+ if updateErr == nil && now {
+ updateErr = u.notices(ctx)
+ }
+ return trace.NewAggregate(updateErr, writeErr)
+}
+
+func (u *Updater) find(ctx context.Context, cfg *UpdateConfig) (FindResp, error) {
+ if cfg.Spec.Proxy == "" {
+ return FindResp{}, trace.Errorf("Teleport proxy URL must be specified with --proxy or present in %s", updateConfigName)
+ }
+ addr, err := libutils.ParseAddr(cfg.Spec.Proxy)
+ if err != nil {
+ return FindResp{}, trace.Wrap(err, "failed to parse proxy server address")
+ }
+ resp, err := webclient.Find(&webclient.Config{
+ Context: ctx,
+ ProxyAddr: addr.Addr,
+ Insecure: u.InsecureSkipVerify,
+ Timeout: 30 * time.Second,
+ UpdateGroup: cfg.Spec.Group,
+ Pool: u.Pool,
+ })
+ if err != nil {
+ return FindResp{}, trace.Wrap(err, "failed to request version from proxy")
+ }
+ var flags autoupdate.InstallFlags
+ var agpl bool
+ switch resp.Edition {
+ case modules.BuildEnterprise:
+ flags |= autoupdate.FlagEnterprise
+ case modules.BuildCommunity:
+ case modules.BuildOSS:
+ agpl = true
+ default:
+ agpl = true
+ u.Log.WarnContext(ctx, "Unknown edition detected, defaulting to OSS.", "edition", resp.Edition)
+ }
+ if resp.FIPS {
+ flags |= autoupdate.FlagFIPS
+ }
+ jitterSec := resp.AutoUpdate.AgentUpdateJitterSeconds
+ return FindResp{
+ Target: NewRevision(resp.AutoUpdate.AgentVersion, flags),
+ InWindow: resp.AutoUpdate.AgentAutoUpdate,
+ Jitter: time.Duration(jitterSec) * time.Second,
+ AGPL: agpl,
+ }, nil
+}
+
+func (u *Updater) removeRevision(ctx context.Context, cfg *UpdateConfig, rev Revision) error {
+ linked, err := u.Installer.IsLinked(ctx, rev, cfg.Spec.Path)
+ if err != nil {
+ return trace.Wrap(err, "failed to determine if linked")
+ }
+ if linked {
+ return trace.Wrap(ErrLinked, "refusing to remove")
+ }
+ return trace.Wrap(u.Installer.Remove(ctx, rev))
+}
+
+func (u *Updater) update(ctx context.Context, cfg *UpdateConfig, target Revision, force, agpl bool) error {
+ baseURL := cfg.Spec.BaseURL
+ if baseURL == "" {
+ if agpl {
+ return trace.Errorf("--base-url flag must be specified for AGPL edition of Teleport")
+ }
+ baseURL = autoupdate.DefaultBaseURL
+ }
+
+ active := cfg.Status.Active
+ backup := deref(cfg.Status.Backup)
+ switch backup {
+ case Revision{}, target, active:
+ default:
+ if target == active {
+ // Keep backup version if we are only verifying active version
+ break
+ }
+ err := u.removeRevision(ctx, cfg, backup)
+ if err != nil {
+ // this could happen if it was already removed due to a failed installation
+ u.Log.WarnContext(ctx, "Failed to remove backup version of Teleport before new install.", errorKey, err, backupKey, backup)
+ }
+ }
+
+ // Install and link the desired version (or validate existing installation)
+
+ linked, err := u.Installer.IsLinked(ctx, target, cfg.Spec.Path)
+ if err != nil {
+ return trace.Wrap(err, "failed to determine if linked")
+ }
+ err = u.Installer.Install(ctx, target, baseURL, !linked)
+ if err != nil {
+ return trace.Wrap(err, "failed to install")
+ }
+
+ // If the target version has fewer binaries, this will leave old binaries linked.
+ // This may prevent the installation from being removed.
+ // Cleanup logic at the end of this function will ensure that they are removed
+ // eventually.
+
+ revert, err := u.Installer.Link(ctx, target, cfg.Spec.Path, force)
+ if err != nil {
+ return trace.Wrap(err, "failed to link")
+ }
+
+ // If we fail to revert after this point, the next update/enable will
+ // fix the link to restore the active version.
+
+ revertConfig := func(ctx context.Context) bool {
+ if target.Version != "" {
+ cfg.Status.Skip = toPtr(target)
+ }
+ if force {
+ u.Log.ErrorContext(ctx, "Unable to revert Teleport symlinks in overwrite mode. Installation likely broken.")
+ return false
+ }
+ if ok := revert(ctx); !ok {
+ u.Log.ErrorContext(ctx, "Failed to revert Teleport symlinks. Installation likely broken.")
+ return false
+ }
+ if err := u.SetupNamespace(ctx, cfg.Spec.Path); err != nil {
+ u.Log.ErrorContext(ctx, "Failed to revert configuration after failed restart.", errorKey, err)
+ return false
+ }
+ return true
+ }
+
+ if cfg.Status.Active != target {
+ err := u.ReexecSetup(ctx, cfg.Spec.Path, true)
+ if errors.Is(err, context.Canceled) {
+ return trace.Errorf("check canceled")
+ }
+ if err != nil {
+ // If reloading Teleport at the new version fails, revert and reload.
+ u.Log.ErrorContext(ctx, "Reverting symlinks due to failed restart.")
+ if ok := revertConfig(ctx); ok {
+ if err := u.Process.Reload(ctx); err != nil && !errors.Is(err, ErrNotNeeded) {
+ u.Log.ErrorContext(ctx, "Failed to reload Teleport after reverting. Installation likely broken.", errorKey, err)
+ } else {
+ u.Log.WarnContext(ctx, "Teleport updater detected an error with the new installation and successfully reverted it.")
+ }
+ }
+ return trace.Wrap(err, "failed to start new version %s of Teleport", target)
+ }
+ u.Log.InfoContext(ctx, "Target version successfully installed.", targetKey, target)
+
+ if r := cfg.Status.Active; r.Version != "" {
+ cfg.Status.Backup = toPtr(r)
+ }
+ cfg.Status.Active = target
+ } else {
+ err := u.ReexecSetup(ctx, cfg.Spec.Path, false)
+ if errors.Is(err, context.Canceled) {
+ return trace.Errorf("check canceled")
+ }
+ if err != nil {
+ // If sync fails, we may have left the host in a bad state, so we revert linking and re-Sync.
+ u.Log.ErrorContext(ctx, "Reverting symlinks due to invalid configuration.")
+ if ok := revertConfig(ctx); ok {
+ u.Log.WarnContext(ctx, "Teleport updater encountered a configuration error and successfully reverted the installation.")
+ }
+ return trace.Wrap(err, "failed to validate new version %s of Teleport", target)
+ }
+ u.Log.InfoContext(ctx, "Target version successfully validated.", targetKey, target)
+ }
+ if r := deref(cfg.Status.Backup); r.Version != "" {
+ u.Log.InfoContext(ctx, "Backup version set.", backupKey, r)
+ }
+
+ return trace.Wrap(u.cleanup(ctx, cfg, []Revision{
+ target, active, backup,
+ }))
+}
+
+// Setup writes updater configuration and verifies the Teleport installation.
+// If restart is true, Setup also restarts Teleport.
+// Setup is safe to run concurrently with other Updater commands.
+func (u *Updater) Setup(ctx context.Context, path string, restart bool) error {
+ // Setup teleport-updater configuration and sync systemd.
+
+ err := u.SetupNamespace(ctx, path)
+ if errors.Is(err, context.Canceled) {
+ return trace.Errorf("sync canceled")
+ }
+ if err != nil {
+ return trace.Wrap(err, "failed to setup updater")
+ }
+
+ present, err := u.Process.IsPresent(ctx)
+ if errors.Is(err, context.Canceled) {
+ return trace.Errorf("config check canceled")
+ }
+ if errors.Is(err, ErrNotSupported) {
+ u.Log.WarnContext(ctx, "Skipping all systemd setup because systemd is not running.")
+ return nil
+ }
+ if err != nil {
+ return trace.Wrap(err, "failed to determine if new version of Teleport has an installed systemd service")
+ }
+ if !present {
+ return trace.Errorf("cannot find systemd service for new version of Teleport, check SELinux settings")
+ }
+
+ // Restart Teleport if necessary.
+
+ if restart {
+ err = u.Process.Reload(ctx)
+ if errors.Is(err, context.Canceled) {
+ return trace.Errorf("reload canceled")
+ }
+ if err != nil &&
+ !errors.Is(err, ErrNotNeeded) { // skip if not needed
+ return trace.Wrap(err, "failed to reload Teleport")
+ }
+ }
+ return nil
+}
+
+// notices displays final notices after install or update.
+func (u *Updater) notices(ctx context.Context) error {
+ enabled, err := u.Process.IsEnabled(ctx)
+ if errors.Is(err, ErrNotSupported) {
+ u.Log.WarnContext(ctx, "Teleport is installed, but systemd is not present to start it.")
+ u.Log.WarnContext(ctx, "After configuring teleport.yaml, your system must also be configured to start Teleport.")
+ return nil
+ }
+ if err != nil {
+ return trace.Wrap(err, "failed to query Teleport systemd enabled status")
+ }
+ active, err := u.Process.IsActive(ctx)
+ if err != nil {
+ return trace.Wrap(err, "failed to query Teleport systemd active status")
+ }
+ if !enabled && active {
+ u.Log.WarnContext(ctx, "Teleport is installed and started, but not configured to start on boot.")
+ u.Log.WarnContext(ctx, "After configuring teleport.yaml, you can enable it with: systemctl enable teleport")
+ }
+ if !active && enabled {
+ u.Log.WarnContext(ctx, "Teleport is installed and enabled at boot, but not running.")
+ u.Log.WarnContext(ctx, "After configuring teleport.yaml, you can start it with: systemctl start teleport")
+ }
+ if !active && !enabled {
+ u.Log.WarnContext(ctx, "Teleport is installed, but not running or enabled at boot.")
+ u.Log.WarnContext(ctx, "After configuring teleport.yaml, you can enable and start it with: systemctl enable teleport --now")
+ }
+
+ return nil
+}
+
+// cleanup orphan installations
+func (u *Updater) cleanup(ctx context.Context, cfg *UpdateConfig, keep []Revision) error {
+ revs, err := u.Installer.List(ctx)
+ if err != nil {
+ u.Log.ErrorContext(ctx, "Failed to read installed versions.", errorKey, err)
+ return nil
+ }
+ if len(revs) < 3 {
+ return nil
+ }
+ u.Log.WarnContext(ctx, "More than two versions of Teleport are installed. Removing unused versions.", "count", len(revs))
+ for _, v := range revs {
+ if v.Version == "" || slices.Contains(keep, v) {
+ continue
+ }
+ err := u.removeRevision(ctx, cfg, v)
+ if errors.Is(err, ErrLinked) {
+ u.Log.WarnContext(ctx, "Refusing to remove version with orphan links.", "version", v)
+ continue
+ }
+ if err != nil {
+ u.Log.WarnContext(ctx, "Failed to remove unused version of Teleport.", errorKey, err, "version", v)
+ continue
+ }
+ u.Log.WarnContext(ctx, "Deleted unused version of Teleport.", "version", v)
+ }
+ return nil
+}
+
+// LinkPackage creates links from the system (package) installation of Teleport, if they are needed.
+// LinkPackage returns nil and warns if an auto-updates version is already linked, but auto-updates is disabled.
+// LinkPackage returns an error only if an unknown version of Teleport is present (e.g., manually copied files).
+// This function is idempotent.
+func (u *Updater) LinkPackage(ctx context.Context) error {
+ cfg, err := readConfig(u.UpdateConfigPath)
if err != nil {
+ return trace.Wrap(err, "failed to read %s", updateConfigName)
+ }
+ if err := validateConfigSpec(&cfg.Spec, OverrideConfig{}); err != nil {
return trace.Wrap(err)
}
- return trace.Wrap(t.CloseAtomicallyReplace())
+ active := cfg.Status.Active
+ if cfg.Spec.Enabled {
+ u.Log.InfoContext(ctx, "Automatic updates is enabled. Skipping system package link.", activeKey, active)
+ return nil
+ }
+ if cfg.Spec.Pinned {
+ u.Log.InfoContext(ctx, "Automatic update version is pinned. Skipping system package link.", activeKey, active)
+ return nil
+ }
+ // If an active version is set, but auto-updates is disabled, try to link the system installation in case the config is stale.
+ // If any links are present, this will return ErrLinked and not create any system links.
+ // This state is important to log as a warning,
+ if err := u.Installer.TryLinkSystem(ctx); errors.Is(err, ErrLinked) {
+ u.Log.WarnContext(ctx, "Automatic updates is disabled, but a non-package version of Teleport is linked.", activeKey, active)
+ return nil
+ } else if err != nil {
+ return trace.Wrap(err, "failed to link system package installation")
+ }
+
+ // If syncing succeeds, ensure the installed systemd service can be found via systemctl.
+ // SELinux contexts can interfere with systemctl's ability to read service files.
+ if err := u.Process.Sync(ctx); errors.Is(err, ErrNotSupported) {
+ u.Log.WarnContext(ctx, "Systemd is not installed. Skipping sync.")
+ } else if err != nil {
+ return trace.Wrap(err, "failed to sync systemd configuration")
+ } else {
+ present, err := u.Process.IsPresent(ctx)
+ if err != nil {
+ return trace.Wrap(err, "failed to determine if Teleport has an installed systemd service")
+ }
+ if !present {
+ return trace.Errorf("cannot find systemd service for Teleport, check SELinux settings")
+ }
+ }
+ u.Log.InfoContext(ctx, "Successfully linked system package installation.")
+ return nil
+}
+
+// UnlinkPackage removes links from the system (package) installation of Teleport, if they are present.
+// This function is idempotent.
+func (u *Updater) UnlinkPackage(ctx context.Context) error {
+ if err := u.Installer.UnlinkSystem(ctx); err != nil {
+ return trace.Wrap(err, "failed to unlink system package installation")
+ }
+ if err := u.Process.Sync(ctx); errors.Is(err, ErrNotSupported) {
+ u.Log.WarnContext(ctx, "Systemd is not installed. Skipping sync.")
+ } else if err != nil {
+ return trace.Wrap(err, "failed to sync systemd configuration")
+ }
+ u.Log.InfoContext(ctx, "Successfully unlinked system package installation.")
+ return nil
}
diff --git a/lib/autoupdate/agent/updater_test.go b/lib/autoupdate/agent/updater_test.go
index 5ec93b43be9cc..cc789c20bca08 100644
--- a/lib/autoupdate/agent/updater_test.go
+++ b/lib/autoupdate/agent/updater_test.go
@@ -20,12 +20,14 @@ package agent
import (
"context"
+ "encoding/json"
"errors"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"regexp"
+ "slices"
"strings"
"testing"
@@ -33,6 +35,8 @@ import (
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
+ "github.com/gravitational/teleport/api/client/webclient"
+ "github.com/gravitational/teleport/lib/autoupdate"
"github.com/gravitational/teleport/lib/utils/testutils/golden"
)
@@ -79,10 +83,1154 @@ func TestUpdater_Disable(t *testing.T) {
}
for _, tt := range tests {
- tt := tt
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
- cfgPath := filepath.Join(dir, "update.yaml")
+ ns := &Namespace{installDir: dir}
+ _, err := ns.Init()
+ require.NoError(t, err)
+ cfgPath := filepath.Join(ns.Dir(), updateConfigName)
+ updater, err := NewLocalUpdater(LocalUpdaterConfig{
+ InsecureSkipVerify: true,
+ }, ns)
+ require.NoError(t, err)
+
+ // Create config file only if provided in test case
+ if tt.cfg != nil {
+ b, err := yaml.Marshal(tt.cfg)
+ require.NoError(t, err)
+ err = os.WriteFile(cfgPath, b, 0600)
+ require.NoError(t, err)
+ }
+
+ err = updater.Disable(context.Background())
+ if tt.errMatch != "" {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errMatch)
+ return
+ }
+ require.NoError(t, err)
+
+ data, err := os.ReadFile(cfgPath)
+
+ // If no config is present, disable should not create it
+ if tt.cfg == nil {
+ require.ErrorIs(t, err, os.ErrNotExist)
+ return
+ }
+ require.NoError(t, err)
+
+ if golden.ShouldSet() {
+ golden.Set(t, data)
+ }
+ require.Equal(t, string(golden.Get(t)), string(data))
+ })
+ }
+}
+
+func TestUpdater_Unpin(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ cfg *UpdateConfig // nil -> file not present
+ errMatch string
+ }{
+ {
+ name: "pinned",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Pinned: true,
+ },
+ },
+ },
+ {
+ name: "not pinned",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Pinned: false,
+ },
+ },
+ },
+ {
+ name: "config does not exist",
+ },
+ {
+ name: "invalid metadata",
+ cfg: &UpdateConfig{
+ Spec: UpdateSpec{
+ Enabled: true,
+ },
+ },
+ errMatch: "invalid",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ dir := t.TempDir()
+ ns := &Namespace{installDir: dir}
+ _, err := ns.Init()
+ require.NoError(t, err)
+ cfgPath := filepath.Join(ns.Dir(), updateConfigName)
+
+ updater, err := NewLocalUpdater(LocalUpdaterConfig{
+ InsecureSkipVerify: true,
+ }, ns)
+ require.NoError(t, err)
+
+ // Create config file only if provided in test case
+ if tt.cfg != nil {
+ b, err := yaml.Marshal(tt.cfg)
+ require.NoError(t, err)
+ err = os.WriteFile(cfgPath, b, 0600)
+ require.NoError(t, err)
+ }
+
+ err = updater.Unpin(context.Background())
+ if tt.errMatch != "" {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errMatch)
+ return
+ }
+ require.NoError(t, err)
+
+ data, err := os.ReadFile(cfgPath)
+
+ // If no config is present, disable should not create it
+ if tt.cfg == nil {
+ require.ErrorIs(t, err, os.ErrNotExist)
+ return
+ }
+ require.NoError(t, err)
+
+ if golden.ShouldSet() {
+ golden.Set(t, data)
+ }
+ require.Equal(t, string(golden.Get(t)), string(data))
+ })
+ }
+}
+
+func TestUpdater_Update(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ cfg *UpdateConfig // nil -> file not present
+ flags autoupdate.InstallFlags
+ inWindow bool
+ now bool
+ agpl bool
+ installErr error
+ setupErr error
+ reloadErr error
+ notActive bool
+ notEnabled bool
+ linkedRevisions []Revision
+
+ removedRevisions []Revision
+ installedRevision Revision
+ installedBaseURL string
+ linkedRevision Revision
+ requestGroup string
+ reloadCalls int
+ revertCalls int
+ setupCalls int
+ restarted bool
+ errMatch string
+ }{
+ {
+ name: "updates enabled during window",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ Group: "group",
+ BaseURL: "https://example.com",
+ Enabled: true,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision("old-version", 0),
+ },
+ },
+ inWindow: true,
+
+ removedRevisions: []Revision{NewRevision("unknown-version", 0)},
+ installedRevision: NewRevision("16.3.0", 0),
+ installedBaseURL: "https://example.com",
+ linkedRevision: NewRevision("16.3.0", 0),
+ requestGroup: "group",
+ setupCalls: 1,
+ restarted: true,
+ },
+ {
+ name: "updates enabled now",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ Group: "group",
+ BaseURL: "https://example.com",
+ Enabled: true,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision("old-version", 0),
+ },
+ },
+ now: true,
+
+ removedRevisions: []Revision{NewRevision("unknown-version", 0)},
+ installedRevision: NewRevision("16.3.0", 0),
+ installedBaseURL: "https://example.com",
+ linkedRevision: NewRevision("16.3.0", 0),
+ requestGroup: "group",
+ setupCalls: 1,
+ restarted: true,
+ },
+ {
+ name: "updates enabled now, not started or enabled",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ Group: "group",
+ BaseURL: "https://example.com",
+ Enabled: true,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision("old-version", 0),
+ },
+ },
+ now: true,
+ notEnabled: true,
+ notActive: true,
+
+ removedRevisions: []Revision{NewRevision("unknown-version", 0)},
+ installedRevision: NewRevision("16.3.0", 0),
+ installedBaseURL: "https://example.com",
+ linkedRevision: NewRevision("16.3.0", 0),
+ requestGroup: "group",
+ setupCalls: 1,
+ restarted: true,
+ },
+ {
+ name: "updates disabled during window",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ Group: "group",
+ BaseURL: "https://example.com",
+ Enabled: false,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision("old-version", 0),
+ },
+ },
+ inWindow: true,
+ },
+ {
+ name: "missing path during window",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Group: "group",
+ BaseURL: "https://example.com",
+ Enabled: true,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision("old-version", 0),
+ },
+ },
+ inWindow: true,
+ errMatch: "destination path",
+ },
+ {
+ name: "updates enabled outside of window",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ Group: "group",
+ BaseURL: "https://example.com",
+ Enabled: true,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision("old-version", 0),
+ },
+ },
+ requestGroup: "group",
+ },
+ {
+ name: "updates disabled outside of window",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ Group: "group",
+ BaseURL: "https://example.com",
+ Enabled: false,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision("old-version", 0),
+ },
+ },
+ },
+ {
+ name: "insecure URL",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ BaseURL: "http://example.com",
+ Enabled: true,
+ },
+ },
+ inWindow: true,
+
+ errMatch: "must use TLS",
+ },
+ {
+ name: "install error",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ Enabled: true,
+ },
+ },
+ inWindow: true,
+ installErr: errors.New("install error"),
+
+ installedRevision: NewRevision("16.3.0", 0),
+ installedBaseURL: autoupdate.DefaultBaseURL,
+ errMatch: "install error",
+ },
+ {
+ name: "version already installed in window",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ BaseURL: "https://example.com",
+ Enabled: true,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision("16.3.0", 0),
+ },
+ },
+ inWindow: true,
+ },
+ {
+ name: "version already installed outside of window",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ BaseURL: "https://example.com",
+ Enabled: true,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision("16.3.0", 0),
+ },
+ },
+ },
+ {
+ name: "version detects as linked",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ BaseURL: "https://example.com",
+ Enabled: true,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision("old-version", 0),
+ },
+ },
+ linkedRevisions: []Revision{NewRevision("16.3.0", 0)},
+ inWindow: true,
+
+ installedRevision: NewRevision("16.3.0", 0),
+ installedBaseURL: "https://example.com",
+ linkedRevision: NewRevision("16.3.0", 0),
+ removedRevisions: []Revision{
+ NewRevision("unknown-version", 0),
+ },
+ setupCalls: 1,
+ restarted: true,
+ },
+ {
+ name: "backup version removed on install",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ BaseURL: "https://example.com",
+ Enabled: true,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision("old-version", 0),
+ Backup: toPtr(NewRevision("backup-version", 0)),
+ },
+ },
+ inWindow: true,
+
+ installedRevision: NewRevision("16.3.0", 0),
+ installedBaseURL: "https://example.com",
+ linkedRevision: NewRevision("16.3.0", 0),
+ removedRevisions: []Revision{
+ NewRevision("backup-version", 0),
+ NewRevision("unknown-version", 0),
+ },
+ setupCalls: 1,
+ restarted: true,
+ },
+ {
+ name: "backup version is linked",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ BaseURL: "https://example.com",
+ Enabled: true,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision("old-version", 0),
+ Backup: toPtr(NewRevision("backup-version", 0)),
+ },
+ },
+ inWindow: true,
+ linkedRevisions: []Revision{NewRevision("backup-version", 0)},
+
+ installedRevision: NewRevision("16.3.0", 0),
+ installedBaseURL: "https://example.com",
+ linkedRevision: NewRevision("16.3.0", 0),
+ removedRevisions: []Revision{
+ NewRevision("unknown-version", 0),
+ },
+ setupCalls: 1,
+ restarted: true,
+ },
+ {
+ name: "backup version kept when no change",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ BaseURL: "https://example.com",
+ Enabled: true,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision("16.3.0", 0),
+ Backup: toPtr(NewRevision("backup-version", 0)),
+ },
+ },
+ inWindow: true,
+ },
+ {
+ name: "config does not exist",
+ },
+ {
+ name: "FIPS and Enterprise flags",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ BaseURL: "https://example.com",
+ Enabled: true,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision("old-version", autoupdate.FlagEnterprise|autoupdate.FlagFIPS),
+ Backup: toPtr(NewRevision("backup-version", autoupdate.FlagEnterprise|autoupdate.FlagFIPS)),
+ },
+ },
+ inWindow: true,
+ flags: autoupdate.FlagEnterprise | autoupdate.FlagFIPS,
+
+ installedRevision: NewRevision("16.3.0", autoupdate.FlagEnterprise|autoupdate.FlagFIPS),
+ installedBaseURL: "https://example.com",
+ linkedRevision: NewRevision("16.3.0", autoupdate.FlagEnterprise|autoupdate.FlagFIPS),
+ removedRevisions: []Revision{
+ NewRevision("backup-version", autoupdate.FlagEnterprise|autoupdate.FlagFIPS),
+ NewRevision("unknown-version", 0),
+ },
+ setupCalls: 1,
+ restarted: true,
+ },
+ {
+ name: "invalid metadata",
+ cfg: &UpdateConfig{},
+ errMatch: "invalid",
+ },
+ {
+ name: "setup fails",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ BaseURL: "https://example.com",
+ Enabled: true,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision("old-version", 0),
+ Backup: toPtr(NewRevision("backup-version", 0)),
+ },
+ },
+ inWindow: true,
+ setupErr: errors.New("setup error"),
+
+ installedRevision: NewRevision("16.3.0", 0),
+ installedBaseURL: "https://example.com",
+ linkedRevision: NewRevision("16.3.0", 0),
+ removedRevisions: []Revision{
+ NewRevision("backup-version", 0),
+ },
+ reloadCalls: 1,
+ revertCalls: 1,
+ setupCalls: 1,
+ restarted: true,
+ errMatch: "setup error",
+ },
+ {
+ name: "agpl requires base URL",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ Enabled: true,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision("old-version", 0),
+ Backup: toPtr(NewRevision("backup-version", 0)),
+ },
+ },
+ inWindow: true,
+ agpl: true,
+
+ reloadCalls: 0,
+ revertCalls: 0,
+ setupCalls: 0,
+ errMatch: "AGPL",
+ },
+ {
+ name: "skip version",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ BaseURL: "https://example.com",
+ Enabled: true,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision("old-version", 0),
+ Backup: toPtr(NewRevision("backup-version", 0)),
+ Skip: toPtr(NewRevision("16.3.0", 0)),
+ },
+ },
+ inWindow: true,
+ },
+ {
+ name: "pinned version",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ BaseURL: "https://example.com",
+ Enabled: true,
+ Pinned: true,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision("old-version", 0),
+ Backup: toPtr(NewRevision("backup-version", 0)),
+ },
+ },
+ inWindow: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var requestedGroup string
+ server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ requestedGroup = r.URL.Query().Get("group")
+ config := webclient.PingResponse{
+ AutoUpdate: webclient.AutoUpdateSettings{
+ AgentVersion: "16.3.0",
+ AgentAutoUpdate: tt.inWindow,
+ },
+ }
+ config.Edition = "community"
+ if tt.flags&autoupdate.FlagEnterprise != 0 {
+ config.Edition = "ent"
+ }
+ if tt.agpl {
+ config.Edition = "oss"
+ }
+ config.FIPS = tt.flags&autoupdate.FlagFIPS != 0
+ err := json.NewEncoder(w).Encode(config)
+ require.NoError(t, err)
+ }))
+ t.Cleanup(server.Close)
+
+ dir := t.TempDir()
+ ns := &Namespace{
+ installDir: dir,
+ defaultPathDir: "ignored",
+ }
+ _, err := ns.Init()
+ require.NoError(t, err)
+ cfgPath := filepath.Join(ns.Dir(), updateConfigName)
+
+ updater, err := NewLocalUpdater(LocalUpdaterConfig{
+ InsecureSkipVerify: true,
+ }, ns)
+ require.NoError(t, err)
+
+ // Create config file only if provided in test case
+ if tt.cfg != nil {
+ tt.cfg.Spec.Proxy = strings.TrimPrefix(server.URL, "https://")
+ b, err := yaml.Marshal(tt.cfg)
+ require.NoError(t, err)
+ err = os.WriteFile(cfgPath, b, 0600)
+ require.NoError(t, err)
+ }
+
+ var (
+ installedRevision Revision
+ installedBaseURL string
+ linkedRevision Revision
+ removedRevisions []Revision
+ revertFuncCalls int
+ setupCalls int
+ revertSetupCalls int
+ reloadCalls int
+ )
+ updater.Installer = &testInstaller{
+ FuncInstall: func(_ context.Context, rev Revision, baseURL string, force bool) error {
+ for _, r := range tt.linkedRevisions {
+ if r == rev {
+ require.False(t, force)
+ }
+ }
+ installedRevision = rev
+ installedBaseURL = baseURL
+ return tt.installErr
+ },
+ FuncLink: func(_ context.Context, rev Revision, path string, force bool) (revert func(context.Context) bool, err error) {
+ linkedRevision = rev
+ return func(_ context.Context) bool {
+ revertFuncCalls++
+ return true
+ }, nil
+ },
+ FuncList: func(_ context.Context) (revs []Revision, err error) {
+ return slices.Compact([]Revision{
+ installedRevision,
+ tt.cfg.Status.Active,
+ NewRevision("unknown-version", 0),
+ }), nil
+ },
+ FuncRemove: func(_ context.Context, rev Revision) error {
+ removedRevisions = append(removedRevisions, rev)
+ return nil
+ },
+ FuncIsLinked: func(ctx context.Context, rev Revision, path string) (bool, error) {
+ for _, r := range tt.linkedRevisions {
+ if r == rev {
+ return true, nil
+ }
+ }
+ return false, nil
+ },
+ }
+ updater.Process = &testProcess{
+ FuncReload: func(_ context.Context) error {
+ reloadCalls++
+ return tt.reloadErr
+ },
+ FuncIsPresent: func(ctx context.Context) (bool, error) {
+ return true, nil
+ },
+ FuncIsEnabled: func(ctx context.Context) (bool, error) {
+ return !tt.notEnabled, nil
+ },
+ FuncIsActive: func(ctx context.Context) (bool, error) {
+ return !tt.notActive, nil
+ },
+ }
+ var restarted bool
+ updater.ReexecSetup = func(_ context.Context, path string, reload bool) error {
+ restarted = reload
+ setupCalls++
+ return tt.setupErr
+ }
+ updater.SetupNamespace = func(_ context.Context, path string) error {
+ revertSetupCalls++
+ return nil
+ }
+
+ ctx := context.Background()
+ err = updater.Update(ctx, tt.now)
+ if tt.errMatch != "" {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errMatch)
+ } else {
+ require.NoError(t, err)
+ }
+ require.Equal(t, tt.installedRevision, installedRevision)
+ require.Equal(t, tt.installedBaseURL, installedBaseURL)
+ require.Equal(t, tt.linkedRevision, linkedRevision)
+ require.Equal(t, tt.removedRevisions, removedRevisions)
+ require.Equal(t, tt.flags, installedRevision.Flags)
+ require.Equal(t, tt.requestGroup, requestedGroup)
+ require.Equal(t, tt.reloadCalls, reloadCalls)
+ require.Equal(t, tt.revertCalls, revertSetupCalls)
+ require.Equal(t, tt.revertCalls, revertFuncCalls)
+ require.Equal(t, tt.setupCalls, setupCalls)
+ require.Equal(t, tt.restarted, restarted)
+
+ if tt.cfg == nil {
+ _, err := os.Stat(cfgPath)
+ require.Error(t, err)
+ return
+ }
+
+ data, err := os.ReadFile(cfgPath)
+ require.NoError(t, err)
+ data = blankTestAddr(data)
+
+ if golden.ShouldSet() {
+ golden.Set(t, data)
+ }
+ require.Equal(t, string(golden.Get(t)), string(data))
+ })
+ }
+}
+
+func TestUpdater_LinkPackage(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ cfg *UpdateConfig // nil -> file not present
+ tryLinkSystemErr error
+ syncErr error
+ notPresent bool
+
+ syncCalls int
+ tryLinkSystemCalls int
+ errMatch string
+ }{
+ {
+ name: "updates enabled",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Enabled: true,
+ },
+ },
+
+ tryLinkSystemCalls: 0,
+ syncCalls: 0,
+ },
+ {
+ name: "pinned",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Pinned: true,
+ },
+ },
+
+ tryLinkSystemCalls: 0,
+ syncCalls: 0,
+ },
+ {
+ name: "updates disabled",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Enabled: false,
+ },
+ },
+
+ tryLinkSystemCalls: 1,
+ syncCalls: 1,
+ },
+ {
+ name: "already linked",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Enabled: false,
+ },
+ },
+ tryLinkSystemErr: ErrLinked,
+
+ tryLinkSystemCalls: 1,
+ syncCalls: 0,
+ },
+ {
+ name: "link error",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Enabled: false,
+ },
+ },
+ tryLinkSystemErr: errors.New("bad"),
+
+ tryLinkSystemCalls: 1,
+ syncCalls: 0,
+ errMatch: "bad",
+ },
+ {
+ name: "no config",
+ tryLinkSystemCalls: 1,
+ syncCalls: 1,
+ },
+ {
+ name: "systemd is not installed",
+ tryLinkSystemCalls: 1,
+ syncCalls: 1,
+ syncErr: ErrNotSupported,
+ },
+ {
+ name: "systemd is not installed, already linked",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Enabled: false,
+ },
+ },
+ tryLinkSystemCalls: 1,
+ syncCalls: 1,
+ syncErr: ErrNotSupported,
+ },
+ {
+ name: "SELinux blocks service from being read",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Enabled: false,
+ },
+ },
+ tryLinkSystemCalls: 1,
+ syncCalls: 1,
+ notPresent: true,
+ errMatch: "cannot find systemd service",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ dir := t.TempDir()
+ ns := &Namespace{installDir: dir}
+ _, err := ns.Init()
+ require.NoError(t, err)
+ cfgPath := filepath.Join(ns.Dir(), updateConfigName)
+
+ updater, err := NewLocalUpdater(LocalUpdaterConfig{
+ InsecureSkipVerify: true,
+ }, ns)
+ require.NoError(t, err)
+
+ // Create config file only if provided in test case
+ if tt.cfg != nil {
+ b, err := yaml.Marshal(tt.cfg)
+ require.NoError(t, err)
+ err = os.WriteFile(cfgPath, b, 0600)
+ require.NoError(t, err)
+ }
+
+ var tryLinkSystemCalls int
+ updater.Installer = &testInstaller{
+ FuncTryLinkSystem: func(_ context.Context) error {
+ tryLinkSystemCalls++
+ return tt.tryLinkSystemErr
+ },
+ }
+ var syncCalls int
+ updater.Process = &testProcess{
+ FuncSync: func(_ context.Context) error {
+ syncCalls++
+ return tt.syncErr
+ },
+ FuncIsPresent: func(ctx context.Context) (bool, error) {
+ return !tt.notPresent, nil
+ },
+ }
+
+ ctx := context.Background()
+ err = updater.LinkPackage(ctx)
+ if tt.errMatch != "" {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errMatch)
+ } else {
+ require.NoError(t, err)
+ }
+ require.Equal(t, tt.tryLinkSystemCalls, tryLinkSystemCalls)
+ require.Equal(t, tt.syncCalls, syncCalls)
+ })
+ }
+}
+
+func TestUpdater_Remove(t *testing.T) {
+ t.Parallel()
+
+ const version = "active-version"
+
+ tests := []struct {
+ name string
+ cfg *UpdateConfig // nil -> file not present
+ linkSystemErr error
+ isEnabledErr error
+ syncErr error
+ reloadErr error
+ processEnabled bool
+ force bool
+
+ unlinkedVersion string
+ teardownCalls int
+ syncCalls int
+ revertFuncCalls int
+ linkSystemCalls int
+ reloadCalls int
+ errMatch string
+ }{
+ {
+ name: "no config",
+ teardownCalls: 1,
+ },
+ {
+ name: "no active version",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ },
+ teardownCalls: 1,
+ },
+ {
+ name: "no conflicting system links, process disabled, force",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Status: UpdateStatus{
+ Active: NewRevision(version, 0),
+ },
+ },
+ unlinkedVersion: version,
+ teardownCalls: 1,
+ force: true,
+ },
+ {
+ name: "no system links, process enabled, force",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision(version, 0),
+ },
+ },
+ linkSystemErr: ErrNoBinaries,
+ linkSystemCalls: 1,
+ processEnabled: true,
+ force: true,
+ errMatch: "refusing to remove",
+ },
+ {
+ name: "no system links, process disabled, force",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision(version, 0),
+ },
+ },
+ linkSystemErr: ErrNoBinaries,
+ linkSystemCalls: 1,
+ unlinkedVersion: version,
+ teardownCalls: 1,
+ force: true,
+ },
+ {
+ name: "no system links, process disabled, no force",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision(version, 0),
+ },
+ },
+ linkSystemErr: ErrNoBinaries,
+ linkSystemCalls: 1,
+ errMatch: "unable to remove",
+ },
+ {
+ name: "no system links, process disabled, no systemd, force",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision(version, 0),
+ },
+ },
+ linkSystemErr: ErrNoBinaries,
+ linkSystemCalls: 1,
+ isEnabledErr: ErrNotSupported,
+ unlinkedVersion: version,
+ teardownCalls: 1,
+ force: true,
+ },
+ {
+ name: "active version",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision(version, 0),
+ },
+ },
+ linkSystemCalls: 1,
+ syncCalls: 1,
+ reloadCalls: 1,
+ teardownCalls: 1,
+ },
+ {
+ name: "active version, no systemd",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision(version, 0),
+ },
+ },
+ linkSystemCalls: 1,
+ syncCalls: 1,
+ reloadCalls: 1,
+ teardownCalls: 1,
+ syncErr: ErrNotSupported,
+ reloadErr: ErrNotSupported,
+ },
+ {
+ name: "active version, no reload",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision(version, 0),
+ },
+ },
+ linkSystemCalls: 1,
+ syncCalls: 1,
+ reloadCalls: 1,
+ teardownCalls: 1,
+ reloadErr: ErrNotNeeded,
+ },
+ {
+ name: "active version, sync error",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision(version, 0),
+ },
+ },
+ linkSystemCalls: 1,
+ syncCalls: 2,
+ revertFuncCalls: 1,
+ syncErr: errors.New("sync error"),
+ errMatch: "configuration",
+ },
+ {
+ name: "active version, reload error",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Spec: UpdateSpec{
+ Path: defaultPathDir,
+ },
+ Status: UpdateStatus{
+ Active: NewRevision(version, 0),
+ },
+ },
+ linkSystemCalls: 1,
+ syncCalls: 2,
+ reloadCalls: 2,
+ revertFuncCalls: 1,
+ reloadErr: errors.New("reload error"),
+ errMatch: "start",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ dir := t.TempDir()
+ ns := &Namespace{installDir: dir}
+ _, err := ns.Init()
+ require.NoError(t, err)
+ cfgPath := filepath.Join(ns.Dir(), updateConfigName)
+
+ updater, err := NewLocalUpdater(LocalUpdaterConfig{
+ InsecureSkipVerify: true,
+ }, ns)
+ require.NoError(t, err)
// Create config file only if provided in test case
if tt.cfg != nil {
@@ -91,47 +1239,92 @@ func TestUpdater_Disable(t *testing.T) {
err = os.WriteFile(cfgPath, b, 0600)
require.NoError(t, err)
}
- updater, err := NewLocalUpdater(LocalUpdaterConfig{
- InsecureSkipVerify: true,
- VersionsDir: dir,
- })
- require.NoError(t, err)
- err = updater.Disable(context.Background())
- if tt.errMatch != "" {
- require.Error(t, err)
- assert.Contains(t, err.Error(), tt.errMatch)
- return
- }
- require.NoError(t, err)
-
- data, err := os.ReadFile(cfgPath)
- // If no config is present, disable should not create it
- if tt.cfg == nil {
- require.ErrorIs(t, err, os.ErrNotExist)
- return
+ var (
+ linkSystemCalls int
+ revertFuncCalls int
+ syncCalls int
+ reloadCalls int
+ teardownCalls int
+ unlinkedVersion string
+ )
+ updater.Installer = &testInstaller{
+ FuncLinkSystem: func(_ context.Context) (revert func(context.Context) bool, err error) {
+ linkSystemCalls++
+ return func(_ context.Context) bool {
+ revertFuncCalls++
+ return true
+ }, tt.linkSystemErr
+ },
+ FuncUnlink: func(_ context.Context, rev Revision, path string) error {
+ unlinkedVersion = rev.Version
+ return nil
+ },
+ }
+ updater.Process = &testProcess{
+ FuncSync: func(_ context.Context) error {
+ syncCalls++
+ return tt.syncErr
+ },
+ FuncReload: func(_ context.Context) error {
+ reloadCalls++
+ return tt.reloadErr
+ },
+ FuncIsEnabled: func(_ context.Context) (bool, error) {
+ return tt.processEnabled, tt.isEnabledErr
+ },
+ FuncIsActive: func(_ context.Context) (bool, error) {
+ return false, nil
+ },
+ }
+ updater.TeardownNamespace = func(_ context.Context) error {
+ teardownCalls++
+ return nil
}
- require.NoError(t, err)
- if golden.ShouldSet() {
- golden.Set(t, data)
+ ctx := context.Background()
+ err = updater.Remove(ctx, tt.force)
+ if tt.errMatch != "" {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errMatch)
+ } else {
+ require.NoError(t, err)
}
- require.Equal(t, string(golden.Get(t)), string(data))
+ require.Equal(t, tt.syncCalls, syncCalls)
+ require.Equal(t, tt.reloadCalls, reloadCalls)
+ require.Equal(t, tt.linkSystemCalls, linkSystemCalls)
+ require.Equal(t, tt.revertFuncCalls, revertFuncCalls)
+ require.Equal(t, tt.unlinkedVersion, unlinkedVersion)
+ require.Equal(t, tt.teardownCalls, teardownCalls)
})
}
}
-func TestUpdater_Enable(t *testing.T) {
+func TestUpdater_Install(t *testing.T) {
t.Parallel()
tests := []struct {
name string
cfg *UpdateConfig // nil -> file not present
userCfg OverrideConfig
+ flags autoupdate.InstallFlags
+ agpl bool
installErr error
+ setupErr error
+ reloadErr error
+ notPresent bool
+ notEnabled bool
+ notActive bool
- installedVersion string
- installedTemplate string
+ removedRevision Revision
+ installedRevision Revision
+ installedBaseURL string
+ linkedRevision Revision
+ requestGroup string
+ reloadCalls int
+ revertCalls int
+ setupCalls int
+ restarted bool
errMatch string
}{
{
@@ -140,15 +1333,22 @@ func TestUpdater_Enable(t *testing.T) {
Version: updateConfigVersion,
Kind: updateConfigKind,
Spec: UpdateSpec{
- Group: "group",
- URLTemplate: "https://example.com",
+ Enabled: true,
+ Group: "group",
+ Path: "/path",
+ BaseURL: "https://example.com",
},
Status: UpdateStatus{
- ActiveVersion: "old-version",
+ Active: NewRevision("old-version", 0),
},
},
- installedVersion: "16.3.0",
- installedTemplate: "https://example.com",
+
+ installedRevision: NewRevision("16.3.0", 0),
+ installedBaseURL: "https://example.com",
+ linkedRevision: NewRevision("16.3.0", 0),
+ requestGroup: "group",
+ setupCalls: 1,
+ restarted: true,
},
{
name: "config from user",
@@ -156,35 +1356,62 @@ func TestUpdater_Enable(t *testing.T) {
Version: updateConfigVersion,
Kind: updateConfigKind,
Spec: UpdateSpec{
- Group: "old-group",
- URLTemplate: "https://example.com/old",
+ Group: "old-group",
+ BaseURL: "https://example.com/old",
},
Status: UpdateStatus{
- ActiveVersion: "old-version",
+ Active: NewRevision("old-version", 0),
},
},
userCfg: OverrideConfig{
- Group: "new-group",
- URLTemplate: "https://example.com/new",
+ UpdateSpec: UpdateSpec{
+ Enabled: true,
+ Path: "/path",
+ Group: "new-group",
+ BaseURL: "https://example.com/new",
+ },
ForceVersion: "new-version",
},
- installedVersion: "new-version",
- installedTemplate: "https://example.com/new",
+
+ installedRevision: NewRevision("new-version", 0),
+ installedBaseURL: "https://example.com/new",
+ linkedRevision: NewRevision("new-version", 0),
+ requestGroup: "new-group",
+ setupCalls: 1,
+ restarted: true,
},
{
- name: "already enabled",
+ name: "defaults",
cfg: &UpdateConfig{
Version: updateConfigVersion,
Kind: updateConfigKind,
- Spec: UpdateSpec{
- Enabled: true,
+ Status: UpdateStatus{
+ Active: NewRevision("old-version", 0),
},
+ },
+
+ installedRevision: NewRevision("16.3.0", 0),
+ installedBaseURL: autoupdate.DefaultBaseURL,
+ linkedRevision: NewRevision("16.3.0", 0),
+ setupCalls: 1,
+ restarted: true,
+ },
+ {
+ name: "override skip",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
Status: UpdateStatus{
- ActiveVersion: "old-version",
+ Active: NewRevision("old-version", 0),
+ Skip: toPtr(NewRevision("16.3.0", 0)),
},
},
- installedVersion: "16.3.0",
- installedTemplate: cdnURITemplate,
+
+ installedRevision: NewRevision("16.3.0", 0),
+ installedBaseURL: autoupdate.DefaultBaseURL,
+ linkedRevision: NewRevision("16.3.0", 0),
+ setupCalls: 1,
+ restarted: true,
},
{
name: "insecure URL",
@@ -192,22 +1419,32 @@ func TestUpdater_Enable(t *testing.T) {
Version: updateConfigVersion,
Kind: updateConfigKind,
Spec: UpdateSpec{
- URLTemplate: "http://example.com",
+ BaseURL: "http://example.com",
},
},
- errMatch: "URL must use TLS",
+
+ errMatch: "must use TLS",
},
{
name: "install error",
cfg: &UpdateConfig{
Version: updateConfigVersion,
Kind: updateConfigKind,
- Spec: UpdateSpec{
- URLTemplate: "https://example.com",
- },
},
installErr: errors.New("install error"),
- errMatch: "install error",
+
+ installedRevision: NewRevision("16.3.0", 0),
+ installedBaseURL: autoupdate.DefaultBaseURL,
+ errMatch: "install error",
+ },
+ {
+ name: "agpl requires base URL",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ },
+ agpl: true,
+ errMatch: "AGPL",
},
{
name: "version already installed",
@@ -215,29 +1452,143 @@ func TestUpdater_Enable(t *testing.T) {
Version: updateConfigVersion,
Kind: updateConfigKind,
Status: UpdateStatus{
- ActiveVersion: "16.3.0",
+ Active: NewRevision("16.3.0", 0),
+ },
+ },
+
+ installedRevision: NewRevision("16.3.0", 0),
+ installedBaseURL: autoupdate.DefaultBaseURL,
+ linkedRevision: NewRevision("16.3.0", 0),
+ setupCalls: 1,
+ restarted: false,
+ },
+ {
+ name: "backup version removed on install",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Status: UpdateStatus{
+ Active: NewRevision("old-version", 0),
+ Backup: toPtr(NewRevision("backup-version", 0)),
},
},
- installedVersion: "16.3.0",
- installedTemplate: cdnURITemplate,
+
+ installedRevision: NewRevision("16.3.0", 0),
+ installedBaseURL: autoupdate.DefaultBaseURL,
+ linkedRevision: NewRevision("16.3.0", 0),
+ removedRevision: NewRevision("backup-version", 0),
+ setupCalls: 1,
+ restarted: true,
+ },
+ {
+ name: "backup version kept for validation",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Status: UpdateStatus{
+ Active: NewRevision("16.3.0", 0),
+ Backup: toPtr(NewRevision("backup-version", 0)),
+ },
+ },
+
+ installedRevision: NewRevision("16.3.0", 0),
+ installedBaseURL: autoupdate.DefaultBaseURL,
+ linkedRevision: NewRevision("16.3.0", 0),
+ setupCalls: 1,
+ },
+ {
+ name: "config does not exist",
+
+ installedRevision: NewRevision("16.3.0", 0),
+ installedBaseURL: autoupdate.DefaultBaseURL,
+ linkedRevision: NewRevision("16.3.0", 0),
+ setupCalls: 1,
+ restarted: true,
},
{
- name: "config does not exist",
- installedVersion: "16.3.0",
- installedTemplate: cdnURITemplate,
+ name: "FIPS and Enterprise flags",
+ flags: autoupdate.FlagEnterprise | autoupdate.FlagFIPS,
+ installedRevision: NewRevision("16.3.0", autoupdate.FlagEnterprise|autoupdate.FlagFIPS),
+ installedBaseURL: autoupdate.DefaultBaseURL,
+ linkedRevision: NewRevision("16.3.0", autoupdate.FlagEnterprise|autoupdate.FlagFIPS),
+ setupCalls: 1,
+ restarted: true,
},
{
name: "invalid metadata",
cfg: &UpdateConfig{},
errMatch: "invalid",
},
+ {
+ name: "setup fails",
+ setupErr: errors.New("setup error"),
+
+ installedRevision: NewRevision("16.3.0", 0),
+ installedBaseURL: autoupdate.DefaultBaseURL,
+ linkedRevision: NewRevision("16.3.0", 0),
+ revertCalls: 1,
+ setupCalls: 1,
+ reloadCalls: 1,
+ restarted: true,
+ errMatch: "setup error",
+ },
+ {
+ name: "setup fails already installed",
+ cfg: &UpdateConfig{
+ Version: updateConfigVersion,
+ Kind: updateConfigKind,
+ Status: UpdateStatus{
+ Active: NewRevision("16.3.0", 0),
+ },
+ },
+ setupErr: errors.New("setup error"),
+
+ installedRevision: NewRevision("16.3.0", 0),
+ installedBaseURL: autoupdate.DefaultBaseURL,
+ linkedRevision: NewRevision("16.3.0", 0),
+ revertCalls: 1,
+ setupCalls: 1,
+ errMatch: "setup error",
+ },
+ {
+ name: "no need to reload",
+ reloadErr: ErrNotNeeded,
+
+ installedRevision: NewRevision("16.3.0", 0),
+ installedBaseURL: autoupdate.DefaultBaseURL,
+ linkedRevision: NewRevision("16.3.0", 0),
+ setupCalls: 1,
+ restarted: true,
+ },
+ {
+ name: "not started or enabled",
+ notEnabled: true,
+ notActive: true,
+
+ installedRevision: NewRevision("16.3.0", 0),
+ installedBaseURL: autoupdate.DefaultBaseURL,
+ linkedRevision: NewRevision("16.3.0", 0),
+ setupCalls: 1,
+ restarted: true,
+ },
}
for _, tt := range tests {
- tt := tt
t.Run(tt.name, func(t *testing.T) {
dir := t.TempDir()
- cfgPath := filepath.Join(dir, "update.yaml")
+ ns := &Namespace{
+ installDir: dir,
+ defaultPathDir: defaultPathDir,
+ defaultProxyAddr: "default-proxy",
+ }
+ _, err := ns.Init()
+ require.NoError(t, err)
+ cfgPath := filepath.Join(ns.Dir(), updateConfigName)
+
+ updater, err := NewLocalUpdater(LocalUpdaterConfig{
+ InsecureSkipVerify: true,
+ }, ns)
+ require.NoError(t, err)
// Create config file only if provided in test case
if tt.cfg != nil {
@@ -247,9 +1598,24 @@ func TestUpdater_Enable(t *testing.T) {
require.NoError(t, err)
}
+ var requestedGroup string
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- // TODO(sclevine): add web API test including group verification
- w.Write([]byte(`{}`))
+ requestedGroup = r.URL.Query().Get("group")
+ config := webclient.PingResponse{
+ AutoUpdate: webclient.AutoUpdateSettings{
+ AgentVersion: "16.3.0",
+ },
+ }
+ config.Edition = "community"
+ if tt.flags&autoupdate.FlagEnterprise != 0 {
+ config.Edition = "ent"
+ }
+ if tt.agpl {
+ config.Edition = "oss"
+ }
+ config.FIPS = tt.flags&autoupdate.FlagFIPS != 0
+ err := json.NewEncoder(w).Encode(config)
+ require.NoError(t, err)
}))
t.Cleanup(server.Close)
@@ -257,31 +1623,91 @@ func TestUpdater_Enable(t *testing.T) {
tt.userCfg.Proxy = strings.TrimPrefix(server.URL, "https://")
}
- updater, err := NewLocalUpdater(LocalUpdaterConfig{
- InsecureSkipVerify: true,
- VersionsDir: dir,
- })
- require.NoError(t, err)
-
- var installedVersion, installedTemplate string
+ var (
+ installedRevision Revision
+ installedBaseURL string
+ linkedRevision Revision
+ removedRevision Revision
+ revertFuncCalls int
+ reloadCalls int
+ setupCalls int
+ revertSetupCalls int
+ )
updater.Installer = &testInstaller{
- FuncInstall: func(_ context.Context, version, template string, _ InstallFlags) error {
- installedVersion = version
- installedTemplate = template
+ FuncInstall: func(_ context.Context, rev Revision, baseURL string, force bool) error {
+ installedRevision = rev
+ installedBaseURL = baseURL
return tt.installErr
},
+ FuncLink: func(_ context.Context, rev Revision, path string, force bool) (revert func(context.Context) bool, err error) {
+ linkedRevision = rev
+ return func(_ context.Context) bool {
+ revertFuncCalls++
+ return true
+ }, nil
+ },
+ FuncList: func(_ context.Context) (revs []Revision, err error) {
+ return []Revision{}, nil
+ },
+ FuncRemove: func(_ context.Context, rev Revision) error {
+ removedRevision = rev
+ return nil
+ },
+ FuncIsLinked: func(ctx context.Context, rev Revision, path string) (bool, error) {
+ return false, nil
+ },
+ }
+ updater.Process = &testProcess{
+ FuncReload: func(_ context.Context) error {
+ reloadCalls++
+ return tt.reloadErr
+ },
+ FuncIsPresent: func(ctx context.Context) (bool, error) {
+ return !tt.notPresent, nil
+ },
+ FuncIsEnabled: func(ctx context.Context) (bool, error) {
+ return !tt.notEnabled, nil
+ },
+ FuncIsActive: func(ctx context.Context) (bool, error) {
+ return !tt.notActive, nil
+ },
+ }
+ var restarted bool
+ updater.ReexecSetup = func(_ context.Context, path string, reload bool) error {
+ setupCalls++
+ restarted = reload
+ return tt.setupErr
+ }
+ updater.SetupNamespace = func(_ context.Context, path string) error {
+ revertSetupCalls++
+ return nil
}
ctx := context.Background()
- err = updater.Enable(ctx, tt.userCfg)
+ err = updater.Install(ctx, tt.userCfg)
if tt.errMatch != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.errMatch)
+ } else {
+ require.NoError(t, err)
+ }
+ require.Equal(t, tt.installedRevision, installedRevision)
+ require.Equal(t, tt.installedBaseURL, installedBaseURL)
+ require.Equal(t, tt.linkedRevision, linkedRevision)
+ require.Equal(t, tt.removedRevision, removedRevision)
+ require.Equal(t, tt.flags, installedRevision.Flags)
+ require.Equal(t, tt.requestGroup, requestedGroup)
+ require.Equal(t, tt.reloadCalls, reloadCalls)
+ require.Equal(t, tt.revertCalls, revertSetupCalls)
+ require.Equal(t, tt.revertCalls, revertFuncCalls)
+ require.Equal(t, tt.setupCalls, setupCalls)
+ require.Equal(t, tt.restarted, restarted)
+
+ if tt.cfg == nil && err != nil {
+ _, err := os.Stat(cfgPath)
+ require.Error(t, err)
return
}
- require.NoError(t, err)
- require.Equal(t, tt.installedVersion, installedVersion)
- require.Equal(t, tt.installedTemplate, installedTemplate)
data, err := os.ReadFile(cfgPath)
require.NoError(t, err)
@@ -295,6 +1721,170 @@ func TestUpdater_Enable(t *testing.T) {
}
}
+func TestSameProxies(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ a, b string
+ match bool
+ }{
+ {
+ name: "protocol missing with port",
+ a: "https://example.com:8080",
+ b: "example.com:8080",
+ match: true,
+ },
+ {
+ name: "protocol missing without port",
+ a: "https://example.com",
+ b: "example.com",
+ match: true,
+ },
+ {
+ name: "same with port",
+ a: "example.com:443",
+ b: "example.com:443",
+ match: true,
+ },
+ {
+ name: "does not set default teleport port",
+ a: "example.com",
+ b: "example.com:3080",
+ match: false,
+ },
+ {
+ name: "does set default standard port",
+ a: "example.com",
+ b: "example.com:443",
+ match: true,
+ },
+ {
+ name: "other formats if equal",
+ a: "@123",
+ b: "@123",
+ match: true,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ s := sameProxies(tt.a, tt.b)
+ require.Equal(t, tt.match, s)
+ })
+ }
+}
+
+func TestUpdater_Setup(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ restart bool
+ present bool
+ setupErr error
+ presentErr error
+ reloadErr error
+
+ errMatch string
+ }{
+ {
+ name: "no restart",
+ restart: false,
+ present: true,
+ },
+ {
+ name: "restart",
+ restart: true,
+ present: true,
+ },
+ {
+ name: "reload not needed",
+ restart: true,
+ present: true,
+ reloadErr: ErrNotNeeded,
+ },
+ {
+ name: "not present",
+ restart: true,
+ present: false,
+ errMatch: "cannot find systemd",
+ },
+ {
+ name: "setup error",
+ restart: false,
+ setupErr: errors.New("some error"),
+ errMatch: "some error",
+ },
+ {
+ name: "setup error canceled",
+ restart: false,
+ setupErr: context.Canceled,
+ errMatch: "canceled",
+ },
+ {
+ name: "present error",
+ restart: false,
+ presentErr: errors.New("some error"),
+ errMatch: "some error",
+ },
+ {
+ name: "present error canceled",
+ restart: false,
+ presentErr: context.Canceled,
+ errMatch: "canceled",
+ },
+ {
+ name: "preset error not supported",
+ restart: false,
+ presentErr: ErrNotSupported,
+ },
+ {
+ name: "reload error canceled",
+ restart: true,
+ present: true,
+ reloadErr: context.Canceled,
+ errMatch: "canceled",
+ },
+ {
+ name: "reload error",
+ restart: true,
+ present: true,
+ reloadErr: errors.New("some error"),
+ errMatch: "some error",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ns := &Namespace{}
+ updater, err := NewLocalUpdater(LocalUpdaterConfig{}, ns)
+ require.NoError(t, err)
+
+ updater.Process = &testProcess{
+ FuncReload: func(_ context.Context) error {
+ return tt.reloadErr
+ },
+ FuncIsPresent: func(_ context.Context) (bool, error) {
+ return tt.present, tt.presentErr
+ },
+ }
+ updater.SetupNamespace = func(_ context.Context, path string) error {
+ require.Equal(t, "test", path)
+ return tt.setupErr
+ }
+
+ ctx := context.Background()
+ err = updater.Setup(ctx, "test", tt.restart)
+ if tt.errMatch != "" {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.errMatch)
+ } else {
+ require.NoError(t, err)
+ }
+ })
+ }
+}
+
var serverRegexp = regexp.MustCompile("127.0.0.1:[0-9]+")
func blankTestAddr(s []byte) []byte {
@@ -302,14 +1892,82 @@ func blankTestAddr(s []byte) []byte {
}
type testInstaller struct {
- FuncInstall func(ctx context.Context, version, template string, flags InstallFlags) error
- FuncRemove func(ctx context.Context, version string) error
+ FuncInstall func(ctx context.Context, rev Revision, baseURL string, force bool) error
+ FuncRemove func(ctx context.Context, rev Revision) error
+ FuncLink func(ctx context.Context, rev Revision, path string, force bool) (revert func(context.Context) bool, err error)
+ FuncLinkSystem func(ctx context.Context) (revert func(context.Context) bool, err error)
+ FuncTryLink func(ctx context.Context, rev Revision, path string) error
+ FuncTryLinkSystem func(ctx context.Context) error
+ FuncUnlink func(ctx context.Context, rev Revision, path string) error
+ FuncUnlinkSystem func(ctx context.Context) error
+ FuncList func(ctx context.Context) (revs []Revision, err error)
+ FuncIsLinked func(ctx context.Context, rev Revision, path string) (bool, error)
+}
+
+func (ti *testInstaller) Install(ctx context.Context, rev Revision, baseURL string, force bool) error {
+ return ti.FuncInstall(ctx, rev, baseURL, force)
+}
+
+func (ti *testInstaller) Remove(ctx context.Context, rev Revision) error {
+ return ti.FuncRemove(ctx, rev)
+}
+
+func (ti *testInstaller) Link(ctx context.Context, rev Revision, path string, force bool) (revert func(context.Context) bool, err error) {
+ return ti.FuncLink(ctx, rev, path, force)
+}
+
+func (ti *testInstaller) LinkSystem(ctx context.Context) (revert func(context.Context) bool, err error) {
+ return ti.FuncLinkSystem(ctx)
+}
+
+func (ti *testInstaller) TryLink(ctx context.Context, rev Revision, path string) error {
+ return ti.FuncTryLink(ctx, rev, path)
+}
+
+func (ti *testInstaller) TryLinkSystem(ctx context.Context) error {
+ return ti.FuncTryLinkSystem(ctx)
+}
+
+func (ti *testInstaller) Unlink(ctx context.Context, rev Revision, path string) error {
+ return ti.FuncUnlink(ctx, rev, path)
+}
+
+func (ti *testInstaller) UnlinkSystem(ctx context.Context) error {
+ return ti.FuncUnlinkSystem(ctx)
+}
+
+func (ti *testInstaller) List(ctx context.Context) (revs []Revision, err error) {
+ return ti.FuncList(ctx)
+}
+
+func (ti *testInstaller) IsLinked(ctx context.Context, rev Revision, path string) (bool, error) {
+ return ti.FuncIsLinked(ctx, rev, path)
+}
+
+type testProcess struct {
+ FuncReload func(ctx context.Context) error
+ FuncSync func(ctx context.Context) error
+ FuncIsEnabled func(ctx context.Context) (bool, error)
+ FuncIsActive func(ctx context.Context) (bool, error)
+ FuncIsPresent func(ctx context.Context) (bool, error)
+}
+
+func (tp *testProcess) Reload(ctx context.Context) error {
+ return tp.FuncReload(ctx)
+}
+
+func (tp *testProcess) Sync(ctx context.Context) error {
+ return tp.FuncSync(ctx)
+}
+
+func (tp *testProcess) IsEnabled(ctx context.Context) (bool, error) {
+ return tp.FuncIsEnabled(ctx)
}
-func (ti *testInstaller) Install(ctx context.Context, version, template string, flags InstallFlags) error {
- return ti.FuncInstall(ctx, version, template, flags)
+func (tp *testProcess) IsActive(ctx context.Context) (bool, error) {
+ return tp.FuncIsActive(ctx)
}
-func (ti *testInstaller) Remove(ctx context.Context, version string) error {
- return ti.FuncRemove(ctx, version)
+func (tp *testProcess) IsPresent(ctx context.Context) (bool, error) {
+ return tp.FuncIsPresent(ctx)
}
diff --git a/lib/autoupdate/agent/validate.go b/lib/autoupdate/agent/validate.go
new file mode 100644
index 0000000000000..8ce1732d3d5d1
--- /dev/null
+++ b/lib/autoupdate/agent/validate.go
@@ -0,0 +1,130 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package agent
+
+import (
+ "bytes"
+ "context"
+ "io"
+ "log/slog"
+ "os"
+ "path/filepath"
+ "unicode"
+
+ "github.com/gravitational/trace"
+)
+
+const (
+ // fileHeaderSniffBytes is the max size to read to determine a file's MIME type
+ fileHeaderSniffBytes = 512 // MIME standard size
+ // execModeMask is the minimum required set of bits to consider a file executable.
+ execModeMask = 0111
+)
+
+// Validator validates filesystem paths.
+type Validator struct {
+ Log *slog.Logger
+}
+
+// IsBinary returns true for working binaries that are executable by all users.
+// If the file is irregular, non-executable, or a shell script, IsBinary returns false and logs a warning.
+// IsBinary errors if lstat fails, a regular file is unreadable, or an executable file does not execute.
+func (v *Validator) IsBinary(ctx context.Context, path string) (bool, error) {
+ // The behavior of this method is intended to protect against executable files
+ // being adding to the Teleport tgz that should not be made available on PATH,
+ // and additionally, should not cause installation to fail.
+ // While known copies of these files (e.g., "install") are excluded during extraction,
+ // it is safer to assume others could be present in past or future tgzs.
+
+ if exec, err := v.IsExecutable(ctx, path); err != nil || !exec {
+ return exec, trace.Wrap(err)
+ }
+ name := filepath.Base(path)
+ d, err := readFileLimit(path, fileHeaderSniffBytes)
+ if err != nil {
+ return false, trace.Wrap(err)
+ }
+ // Refuse to test or link shell scripts
+ if isTextScript(d) {
+ v.Log.WarnContext(ctx, "Found unexpected shell script", "name", name)
+ return false, nil
+ }
+ v.Log.InfoContext(ctx, "Validating binary", "name", name)
+ r := localExec{
+ Log: v.Log,
+ ErrLevel: slog.LevelDebug,
+ OutLevel: slog.LevelInfo, // always show version
+ }
+ code, err := r.Run(ctx, path, "version")
+ if code < 0 {
+ return false, trace.Wrap(err, "error validating binary %s", name)
+ }
+ if code > 0 {
+ v.Log.InfoContext(ctx, "Binary does not support version command", "name", name)
+ }
+ return true, nil
+}
+
+// IsExecutable returns true for regular, executable files.
+func (v *Validator) IsExecutable(ctx context.Context, path string) (bool, error) {
+ name := filepath.Base(path)
+ fi, err := os.Lstat(path)
+ if err != nil {
+ return false, trace.Wrap(err)
+ }
+ if !fi.Mode().IsRegular() {
+ v.Log.WarnContext(ctx, "Found unexpected irregular file", "name", name)
+ return false, nil
+ }
+ if fi.Mode()&execModeMask != execModeMask {
+ v.Log.WarnContext(ctx, "Found unexpected non-executable file", "name", name)
+ return false, nil
+ }
+ return true, nil
+}
+
+func isTextScript(data []byte) bool {
+ data = bytes.TrimLeftFunc(data, unicode.IsSpace)
+ if !bytes.HasPrefix(data, []byte("#!")) {
+ return false
+ }
+ // Assume presence of MIME binary data bytes means binary:
+ // https://mimesniff.spec.whatwg.org/#terminology
+ for _, b := range data {
+ switch {
+ case b <= 0x08, b == 0x0B,
+ 0x0E <= b && b <= 0x1A,
+ 0x1C <= b && b <= 0x1F:
+ return false
+ }
+ }
+ return true
+}
+
+// readFileLimit the first n bytes of a file.
+func readFileLimit(name string, n int64) ([]byte, error) {
+ f, err := os.Open(name)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+ var buf bytes.Buffer
+ _, err = io.Copy(&buf, io.LimitReader(f, n))
+ return buf.Bytes(), trace.Wrap(err)
+}
diff --git a/lib/autoupdate/agent/validate_test.go b/lib/autoupdate/agent/validate_test.go
new file mode 100644
index 0000000000000..71b979016f8fc
--- /dev/null
+++ b/lib/autoupdate/agent/validate_test.go
@@ -0,0 +1,102 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package agent
+
+import (
+ "bytes"
+ "context"
+ "log/slog"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestValidator_IsBinary(t *testing.T) {
+ for _, tt := range []struct {
+ name string
+ mode os.FileMode
+ contents string
+
+ valid bool
+ errMatch string
+ logMatch string
+ }{
+ {
+ name: "missing",
+ errMatch: "no such",
+ },
+ {
+ name: "non-executable",
+ contents: "test",
+ mode: 0666,
+ logMatch: "non-executable",
+ },
+ {
+ name: "shell script",
+ contents: " #!bash ",
+ mode: 0777,
+ logMatch: "unexpected shell",
+ },
+ {
+ name: "unqualified shell script",
+ contents: " #!bash" + string([]byte{0x0B}),
+ mode: 0777,
+ errMatch: "validating binary",
+ },
+ {
+ name: "exit 0",
+ contents: "#!/bin/sh\nexit 0\n" + string([]byte{0x0B}),
+ mode: 0777,
+ valid: true,
+ },
+ {
+ name: "exit 1",
+ contents: "#!/bin/sh\nexit 1\n" + string([]byte{0x0B}),
+ mode: 0777,
+ valid: true,
+ logMatch: "version command",
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ var buf bytes.Buffer
+ opts := &slog.HandlerOptions{AddSource: true}
+ log := slog.New(slog.NewTextHandler(&buf, opts))
+ v := Validator{Log: log}
+ ctx := context.Background()
+ path := filepath.Join(t.TempDir(), "file")
+ if tt.contents != "" {
+ os.WriteFile(path, []byte(tt.contents), tt.mode)
+ }
+ val, err := v.IsBinary(ctx, path)
+ if tt.logMatch != "" {
+ require.Contains(t, buf.String(), tt.logMatch)
+ }
+ if tt.errMatch != "" {
+ require.Error(t, err)
+ require.False(t, val)
+ require.Contains(t, err.Error(), tt.errMatch)
+ return
+ }
+ require.Equal(t, tt.valid, val)
+ require.NoError(t, err)
+ })
+ }
+}
diff --git a/lib/autoupdate/package_url.go b/lib/autoupdate/package_url.go
index 9b283c3da59c2..b00eb59fea5c9 100644
--- a/lib/autoupdate/package_url.go
+++ b/lib/autoupdate/package_url.go
@@ -20,10 +20,12 @@ package autoupdate
import (
"bytes"
+ "encoding/json"
"runtime"
"text/template"
"github.com/gravitational/trace"
+ "gopkg.in/yaml.v3"
)
// InstallFlags sets flags for the Teleport installation.
@@ -54,6 +56,82 @@ const (
BaseURLEnvVar = "TELEPORT_CDN_BASE_URL"
)
+// NewInstallFlagsFromStrings returns InstallFlags given a slice of human-readable strings.
+func NewInstallFlagsFromStrings(s []string) InstallFlags {
+ var out InstallFlags
+ for _, f := range s {
+ for _, flag := range []InstallFlags{
+ FlagEnterprise,
+ FlagFIPS,
+ } {
+ if f == flag.String() {
+ out |= flag
+ }
+ }
+ }
+ return out
+}
+
+// Strings converts InstallFlags to a slice of human-readable strings.
+func (i InstallFlags) Strings() []string {
+ var out []string
+ for _, flag := range []InstallFlags{
+ FlagEnterprise,
+ FlagFIPS,
+ } {
+ if i&flag != 0 {
+ out = append(out, flag.String())
+ }
+ }
+ return out
+}
+
+// String returns the string representation of a single InstallFlag flag, or "Unknown".
+func (i InstallFlags) String() string {
+ switch i {
+ case 0:
+ return ""
+ case FlagEnterprise:
+ return "Enterprise"
+ case FlagFIPS:
+ return "FIPS"
+ }
+ return "Unknown"
+}
+
+// DirFlag returns the directory path representation of a single InstallFlag flag, or "unknown".
+func (i InstallFlags) DirFlag() string {
+ switch i {
+ case 0:
+ return ""
+ case FlagEnterprise:
+ return "ent"
+ case FlagFIPS:
+ return "fips"
+ }
+ return "unknown"
+}
+
+func (i InstallFlags) MarshalYAML() (any, error) {
+ return i.Strings(), nil
+}
+
+func (i InstallFlags) MarshalJSON() ([]byte, error) {
+ return json.Marshal(i.Strings())
+}
+
+func (i *InstallFlags) UnmarshalYAML(n *yaml.Node) error {
+ var s []string
+ if err := n.Decode(&s); err != nil {
+ return trace.Wrap(err)
+ }
+ if i == nil {
+ return trace.BadParameter("nil install flags while parsing YAML")
+ }
+ *i = NewInstallFlagsFromStrings(s)
+ return nil
+}
+
// MakeURL constructs the package download URL from template, base URL and revision.
func MakeURL(uriTmpl string, baseURL string, pkg string, version string, flags InstallFlags) (string, error) {
tmpl, err := template.New("uri").Parse(uriTmpl)
diff --git a/lib/autoupdate/package_url_test.go b/lib/autoupdate/package_url_test.go
new file mode 100644
index 0000000000000..b3eca4be38d7e
--- /dev/null
+++ b/lib/autoupdate/package_url_test.go
@@ -0,0 +1,95 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package autoupdate
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "gopkg.in/yaml.v3"
+)
+
+func TestInstallFlagsYAML(t *testing.T) {
+ t.Parallel()
+
+ for _, tt := range []struct {
+ name string
+ yaml string
+ flags InstallFlags
+ skipYAML bool
+ }{
+ {
+ name: "both",
+ yaml: `["Enterprise", "FIPS"]`,
+ flags: FlagEnterprise | FlagFIPS,
+ },
+ {
+ name: "order",
+ yaml: `["FIPS", "Enterprise"]`,
+ flags: FlagEnterprise | FlagFIPS,
+ skipYAML: true,
+ },
+ {
+ name: "extra",
+ yaml: `["FIPS", "Enterprise", "bad"]`,
+ flags: FlagEnterprise | FlagFIPS,
+ skipYAML: true,
+ },
+ {
+ name: "enterprise",
+ yaml: `["Enterprise"]`,
+ flags: FlagEnterprise,
+ },
+ {
+ name: "fips",
+ yaml: `["FIPS"]`,
+ flags: FlagFIPS,
+ },
+ {
+ name: "empty",
+ yaml: `[]`,
+ },
+ {
+ name: "nil",
+ skipYAML: true,
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ var flags InstallFlags
+ err := yaml.Unmarshal([]byte(tt.yaml), &flags)
+ require.NoError(t, err)
+ require.Equal(t, tt.flags, flags)
+
+ // verify test YAML
+ var v any
+ err = yaml.Unmarshal([]byte(tt.yaml), &v)
+ require.NoError(t, err)
+ res, err := yaml.Marshal(v)
+ require.NoError(t, err)
+
+ // compare verified YAML to flag output
+ out, err := yaml.Marshal(flags)
+ require.NoError(t, err)
+
+ if !tt.skipYAML {
+ require.Equal(t, string(res), string(out))
+ }
+ })
+ }
+}
diff --git a/lib/client/debug/debug.go b/lib/client/debug/debug.go
index 5553af7b8187c..368a4eadc3ab1 100644
--- a/lib/client/debug/debug.go
+++ b/lib/client/debug/debug.go
@@ -19,6 +19,7 @@ package debug
import (
"bytes"
"context"
+ "encoding/json"
"io"
"net"
"net/http"
@@ -55,9 +56,14 @@ func NewClient(socketPath string) *Client {
clt: &http.Client{
Timeout: apidefaults.DefaultIOTimeout,
Transport: &http.Transport{
- DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
- return net.Dial("unix", socketPath)
+ DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
+ var d net.Dialer
+ return d.DialContext(ctx, "unix", socketPath)
},
+ DisableKeepAlives: true,
+ },
+ CheckRedirect: func(_ *http.Request, _ []*http.Request) error {
+ return trace.Errorf("redirect via socket not allowed")
},
},
}
@@ -137,6 +143,36 @@ func (c *Client) CollectProfile(ctx context.Context, profileName string, seconds
return result, nil
}
+// Readiness describes the readiness of the Teleport instance.
+type Readiness struct {
+ // Ready is true if the instance is ready.
+ // This field is only set by clients, based on status.
+ Ready bool `json:"-"`
+ // Status provides more detail about the readiness status.
+ Status string `json:"status"`
+ // PID is the process PID
+ PID int `json:"pid"`
+}
+
+// GetReadiness returns true if the Teleport service is ready.
+func (c *Client) GetReadiness(ctx context.Context) (Readiness, error) {
+ var ready Readiness
+ resp, err := c.do(ctx, http.MethodGet, url.URL{Path: "/readyz"}, nil)
+ if err != nil {
+ return ready, trace.Wrap(err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode == http.StatusNotFound {
+ return ready, trace.NotFound("readiness endpoint not found")
+ }
+ ready.Ready = resp.StatusCode == http.StatusOK
+ err = json.NewDecoder(resp.Body).Decode(&ready)
+ if err != nil {
+ return ready, trace.Wrap(err)
+ }
+ return ready, nil
+}
+
func (c *Client) do(ctx context.Context, method string, u url.URL, body []byte) (*http.Response, error) {
u.Scheme = "http"
u.Host = "debug"
diff --git a/lib/client/debug/debug_test.go b/lib/client/debug/debug_test.go
index 0c9872baa6724..fd1dbe663b31a 100644
--- a/lib/client/debug/debug_test.go
+++ b/lib/client/debug/debug_test.go
@@ -56,6 +56,69 @@ func TestSetLogLevel(t *testing.T) {
})
}
+func TestGetReadiness(t *testing.T) {
+ ctx := context.Background()
+
+ t.Run("Success", func(t *testing.T) {
+ socketPath, _ := newSocketMockService(t, http.StatusOK, []byte(`{"status": "OK", "pid": 1234}`))
+ clt := NewClient(socketPath)
+
+ out, err := clt.GetReadiness(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "OK", out.Status)
+ require.True(t, out.Ready)
+ require.Equal(t, 1234, out.PID)
+ })
+
+ t.Run("Failure", func(t *testing.T) {
+ socketPath, _ := newSocketMockService(t, http.StatusBadRequest, []byte(`{"status": "BAD", "pid": 1234}`))
+ clt := NewClient(socketPath)
+
+ out, err := clt.GetReadiness(ctx)
+ require.NoError(t, err)
+ require.Equal(t, "BAD", out.Status)
+ require.False(t, out.Ready)
+ require.Equal(t, 1234, out.PID)
+ })
+
+ t.Run("Not found", func(t *testing.T) {
+ socketPath, _ := newSocketMockService(t, http.StatusNotFound, []byte(`404`))
+ clt := NewClient(socketPath)
+
+ out, err := clt.GetReadiness(ctx)
+ require.True(t, trace.IsNotFound(err))
+ require.Equal(t, "", out.Status)
+ require.False(t, out.Ready)
+ require.Equal(t, 0, out.PID)
+ })
+
+ t.Run("Closed", func(t *testing.T) {
+ socketPath, closeFn := newSocketMockService(t, http.StatusOK, []byte(`{"status": "OK", "pid": 1234}`))
+ closeFn()
+ clt := NewClient(socketPath)
+
+ out, err := clt.GetReadiness(ctx)
+ var netError net.Error
+ require.ErrorAs(t, err, &netError)
+ require.Equal(t, "", out.Status)
+ require.False(t, out.Ready)
+ require.Equal(t, 0, out.PID)
+ })
+
+ t.Run("Missing", func(t *testing.T) {
+ socketPath, _ := newSocketMockService(t, http.StatusOK, []byte(`{"status": "OK", "pid": 1234}`))
+ err := os.RemoveAll(socketPath)
+ require.NoError(t, err)
+ clt := NewClient(socketPath)
+
+ out, err := clt.GetReadiness(ctx)
+ require.ErrorIs(t, err, os.ErrNotExist)
+ require.Equal(t, "", out.Status)
+ require.False(t, out.Ready)
+ require.Equal(t, 0, out.PID)
+ })
+}
+
func TestCollectProfile(t *testing.T) {
ctx := context.Background()
@@ -144,6 +207,7 @@ func newSocketMockService(t *testing.T, status int, contents []byte) (string, fu
t.Cleanup(func() { srv.Shutdown(context.Background()) })
return socketPath, func() []string {
+ srv.Shutdown(context.Background())
return requests
}
}
diff --git a/lib/service/service.go b/lib/service/service.go
index a7b76ae337c90..a1de3ec15b0dd 100644
--- a/lib/service/service.go
+++ b/lib/service/service.go
@@ -96,6 +96,7 @@ import (
"github.com/gravitational/teleport/lib/auth/storage"
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/automaticupgrades"
+ autoupdate "github.com/gravitational/teleport/lib/autoupdate/agent"
"github.com/gravitational/teleport/lib/autoupdate/rollout"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/backend/dynamo"
@@ -1246,18 +1247,7 @@ func NewTeleport(cfg *servicecfg.Config) (*TeleportProcess, error) {
return nil, trace.Wrap(err)
}
- upgraderKind := os.Getenv(automaticupgrades.EnvUpgrader)
- upgraderVersion := automaticupgrades.GetUpgraderVersion(process.GracefulExitContext())
- if upgraderVersion == "" {
- upgraderKind = ""
- }
-
- // Instances deployed using the AWS OIDC integration are automatically updated
- // by the proxy. The instance heartbeat should properly reflect that.
- externalUpgrader := upgraderKind
- if externalUpgrader == "" && os.Getenv(types.InstallMethodAWSOIDCDeployServiceEnvVar) == "true" {
- externalUpgrader = types.OriginIntegrationAWSOIDC
- }
+ upgraderKind, externalUpgrader, upgraderVersion := process.detectUpgrader()
// note: we must create the inventory handle *after* registerExpectedServices because that function determines
// the list of services (instance roles) to be included in the heartbeat.
@@ -1286,7 +1276,26 @@ func NewTeleport(cfg *servicecfg.Config) (*TeleportProcess, error) {
process.logger.WarnContext(process.ExitContext(), "Use of external upgraders on control-plane instances is not recommended.")
}
- if upgraderKind == "unit" {
+ switch upgraderKind {
+ case types.UpgraderKindTeleportUpdate:
+ isDefault, err := autoupdate.IsManagedAndDefault()
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ if !isDefault {
+ // Only write the nop schedule for the default updater.
+ // Suffixed installations of Teleport can coexist with the old upgrader system.
+ break
+ }
+ driver, err := uw.NewSystemdUnitDriver(uw.SystemdUnitDriverConfig{})
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ if err := driver.ForceNop(process.ExitContext()); err != nil {
+ process.logger.WarnContext(process.ExitContext(), "Unable to disable the teleport-upgrade command provided by the deprecated teleport-ent-updater package.", "error", err)
+ process.logger.WarnContext(process.ExitContext(), "If the deprecated teleport-ent-updater package is installed, please ensure /etc/teleport-upgrade.d/schedule contains 'nop'.")
+ }
+ case types.UpgraderKindSystemdUnit:
process.RegisterFunc("autoupdates.endpoint.export", func() error {
conn, err := waitForInstanceConnector(process, process.logger)
if err != nil {
@@ -1314,28 +1323,14 @@ func NewTeleport(cfg *servicecfg.Config) (*TeleportProcess, error) {
process.logger.InfoContext(process.ExitContext(), "Exported autoupdates endpoint.", "addr", resolverAddr.String())
return nil
})
+ if err := process.configureUpgraderExporter(upgraderKind); err != nil {
+ return nil, trace.Wrap(err)
+ }
+ default:
+ if err := process.configureUpgraderExporter(upgraderKind); err != nil {
+ return nil, trace.Wrap(err)
+ }
}
-
- driver, err := uw.NewDriver(upgraderKind)
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- exporter, err := uw.NewExporter(uw.ExporterConfig[inventory.DownstreamSender]{
- Driver: driver,
- ExportFunc: process.exportUpgradeWindows,
- AuthConnectivitySentinel: process.inventoryHandle.Sender(),
- })
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- process.RegisterCriticalFunc("upgradeewindow.export", exporter.Run)
- process.OnExit("upgradewindow.export.stop", func(_ interface{}) {
- exporter.Close()
- })
-
- process.logger.InfoContext(process.ExitContext(), "Configured upgrade window exporter for external upgrader.", "kind", upgraderKind)
}
serviceStarted := false
@@ -1549,6 +1544,63 @@ func NewTeleport(cfg *servicecfg.Config) (*TeleportProcess, error) {
return process, nil
}
+// detectUpgrader returns metadata about auto-upgraders that may be active.
+// Note that kind and externalName are usually the same.
+// However, some unregistered upgraders like the AWS ODIC upgrader are not valid kinds.
+// For these upgraders, kind is empty and externalName is set to a non-kind value.
+func (process *TeleportProcess) detectUpgrader() (kind, externalName, version string) {
+ // Check if the deprecated teleport-upgrader script is being used.
+ kind = os.Getenv(automaticupgrades.EnvUpgrader)
+ version = automaticupgrades.GetUpgraderVersion(process.GracefulExitContext())
+ if version == "" {
+ kind = ""
+ }
+
+ // If the installation is managed by teleport-update, it supersedes the teleport-upgrader script.
+ ok, err := autoupdate.IsManagedByUpdater()
+ if err != nil {
+ process.logger.WarnContext(process.ExitContext(), "Failed to determine if auto-updates are enabled.", "error", err)
+ } else if ok {
+ // If this is a teleport-update managed installation, the version
+ // managed by the timer will always match the installed version of teleport.
+ kind = types.UpgraderKindTeleportUpdate
+ version = "v" + teleport.Version
+ }
+
+ // Instances deployed using the AWS OIDC integration are automatically updated
+ // by the proxy. The instance heartbeat should properly reflect that.
+ externalName = kind
+ if externalName == "" && os.Getenv(types.InstallMethodAWSOIDCDeployServiceEnvVar) == "true" {
+ externalName = types.OriginIntegrationAWSOIDC
+ }
+ return kind, externalName, version
+}
+
+// configureUpgraderExporter configures the window exporter for upgraders that export windows.
+func (process *TeleportProcess) configureUpgraderExporter(kind string) error {
+ driver, err := uw.NewDriver(kind)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ exporter, err := uw.NewExporter(uw.ExporterConfig[inventory.DownstreamSender]{
+ Driver: driver,
+ ExportFunc: process.exportUpgradeWindows,
+ AuthConnectivitySentinel: process.inventoryHandle.Sender(),
+ })
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ process.RegisterCriticalFunc("upgradeewindow.export", exporter.Run)
+ process.OnExit("upgradewindow.export.stop", func(_ interface{}) {
+ exporter.Close()
+ })
+
+ process.logger.InfoContext(process.ExitContext(), "Configured upgrade window exporter for external upgrader.", "kind", kind)
+ return nil
+}
+
// enterpriseServicesEnabled will return true if any enterprise services are enabled.
func (process *TeleportProcess) enterpriseServicesEnabled() bool {
return modules.GetModules().BuildType() == modules.BuildEnterprise &&
diff --git a/lib/service/state.go b/lib/service/state.go
index 3c057148683c6..fab3b29a23bc4 100644
--- a/lib/service/state.go
+++ b/lib/service/state.go
@@ -21,6 +21,7 @@ package service
import (
"fmt"
"net/http"
+ "os"
"sync"
"time"
@@ -29,6 +30,7 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/gravitational/teleport"
+ "github.com/gravitational/teleport/lib/client/debug"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/observability/metrics"
)
@@ -179,22 +181,26 @@ func (f *processState) readinessHandler() http.HandlerFunc {
switch f.getState() {
// 503
case stateDegraded:
- roundtrip.ReplyJSON(w, http.StatusServiceUnavailable, map[string]any{
- "status": "teleport is in a degraded state, check logs for details",
+ roundtrip.ReplyJSON(w, http.StatusServiceUnavailable, debug.Readiness{
+ Status: "teleport is in a degraded state, check logs for details",
+ PID: os.Getpid(),
})
// 400
case stateRecovering:
- roundtrip.ReplyJSON(w, http.StatusBadRequest, map[string]any{
- "status": "teleport is recovering from a degraded state, check logs for details",
+ roundtrip.ReplyJSON(w, http.StatusBadRequest, debug.Readiness{
+ Status: "teleport is recovering from a degraded state, check logs for details",
+ PID: os.Getpid(),
})
case stateStarting:
- roundtrip.ReplyJSON(w, http.StatusBadRequest, map[string]any{
- "status": "teleport is starting and hasn't joined the cluster yet",
+ roundtrip.ReplyJSON(w, http.StatusBadRequest, debug.Readiness{
+ Status: "teleport is starting and hasn't joined the cluster yet",
+ PID: os.Getpid(),
})
// 200
case stateOK:
- roundtrip.ReplyJSON(w, http.StatusOK, map[string]any{
- "status": "ok",
+ roundtrip.ReplyJSON(w, http.StatusOK, debug.Readiness{
+ Status: "ok",
+ PID: os.Getpid(),
})
}
}
diff --git a/lib/utils/unpack.go b/lib/utils/unpack.go
index a32cff6ef76b9..14b213f08a173 100644
--- a/lib/utils/unpack.go
+++ b/lib/utils/unpack.go
@@ -23,6 +23,7 @@ import (
"errors"
"io"
"os"
+ "path"
"path/filepath"
"strings"
@@ -36,7 +37,10 @@ import (
// resulting files and directories are created using the current user context.
// Extract will only unarchive files into dir, and will fail if the tarball
// tries to write files outside of dir.
-func Extract(r io.Reader, dir string) error {
+//
+// If any paths are specified, only the specified paths are extracted.
+// The destination specified in the first matching path is selected.
+func Extract(r io.Reader, dir string, paths ...ExtractPath) error {
tarball := tar.NewReader(r)
for {
@@ -46,32 +50,95 @@ func Extract(r io.Reader, dir string) error {
} else if err != nil {
return trace.Wrap(err)
}
-
+ dirMode, ok := filterHeader(header, paths)
+ if !ok {
+ continue
+ }
err = sanitizeTarPath(header, dir)
if err != nil {
return trace.Wrap(err)
}
- if err := extractFile(tarball, header, dir); err != nil {
+ if err := extractFile(tarball, header, dir, dirMode); err != nil {
return trace.Wrap(err)
}
}
return nil
}
+// ExtractPath specifies a path to be extracted.
+type ExtractPath struct {
+ // Src path and Dst path within the archive to extract files to.
+ // Directories in the Src path are not included in the extraction dir.
+ // For example, given foo/bar/file.txt with Src=foo/bar Dst=baz, baz/file.txt results.
+ // Trailing slashes are always ignored.
+ Src, Dst string
+ // Skip extracting the Src path and ignore Dst.
+ Skip bool
+ // DirMode is the file mode for implicit parent directories in Dst.
+ DirMode os.FileMode
+}
+
+// filterHeader modifies the tar header by filtering it through the ExtractPaths.
+// filterHeader returns false if the tar header should be skipped.
+// If no paths are provided, filterHeader assumes the header should be included, and sets
+// the mode for implicit parent directories to teleport.DirMaskSharedGroup.
+func filterHeader(hdr *tar.Header, paths []ExtractPath) (dirMode os.FileMode, include bool) {
+ name := path.Clean(hdr.Name)
+ for _, p := range paths {
+ src := path.Clean(p.Src)
+ switch hdr.Typeflag {
+ case tar.TypeDir:
+ // If name is a directory, then
+ // assume src is a directory prefix, or the directory itself,
+ // and replace that prefix with dst.
+ if src != "/" {
+ src += "/" // ensure HasPrefix does not match partial names
+ }
+ if !strings.HasPrefix(name, src) {
+ continue
+ }
+ dst := path.Join(p.Dst, strings.TrimPrefix(name, src))
+ if dst != "/" {
+ dst += "/" // tar directory headers end in /
+ }
+ hdr.Name = dst
+ return p.DirMode, !p.Skip
+ default:
+ // If name is a file, then
+ // if src is an exact match to the file name, assume src is a file and write directly to dst,
+ // otherwise, assume src is a directory prefix, and replace that prefix with dst.
+ if src == name {
+ hdr.Name = path.Clean(p.Dst)
+ return p.DirMode, !p.Skip
+ }
+ if src != "/" {
+ src += "/" // ensure HasPrefix does not match partial names
+ }
+ if !strings.HasPrefix(name, src) {
+ continue
+ }
+ hdr.Name = path.Join(p.Dst, strings.TrimPrefix(name, src))
+ return p.DirMode, !p.Skip
+
+ }
+ }
+ return teleport.DirMaskSharedGroup, len(paths) == 0
+}
+
// extractFile extracts a single file or directory from tarball into dir.
// Uses header to determine the type of item to create
// Based on https://github.com/mholt/archiver
-func extractFile(tarball *tar.Reader, header *tar.Header, dir string) error {
+func extractFile(tarball *tar.Reader, header *tar.Header, dir string, dirMode os.FileMode) error {
switch header.Typeflag {
case tar.TypeDir:
- return withDir(filepath.Join(dir, header.Name), nil)
+ return withDir(filepath.Join(dir, header.Name), dirMode, nil)
case tar.TypeBlock, tar.TypeChar, tar.TypeReg, tar.TypeFifo:
- return writeFile(filepath.Join(dir, header.Name), tarball, header.FileInfo().Mode())
+ return writeFile(filepath.Join(dir, header.Name), tarball, header.FileInfo().Mode(), dirMode)
case tar.TypeLink:
- return writeHardLink(filepath.Join(dir, header.Name), filepath.Join(dir, header.Linkname))
+ return writeHardLink(filepath.Join(dir, header.Name), filepath.Join(dir, header.Linkname), dirMode)
case tar.TypeSymlink:
- return writeSymbolicLink(filepath.Join(dir, header.Name), header.Linkname)
+ return writeSymbolicLink(filepath.Join(dir, header.Name), header.Linkname, dirMode)
default:
log.Warnf("Unsupported type flag %v for %v.", header.Typeflag, header.Name)
}
@@ -106,8 +173,8 @@ func sanitizeTarPath(header *tar.Header, dir string) error {
return nil
}
-func writeFile(path string, r io.Reader, mode os.FileMode) error {
- err := withDir(path, func() error {
+func writeFile(path string, r io.Reader, mode, dirMode os.FileMode) error {
+ err := withDir(path, dirMode, func() error {
// Create file only if it does not exist to prevent overwriting existing
// files (like session recordings).
out, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, mode)
@@ -120,24 +187,24 @@ func writeFile(path string, r io.Reader, mode os.FileMode) error {
return trace.Wrap(err)
}
-func writeSymbolicLink(path string, target string) error {
- err := withDir(path, func() error {
+func writeSymbolicLink(path, target string, dirMode os.FileMode) error {
+ err := withDir(path, dirMode, func() error {
err := os.Symlink(target, path)
return trace.ConvertSystemError(err)
})
return trace.Wrap(err)
}
-func writeHardLink(path string, target string) error {
- err := withDir(path, func() error {
+func writeHardLink(path, target string, dirMode os.FileMode) error {
+ err := withDir(path, dirMode, func() error {
err := os.Link(target, path)
return trace.ConvertSystemError(err)
})
return trace.Wrap(err)
}
-func withDir(path string, fn func() error) error {
- err := os.MkdirAll(filepath.Dir(path), teleport.DirMaskSharedGroup)
+func withDir(path string, mode os.FileMode, fn func() error) error {
+ err := os.MkdirAll(filepath.Dir(path), mode)
if err != nil {
return trace.ConvertSystemError(err)
}
diff --git a/lib/versioncontrol/upgradewindow/upgradewindow.go b/lib/versioncontrol/upgradewindow/upgradewindow.go
index 7f90a9d70d41c..2a82b5f14fc14 100644
--- a/lib/versioncontrol/upgradewindow/upgradewindow.go
+++ b/lib/versioncontrol/upgradewindow/upgradewindow.go
@@ -44,6 +44,9 @@ const (
// unitScheduleFile is the name of the file to which the unit schedule is exported.
unitScheduleFile = "schedule"
+
+ // scheduleNop is the name of the no-op schedule.
+ scheduleNop = "nop"
)
// ExportFunc represents the ExportUpgradeWindows rpc exposed by auth servers.
@@ -313,6 +316,12 @@ type Driver interface {
// called if teleport experiences prolonged loss of auth connectivity, which may be an indicator
// that the control plane has been upgraded s.t. this agent is no longer compatible.
Reset(ctx context.Context) error
+
+ // ForceNop sets the NOP schedule, ensuring that updates do not happen.
+ // This schedule was originally only used for testing, but now it is also used by the
+ // teleport-update binary to protect against package updates that could interfere with
+ // the new update system.
+ ForceNop(ctx context.Context) error
}
// NewDriver sets up a new export driver corresponding to the specified upgrader kind.
@@ -361,7 +370,15 @@ func (e *kubeDriver) Kind() string {
}
func (e *kubeDriver) Sync(ctx context.Context, rsp proto.ExportUpgradeWindowsResponse) error {
- if rsp.KubeControllerSchedule == "" {
+ return trace.Wrap(e.setSchedule(ctx, rsp.KubeControllerSchedule))
+}
+
+func (e *kubeDriver) ForceNop(ctx context.Context) error {
+ return trace.Wrap(e.setSchedule(ctx, scheduleNop))
+}
+
+func (e *kubeDriver) setSchedule(ctx context.Context, schedule string) error {
+ if schedule == "" {
return e.Reset(ctx)
}
@@ -369,7 +386,7 @@ func (e *kubeDriver) Sync(ctx context.Context, rsp proto.ExportUpgradeWindowsRes
// backend.KeyFromString is intentionally used here instead of backend.NewKey
// because existing backend items were persisted without the leading /.
Key: backend.KeyFromString(kubeSchedKey),
- Value: []byte(rsp.KubeControllerSchedule),
+ Value: []byte(schedule),
})
return trace.Wrap(err)
@@ -411,7 +428,15 @@ func (e *systemdDriver) Kind() string {
}
func (e *systemdDriver) Sync(ctx context.Context, rsp proto.ExportUpgradeWindowsResponse) error {
- if len(rsp.SystemdUnitSchedule) == 0 {
+ return trace.Wrap(e.setSchedule(ctx, rsp.SystemdUnitSchedule))
+}
+
+func (e *systemdDriver) ForceNop(ctx context.Context) error {
+ return trace.Wrap(e.setSchedule(ctx, scheduleNop))
+}
+
+func (e *systemdDriver) setSchedule(ctx context.Context, schedule string) error {
+ if len(schedule) == 0 {
// treat an empty schedule value as equivalent to a reset
return e.Reset(ctx)
}
@@ -423,7 +448,7 @@ func (e *systemdDriver) Sync(ctx context.Context, rsp proto.ExportUpgradeWindows
}
// export schedule file. if created it is set to 644, which is reasonable for a sensitive but non-secret config value.
- if err := os.WriteFile(e.scheduleFile(), []byte(rsp.SystemdUnitSchedule), defaults.FilePermissions); err != nil {
+ if err := os.WriteFile(e.scheduleFile(), []byte(schedule), defaults.FilePermissions); err != nil {
return trace.Errorf("failed to write schedule file: %v", err)
}
diff --git a/lib/versioncontrol/upgradewindow/upgradewindow_test.go b/lib/versioncontrol/upgradewindow/upgradewindow_test.go
index 7b724708652f4..c5c2236673ea7 100644
--- a/lib/versioncontrol/upgradewindow/upgradewindow_test.go
+++ b/lib/versioncontrol/upgradewindow/upgradewindow_test.go
@@ -27,6 +27,7 @@ import (
"testing"
"time"
+ "github.com/gravitational/trace"
"github.com/stretchr/testify/require"
"github.com/gravitational/teleport/api/client/proto"
@@ -182,6 +183,37 @@ func TestSystemdUnitDriver(t *testing.T) {
require.Equal(t, "", string(sb))
}
+// TestSystemdUnitDriverNop verifies the nop schedule behavior of the systemd unit export driver.
+func TestSystemdUnitDriverNop(t *testing.T) {
+ t.Parallel()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ // use a sub-directory of a temp dir in order to verify that
+ // driver creates dir when needed.
+ dir := filepath.Join(t.TempDir(), "config")
+
+ driver, err := NewSystemdUnitDriver(SystemdUnitDriverConfig{
+ ConfigDir: dir,
+ })
+ require.NoError(t, err)
+
+ err = driver.Sync(ctx, proto.ExportUpgradeWindowsResponse{
+ SystemdUnitSchedule: "fake-schedule",
+ })
+ require.NoError(t, err)
+
+ err = driver.ForceNop(ctx)
+ require.NoError(t, err)
+
+ schedPath := filepath.Join(dir, "schedule")
+ sb, err := os.ReadFile(schedPath)
+ require.NoError(t, err)
+
+ require.Equal(t, scheduleNop, string(sb))
+}
+
// fakeDriver is used to inject custom behavior into a dummy Driver instance.
type fakeDriver struct {
mu sync.Mutex
@@ -209,6 +241,10 @@ func (d *fakeDriver) Sync(ctx context.Context, rsp proto.ExportUpgradeWindowsRes
return nil
}
+func (d *fakeDriver) ForceNop(ctx context.Context) error {
+ return trace.NotImplemented("force-nop not used by exporter")
+}
+
func (d *fakeDriver) Reset(ctx context.Context) error {
d.mu.Lock()
defer d.mu.Unlock()
diff --git a/tool/teleport-update/main.go b/tool/teleport-update/main.go
index 11aee2aae3906..801b62a3d8171 100644
--- a/tool/teleport-update/main.go
+++ b/tool/teleport-update/main.go
@@ -20,17 +20,19 @@ package main
import (
"context"
+ "errors"
+ "fmt"
"log/slog"
"os"
"os/signal"
- "path/filepath"
"syscall"
"github.com/gravitational/trace"
+ "gopkg.in/yaml.v3"
"github.com/gravitational/teleport"
+ common "github.com/gravitational/teleport/lib/autoupdate"
autoupdate "github.com/gravitational/teleport/lib/autoupdate/agent"
- libdefaults "github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/modules"
libutils "github.com/gravitational/teleport/lib/utils"
logutils "github.com/gravitational/teleport/lib/utils/log"
@@ -38,16 +40,13 @@ import (
const appHelp = `Teleport Updater
-The Teleport Updater updates the version a Teleport agent on a Linux server
-that is being used as agent to provide connectivity to Teleport resources.
+The Teleport Updater applies Managed Updates to a Teleport agent installation.
-The Teleport Updater supports upgrade schedules and automated rollbacks.
+The Teleport Updater supports update scheduling and automated rollbacks.
-Find out more at https://goteleport.com/docs/updater`
+Find out more at https://goteleport.com/docs/upgrading/agent-managed-updates`
const (
- // templateEnvVar allows the template for the Teleport tgz to be specified via env var.
- templateEnvVar = "TELEPORT_URL_TEMPLATE"
// proxyServerEnvVar allows the proxy server address to be specified via env var.
proxyServerEnvVar = "TELEPORT_PROXY"
// updateGroupEnvVar allows the update group to be specified via env var.
@@ -56,107 +55,166 @@ const (
updateVersionEnvVar = "TELEPORT_UPDATE_VERSION"
)
-const (
- // versionsDirName specifies the name of the subdirectory inside of the Teleport data dir for storing Teleport versions.
- versionsDirName = "versions"
- // lockFileName specifies the name of the file inside versionsDirName containing the flock lock preventing concurrent updater execution.
- lockFileName = ".lock"
-)
-
var plog = logutils.NewPackageLogger(teleport.ComponentKey, teleport.ComponentUpdater)
func main() {
- if err := Run(os.Args[1:]); err != nil {
- libutils.FatalError(err)
+ if code := Run(os.Args[1:]); code != 0 {
+ os.Exit(code)
}
}
type cliConfig struct {
autoupdate.OverrideConfig
-
// Debug logs enabled
Debug bool
// LogFormat controls the format of logging. Can be either `json` or `text`.
// By default, this is `text`.
LogFormat string
- // DataDir for Teleport (usually /var/lib/teleport)
- DataDir string
+ // InstallDir for Teleport (usually /opt/teleport)
+ InstallDir string
+ // InstallSuffix is the isolated suffix for the installation.
+ InstallSuffix string
+ // SelfSetup mode for using the current version of the teleport-update to setup the update service.
+ SelfSetup bool
+ // UpdateNow forces an immediate update.
+ UpdateNow bool
+ // Reload reloads Teleport.
+ Reload bool
+ // ForceUninstall allows Teleport to be completely removed.
+ ForceUninstall bool
+ // Insecure skips TLS certificate verification.
+ Insecure bool
}
-func (c *cliConfig) CheckAndSetDefaults() error {
- if c.DataDir == "" {
- c.DataDir = libdefaults.DataDir
- }
- if c.LogFormat == "" {
- c.LogFormat = libutils.LogFormatText
- }
- return nil
-}
-
-func Run(args []string) error {
+func Run(args []string) int {
var ccfg cliConfig
+
ctx := context.Background()
ctx, _ = signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)
- app := libutils.InitCLIParser("teleport-updater", appHelp).Interspersed(false)
+ app := libutils.InitCLIParser(autoupdate.BinaryName, appHelp).Interspersed(false)
app.Flag("debug", "Verbose logging to stdout.").
Short('d').BoolVar(&ccfg.Debug)
- app.Flag("data-dir", "Teleport data directory. Access to this directory should be limited.").
- Default(libdefaults.DataDir).StringVar(&ccfg.DataDir)
app.Flag("log-format", "Controls the format of output logs. Can be `json` or `text`. Defaults to `text`.").
Default(libutils.LogFormatText).EnumVar(&ccfg.LogFormat, libutils.LogFormatJSON, libutils.LogFormatText)
+ app.Flag("install-suffix", "Suffix for creating an agent installation outside of the default $PATH. Note: this changes the default data directory.").
+ Short('i').StringVar(&ccfg.InstallSuffix)
+ app.Flag("install-dir", "Directory containing Teleport installations.").
+ Hidden().StringVar(&ccfg.InstallDir)
+ app.Flag("insecure", "Insecure mode disables certificate verification. Do not use in production.").
+ BoolVar(&ccfg.Insecure)
app.HelpFlag.Short('h')
- versionCmd := app.Command("version", "Print the version of your teleport-updater binary.")
+ versionCmd := app.Command("version", fmt.Sprintf("Print the version of your %s binary.", autoupdate.BinaryName))
- enableCmd := app.Command("enable", "Enable agent auto-updates and perform initial update.")
+ enableCmd := app.Command("enable", "Enable agent managed updates and perform initial installation or update. This creates a systemd timer that periodically runs the update subcommand.")
enableCmd.Flag("proxy", "Address of the Teleport Proxy.").
Short('p').Envar(proxyServerEnvVar).StringVar(&ccfg.Proxy)
enableCmd.Flag("group", "Update group for this agent installation.").
Short('g').Envar(updateGroupEnvVar).StringVar(&ccfg.Group)
- enableCmd.Flag("template", "Go template used to override Teleport download URL.").
- Short('t').Envar(templateEnvVar).StringVar(&ccfg.URLTemplate)
- enableCmd.Flag("force-version", "Force the provided version instead of querying it from the Teleport cluster.").
- Short('f').Envar(updateVersionEnvVar).Hidden().StringVar(&ccfg.ForceVersion)
-
- disableCmd := app.Command("disable", "Disable agent auto-updates.")
-
- updateCmd := app.Command("update", "Update agent to the latest version, if a new version is available.")
- updateCmd.Flag("force-version", "Use the provided version instead of querying it from the Teleport cluster.").
- Short('f').Envar(updateVersionEnvVar).Hidden().StringVar(&ccfg.ForceVersion)
+ enableCmd.Flag("base-url", "Base URL used to override the Teleport download URL.").
+ Short('b').Envar(common.BaseURLEnvVar).StringVar(&ccfg.BaseURL)
+ enableCmd.Flag("overwrite", "Allow existing installed Teleport binaries to be overwritten.").
+ Short('o').BoolVar(&ccfg.AllowOverwrite)
+ enableCmd.Flag("force-version", "Force the provided version instead of using the version provided by the Teleport cluster.").
+ Hidden().Short('f').Envar(updateVersionEnvVar).StringVar(&ccfg.ForceVersion)
+ enableCmd.Flag("self-setup", "Use the current teleport-update binary to create systemd service config for managed updates.").
+ Hidden().BoolVar(&ccfg.SelfSetup)
+ enableCmd.Flag("path", "Directory to link the active Teleport installation's binaries into.").
+ Hidden().StringVar(&ccfg.Path)
+ // TODO(sclevine): add force-fips and force-enterprise as hidden flags
+
+ disableCmd := app.Command("disable", "Disable agent managed updates. Does not affect the active installation of Teleport.")
+
+ pinCmd := app.Command("pin", "Install Teleport and lock the updater to the installed version.")
+ pinCmd.Flag("proxy", "Address of the Teleport Proxy.").
+ Short('p').Envar(proxyServerEnvVar).StringVar(&ccfg.Proxy)
+ pinCmd.Flag("group", "Update group for this agent installation.").
+ Short('g').Envar(updateGroupEnvVar).StringVar(&ccfg.Group)
+ pinCmd.Flag("base-url", "Base URL used to override the Teleport download URL.").
+ Short('b').Envar(common.BaseURLEnvVar).StringVar(&ccfg.BaseURL)
+ pinCmd.Flag("overwrite", "Allow existing installed Teleport binaries to be overwritten.").
+ Short('o').BoolVar(&ccfg.AllowOverwrite)
+ pinCmd.Flag("force-version", "Force the provided version instead of using the version provided by the Teleport cluster.").
+ Short('f').Envar(updateVersionEnvVar).StringVar(&ccfg.ForceVersion)
+ pinCmd.Flag("self-setup", "Use the current teleport-update binary to create systemd service config for managed updates.").
+ Hidden().BoolVar(&ccfg.SelfSetup)
+ pinCmd.Flag("path", "Directory to link the active Teleport installation's binaries into.").
+ Hidden().StringVar(&ccfg.Path)
+
+ unpinCmd := app.Command("unpin", "Unpin the current version, allowing it to be updated.")
+
+ updateCmd := app.Command("update", "Update the agent to the latest version, if a new version is available.")
+ updateCmd.Flag("now", "Force immediate update even if update window is not active.").
+ Short('n').BoolVar(&ccfg.UpdateNow)
+ updateCmd.Flag("self-setup", "Use the current teleport-update binary to create systemd service config for managed updates and verify the Teleport installation.").
+ Hidden().BoolVar(&ccfg.SelfSetup)
+
+ linkCmd := app.Command("link-package", "Link the system installation of Teleport from the Teleport package, if managed updates is disabled.")
+ unlinkCmd := app.Command("unlink-package", "Unlink the system installation of Teleport from the Teleport package.")
+
+ setupCmd := app.Command("setup", "Write configuration files that run the update subcommand on a timer and verify the Teleport installation.").
+ Hidden()
+ setupCmd.Flag("reload", "Reload the Teleport agent. If not set, Teleport is not reloaded or restarted.").
+ BoolVar(&ccfg.Reload)
+ setupCmd.Flag("path", "Directory that the active Teleport installation's binaries are linked into.").
+ Required().StringVar(&ccfg.Path)
+
+ statusCmd := app.Command("status", "Show Teleport agent auto-update status.")
+
+ uninstallCmd := app.Command("uninstall", "Uninstall the updater-managed installation of Teleport. If the Teleport package is installed, it is restored as the primary installation.")
+ uninstallCmd.Flag("force", "Force complete uninstallation of Teleport, even if there is no packaged version of Teleport to revert to.").
+ Short('f').BoolVar(&ccfg.ForceUninstall)
libutils.UpdateAppUsageTemplate(app, args)
command, err := app.Parse(args)
if err != nil {
app.Usage(args)
- return trace.Wrap(err)
+ libutils.FatalError(err)
}
+
// Logging must be configured as early as possible to ensure all log
// message are formatted correctly.
if err := setupLogger(ccfg.Debug, ccfg.LogFormat); err != nil {
- return trace.Errorf("failed to set up logger")
- }
-
- if err := ccfg.CheckAndSetDefaults(); err != nil {
- return trace.Wrap(err)
+ plog.ErrorContext(ctx, "Failed to set up logger.", "error", err)
+ return 1
}
switch command {
case enableCmd.FullCommand():
- err = cmdEnable(ctx, &ccfg)
+ ccfg.Enabled = true
+ err = cmdInstall(ctx, &ccfg)
+ case pinCmd.FullCommand():
+ ccfg.Pinned = true
+ err = cmdInstall(ctx, &ccfg)
case disableCmd.FullCommand():
err = cmdDisable(ctx, &ccfg)
+ case unpinCmd.FullCommand():
+ err = cmdUnpin(ctx, &ccfg)
case updateCmd.FullCommand():
err = cmdUpdate(ctx, &ccfg)
+ case linkCmd.FullCommand():
+ err = cmdLinkPackage(ctx, &ccfg)
+ case unlinkCmd.FullCommand():
+ err = cmdUnlinkPackage(ctx, &ccfg)
+ case setupCmd.FullCommand():
+ err = cmdSetup(ctx, &ccfg)
+ case statusCmd.FullCommand():
+ err = cmdStatus(ctx, &ccfg)
+ case uninstallCmd.FullCommand():
+ err = cmdUninstall(ctx, &ccfg)
case versionCmd.FullCommand():
modules.GetModules().PrintVersion()
default:
// This should only happen when there's a missing switch case above.
- err = trace.Errorf("command %q not configured", command)
+ err = trace.Errorf("command %s not configured", command)
}
-
- return err
+ if err != nil {
+ plog.ErrorContext(ctx, "Command failed.", "error", err)
+ return 1
+ }
+ return 0
}
func setupLogger(debug bool, format string) error {
@@ -176,46 +234,178 @@ func setupLogger(debug bool, format string) error {
return nil
}
+func initConfig(ctx context.Context, ccfg *cliConfig) (updater *autoupdate.Updater, lockFile string, err error) {
+ ns, err := autoupdate.NewNamespace(ctx, plog, ccfg.InstallSuffix, ccfg.InstallDir)
+ if err != nil {
+ return nil, "", trace.Wrap(err)
+ }
+ lockFile, err = ns.Init()
+ if err != nil {
+ return nil, "", trace.Wrap(err)
+ }
+ updater, err = autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{
+ SelfSetup: ccfg.SelfSetup,
+ Log: plog,
+ LogFormat: ccfg.LogFormat,
+ Debug: ccfg.Debug,
+ InsecureSkipVerify: ccfg.Insecure,
+ }, ns)
+ return updater, lockFile, trace.Wrap(err)
+}
+
+func statusConfig(ctx context.Context, ccfg *cliConfig) (*autoupdate.Updater, error) {
+ ns, err := autoupdate.NewNamespace(ctx, plog, ccfg.InstallSuffix, ccfg.InstallDir)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{
+ SelfSetup: ccfg.SelfSetup,
+ Log: plog,
+ LogFormat: ccfg.LogFormat,
+ Debug: ccfg.Debug,
+ InsecureSkipVerify: ccfg.Insecure,
+ }, ns)
+ return updater, trace.Wrap(err)
+}
+
// cmdDisable disables updates.
func cmdDisable(ctx context.Context, ccfg *cliConfig) error {
- versionsDir := filepath.Join(ccfg.DataDir, versionsDirName)
- if err := os.MkdirAll(versionsDir, 0755); err != nil {
- return trace.Errorf("failed to create versions directory: %w", err)
+ updater, lockFile, err := initConfig(ctx, ccfg)
+ if err != nil {
+ return trace.Wrap(err, "failed to initialize updater")
}
-
- unlock, err := libutils.FSWriteLock(filepath.Join(versionsDir, lockFileName))
+ unlock, err := libutils.FSWriteLock(lockFile)
if err != nil {
- return trace.Errorf("failed to grab concurrent execution lock: %w", err)
+ return trace.Wrap(err, "failed to grab concurrent execution lock %s", lockFile)
}
defer func() {
if err := unlock(); err != nil {
plog.DebugContext(ctx, "Failed to close lock file", "error", err)
}
}()
- updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{
- VersionsDir: versionsDir,
- Log: plog,
- })
+ if err := updater.Disable(ctx); err != nil {
+ return trace.Wrap(err)
+ }
+ return nil
+}
+
+// cmdUnpin unpins the current version.
+func cmdUnpin(ctx context.Context, ccfg *cliConfig) error {
+ updater, lockFile, err := initConfig(ctx, ccfg)
if err != nil {
- return trace.Errorf("failed to setup updater: %w", err)
+ return trace.Wrap(err, "failed to setup updater")
}
- if err := updater.Disable(ctx); err != nil {
+ unlock, err := libutils.FSWriteLock(lockFile)
+ if err != nil {
+ return trace.Wrap(err, "failed to grab concurrent execution lock %n", lockFile)
+ }
+ defer func() {
+ if err := unlock(); err != nil {
+ plog.DebugContext(ctx, "Failed to close lock file", "error", err)
+ }
+ }()
+ if err := updater.Unpin(ctx); err != nil {
return trace.Wrap(err)
}
return nil
}
-// cmdEnable enables updates and triggers an initial update.
-func cmdEnable(ctx context.Context, ccfg *cliConfig) error {
- versionsDir := filepath.Join(ccfg.DataDir, versionsDirName)
- if err := os.MkdirAll(versionsDir, 0755); err != nil {
- return trace.Errorf("failed to create versions directory: %w", err)
+// cmdInstall installs Teleport and sets configuration.
+func cmdInstall(ctx context.Context, ccfg *cliConfig) error {
+ if ccfg.InstallSuffix != "" {
+ ns, err := autoupdate.NewNamespace(ctx, plog, ccfg.InstallSuffix, ccfg.InstallDir)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ ns.LogWarning(ctx)
+ }
+ updater, lockFile, err := initConfig(ctx, ccfg)
+ if err != nil {
+ return trace.Wrap(err, "failed to initialize updater")
}
// Ensure enable can't run concurrently.
- unlock, err := libutils.FSWriteLock(filepath.Join(versionsDir, lockFileName))
+ unlock, err := libutils.FSWriteLock(lockFile)
+ if err != nil {
+ return trace.Wrap(err, "failed to grab concurrent execution lock %s", lockFile)
+ }
+ defer func() {
+ if err := unlock(); err != nil {
+ plog.DebugContext(ctx, "Failed to close lock file", "error", err)
+ }
+ }()
+ if err := updater.Install(ctx, ccfg.OverrideConfig); err != nil {
+ return trace.Wrap(err)
+ }
+ return nil
+}
+
+// cmdUpdate updates Teleport to the version specified by cluster reachable at the proxy address.
+func cmdUpdate(ctx context.Context, ccfg *cliConfig) error {
+ updater, lockFile, err := initConfig(ctx, ccfg)
+ if err != nil {
+ return trace.Wrap(err, "failed to initialize updater")
+ }
+ // Ensure update can't run concurrently.
+ unlock, err := libutils.FSWriteLock(lockFile)
+ if err != nil {
+ return trace.Wrap(err, "failed to grab concurrent execution lock %s", lockFile)
+ }
+ defer func() {
+ if err := unlock(); err != nil {
+ plog.DebugContext(ctx, "Failed to close lock file", "error", err)
+ }
+ }()
+
+ if err := updater.Update(ctx, ccfg.UpdateNow); err != nil {
+ return trace.Wrap(err)
+ }
+ return nil
+}
+
+// cmdLinkPackage creates system package links if no version is linked and managed updates is disabled.
+func cmdLinkPackage(ctx context.Context, ccfg *cliConfig) error {
+ updater, lockFile, err := initConfig(ctx, ccfg)
+ if err != nil {
+ return trace.Wrap(err, "failed to initialize updater")
+ }
+
+ // Skip operation and warn if the updater is currently running.
+ unlock, err := libutils.FSTryReadLock(lockFile)
+ if errors.Is(err, libutils.ErrUnsuccessfulLockTry) {
+ plog.WarnContext(ctx, "Updater is currently running. Skipping package linking.")
+ return nil
+ }
+ if err != nil {
+ return trace.Wrap(err, "failed to grab concurrent execution lock %q", lockFile)
+ }
+ defer func() {
+ if err := unlock(); err != nil {
+ plog.DebugContext(ctx, "Failed to close lock file", "error", err)
+ }
+ }()
+
+ if err := updater.LinkPackage(ctx); err != nil {
+ return trace.Wrap(err)
+ }
+ return nil
+}
+
+// cmdUnlinkPackage remove system package links.
+func cmdUnlinkPackage(ctx context.Context, ccfg *cliConfig) error {
+ updater, lockFile, err := initConfig(ctx, ccfg)
+ if err != nil {
+ return trace.Wrap(err, "failed to setup updater")
+ }
+
+ // Error if the updater is running. We could remove its links by accident.
+ unlock, err := libutils.FSTryWriteLock(lockFile)
+ if errors.Is(err, libutils.ErrUnsuccessfulLockTry) {
+ plog.WarnContext(ctx, "Updater is currently running. Skipping package unlinking.")
+ return nil
+ }
if err != nil {
- return trace.Errorf("failed to grab concurrent execution lock: %w", err)
+ return trace.Wrap(err, "failed to grab concurrent execution lock %q", lockFile)
}
defer func() {
if err := unlock(); err != nil {
@@ -223,20 +413,68 @@ func cmdEnable(ctx context.Context, ccfg *cliConfig) error {
}
}()
+ if err := updater.UnlinkPackage(ctx); err != nil {
+ return trace.Wrap(err)
+ }
+ return nil
+}
+
+// cmdSetup writes configuration files that are needed to run teleport-update update.
+func cmdSetup(ctx context.Context, ccfg *cliConfig) error {
+ ns, err := autoupdate.NewNamespace(ctx, plog, ccfg.InstallSuffix, ccfg.InstallDir)
+ if err != nil {
+ return trace.Wrap(err)
+ }
updater, err := autoupdate.NewLocalUpdater(autoupdate.LocalUpdaterConfig{
- VersionsDir: versionsDir,
- Log: plog,
- })
+ SelfSetup: ccfg.SelfSetup,
+ Log: plog,
+ LogFormat: ccfg.LogFormat,
+ Debug: ccfg.Debug,
+ InsecureSkipVerify: ccfg.Insecure,
+ }, ns)
if err != nil {
- return trace.Errorf("failed to setup updater: %w", err)
+ return trace.Wrap(err)
}
- if err := updater.Enable(ctx, ccfg.OverrideConfig); err != nil {
+ err = updater.Setup(ctx, ccfg.Path, ccfg.Reload)
+ if err != nil {
return trace.Wrap(err)
}
return nil
}
-// cmdUpdate updates Teleport to the version specified by cluster reachable at the proxy address.
-func cmdUpdate(ctx context.Context, ccfg *cliConfig) error {
- return trace.NotImplemented("TODO")
+// cmdStatus displays auto-update status.
+func cmdStatus(ctx context.Context, ccfg *cliConfig) error {
+ updater, err := statusConfig(ctx, ccfg)
+ if err != nil {
+ return trace.Wrap(err, "failed to initialize updater")
+ }
+ status, err := updater.Status(ctx)
+ if err != nil {
+ return trace.Wrap(err, "failed to get status")
+ }
+ enc := yaml.NewEncoder(os.Stdout)
+ return trace.Wrap(enc.Encode(status))
+}
+
+// cmdUninstall removes the updater-managed install of Teleport and gracefully reverts back to the Teleport package.
+func cmdUninstall(ctx context.Context, ccfg *cliConfig) error {
+ updater, lockFile, err := initConfig(ctx, ccfg)
+ if err != nil {
+ return trace.Wrap(err, "failed to initialize updater")
+ }
+ // Ensure update can't run concurrently.
+ unlock, err := libutils.FSWriteLock(lockFile)
+ if err != nil {
+ return trace.Wrap(err, "failed to grab concurrent execution lock %s", lockFile)
+ }
+ defer func() {
+ if err := unlock(); err != nil {
+ plog.DebugContext(ctx, "Failed to close lock file", "error", err)
+ }
+ }()
+
+ if err := updater.Remove(ctx, ccfg.ForceUninstall); err != nil {
+ return trace.Wrap(err)
+ }
+ return nil
}