diff --git a/utils/utils.go b/utils/utils.go index 60583b0..8061675 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -190,13 +190,42 @@ func DownloadFile(ctx context.Context, rawURL string, logger logging.Logger) (st } case "http", "https": // note: we shrink the hash to avoid system path length limits - partialDest := CreatePartialPath(rawURL) + partialDest, etagPath := CreatePartialPath(rawURL) + + // Do a HEAD request to get the ETag value + remoteETag, err := getRemoteETag(ctx, parsedURL.String(), logger) + if err != nil { + logger.Warnw("failed to get remote ETag, proceeding with download", "err", err) + } g := getter.HttpGetter{Client: socksClient(parsedURL.String(), logger)} g.SetClient(getterClient) if stat, err := os.Stat(partialDest); err == nil { - logger.Infow("download to existing", "dest", partialDest, "size", stat.Size()) + // File exists, load the etag file and check it against the one we just got + storedETag, err := readETag(etagPath) + if err == nil && remoteETag != "" { + if storedETag == remoteETag { + // ETag matches - allow resume (getter will handle range requests) + logger.Infow("resuming download with matching ETag", "dest", partialDest, "size", stat.Size(), "etag", remoteETag) + } else { + // If it's a mismatch, delete the old .part file and save the new .etag file + logger.Infow("ETag mismatch, deleting old file", "dest", partialDest, "stored_etag", storedETag, "remote_etag", remoteETag) + if err := os.Remove(partialDest); err != nil && !errors.Is(err, fs.ErrNotExist) { + logger.Warnw("failed to remove old partial file", "err", err) + } + if err := writeETag(etagPath, remoteETag); err != nil { + logger.Warnw("failed to save ETag", "err", err) + } + } + } else { + logger.Infow("download to existing", "dest", partialDest, "size", stat.Size()) + } + } else if remoteETag != "" { + // If the file doesn't exist, save a new etag file + if err := writeETag(etagPath, remoteETag); err != nil { + logger.Warnw("failed to save ETag", "err", err) + } } done := make(chan struct{}) @@ -258,7 +287,8 @@ func rewriteGCPDownload(orig *url.URL) (*url.URL, bool) { } // CreatePartialPath makes a path under cachedir/part. These get cleaned up by CleanPartials. -func CreatePartialPath(rawURL string) string { +// Returns both the .part path and the .etag path. +func CreatePartialPath(rawURL string) (partPath, etagPath string) { var urlPath string if parsed, err := url.Parse(rawURL); err != nil { urlPath = "UNPARSED" @@ -266,7 +296,10 @@ func CreatePartialPath(rawURL string) string { urlPath = parsed.Path } - return path.Join(ViamDirs.Partials, hashString(rawURL, 7), last(strings.Split(urlPath, "/"), "")+".part") + basePath := path.Join(ViamDirs.Partials, hashString(rawURL, 7), last(strings.Split(urlPath, "/"), "")) + partPath = basePath + ".part" + etagPath = basePath + ".etag" + return partPath, etagPath } // helper: return last item of `items` slice, or `default_` if items is empty. @@ -289,6 +322,49 @@ func hashString(input string, n int) string { return ret } +// getRemoteETag performs a HEAD request to get the ETag from the remote server. +// ETags are returned with quotes removed for consistent comparison. +func getRemoteETag(ctx context.Context, url string, logger logging.Logger) (string, error) { + etag, _, err := getRemoteETagAndSize(ctx, url, logger) + return etag, err +} + +// getRemoteETagAndSize performs a HEAD request to get the ETag and content length from the remote server. +// ETags are returned with quotes removed for consistent comparison. +func getRemoteETagAndSize(ctx context.Context, url string, logger logging.Logger) (string, int64, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) + if err != nil { + return "", 0, errw.Wrap(err, "creating HEAD request") + } + res, err := socksClient(url, logger).Do(req) + if err != nil { + return "", 0, errw.Wrap(err, "executing HEAD request") + } + defer res.Body.Close() //nolint:errcheck + etag := res.Header.Get("ETag") + // Remove surrounding quotes if present (ETags are often returned as "value") + etag = strings.Trim(etag, `"`) + return etag, res.ContentLength, nil +} + +// readETag reads the ETag from a file. +func readETag(etagPath string) (string, error) { + data, err := os.ReadFile(etagPath) //nolint:gosec + if err != nil { + return "", err + } + return strings.TrimSpace(string(data)), nil +} + +// writeETag writes the ETag to a file. +func writeETag(etagPath, etag string) error { + // Ensure the directory exists + if err := os.MkdirAll(path.Dir(etagPath), 0o750); err != nil { + return errw.Wrapf(err, "creating directory for %s", etagPath) + } + return errw.Wrapf(os.WriteFile(etagPath, []byte(etag), 0o600), "writing ETag to %s", etagPath) +} + // on windows only, create a firewall exception for the newly-downloaded file. func allowFirewall(logger logging.Logger, outPath string) error { // todo: confirm this is right; this isn't the final destination. Does the rule move when the file is renamed? Link to docs. diff --git a/utils/utils_test.go b/utils/utils_test.go index f2f1221..49979cf 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -386,13 +386,14 @@ func TestInitPaths(t *testing.T) { } func TestPartialPath(t *testing.T) { - path := CreatePartialPath("https://storage.googleapis.com/packages.viam.com/apps/viam-server/viam-server-latest-x86_64") + partPath, etagPath := CreatePartialPath("https://storage.googleapis.com/packages.viam.com/apps/viam-server/viam-server-latest-x86_64") maxPathLengths := map[string]int{ "linux": 4096, "windows": 260, } for _, maxPath := range maxPathLengths { - test.That(t, len(path), test.ShouldBeLessThanOrEqualTo, maxPath) + test.That(t, len(partPath), test.ShouldBeLessThanOrEqualTo, maxPath) + test.That(t, len(etagPath), test.ShouldBeLessThanOrEqualTo, maxPath) } } diff --git a/version_control_test.go b/version_control_test.go index 0585cc7..51c3962 100644 --- a/version_control_test.go +++ b/version_control_test.go @@ -206,7 +206,7 @@ func TestCleanPartials(t *testing.T) { vc := VersionCache{logger: logging.NewTestLogger(t)} // make a part file to clean up - oldPath := utils.CreatePartialPath("https://viam.com/old.part") + oldPath, _ := utils.CreatePartialPath("https://viam.com/old.part") err := os.Mkdir(filepath.Dir(oldPath), 0o755) test.That(t, err, test.ShouldBeNil) err = os.WriteFile(oldPath, []byte("hello"), 0o600) @@ -214,7 +214,7 @@ func TestCleanPartials(t *testing.T) { os.Chtimes(oldPath, time.Now(), time.Now().Add(-time.Hour*24*4)) // make another one too new to clean up - newPath := utils.CreatePartialPath("https://viam.com/subpath/new.part") + newPath, _ := utils.CreatePartialPath("https://viam.com/subpath/new.part") err = os.Mkdir(filepath.Dir(newPath), 0o755) test.That(t, err, test.ShouldBeNil) err = os.WriteFile(newPath, []byte("hello"), 0o600)